diff --git a/flyteidl-protos/src/main/proto/flyteidl/core/tasks.proto b/flyteidl-protos/src/main/proto/flyteidl/core/tasks.proto index 48961029e..5844ac5c3 100644 --- a/flyteidl-protos/src/main/proto/flyteidl/core/tasks.proto +++ b/flyteidl-protos/src/main/proto/flyteidl/core/tasks.proto @@ -59,6 +59,14 @@ message RuntimeMetadata { //+optional It can be used to provide extra information about the runtime (e.g. python, golang... etc.). string flavor = 3; + + //+optional It can be used to provide extra information for the plugin. + PluginMetadata plugin_metadata = 4; +} + +message PluginMetadata { + //+optional It can be used to decide use sync plugin or async plugin during runtime. + bool is_sync_plugin = 1; } // Task Metadata diff --git a/flytekit-api/src/main/java/org/flyte/api/v1/PluginTask.java b/flytekit-api/src/main/java/org/flyte/api/v1/PluginTask.java new file mode 100644 index 000000000..fc1a7d8dc --- /dev/null +++ b/flytekit-api/src/main/java/org/flyte/api/v1/PluginTask.java @@ -0,0 +1,22 @@ +/* + * Copyright 2023 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.api.v1; + +/** A task that is handled by a Flyte backend plugin instead of run as a container. */ +public interface PluginTask extends Task { + boolean isSyncPlugin(); +} diff --git a/flytekit-api/src/main/java/org/flyte/api/v1/PluginTaskRegistrar.java b/flytekit-api/src/main/java/org/flyte/api/v1/PluginTaskRegistrar.java new file mode 100644 index 000000000..d2562e4fa --- /dev/null +++ b/flytekit-api/src/main/java/org/flyte/api/v1/PluginTaskRegistrar.java @@ -0,0 +1,20 @@ +/* + * Copyright 2023 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.api.v1; + +/** A registrar that creates {@link PluginTask} instances. */ +public abstract class PluginTaskRegistrar implements Registrar {} diff --git a/flytekit-api/src/main/java/org/flyte/api/v1/TaskTemplate.java b/flytekit-api/src/main/java/org/flyte/api/v1/TaskTemplate.java index 5a48c8900..7452ba6fc 100644 --- a/flytekit-api/src/main/java/org/flyte/api/v1/TaskTemplate.java +++ b/flytekit-api/src/main/java/org/flyte/api/v1/TaskTemplate.java @@ -22,6 +22,9 @@ /** * A Task structure that uniquely identifies a task in the system. Tasks are registered as a first * step in the system. + * + *

FIXME: consider offering TaskMetadata instead of having everything in TaskTemplate, see + * https://github.com/flyteorg/flyte/blob/ea72bbd12578d64087221592554fb71c368f8057/flyteidl/protos/flyteidl/core/tasks.proto#L90 */ @AutoValue public abstract class TaskTemplate { @@ -64,6 +67,9 @@ public abstract class TaskTemplate { */ public abstract boolean cacheSerializable(); + /** Indicates whether to use sync plugin or async plugin to handle this task. */ + public abstract boolean isSyncPlugin(); + public abstract Builder toBuilder(); public static Builder builder() { @@ -89,6 +95,8 @@ public abstract static class Builder { public abstract Builder cacheSerializable(boolean cacheSerializable); + public abstract Builder isSyncPlugin(boolean isSyncPlugin); + public abstract TaskTemplate build(); } } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTask.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTask.java new file mode 100644 index 000000000..eedf3ab33 --- /dev/null +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTask.java @@ -0,0 +1,115 @@ +/* + * Copyright 2023 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.flytekit; + +import java.util.List; +import java.util.Map; +import javax.annotation.Nullable; +import org.flyte.api.v1.PartialTaskIdentifier; + +/** A task that is handled by a Flyte backend plugin instead of run as a container. */ +public abstract class SdkPluginTask extends SdkTransform { + + private final SdkType inputType; + private final SdkType outputType; + + /** + * Called by subclasses passing the {@link SdkType}s for inputs and outputs. + * + * @param inputType type for inputs. + * @param outputType type for outputs. + */ + public SdkPluginTask(SdkType inputType, SdkType outputType) { + this.inputType = inputType; + this.outputType = outputType; + } + + public abstract String getType(); + + @Override + public SdkType getInputType() { + return inputType; + } + + @Override + public SdkType getOutputType() { + return outputType; + } + + /** Specifies custom data that can be read by the backend plugin. */ + public SdkStruct getCustom() { + return SdkStruct.empty(); + } + + /** + * Number of retries. Retries will be consumed when the task fails with a recoverable error. The + * number of retries must be less than or equals to 10. + * + * @return number of retries + */ + public int getRetries() { + return 0; + } + + /** + * Indicates whether the system should attempt to look up this task's output to avoid duplication + * of work. + */ + public boolean isCached() { + return false; + } + + /** Indicates a logical version to apply to this task for the purpose of cache. */ + public String getCacheVersion() { + return null; + } + + /** + * Indicates whether the system should attempt to execute cached instances in serial to avoid + * duplicate work. + */ + public boolean isCacheSerializable() { + return false; + } + + @Override + SdkNode apply( + SdkWorkflowBuilder builder, + String nodeId, + List upstreamNodeIds, + @Nullable SdkNodeMetadata metadata, + Map> inputs) { + PartialTaskIdentifier taskId = PartialTaskIdentifier.builder().name(getName()).build(); + List errors = + Compiler.validateApply(nodeId, inputs, getInputType().getVariableMap()); + + if (!errors.isEmpty()) { + throw new CompilerException(errors); + } + + return new SdkTaskNode<>( + builder, nodeId, taskId, upstreamNodeIds, metadata, inputs, outputType); + } + + /** + * Signaling whether this task is supposed to be handled by a synchronous backend plugin, + * defaulting to false. + */ + public boolean isSyncPlugin() { + return false; + } +} diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTaskRegistrar.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTaskRegistrar.java new file mode 100644 index 000000000..114aea66e --- /dev/null +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkPluginTaskRegistrar.java @@ -0,0 +1,146 @@ +/* + * Copyright 2023 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.flytekit; + +import com.google.auto.service.AutoService; +import java.util.HashMap; +import java.util.Map; +import java.util.ServiceLoader; +import java.util.logging.Level; +import java.util.logging.Logger; +import org.flyte.api.v1.PluginTask; +import org.flyte.api.v1.PluginTaskRegistrar; +import org.flyte.api.v1.RetryStrategy; +import org.flyte.api.v1.Struct; +import org.flyte.api.v1.TaskIdentifier; +import org.flyte.api.v1.TypedInterface; + +/** + * Default implementation of a {@link PluginTaskRegistrar} that discovers {@link SdkPluginTask}s + * implementation via {@link ServiceLoader} mechanism. Plugin tasks implementations must use + * {@code @AutoService(SdkPluginTask.class)} or manually add their fully qualifies name to the + * corresponding file. + * + * @see ServiceLoader + */ +@AutoService(PluginTaskRegistrar.class) +public class SdkPluginTaskRegistrar extends PluginTaskRegistrar { + private static final Logger LOG = Logger.getLogger(SdkPluginTaskRegistrar.class.getName()); + + static { + // enable all levels for the actual handler to pick up + LOG.setLevel(Level.ALL); + } + + private static class PluginTaskImpl implements PluginTask { + private final SdkPluginTask sdkTask; + + private PluginTaskImpl(SdkPluginTask sdkTask) { + this.sdkTask = sdkTask; + } + + @Override + public String getType() { + return sdkTask.getType(); + } + + @Override + public Struct getCustom() { + return sdkTask.getCustom().struct(); + } + + @Override + public TypedInterface getInterface() { + return TypedInterface.builder() + .inputs(sdkTask.getInputType().getVariableMap()) + .outputs(sdkTask.getOutputType().getVariableMap()) + .build(); + } + + @Override + public RetryStrategy getRetries() { + return RetryStrategy.builder().retries(sdkTask.getRetries()).build(); + } + + @Override + public boolean isCached() { + return sdkTask.isCached(); + } + + @Override + public String getCacheVersion() { + return sdkTask.getCacheVersion(); + } + + @Override + public boolean isCacheSerializable() { + return sdkTask.isCacheSerializable(); + } + + @Override + public String getName() { + return sdkTask.getName(); + } + + @Override + public boolean isSyncPlugin() { + return sdkTask.isSyncPlugin(); + } + } + + /** + * Load {@link SdkPluginTask}s using {@link ServiceLoader}. + * + * @param env env vars in a map that would be used to pick up the project, domain and version for + * the discovered tasks. + * @param classLoader class loader to use when discovering the task using {@link + * ServiceLoader#load(Class, ClassLoader)} + * @return a map of {@link SdkPluginTask}s by its task identifier. + */ + @Override + @SuppressWarnings("rawtypes") + public Map load(Map env, ClassLoader classLoader) { + ServiceLoader loader = ServiceLoader.load(SdkPluginTask.class, classLoader); + + LOG.fine("Discovering SdkPluginTask"); + + Map tasks = new HashMap<>(); + SdkConfig sdkConfig = SdkConfig.load(env); + + for (SdkPluginTask sdkTask : loader) { + String name = sdkTask.getName(); + TaskIdentifier taskId = + TaskIdentifier.builder() + .domain(sdkConfig.domain()) + .project(sdkConfig.project()) + .name(name) + .version(sdkConfig.version()) + .build(); + LOG.fine(String.format("Discovered [%s]", name)); + + PluginTask task = new PluginTaskImpl<>(sdkTask); + PluginTask previous = tasks.put(taskId, task); + + if (previous != null) { + throw new IllegalArgumentException( + String.format("Discovered a duplicate task [%s] [%s] [%s]", name, task, previous)); + } + } + + return tasks; + } +} diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/SdkPluginTaskRegistrarTest.java b/flytekit-java/src/test/java/org/flyte/flytekit/SdkPluginTaskRegistrarTest.java new file mode 100644 index 000000000..87723eaad --- /dev/null +++ b/flytekit-java/src/test/java/org/flyte/flytekit/SdkPluginTaskRegistrarTest.java @@ -0,0 +1,171 @@ +/* + * Copyright 2023 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.flytekit; + +import static org.flyte.flytekit.SdkConfig.DOMAIN_ENV_VAR; +import static org.flyte.flytekit.SdkConfig.PROJECT_ENV_VAR; +import static org.flyte.flytekit.SdkConfig.VERSION_ENV_VAR; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.hasKey; +import static org.junit.jupiter.api.Assertions.assertAll; + +import com.google.auto.service.AutoService; +import java.util.Map; +import org.flyte.api.v1.PluginTask; +import org.flyte.api.v1.RetryStrategy; +import org.flyte.api.v1.TaskIdentifier; +import org.flyte.api.v1.TypedInterface; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class SdkPluginTaskRegistrarTest { + private static final String TASK_TYPE = "test-task"; + private static final Map ENV = + Map.of(PROJECT_ENV_VAR, "project", DOMAIN_ENV_VAR, "domain", VERSION_ENV_VAR, "version"); + + private SdkPluginTaskRegistrar registrar; + + @BeforeEach + void setUp() { + registrar = new SdkPluginTaskRegistrar(); + } + + @Test + void shouldLoadPluginTasksFromDiscoveredRegistries() { + // given + String testTaskName = "org.flyte.flytekit.SdkPluginTaskRegistrarTest$TestTask"; + String otherTestTaskName = "org.flyte.flytekit.SdkPluginTaskRegistrarTest$OtherTestTask"; + TaskIdentifier expectedTestTaskId = + TaskIdentifier.builder() + .project("project") + .domain("domain") + .name(testTaskName) + .version("version") + .build(); + + TypedInterface typedInterface = + TypedInterface.builder() + .inputs(SdkTypes.nulls().getVariableMap()) + .outputs(SdkTypes.nulls().getVariableMap()) + .build(); + + RetryStrategy retries = RetryStrategy.builder().retries(0).build(); + RetryStrategy otherRetries = RetryStrategy.builder().retries(1).build(); + + PluginTask expectedTask = createPluginTask(testTaskName, typedInterface, retries, false); + + TaskIdentifier expectedOtherTestTaskId = + TaskIdentifier.builder() + .project("project") + .domain("domain") + .name(otherTestTaskName) + .version("version") + .build(); + PluginTask expectedOtherTask = + createPluginTask(otherTestTaskName, typedInterface, otherRetries, true); + + // when + Map tasks = registrar.load(ENV); + + // then + assertAll( + () -> assertThat(tasks, hasKey(is(expectedTestTaskId))), + () -> assertThat(tasks, hasKey(is(expectedOtherTestTaskId)))); + assertTaskEquals(tasks.get(expectedTestTaskId), expectedTask); + assertTaskEquals(tasks.get(expectedOtherTestTaskId), expectedOtherTask); + } + + private PluginTask createPluginTask( + String taskName, TypedInterface typedInterface, RetryStrategy retries, boolean isSyncPlugin) { + return new PluginTask() { + @Override + public boolean isSyncPlugin() { + return isSyncPlugin; + } + + @Override + public String getName() { + return taskName; + } + + @Override + public String getType() { + return TASK_TYPE; + } + + @Override + public TypedInterface getInterface() { + return typedInterface; + } + + @Override + public RetryStrategy getRetries() { + return retries; + } + }; + } + + private void assertTaskEquals(PluginTask actualTask, PluginTask expectedTask) { + assertThat(actualTask.getName(), equalTo(expectedTask.getName())); + assertThat(actualTask.getType(), equalTo(expectedTask.getType())); + assertThat(actualTask.getCustom(), equalTo(expectedTask.getCustom())); + assertThat(actualTask.getInterface(), equalTo(expectedTask.getInterface())); + assertThat(actualTask.getRetries(), equalTo(expectedTask.getRetries())); + } + + @AutoService(SdkPluginTask.class) + public static class TestTask extends SdkPluginTask { + + private static final long serialVersionUID = 2751205856616541247L; + + public TestTask() { + super(SdkTypes.nulls(), SdkTypes.nulls()); + } + + @Override + public String getType() { + return TASK_TYPE; + } + } + + @AutoService(SdkPluginTask.class) + public static class OtherTestTask extends SdkPluginTask { + + private static final long serialVersionUID = -7757282344498000982L; + + public OtherTestTask() { + super(SdkTypes.nulls(), SdkTypes.nulls()); + } + + @Override + public String getType() { + return TASK_TYPE; + } + + @Override + public int getRetries() { + return 1; + } + + @Override + public boolean isSyncPlugin() { + return true; + } + } +} diff --git a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java index fba0e4154..aec530da0 100644 --- a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java +++ b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java @@ -64,6 +64,8 @@ import org.flyte.api.v1.Node; import org.flyte.api.v1.PartialTaskIdentifier; import org.flyte.api.v1.PartialWorkflowIdentifier; +import org.flyte.api.v1.PluginTask; +import org.flyte.api.v1.PluginTaskRegistrar; import org.flyte.api.v1.Resources; import org.flyte.api.v1.Resources.ResourceName; import org.flyte.api.v1.RunnableTask; @@ -219,6 +221,10 @@ static ProjectClosure load( ClassLoaders.withClassLoader( packageClassLoader, () -> Registrars.loadAll(ContainerTaskRegistrar.class, env)); + Map pluginTasks = + ClassLoaders.withClassLoader( + packageClassLoader, () -> Registrars.loadAll(PluginTaskRegistrar.class, env)); + Map workflows = ClassLoaders.withClassLoader( packageClassLoader, () -> Registrars.loadAll(WorkflowTemplateRegistrar.class, env)); @@ -233,6 +239,7 @@ static ProjectClosure load( runnableTasks, dynamicWorkflowTasks, containerTasks, + pluginTasks, workflows, launchPlans); } @@ -243,10 +250,12 @@ static ProjectClosure load( Map runnableTasks, Map dynamicWorkflowTasks, Map containerTasks, + Map pluginTasks, Map workflowTemplates, Map launchPlans) { Map taskTemplates = - createTaskTemplates(config, runnableTasks, dynamicWorkflowTasks, containerTasks); + createTaskTemplates( + config, runnableTasks, dynamicWorkflowTasks, containerTasks, pluginTasks); // 2. rewrite workflows and launch plans Map rewrittenWorkflowTemplates = @@ -424,7 +433,8 @@ public static Map createTaskTemplates( ExecutionConfig config, Map runnableTasks, Map dynamicWorkflowTasks, - Map containerTasks) { + Map containerTasks, + Map pluginTasks) { Map taskTemplates = new HashMap<>(); runnableTasks.forEach( @@ -448,6 +458,13 @@ public static Map createTaskTemplates( taskTemplates.put(id, taskTemplate); }); + pluginTasks.forEach( + (id, task) -> { + TaskTemplate taskTemplate = createTaskTemplateForPluginTask(task); + + taskTemplates.put(id, taskTemplate); + }); + return taskTemplates; } @@ -473,7 +490,7 @@ static TaskTemplate createTaskTemplateForRunnableTask(RunnableTask task, String .resources(task.getResources()) .build(); - return createTaskTemplate(task, container); + return createTaskTemplateBuilder(task).container(container).build(); } @VisibleForTesting @@ -488,25 +505,30 @@ static TaskTemplate createTaskTemplateForContainerTask(ContainerTask task) { .resources(resources) .build(); - return createTaskTemplate(task, container); + return createTaskTemplateBuilder(task).container(container).build(); + } + + @VisibleForTesting + static TaskTemplate createTaskTemplateForPluginTask(PluginTask task) { + return createTaskTemplateBuilder(task).isSyncPlugin(task.isSyncPlugin()).build(); } - private static TaskTemplate createTaskTemplate(Task task, Container container) { + private static TaskTemplate.Builder createTaskTemplateBuilder(Task task) { TaskTemplate.Builder templateBuilder = TaskTemplate.builder() - .container(container) .interface_(task.getInterface()) .retries(task.getRetries()) .type(task.getType()) .custom(task.getCustom()) .discoverable(task.isCached()) - .cacheSerializable(task.isCacheSerializable()); + .cacheSerializable(task.isCacheSerializable()) + .isSyncPlugin(false); if (task.getCacheVersion() != null) { templateBuilder.discoveryVersion(task.getCacheVersion()); } - return templateBuilder.build(); + return templateBuilder; } private static Optional javaToolOptionsEnv(RunnableTask task) { @@ -559,6 +581,7 @@ private static TaskTemplate createTaskTemplateForDynamicWorkflow( // it or change this comment to explicitly say no cache for dynamic tasks .discoverable(false) .cacheSerializable(false) + .isSyncPlugin(false) .build(); } diff --git a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProtoUtil.java b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProtoUtil.java index 3821db033..d7647ab70 100644 --- a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProtoUtil.java +++ b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProtoUtil.java @@ -45,6 +45,9 @@ import flyteidl.core.Interface; import flyteidl.core.Literals; import flyteidl.core.Tasks; +import flyteidl.core.Tasks.PluginMetadata; +import flyteidl.core.Tasks.RuntimeMetadata; +import flyteidl.core.Tasks.RuntimeMetadata.RuntimeType; import flyteidl.core.Tasks.TaskMetadata; import flyteidl.core.Types; import flyteidl.core.Types.SchemaType.SchemaColumn.SchemaColumnType; @@ -327,10 +330,12 @@ static Tasks.TaskTemplate serialize(TaskTemplate taskTemplate) { private static TaskMetadata serializeTaskMetadata(TaskTemplate taskTemplate) { Tasks.RuntimeMetadata runtime = - Tasks.RuntimeMetadata.newBuilder() - .setType(Tasks.RuntimeMetadata.RuntimeType.FLYTE_SDK) + RuntimeMetadata.newBuilder() + .setType(RuntimeType.FLYTE_SDK) .setFlavor(RUNTIME_FLAVOR) .setVersion(RUNTIME_VERSION) + .setPluginMetadata( + PluginMetadata.newBuilder().setIsSyncPlugin(taskTemplate.isSyncPlugin()).build()) .build(); return TaskMetadata.newBuilder() @@ -354,6 +359,7 @@ static TaskTemplate deserialize(Tasks.TaskTemplate proto) { // Proto uses empty strings instead of null, we use null in TaskTemplate .discoveryVersion(emptyToNull(proto.getMetadata().getDiscoveryVersion())) .cacheSerializable(proto.getMetadata().getCacheSerializable()) + .isSyncPlugin(proto.getMetadata().getRuntime().getPluginMetadata().getIsSyncPlugin()) .build(); } diff --git a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/Fixtures.java b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/Fixtures.java index 37c58a958..fcff19737 100644 --- a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/Fixtures.java +++ b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/Fixtures.java @@ -54,6 +54,7 @@ final class Fixtures { .retries(RETRIES) .discoverable(false) .cacheSerializable(false) + .isSyncPlugin(false) .build(); private Fixtures() { diff --git a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/FlyteAdminClientTest.java b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/FlyteAdminClientTest.java index 93344b321..c5abfcfce 100644 --- a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/FlyteAdminClientTest.java +++ b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/FlyteAdminClientTest.java @@ -534,6 +534,10 @@ private TaskOuterClass.TaskSpec newTaskSpec() { .setType(Tasks.RuntimeMetadata.RuntimeType.FLYTE_SDK) .setFlavor(ProtoUtil.RUNTIME_FLAVOR) .setVersion(ProtoUtil.RUNTIME_VERSION) + .setPluginMetadata( + Tasks.PluginMetadata.newBuilder() + .setIsSyncPlugin(false) + .build()) .build()) .setRetries(Literals.RetryStrategy.newBuilder().setRetries(4).build()) .build()) diff --git a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProjectClosureTest.java b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProjectClosureTest.java index cd1f698e7..56d46bd7c 100644 --- a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProjectClosureTest.java +++ b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProjectClosureTest.java @@ -26,6 +26,7 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.nullValue; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -53,6 +54,7 @@ import org.flyte.api.v1.Operand; import org.flyte.api.v1.PartialTaskIdentifier; import org.flyte.api.v1.PartialWorkflowIdentifier; +import org.flyte.api.v1.PluginTask; import org.flyte.api.v1.Primitive; import org.flyte.api.v1.Resources; import org.flyte.api.v1.RetryStrategy; @@ -582,6 +584,28 @@ public void testCreateTaskTemplateForContainerTask() { assertThat(result.type(), equalTo("raw-container")); } + @Test + public void testCreateTaskTemplateForPluginTask() { + // given + PluginTask task = createPluginTask(); + + // when + TaskTemplate result = ProjectClosure.createTaskTemplateForPluginTask(task); + + // then + assertThat( + result.interface_(), + equalTo( + TypedInterface.builder() + .inputs(SdkTypes.nulls().getVariableMap()) + .outputs(SdkTypes.nulls().getVariableMap()) + .build())); + assertThat(result.custom(), equalTo(Struct.of(emptyMap()))); + assertThat(result.retries(), equalTo(RetryStrategy.builder().retries(0).build())); + assertThat(result.type(), equalTo("test-plugin-task")); + assertThat(result.isSyncPlugin(), is(true)); + } + @Test public void testCreateTaskTemplateForTasksWithDefaultCacheSettings() { // given @@ -829,6 +853,38 @@ public List getEnv() { }; } + private PluginTask createPluginTask() { + return new PluginTask() { + @Override + public boolean isSyncPlugin() { + return true; + } + + @Override + public String getName() { + return "foo"; + } + + @Override + public String getType() { + return "test-plugin-task"; + } + + @Override + public TypedInterface getInterface() { + return TypedInterface.builder() + .inputs(SdkTypes.nulls().getVariableMap()) + .outputs(SdkTypes.nulls().getVariableMap()) + .build(); + } + + @Override + public RetryStrategy getRetries() { + return RetryStrategy.builder().retries(0).build(); + } + }; + } + private T wrapTaskWithRetries(Class taskClass, T task, boolean cacheEnabled) { return Reflection.newProxy( taskClass, diff --git a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProtoReaderTest.java b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProtoReaderTest.java index c591ae5af..0349af703 100644 --- a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProtoReaderTest.java +++ b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProtoReaderTest.java @@ -89,6 +89,7 @@ void shouldReadTaskTemplate() throws IOException { .custom(Struct.of(emptyMap())) .discoverable(false) .cacheSerializable(false) + .isSyncPlugin(false) .build())); } diff --git a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProtoUtilTest.java b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProtoUtilTest.java index 9993b7e8b..5d98bb26b 100644 --- a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProtoUtilTest.java +++ b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProtoUtilTest.java @@ -45,6 +45,7 @@ import flyteidl.core.Interface; import flyteidl.core.Literals; import flyteidl.core.Tasks; +import flyteidl.core.Tasks.PluginMetadata; import flyteidl.core.Tasks.TaskMetadata; import flyteidl.core.Types; import flyteidl.core.Types.SchemaType.SchemaColumn.SchemaColumnType; @@ -439,6 +440,7 @@ void shouldSerDeTaskTemplate() { .discoverable(true) .discoveryVersion("0.0.1") .cacheSerializable(true) + .isSyncPlugin(false) .build(); Tasks.TaskTemplate templateProto = @@ -458,6 +460,8 @@ void shouldSerDeTaskTemplate() { .setType(Tasks.RuntimeMetadata.RuntimeType.FLYTE_SDK) .setFlavor(ProtoUtil.RUNTIME_FLAVOR) .setVersion(ProtoUtil.RUNTIME_VERSION) + .setPluginMetadata( + PluginMetadata.newBuilder().setIsSyncPlugin(false).build()) .build()) .setRetries(Literals.RetryStrategy.newBuilder().setRetries(4).build()) .setDiscoverable(true) @@ -522,6 +526,7 @@ void shouldSerializeTaskTemplateHandlingNullStringsAsEmptyString() { .discoverable(false) .cacheSerializable(false) .discoveryVersion(null) + .isSyncPlugin(false) .build(); Tasks.TaskTemplate protoTemplate = ProtoUtil.serialize(apiTemplate); @@ -548,6 +553,7 @@ void shouldSerializeTaskTemplatePropagatingNonNullStringAsIs() { .discoverable(true) .cacheSerializable(true) .discoveryVersion("1") + .isSyncPlugin(false) .build(); Tasks.TaskTemplate protoTemplate = ProtoUtil.serialize(apiTemplate); diff --git a/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java b/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java index c011611dc..b46c44895 100644 --- a/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java +++ b/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java @@ -44,6 +44,8 @@ import org.flyte.api.v1.Literal; import org.flyte.api.v1.NamedEntityIdentifier; import org.flyte.api.v1.Node; +import org.flyte.api.v1.PluginTask; +import org.flyte.api.v1.PluginTaskRegistrar; import org.flyte.api.v1.RunnableTask; import org.flyte.api.v1.RunnableTaskRegistrar; import org.flyte.api.v1.Struct; @@ -150,6 +152,10 @@ private void execute() { ClassLoaders.withClassLoader( packageClassLoader, () -> Registrars.loadAll(ContainerTaskRegistrar.class, env)); + Map pluginTasks = + ClassLoaders.withClassLoader( + packageClassLoader, () -> Registrars.loadAll(PluginTaskRegistrar.class, env)); + // before we run anything, switch class loader, otherwise, // ServiceLoaders and other things wouldn't work, for instance, // FileSystemRegister in Apache Beam @@ -162,7 +168,11 @@ private void execute() { Map taskTemplates = mapValues( ProjectClosure.createTaskTemplates( - executionConfig, runnableTasks, dynamicWorkflowTasks, containerTasks), + executionConfig, + runnableTasks, + dynamicWorkflowTasks, + containerTasks, + pluginTasks), template -> template.toBuilder() .custom(ProjectClosure.merge(template.custom(), custom))