From 9417e8b8d52342c2fb5caf3b94de7d379bebe4f3 Mon Sep 17 00:00:00 2001 From: Honnix Date: Mon, 11 Dec 2023 10:27:23 +0100 Subject: [PATCH 1/5] Revert "Revert "Make plugin task first-class citizen (#268)" (#271)" This reverts commit 3d9ab4e48f58b1f8199c4a36daf15310b282708f. Signed-off-by: Hongxin Liang --- .../src/main/proto/flyteidl/core/tasks.proto | 8 + .../java/org/flyte/api/v1/PluginTask.java | 22 +++ .../org/flyte/api/v1/PluginTaskRegistrar.java | 20 ++ .../java/org/flyte/api/v1/TaskTemplate.java | 8 + .../org/flyte/flytekit/SdkPluginTask.java | 115 ++++++++++++ .../flytekit/SdkPluginTaskRegistrar.java | 146 +++++++++++++++ .../flytekit/SdkPluginTaskRegistrarTest.java | 171 ++++++++++++++++++ .../flyte/jflyte/utils/ProjectClosure.java | 39 +++- .../org/flyte/jflyte/utils/ProtoUtil.java | 10 +- .../java/org/flyte/jflyte/utils/Fixtures.java | 1 + .../jflyte/utils/FlyteAdminClientTest.java | 4 + .../jflyte/utils/ProjectClosureTest.java | 56 ++++++ .../flyte/jflyte/utils/ProtoReaderTest.java | 1 + .../org/flyte/jflyte/utils/ProtoUtilTest.java | 6 + .../flyte/jflyte/ExecuteDynamicWorkflow.java | 12 +- 15 files changed, 608 insertions(+), 11 deletions(-) create mode 100644 flytekit-api/src/main/java/org/flyte/api/v1/PluginTask.java create mode 100644 flytekit-api/src/main/java/org/flyte/api/v1/PluginTaskRegistrar.java create mode 100644 flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTask.java create mode 100644 flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTaskRegistrar.java create mode 100644 flytekit-java/src/test/java/org/flyte/flytekit/SdkPluginTaskRegistrarTest.java diff --git a/flyteidl-protos/src/main/proto/flyteidl/core/tasks.proto b/flyteidl-protos/src/main/proto/flyteidl/core/tasks.proto index 48961029e..5844ac5c3 100644 --- a/flyteidl-protos/src/main/proto/flyteidl/core/tasks.proto +++ b/flyteidl-protos/src/main/proto/flyteidl/core/tasks.proto @@ -59,6 +59,14 @@ message RuntimeMetadata { //+optional It can be used to provide extra information about the runtime (e.g. python, golang... etc.). string flavor = 3; + + //+optional It can be used to provide extra information for the plugin. + PluginMetadata plugin_metadata = 4; +} + +message PluginMetadata { + //+optional It can be used to decide use sync plugin or async plugin during runtime. + bool is_sync_plugin = 1; } // Task Metadata diff --git a/flytekit-api/src/main/java/org/flyte/api/v1/PluginTask.java b/flytekit-api/src/main/java/org/flyte/api/v1/PluginTask.java new file mode 100644 index 000000000..fc1a7d8dc --- /dev/null +++ b/flytekit-api/src/main/java/org/flyte/api/v1/PluginTask.java @@ -0,0 +1,22 @@ +/* + * Copyright 2023 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.api.v1; + +/** A task that is handled by a Flyte backend plugin instead of run as a container. */ +public interface PluginTask extends Task { + boolean isSyncPlugin(); +} diff --git a/flytekit-api/src/main/java/org/flyte/api/v1/PluginTaskRegistrar.java b/flytekit-api/src/main/java/org/flyte/api/v1/PluginTaskRegistrar.java new file mode 100644 index 000000000..d2562e4fa --- /dev/null +++ b/flytekit-api/src/main/java/org/flyte/api/v1/PluginTaskRegistrar.java @@ -0,0 +1,20 @@ +/* + * Copyright 2023 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.api.v1; + +/** A registrar that creates {@link PluginTask} instances. */ +public abstract class PluginTaskRegistrar implements Registrar {} diff --git a/flytekit-api/src/main/java/org/flyte/api/v1/TaskTemplate.java b/flytekit-api/src/main/java/org/flyte/api/v1/TaskTemplate.java index 5a48c8900..7452ba6fc 100644 --- a/flytekit-api/src/main/java/org/flyte/api/v1/TaskTemplate.java +++ b/flytekit-api/src/main/java/org/flyte/api/v1/TaskTemplate.java @@ -22,6 +22,9 @@ /** * A Task structure that uniquely identifies a task in the system. Tasks are registered as a first * step in the system. + * + *

FIXME: consider offering TaskMetadata instead of having everything in TaskTemplate, see + * https://github.com/flyteorg/flyte/blob/ea72bbd12578d64087221592554fb71c368f8057/flyteidl/protos/flyteidl/core/tasks.proto#L90 */ @AutoValue public abstract class TaskTemplate { @@ -64,6 +67,9 @@ public abstract class TaskTemplate { */ public abstract boolean cacheSerializable(); + /** Indicates whether to use sync plugin or async plugin to handle this task. */ + public abstract boolean isSyncPlugin(); + public abstract Builder toBuilder(); public static Builder builder() { @@ -89,6 +95,8 @@ public abstract static class Builder { public abstract Builder cacheSerializable(boolean cacheSerializable); + public abstract Builder isSyncPlugin(boolean isSyncPlugin); + public abstract TaskTemplate build(); } } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTask.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTask.java new file mode 100644 index 000000000..eedf3ab33 --- /dev/null +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTask.java @@ -0,0 +1,115 @@ +/* + * Copyright 2023 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.flytekit; + +import java.util.List; +import java.util.Map; +import javax.annotation.Nullable; +import org.flyte.api.v1.PartialTaskIdentifier; + +/** A task that is handled by a Flyte backend plugin instead of run as a container. */ +public abstract class SdkPluginTask extends SdkTransform { + + private final SdkType inputType; + private final SdkType outputType; + + /** + * Called by subclasses passing the {@link SdkType}s for inputs and outputs. + * + * @param inputType type for inputs. + * @param outputType type for outputs. + */ + public SdkPluginTask(SdkType inputType, SdkType outputType) { + this.inputType = inputType; + this.outputType = outputType; + } + + public abstract String getType(); + + @Override + public SdkType getInputType() { + return inputType; + } + + @Override + public SdkType getOutputType() { + return outputType; + } + + /** Specifies custom data that can be read by the backend plugin. */ + public SdkStruct getCustom() { + return SdkStruct.empty(); + } + + /** + * Number of retries. Retries will be consumed when the task fails with a recoverable error. The + * number of retries must be less than or equals to 10. + * + * @return number of retries + */ + public int getRetries() { + return 0; + } + + /** + * Indicates whether the system should attempt to look up this task's output to avoid duplication + * of work. + */ + public boolean isCached() { + return false; + } + + /** Indicates a logical version to apply to this task for the purpose of cache. */ + public String getCacheVersion() { + return null; + } + + /** + * Indicates whether the system should attempt to execute cached instances in serial to avoid + * duplicate work. + */ + public boolean isCacheSerializable() { + return false; + } + + @Override + SdkNode apply( + SdkWorkflowBuilder builder, + String nodeId, + List upstreamNodeIds, + @Nullable SdkNodeMetadata metadata, + Map> inputs) { + PartialTaskIdentifier taskId = PartialTaskIdentifier.builder().name(getName()).build(); + List errors = + Compiler.validateApply(nodeId, inputs, getInputType().getVariableMap()); + + if (!errors.isEmpty()) { + throw new CompilerException(errors); + } + + return new SdkTaskNode<>( + builder, nodeId, taskId, upstreamNodeIds, metadata, inputs, outputType); + } + + /** + * Signaling whether this task is supposed to be handled by a synchronous backend plugin, + * defaulting to false. + */ + public boolean isSyncPlugin() { + return false; + } +} diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTaskRegistrar.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTaskRegistrar.java new file mode 100644 index 000000000..114aea66e --- /dev/null +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTaskRegistrar.java @@ -0,0 +1,146 @@ +/* + * Copyright 2023 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.flytekit; + +import com.google.auto.service.AutoService; +import java.util.HashMap; +import java.util.Map; +import java.util.ServiceLoader; +import java.util.logging.Level; +import java.util.logging.Logger; +import org.flyte.api.v1.PluginTask; +import org.flyte.api.v1.PluginTaskRegistrar; +import org.flyte.api.v1.RetryStrategy; +import org.flyte.api.v1.Struct; +import org.flyte.api.v1.TaskIdentifier; +import org.flyte.api.v1.TypedInterface; + +/** + * Default implementation of a {@link PluginTaskRegistrar} that discovers {@link SdkPluginTask}s + * implementation via {@link ServiceLoader} mechanism. Plugin tasks implementations must use + * {@code @AutoService(SdkPluginTask.class)} or manually add their fully qualifies name to the + * corresponding file. + * + * @see ServiceLoader + */ +@AutoService(PluginTaskRegistrar.class) +public class SdkPluginTaskRegistrar extends PluginTaskRegistrar { + private static final Logger LOG = Logger.getLogger(SdkPluginTaskRegistrar.class.getName()); + + static { + // enable all levels for the actual handler to pick up + LOG.setLevel(Level.ALL); + } + + private static class PluginTaskImpl implements PluginTask { + private final SdkPluginTask sdkTask; + + private PluginTaskImpl(SdkPluginTask sdkTask) { + this.sdkTask = sdkTask; + } + + @Override + public String getType() { + return sdkTask.getType(); + } + + @Override + public Struct getCustom() { + return sdkTask.getCustom().struct(); + } + + @Override + public TypedInterface getInterface() { + return TypedInterface.builder() + .inputs(sdkTask.getInputType().getVariableMap()) + .outputs(sdkTask.getOutputType().getVariableMap()) + .build(); + } + + @Override + public RetryStrategy getRetries() { + return RetryStrategy.builder().retries(sdkTask.getRetries()).build(); + } + + @Override + public boolean isCached() { + return sdkTask.isCached(); + } + + @Override + public String getCacheVersion() { + return sdkTask.getCacheVersion(); + } + + @Override + public boolean isCacheSerializable() { + return sdkTask.isCacheSerializable(); + } + + @Override + public String getName() { + return sdkTask.getName(); + } + + @Override + public boolean isSyncPlugin() { + return sdkTask.isSyncPlugin(); + } + } + + /** + * Load {@link SdkPluginTask}s using {@link ServiceLoader}. + * + * @param env env vars in a map that would be used to pick up the project, domain and version for + * the discovered tasks. + * @param classLoader class loader to use when discovering the task using {@link + * ServiceLoader#load(Class, ClassLoader)} + * @return a map of {@link SdkPluginTask}s by its task identifier. + */ + @Override + @SuppressWarnings("rawtypes") + public Map load(Map env, ClassLoader classLoader) { + ServiceLoader loader = ServiceLoader.load(SdkPluginTask.class, classLoader); + + LOG.fine("Discovering SdkPluginTask"); + + Map tasks = new HashMap<>(); + SdkConfig sdkConfig = SdkConfig.load(env); + + for (SdkPluginTask sdkTask : loader) { + String name = sdkTask.getName(); + TaskIdentifier taskId = + TaskIdentifier.builder() + .domain(sdkConfig.domain()) + .project(sdkConfig.project()) + .name(name) + .version(sdkConfig.version()) + .build(); + LOG.fine(String.format("Discovered [%s]", name)); + + PluginTask task = new PluginTaskImpl<>(sdkTask); + PluginTask previous = tasks.put(taskId, task); + + if (previous != null) { + throw new IllegalArgumentException( + String.format("Discovered a duplicate task [%s] [%s] [%s]", name, task, previous)); + } + } + + return tasks; + } +} diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/SdkPluginTaskRegistrarTest.java b/flytekit-java/src/test/java/org/flyte/flytekit/SdkPluginTaskRegistrarTest.java new file mode 100644 index 000000000..87723eaad --- /dev/null +++ b/flytekit-java/src/test/java/org/flyte/flytekit/SdkPluginTaskRegistrarTest.java @@ -0,0 +1,171 @@ +/* + * Copyright 2023 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.flytekit; + +import static org.flyte.flytekit.SdkConfig.DOMAIN_ENV_VAR; +import static org.flyte.flytekit.SdkConfig.PROJECT_ENV_VAR; +import static org.flyte.flytekit.SdkConfig.VERSION_ENV_VAR; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.hasKey; +import static org.junit.jupiter.api.Assertions.assertAll; + +import com.google.auto.service.AutoService; +import java.util.Map; +import org.flyte.api.v1.PluginTask; +import org.flyte.api.v1.RetryStrategy; +import org.flyte.api.v1.TaskIdentifier; +import org.flyte.api.v1.TypedInterface; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class SdkPluginTaskRegistrarTest { + private static final String TASK_TYPE = "test-task"; + private static final Map ENV = + Map.of(PROJECT_ENV_VAR, "project", DOMAIN_ENV_VAR, "domain", VERSION_ENV_VAR, "version"); + + private SdkPluginTaskRegistrar registrar; + + @BeforeEach + void setUp() { + registrar = new SdkPluginTaskRegistrar(); + } + + @Test + void shouldLoadPluginTasksFromDiscoveredRegistries() { + // given + String testTaskName = "org.flyte.flytekit.SdkPluginTaskRegistrarTest$TestTask"; + String otherTestTaskName = "org.flyte.flytekit.SdkPluginTaskRegistrarTest$OtherTestTask"; + TaskIdentifier expectedTestTaskId = + TaskIdentifier.builder() + .project("project") + .domain("domain") + .name(testTaskName) + .version("version") + .build(); + + TypedInterface typedInterface = + TypedInterface.builder() + .inputs(SdkTypes.nulls().getVariableMap()) + .outputs(SdkTypes.nulls().getVariableMap()) + .build(); + + RetryStrategy retries = RetryStrategy.builder().retries(0).build(); + RetryStrategy otherRetries = RetryStrategy.builder().retries(1).build(); + + PluginTask expectedTask = createPluginTask(testTaskName, typedInterface, retries, false); + + TaskIdentifier expectedOtherTestTaskId = + TaskIdentifier.builder() + .project("project") + .domain("domain") + .name(otherTestTaskName) + .version("version") + .build(); + PluginTask expectedOtherTask = + createPluginTask(otherTestTaskName, typedInterface, otherRetries, true); + + // when + Map tasks = registrar.load(ENV); + + // then + assertAll( + () -> assertThat(tasks, hasKey(is(expectedTestTaskId))), + () -> assertThat(tasks, hasKey(is(expectedOtherTestTaskId)))); + assertTaskEquals(tasks.get(expectedTestTaskId), expectedTask); + assertTaskEquals(tasks.get(expectedOtherTestTaskId), expectedOtherTask); + } + + private PluginTask createPluginTask( + String taskName, TypedInterface typedInterface, RetryStrategy retries, boolean isSyncPlugin) { + return new PluginTask() { + @Override + public boolean isSyncPlugin() { + return isSyncPlugin; + } + + @Override + public String getName() { + return taskName; + } + + @Override + public String getType() { + return TASK_TYPE; + } + + @Override + public TypedInterface getInterface() { + return typedInterface; + } + + @Override + public RetryStrategy getRetries() { + return retries; + } + }; + } + + private void assertTaskEquals(PluginTask actualTask, PluginTask expectedTask) { + assertThat(actualTask.getName(), equalTo(expectedTask.getName())); + assertThat(actualTask.getType(), equalTo(expectedTask.getType())); + assertThat(actualTask.getCustom(), equalTo(expectedTask.getCustom())); + assertThat(actualTask.getInterface(), equalTo(expectedTask.getInterface())); + assertThat(actualTask.getRetries(), equalTo(expectedTask.getRetries())); + } + + @AutoService(SdkPluginTask.class) + public static class TestTask extends SdkPluginTask { + + private static final long serialVersionUID = 2751205856616541247L; + + public TestTask() { + super(SdkTypes.nulls(), SdkTypes.nulls()); + } + + @Override + public String getType() { + return TASK_TYPE; + } + } + + @AutoService(SdkPluginTask.class) + public static class OtherTestTask extends SdkPluginTask { + + private static final long serialVersionUID = -7757282344498000982L; + + public OtherTestTask() { + super(SdkTypes.nulls(), SdkTypes.nulls()); + } + + @Override + public String getType() { + return TASK_TYPE; + } + + @Override + public int getRetries() { + return 1; + } + + @Override + public boolean isSyncPlugin() { + return true; + } + } +} diff --git a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java index fba0e4154..aec530da0 100644 --- a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java +++ b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java @@ -64,6 +64,8 @@ import org.flyte.api.v1.Node; import org.flyte.api.v1.PartialTaskIdentifier; import org.flyte.api.v1.PartialWorkflowIdentifier; +import org.flyte.api.v1.PluginTask; +import org.flyte.api.v1.PluginTaskRegistrar; import org.flyte.api.v1.Resources; import org.flyte.api.v1.Resources.ResourceName; import org.flyte.api.v1.RunnableTask; @@ -219,6 +221,10 @@ static ProjectClosure load( ClassLoaders.withClassLoader( packageClassLoader, () -> Registrars.loadAll(ContainerTaskRegistrar.class, env)); + Map pluginTasks = + ClassLoaders.withClassLoader( + packageClassLoader, () -> Registrars.loadAll(PluginTaskRegistrar.class, env)); + Map workflows = ClassLoaders.withClassLoader( packageClassLoader, () -> Registrars.loadAll(WorkflowTemplateRegistrar.class, env)); @@ -233,6 +239,7 @@ static ProjectClosure load( runnableTasks, dynamicWorkflowTasks, containerTasks, + pluginTasks, workflows, launchPlans); } @@ -243,10 +250,12 @@ static ProjectClosure load( Map runnableTasks, Map dynamicWorkflowTasks, Map containerTasks, + Map pluginTasks, Map workflowTemplates, Map launchPlans) { Map taskTemplates = - createTaskTemplates(config, runnableTasks, dynamicWorkflowTasks, containerTasks); + createTaskTemplates( + config, runnableTasks, dynamicWorkflowTasks, containerTasks, pluginTasks); // 2. rewrite workflows and launch plans Map rewrittenWorkflowTemplates = @@ -424,7 +433,8 @@ public static Map createTaskTemplates( ExecutionConfig config, Map runnableTasks, Map dynamicWorkflowTasks, - Map containerTasks) { + Map containerTasks, + Map pluginTasks) { Map taskTemplates = new HashMap<>(); runnableTasks.forEach( @@ -448,6 +458,13 @@ public static Map createTaskTemplates( taskTemplates.put(id, taskTemplate); }); + pluginTasks.forEach( + (id, task) -> { + TaskTemplate taskTemplate = createTaskTemplateForPluginTask(task); + + taskTemplates.put(id, taskTemplate); + }); + return taskTemplates; } @@ -473,7 +490,7 @@ static TaskTemplate createTaskTemplateForRunnableTask(RunnableTask task, String .resources(task.getResources()) .build(); - return createTaskTemplate(task, container); + return createTaskTemplateBuilder(task).container(container).build(); } @VisibleForTesting @@ -488,25 +505,30 @@ static TaskTemplate createTaskTemplateForContainerTask(ContainerTask task) { .resources(resources) .build(); - return createTaskTemplate(task, container); + return createTaskTemplateBuilder(task).container(container).build(); + } + + @VisibleForTesting + static TaskTemplate createTaskTemplateForPluginTask(PluginTask task) { + return createTaskTemplateBuilder(task).isSyncPlugin(task.isSyncPlugin()).build(); } - private static TaskTemplate createTaskTemplate(Task task, Container container) { + private static TaskTemplate.Builder createTaskTemplateBuilder(Task task) { TaskTemplate.Builder templateBuilder = TaskTemplate.builder() - .container(container) .interface_(task.getInterface()) .retries(task.getRetries()) .type(task.getType()) .custom(task.getCustom()) .discoverable(task.isCached()) - .cacheSerializable(task.isCacheSerializable()); + .cacheSerializable(task.isCacheSerializable()) + .isSyncPlugin(false); if (task.getCacheVersion() != null) { templateBuilder.discoveryVersion(task.getCacheVersion()); } - return templateBuilder.build(); + return templateBuilder; } private static Optional javaToolOptionsEnv(RunnableTask task) { @@ -559,6 +581,7 @@ private static TaskTemplate createTaskTemplateForDynamicWorkflow( // it or change this comment to explicitly say no cache for dynamic tasks .discoverable(false) .cacheSerializable(false) + .isSyncPlugin(false) .build(); } diff --git a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProtoUtil.java b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProtoUtil.java index 3821db033..d7647ab70 100644 --- a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProtoUtil.java +++ b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProtoUtil.java @@ -45,6 +45,9 @@ import flyteidl.core.Interface; import flyteidl.core.Literals; import flyteidl.core.Tasks; +import flyteidl.core.Tasks.PluginMetadata; +import flyteidl.core.Tasks.RuntimeMetadata; +import flyteidl.core.Tasks.RuntimeMetadata.RuntimeType; import flyteidl.core.Tasks.TaskMetadata; import flyteidl.core.Types; import flyteidl.core.Types.SchemaType.SchemaColumn.SchemaColumnType; @@ -327,10 +330,12 @@ static Tasks.TaskTemplate serialize(TaskTemplate taskTemplate) { private static TaskMetadata serializeTaskMetadata(TaskTemplate taskTemplate) { Tasks.RuntimeMetadata runtime = - Tasks.RuntimeMetadata.newBuilder() - .setType(Tasks.RuntimeMetadata.RuntimeType.FLYTE_SDK) + RuntimeMetadata.newBuilder() + .setType(RuntimeType.FLYTE_SDK) .setFlavor(RUNTIME_FLAVOR) .setVersion(RUNTIME_VERSION) + .setPluginMetadata( + PluginMetadata.newBuilder().setIsSyncPlugin(taskTemplate.isSyncPlugin()).build()) .build(); return TaskMetadata.newBuilder() @@ -354,6 +359,7 @@ static TaskTemplate deserialize(Tasks.TaskTemplate proto) { // Proto uses empty strings instead of null, we use null in TaskTemplate .discoveryVersion(emptyToNull(proto.getMetadata().getDiscoveryVersion())) .cacheSerializable(proto.getMetadata().getCacheSerializable()) + .isSyncPlugin(proto.getMetadata().getRuntime().getPluginMetadata().getIsSyncPlugin()) .build(); } diff --git a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/Fixtures.java b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/Fixtures.java index 37c58a958..fcff19737 100644 --- a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/Fixtures.java +++ b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/Fixtures.java @@ -54,6 +54,7 @@ final class Fixtures { .retries(RETRIES) .discoverable(false) .cacheSerializable(false) + .isSyncPlugin(false) .build(); private Fixtures() { diff --git a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/FlyteAdminClientTest.java b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/FlyteAdminClientTest.java index 93344b321..c5abfcfce 100644 --- a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/FlyteAdminClientTest.java +++ b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/FlyteAdminClientTest.java @@ -534,6 +534,10 @@ private TaskOuterClass.TaskSpec newTaskSpec() { .setType(Tasks.RuntimeMetadata.RuntimeType.FLYTE_SDK) .setFlavor(ProtoUtil.RUNTIME_FLAVOR) .setVersion(ProtoUtil.RUNTIME_VERSION) + .setPluginMetadata( + Tasks.PluginMetadata.newBuilder() + .setIsSyncPlugin(false) + .build()) .build()) .setRetries(Literals.RetryStrategy.newBuilder().setRetries(4).build()) .build()) diff --git a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProjectClosureTest.java b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProjectClosureTest.java index cd1f698e7..56d46bd7c 100644 --- a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProjectClosureTest.java +++ b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProjectClosureTest.java @@ -26,6 +26,7 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.nullValue; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -53,6 +54,7 @@ import org.flyte.api.v1.Operand; import org.flyte.api.v1.PartialTaskIdentifier; import org.flyte.api.v1.PartialWorkflowIdentifier; +import org.flyte.api.v1.PluginTask; import org.flyte.api.v1.Primitive; import org.flyte.api.v1.Resources; import org.flyte.api.v1.RetryStrategy; @@ -582,6 +584,28 @@ public void testCreateTaskTemplateForContainerTask() { assertThat(result.type(), equalTo("raw-container")); } + @Test + public void testCreateTaskTemplateForPluginTask() { + // given + PluginTask task = createPluginTask(); + + // when + TaskTemplate result = ProjectClosure.createTaskTemplateForPluginTask(task); + + // then + assertThat( + result.interface_(), + equalTo( + TypedInterface.builder() + .inputs(SdkTypes.nulls().getVariableMap()) + .outputs(SdkTypes.nulls().getVariableMap()) + .build())); + assertThat(result.custom(), equalTo(Struct.of(emptyMap()))); + assertThat(result.retries(), equalTo(RetryStrategy.builder().retries(0).build())); + assertThat(result.type(), equalTo("test-plugin-task")); + assertThat(result.isSyncPlugin(), is(true)); + } + @Test public void testCreateTaskTemplateForTasksWithDefaultCacheSettings() { // given @@ -829,6 +853,38 @@ public List getEnv() { }; } + private PluginTask createPluginTask() { + return new PluginTask() { + @Override + public boolean isSyncPlugin() { + return true; + } + + @Override + public String getName() { + return "foo"; + } + + @Override + public String getType() { + return "test-plugin-task"; + } + + @Override + public TypedInterface getInterface() { + return TypedInterface.builder() + .inputs(SdkTypes.nulls().getVariableMap()) + .outputs(SdkTypes.nulls().getVariableMap()) + .build(); + } + + @Override + public RetryStrategy getRetries() { + return RetryStrategy.builder().retries(0).build(); + } + }; + } + private T wrapTaskWithRetries(Class taskClass, T task, boolean cacheEnabled) { return Reflection.newProxy( taskClass, diff --git a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProtoReaderTest.java b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProtoReaderTest.java index c591ae5af..0349af703 100644 --- a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProtoReaderTest.java +++ b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProtoReaderTest.java @@ -89,6 +89,7 @@ void shouldReadTaskTemplate() throws IOException { .custom(Struct.of(emptyMap())) .discoverable(false) .cacheSerializable(false) + .isSyncPlugin(false) .build())); } diff --git a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProtoUtilTest.java b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProtoUtilTest.java index 9993b7e8b..5d98bb26b 100644 --- a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProtoUtilTest.java +++ b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProtoUtilTest.java @@ -45,6 +45,7 @@ import flyteidl.core.Interface; import flyteidl.core.Literals; import flyteidl.core.Tasks; +import flyteidl.core.Tasks.PluginMetadata; import flyteidl.core.Tasks.TaskMetadata; import flyteidl.core.Types; import flyteidl.core.Types.SchemaType.SchemaColumn.SchemaColumnType; @@ -439,6 +440,7 @@ void shouldSerDeTaskTemplate() { .discoverable(true) .discoveryVersion("0.0.1") .cacheSerializable(true) + .isSyncPlugin(false) .build(); Tasks.TaskTemplate templateProto = @@ -458,6 +460,8 @@ void shouldSerDeTaskTemplate() { .setType(Tasks.RuntimeMetadata.RuntimeType.FLYTE_SDK) .setFlavor(ProtoUtil.RUNTIME_FLAVOR) .setVersion(ProtoUtil.RUNTIME_VERSION) + .setPluginMetadata( + PluginMetadata.newBuilder().setIsSyncPlugin(false).build()) .build()) .setRetries(Literals.RetryStrategy.newBuilder().setRetries(4).build()) .setDiscoverable(true) @@ -522,6 +526,7 @@ void shouldSerializeTaskTemplateHandlingNullStringsAsEmptyString() { .discoverable(false) .cacheSerializable(false) .discoveryVersion(null) + .isSyncPlugin(false) .build(); Tasks.TaskTemplate protoTemplate = ProtoUtil.serialize(apiTemplate); @@ -548,6 +553,7 @@ void shouldSerializeTaskTemplatePropagatingNonNullStringAsIs() { .discoverable(true) .cacheSerializable(true) .discoveryVersion("1") + .isSyncPlugin(false) .build(); Tasks.TaskTemplate protoTemplate = ProtoUtil.serialize(apiTemplate); diff --git a/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java b/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java index c011611dc..b46c44895 100644 --- a/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java +++ b/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java @@ -44,6 +44,8 @@ import org.flyte.api.v1.Literal; import org.flyte.api.v1.NamedEntityIdentifier; import org.flyte.api.v1.Node; +import org.flyte.api.v1.PluginTask; +import org.flyte.api.v1.PluginTaskRegistrar; import org.flyte.api.v1.RunnableTask; import org.flyte.api.v1.RunnableTaskRegistrar; import org.flyte.api.v1.Struct; @@ -150,6 +152,10 @@ private void execute() { ClassLoaders.withClassLoader( packageClassLoader, () -> Registrars.loadAll(ContainerTaskRegistrar.class, env)); + Map pluginTasks = + ClassLoaders.withClassLoader( + packageClassLoader, () -> Registrars.loadAll(PluginTaskRegistrar.class, env)); + // before we run anything, switch class loader, otherwise, // ServiceLoaders and other things wouldn't work, for instance, // FileSystemRegister in Apache Beam @@ -162,7 +168,11 @@ private void execute() { Map taskTemplates = mapValues( ProjectClosure.createTaskTemplates( - executionConfig, runnableTasks, dynamicWorkflowTasks, containerTasks), + executionConfig, + runnableTasks, + dynamicWorkflowTasks, + containerTasks, + pluginTasks), template -> template.toBuilder() .custom(ProjectClosure.merge(template.custom(), custom)) From 2f0c8acae61f15b340d52e71897dec0a58df33ed Mon Sep 17 00:00:00 2001 From: Hongxin Liang Date: Tue, 12 Dec 2023 22:03:45 +0100 Subject: [PATCH 2/5] Drop isSyncPlugin for now Signed-off-by: Hongxin Liang --- .../src/main/proto/flyteidl/core/tasks.proto | 8 -------- .../main/java/org/flyte/api/v1/PluginTask.java | 4 +--- .../main/java/org/flyte/api/v1/TaskTemplate.java | 5 ----- .../java/org/flyte/flytekit/SdkPluginTask.java | 8 -------- .../flyte/flytekit/SdkPluginTaskRegistrar.java | 5 ----- .../flytekit/SdkPluginTaskRegistrarTest.java | 16 +++------------- .../org/flyte/jflyte/utils/ProjectClosure.java | 6 ++---- .../java/org/flyte/jflyte/utils/ProtoUtil.java | 4 ---- .../java/org/flyte/jflyte/utils/Fixtures.java | 1 - .../flyte/jflyte/utils/FlyteAdminClientTest.java | 4 ---- .../flyte/jflyte/utils/ProjectClosureTest.java | 7 ------- .../org/flyte/jflyte/utils/ProtoReaderTest.java | 1 - .../org/flyte/jflyte/utils/ProtoUtilTest.java | 6 ------ 13 files changed, 6 insertions(+), 69 deletions(-) diff --git a/flyteidl-protos/src/main/proto/flyteidl/core/tasks.proto b/flyteidl-protos/src/main/proto/flyteidl/core/tasks.proto index 5844ac5c3..48961029e 100644 --- a/flyteidl-protos/src/main/proto/flyteidl/core/tasks.proto +++ b/flyteidl-protos/src/main/proto/flyteidl/core/tasks.proto @@ -59,14 +59,6 @@ message RuntimeMetadata { //+optional It can be used to provide extra information about the runtime (e.g. python, golang... etc.). string flavor = 3; - - //+optional It can be used to provide extra information for the plugin. - PluginMetadata plugin_metadata = 4; -} - -message PluginMetadata { - //+optional It can be used to decide use sync plugin or async plugin during runtime. - bool is_sync_plugin = 1; } // Task Metadata diff --git a/flytekit-api/src/main/java/org/flyte/api/v1/PluginTask.java b/flytekit-api/src/main/java/org/flyte/api/v1/PluginTask.java index fc1a7d8dc..af22b3ff4 100644 --- a/flytekit-api/src/main/java/org/flyte/api/v1/PluginTask.java +++ b/flytekit-api/src/main/java/org/flyte/api/v1/PluginTask.java @@ -17,6 +17,4 @@ package org.flyte.api.v1; /** A task that is handled by a Flyte backend plugin instead of run as a container. */ -public interface PluginTask extends Task { - boolean isSyncPlugin(); -} +public interface PluginTask extends Task {} diff --git a/flytekit-api/src/main/java/org/flyte/api/v1/TaskTemplate.java b/flytekit-api/src/main/java/org/flyte/api/v1/TaskTemplate.java index 7452ba6fc..43a7d60cc 100644 --- a/flytekit-api/src/main/java/org/flyte/api/v1/TaskTemplate.java +++ b/flytekit-api/src/main/java/org/flyte/api/v1/TaskTemplate.java @@ -67,9 +67,6 @@ public abstract class TaskTemplate { */ public abstract boolean cacheSerializable(); - /** Indicates whether to use sync plugin or async plugin to handle this task. */ - public abstract boolean isSyncPlugin(); - public abstract Builder toBuilder(); public static Builder builder() { @@ -95,8 +92,6 @@ public abstract static class Builder { public abstract Builder cacheSerializable(boolean cacheSerializable); - public abstract Builder isSyncPlugin(boolean isSyncPlugin); - public abstract TaskTemplate build(); } } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTask.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTask.java index eedf3ab33..ab31ddec6 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTask.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTask.java @@ -104,12 +104,4 @@ SdkNode apply( return new SdkTaskNode<>( builder, nodeId, taskId, upstreamNodeIds, metadata, inputs, outputType); } - - /** - * Signaling whether this task is supposed to be handled by a synchronous backend plugin, - * defaulting to false. - */ - public boolean isSyncPlugin() { - return false; - } } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTaskRegistrar.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTaskRegistrar.java index 114aea66e..939ef7a0e 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTaskRegistrar.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTaskRegistrar.java @@ -95,11 +95,6 @@ public boolean isCacheSerializable() { public String getName() { return sdkTask.getName(); } - - @Override - public boolean isSyncPlugin() { - return sdkTask.isSyncPlugin(); - } } /** diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/SdkPluginTaskRegistrarTest.java b/flytekit-java/src/test/java/org/flyte/flytekit/SdkPluginTaskRegistrarTest.java index 87723eaad..66a481650 100644 --- a/flytekit-java/src/test/java/org/flyte/flytekit/SdkPluginTaskRegistrarTest.java +++ b/flytekit-java/src/test/java/org/flyte/flytekit/SdkPluginTaskRegistrarTest.java @@ -68,7 +68,7 @@ void shouldLoadPluginTasksFromDiscoveredRegistries() { RetryStrategy retries = RetryStrategy.builder().retries(0).build(); RetryStrategy otherRetries = RetryStrategy.builder().retries(1).build(); - PluginTask expectedTask = createPluginTask(testTaskName, typedInterface, retries, false); + PluginTask expectedTask = createPluginTask(testTaskName, typedInterface, retries); TaskIdentifier expectedOtherTestTaskId = TaskIdentifier.builder() @@ -78,7 +78,7 @@ void shouldLoadPluginTasksFromDiscoveredRegistries() { .version("version") .build(); PluginTask expectedOtherTask = - createPluginTask(otherTestTaskName, typedInterface, otherRetries, true); + createPluginTask(otherTestTaskName, typedInterface, otherRetries); // when Map tasks = registrar.load(ENV); @@ -92,13 +92,8 @@ void shouldLoadPluginTasksFromDiscoveredRegistries() { } private PluginTask createPluginTask( - String taskName, TypedInterface typedInterface, RetryStrategy retries, boolean isSyncPlugin) { + String taskName, TypedInterface typedInterface, RetryStrategy retries) { return new PluginTask() { - @Override - public boolean isSyncPlugin() { - return isSyncPlugin; - } - @Override public String getName() { return taskName; @@ -162,10 +157,5 @@ public String getType() { public int getRetries() { return 1; } - - @Override - public boolean isSyncPlugin() { - return true; - } } } diff --git a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java index aec530da0..275e3ab95 100644 --- a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java +++ b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java @@ -510,7 +510,7 @@ static TaskTemplate createTaskTemplateForContainerTask(ContainerTask task) { @VisibleForTesting static TaskTemplate createTaskTemplateForPluginTask(PluginTask task) { - return createTaskTemplateBuilder(task).isSyncPlugin(task.isSyncPlugin()).build(); + return createTaskTemplateBuilder(task).build(); } private static TaskTemplate.Builder createTaskTemplateBuilder(Task task) { @@ -521,8 +521,7 @@ private static TaskTemplate.Builder createTaskTemplateBuilder(Task task) { .type(task.getType()) .custom(task.getCustom()) .discoverable(task.isCached()) - .cacheSerializable(task.isCacheSerializable()) - .isSyncPlugin(false); + .cacheSerializable(task.isCacheSerializable()); if (task.getCacheVersion() != null) { templateBuilder.discoveryVersion(task.getCacheVersion()); @@ -581,7 +580,6 @@ private static TaskTemplate createTaskTemplateForDynamicWorkflow( // it or change this comment to explicitly say no cache for dynamic tasks .discoverable(false) .cacheSerializable(false) - .isSyncPlugin(false) .build(); } diff --git a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProtoUtil.java b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProtoUtil.java index d7647ab70..c39a6f85d 100644 --- a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProtoUtil.java +++ b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProtoUtil.java @@ -45,7 +45,6 @@ import flyteidl.core.Interface; import flyteidl.core.Literals; import flyteidl.core.Tasks; -import flyteidl.core.Tasks.PluginMetadata; import flyteidl.core.Tasks.RuntimeMetadata; import flyteidl.core.Tasks.RuntimeMetadata.RuntimeType; import flyteidl.core.Tasks.TaskMetadata; @@ -334,8 +333,6 @@ private static TaskMetadata serializeTaskMetadata(TaskTemplate taskTemplate) { .setType(RuntimeType.FLYTE_SDK) .setFlavor(RUNTIME_FLAVOR) .setVersion(RUNTIME_VERSION) - .setPluginMetadata( - PluginMetadata.newBuilder().setIsSyncPlugin(taskTemplate.isSyncPlugin()).build()) .build(); return TaskMetadata.newBuilder() @@ -359,7 +356,6 @@ static TaskTemplate deserialize(Tasks.TaskTemplate proto) { // Proto uses empty strings instead of null, we use null in TaskTemplate .discoveryVersion(emptyToNull(proto.getMetadata().getDiscoveryVersion())) .cacheSerializable(proto.getMetadata().getCacheSerializable()) - .isSyncPlugin(proto.getMetadata().getRuntime().getPluginMetadata().getIsSyncPlugin()) .build(); } diff --git a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/Fixtures.java b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/Fixtures.java index fcff19737..37c58a958 100644 --- a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/Fixtures.java +++ b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/Fixtures.java @@ -54,7 +54,6 @@ final class Fixtures { .retries(RETRIES) .discoverable(false) .cacheSerializable(false) - .isSyncPlugin(false) .build(); private Fixtures() { diff --git a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/FlyteAdminClientTest.java b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/FlyteAdminClientTest.java index c5abfcfce..93344b321 100644 --- a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/FlyteAdminClientTest.java +++ b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/FlyteAdminClientTest.java @@ -534,10 +534,6 @@ private TaskOuterClass.TaskSpec newTaskSpec() { .setType(Tasks.RuntimeMetadata.RuntimeType.FLYTE_SDK) .setFlavor(ProtoUtil.RUNTIME_FLAVOR) .setVersion(ProtoUtil.RUNTIME_VERSION) - .setPluginMetadata( - Tasks.PluginMetadata.newBuilder() - .setIsSyncPlugin(false) - .build()) .build()) .setRetries(Literals.RetryStrategy.newBuilder().setRetries(4).build()) .build()) diff --git a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProjectClosureTest.java b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProjectClosureTest.java index 56d46bd7c..d7b270da1 100644 --- a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProjectClosureTest.java +++ b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProjectClosureTest.java @@ -26,7 +26,6 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.nullValue; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -603,7 +602,6 @@ public void testCreateTaskTemplateForPluginTask() { assertThat(result.custom(), equalTo(Struct.of(emptyMap()))); assertThat(result.retries(), equalTo(RetryStrategy.builder().retries(0).build())); assertThat(result.type(), equalTo("test-plugin-task")); - assertThat(result.isSyncPlugin(), is(true)); } @Test @@ -855,11 +853,6 @@ public List getEnv() { private PluginTask createPluginTask() { return new PluginTask() { - @Override - public boolean isSyncPlugin() { - return true; - } - @Override public String getName() { return "foo"; diff --git a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProtoReaderTest.java b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProtoReaderTest.java index 0349af703..c591ae5af 100644 --- a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProtoReaderTest.java +++ b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProtoReaderTest.java @@ -89,7 +89,6 @@ void shouldReadTaskTemplate() throws IOException { .custom(Struct.of(emptyMap())) .discoverable(false) .cacheSerializable(false) - .isSyncPlugin(false) .build())); } diff --git a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProtoUtilTest.java b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProtoUtilTest.java index 5d98bb26b..9993b7e8b 100644 --- a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProtoUtilTest.java +++ b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProtoUtilTest.java @@ -45,7 +45,6 @@ import flyteidl.core.Interface; import flyteidl.core.Literals; import flyteidl.core.Tasks; -import flyteidl.core.Tasks.PluginMetadata; import flyteidl.core.Tasks.TaskMetadata; import flyteidl.core.Types; import flyteidl.core.Types.SchemaType.SchemaColumn.SchemaColumnType; @@ -440,7 +439,6 @@ void shouldSerDeTaskTemplate() { .discoverable(true) .discoveryVersion("0.0.1") .cacheSerializable(true) - .isSyncPlugin(false) .build(); Tasks.TaskTemplate templateProto = @@ -460,8 +458,6 @@ void shouldSerDeTaskTemplate() { .setType(Tasks.RuntimeMetadata.RuntimeType.FLYTE_SDK) .setFlavor(ProtoUtil.RUNTIME_FLAVOR) .setVersion(ProtoUtil.RUNTIME_VERSION) - .setPluginMetadata( - PluginMetadata.newBuilder().setIsSyncPlugin(false).build()) .build()) .setRetries(Literals.RetryStrategy.newBuilder().setRetries(4).build()) .setDiscoverable(true) @@ -526,7 +522,6 @@ void shouldSerializeTaskTemplateHandlingNullStringsAsEmptyString() { .discoverable(false) .cacheSerializable(false) .discoveryVersion(null) - .isSyncPlugin(false) .build(); Tasks.TaskTemplate protoTemplate = ProtoUtil.serialize(apiTemplate); @@ -553,7 +548,6 @@ void shouldSerializeTaskTemplatePropagatingNonNullStringAsIs() { .discoverable(true) .cacheSerializable(true) .discoveryVersion("1") - .isSyncPlugin(false) .build(); Tasks.TaskTemplate protoTemplate = ProtoUtil.serialize(apiTemplate); From 70ad715229866614c226e7d98dad2bea3125f647 Mon Sep 17 00:00:00 2001 From: Hongxin Liang Date: Tue, 12 Dec 2023 23:09:47 +0100 Subject: [PATCH 3/5] Skip staging and merging custom where applicable Signed-off-by: Hongxin Liang --- .../org/flyte/examples/NoopPluginTask.java | 34 +++++++++++++++++++ .../flyte/jflyte/utils/ProjectClosure.java | 25 ++++++++------ 2 files changed, 49 insertions(+), 10 deletions(-) create mode 100644 flytekit-examples/src/main/java/org/flyte/examples/NoopPluginTask.java diff --git a/flytekit-examples/src/main/java/org/flyte/examples/NoopPluginTask.java b/flytekit-examples/src/main/java/org/flyte/examples/NoopPluginTask.java new file mode 100644 index 000000000..2be34420b --- /dev/null +++ b/flytekit-examples/src/main/java/org/flyte/examples/NoopPluginTask.java @@ -0,0 +1,34 @@ +/* + * Copyright 2023 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.examples; + +import com.google.auto.service.AutoService; +import org.flyte.flytekit.SdkPluginTask; +import org.flyte.flytekit.SdkTypes; + +@AutoService(SdkPluginTask.class) +public class NoopPluginTask extends SdkPluginTask { + + public NoopPluginTask() { + super(SdkTypes.nulls(), SdkTypes.nulls()); + } + + @Override + public String getType() { + return "noop"; + } +} diff --git a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java index 275e3ab95..21ad821c7 100644 --- a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java +++ b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java @@ -137,6 +137,10 @@ public void serialize(BiConsumer output) { } private static TaskSpec applyCustom(TaskSpec taskSpec, JFlyteCustom custom) { + if (taskSpec.taskTemplate().container() == null) { + return taskSpec; + } + Struct rewrittenCustom = merge(custom.serializeToStruct(), taskSpec.taskTemplate().custom()); TaskTemplate rewrittenTaskTemplate = taskSpec.taskTemplate().toBuilder().custom(rewrittenCustom).build(); @@ -163,25 +167,26 @@ public static ProjectClosure loadAndStage( ProjectClosure closure = ProjectClosure.load(config, rewrite, packageClassLoader); - List artifacts; if (isStagingRequired(closure)) { - artifacts = stagePackageFiles(stagerSupplier.get(), packageDir); - } else { - artifacts = emptyList(); - LOG.info( - "Skipping artifact staging because there are no runnable tasks or dynamic workflow tasks"); + List artifacts = stagePackageFiles(stagerSupplier.get(), packageDir); + JFlyteCustom custom = JFlyteCustom.builder().artifacts(artifacts).build(); + return closure.applyCustom(custom); } - JFlyteCustom custom = JFlyteCustom.builder().artifacts(artifacts).build(); + LOG.info( + "Skipping artifact staging because there are no runnable tasks or dynamic workflow tasks"); - return closure.applyCustom(custom); + return closure; } private static boolean isStagingRequired(ProjectClosure closure) { return closure.taskSpecs().values().stream() .map(TaskSpec::taskTemplate) - .map(TaskTemplate::type) - .anyMatch(type -> !type.equals("raw-container")); + .anyMatch(ProjectClosure::isRunnableOrDynamicWorkflowTask); + } + + private static boolean isRunnableOrDynamicWorkflowTask(TaskTemplate taskTemplate) { + return taskTemplate.container() != null && !taskTemplate.type().equals("raw-container"); } private static List stagePackageFiles(ArtifactStager stager, String packageDir) { From bd0d7054b479327f63e49096065db7be32392ffb Mon Sep 17 00:00:00 2001 From: Hongxin Liang Date: Tue, 12 Dec 2023 23:45:06 +0100 Subject: [PATCH 4/5] Skip container Signed-off-by: Hongxin Liang --- .../org/flyte/jflyte/utils/ProtoUtil.java | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProtoUtil.java b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProtoUtil.java index c39a6f85d..1d542b0b9 100644 --- a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProtoUtil.java +++ b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProtoUtil.java @@ -19,7 +19,6 @@ import static com.google.common.base.Strings.emptyToNull; import static com.google.common.base.Strings.nullToEmpty; import static java.time.format.DateTimeFormatter.ISO_DATE_TIME; -import static java.util.Objects.requireNonNull; import static org.flyte.jflyte.utils.MoreCollectors.mapValues; import static org.flyte.jflyte.utils.MoreCollectors.toUnmodifiableList; import static org.flyte.jflyte.utils.MoreCollectors.toUnmodifiableMap; @@ -312,19 +311,18 @@ static TaskOuterClass.TaskSpec serialize(TaskSpec spec) { } static Tasks.TaskTemplate serialize(TaskTemplate taskTemplate) { - Container container = - requireNonNull( - taskTemplate.container(), "Only container based task templates are supported"); - TaskMetadata metadata = serializeTaskMetadata(taskTemplate); - return Tasks.TaskTemplate.newBuilder() - .setContainer(serialize(container)) - .setMetadata(metadata) - .setInterface(serialize(taskTemplate.interface_())) - .setType(taskTemplate.type()) - .setCustom(serializeStruct(taskTemplate.custom())) - .build(); + Tasks.TaskTemplate.Builder builder = + Tasks.TaskTemplate.newBuilder() + .setMetadata(metadata) + .setInterface(serialize(taskTemplate.interface_())) + .setType(taskTemplate.type()) + .setCustom(serializeStruct(taskTemplate.custom())); + if (taskTemplate.container() == null) { + return builder.build(); + } + return builder.setContainer(serialize(taskTemplate.container())).build(); } private static TaskMetadata serializeTaskMetadata(TaskTemplate taskTemplate) { From 80ba88517b63435914907b9627fe049d4ca0c412 Mon Sep 17 00:00:00 2001 From: Hongxin Liang Date: Wed, 13 Dec 2023 09:10:04 +0100 Subject: [PATCH 5/5] More description of SdkPluginTask Signed-off-by: Hongxin Liang --- .../src/main/java/org/flyte/flytekit/SdkPluginTask.java | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTask.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTask.java index ab31ddec6..88db8ff44 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTask.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTask.java @@ -21,7 +21,12 @@ import javax.annotation.Nullable; import org.flyte.api.v1.PartialTaskIdentifier; -/** A task that is handled by a Flyte backend plugin instead of run as a container. */ +/** + * A task that is handled by a Flyte backend plugin instead of run as a container. Note that a + * plugin task template does not have a container defined, neither all the jars captured in + * classpath, so if this is a requirement, one should use SdkRunnableTask overriding run method to + * simply return null. + */ public abstract class SdkPluginTask extends SdkTransform { private final SdkType inputType;