diff --git a/flytekit-examples/src/main/java/org/flyte/examples/DynamicFibonacciWorkflowTask.java b/flytekit-examples/src/main/java/org/flyte/examples/DynamicFibonacciWorkflowTask.java index abc2bb121..f1d57ebeb 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/DynamicFibonacciWorkflowTask.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/DynamicFibonacciWorkflowTask.java @@ -87,13 +87,19 @@ public Output run(SdkWorkflowBuilder builder, Input input) { SdkTypes.nulls(), SdkTypes.nulls()) .withUpstreamNode(hello)); + // subworkflow that contains another subworkflow + SdkNode greet = + builder.apply( + "greet", + new SubWorkflow().withUpstreamNode(world), + WelcomeWorkflow.Input.create(SdkBindingDataFactory.of("greet"))); @Var SdkBindingData prev = SdkBindingDataFactory.of(0); @Var SdkBindingData value = SdkBindingDataFactory.of(1); for (int i = 2; i <= input.n().get(); i++) { SdkBindingData next = builder .apply( - "fib-" + i, new SumTask().withUpstreamNode(world), SumInput.create(value, prev)) + "fib-" + i, new SumTask().withUpstreamNode(greet), SumInput.create(value, prev)) .getOutputs(); prev = value; value = next; diff --git a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java index 7ef4b9de8..fba0e4154 100644 --- a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java +++ b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java @@ -343,8 +343,16 @@ && checkCycles(subWorkflowId, allWorkflows, beingVisited, visited))) { } @VisibleForTesting + static Map collectSubWorkflows( + List nodes, Map allWorkflows) { + return collectSubWorkflows(nodes, allWorkflows, Function.identity()); + } + public static Map collectSubWorkflows( - List rewrittenNodes, Map allWorkflows) { + List nodes, + Map allWorkflows, + Function, List> nodesRewriter) { + List rewrittenNodes = nodesRewriter.apply(nodes); return collectSubWorkflowIds(rewrittenNodes).stream() // all identifiers should be rewritten at this point .map( @@ -366,7 +374,7 @@ public static Map collectSubWorkflows( } Map nestedSubWorkflows = - collectSubWorkflows(subWorkflow.nodes(), allWorkflows); + collectSubWorkflows(subWorkflow.nodes(), allWorkflows, nodesRewriter); return Stream.concat( Stream.of(Maps.immutableEntry(workflowId, subWorkflow)), @@ -376,10 +384,10 @@ public static Map collectSubWorkflows( } public static Map collectDynamicWorkflowTasks( - List rewrittenNodes, + List nodes, Map allTasks, Function remoteTaskTemplateFetcher) { - return collectTaskIds(rewrittenNodes).stream() + return collectTaskIds(nodes).stream() // all identifiers should be rewritten at this point .map( taskId -> diff --git a/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java b/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java index c5de2cc88..c011611dc 100644 --- a/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java +++ b/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java @@ -25,11 +25,13 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; import java.util.Collection; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; import java.util.concurrent.ForkJoinPool; +import java.util.function.Function; import java.util.stream.Collectors; import org.flyte.api.v1.Binding; import org.flyte.api.v1.BindingData; @@ -197,12 +199,12 @@ private void execute() { } } - static DynamicJobSpec rewrite( + private static DynamicJobSpec rewrite( Config config, ExecutionConfig executionConfig, DynamicJobSpec spec, - Map taskTemplates, - Map workflowTemplates) { + Map allTaskTemplates, + Map allWorkflowTemplates) { try (FlyteAdminClient flyteAdminClient = FlyteAdminClient.create(config.platformUrl(), config.platformInsecure(), null)) { @@ -215,39 +217,104 @@ static DynamicJobSpec rewrite( .adminClient(flyteAdminClient) .build() .visitor(); + Function, List> nodesRewriter = + nodes -> nodes.stream().map(workflowNodeVisitor::visitNode).collect(toUnmodifiableList()); - List rewrittenNodes = - spec.nodes().stream().map(workflowNodeVisitor::visitNode).collect(toUnmodifiableList()); - - Map usedSubWorkflows = - ProjectClosure.collectSubWorkflows(rewrittenNodes, workflowTemplates); - - Map usedTaskTemplates = - ProjectClosure.collectDynamicWorkflowTasks( - rewrittenNodes, taskTemplates, id -> fetchTaskTemplate(flyteAdminClient, id)); - - // FIXME one sub-workflow can use more sub-workflows, we should recursively collect used tasks - // and workflows + Map allUsedSubWorkflows = + collectAllUsedSubWorkflows( + spec.nodes(), allWorkflowTemplates, workflowNodeVisitor, nodesRewriter); - Map rewrittenUsedSubWorkflows = - mapValues(usedSubWorkflows, workflowNodeVisitor::visitWorkflowTemplate); + Map allUsedTaskTemplates = new HashMap<>(); + List rewrittenNodes = + collectAllUsedTaskTemplates( + spec, + allTaskTemplates, + nodesRewriter, + allUsedTaskTemplates, + flyteAdminClient, + allUsedSubWorkflows); return spec.toBuilder() .nodes(rewrittenNodes) .subWorkflows( ImmutableMap.builder() .putAll(spec.subWorkflows()) - .putAll(rewrittenUsedSubWorkflows) + .putAll(allUsedSubWorkflows) .build()) .tasks( ImmutableMap.builder() .putAll(spec.tasks()) - .putAll(usedTaskTemplates) + .putAll(allUsedTaskTemplates) .build()) .build(); } } + private static List collectAllUsedTaskTemplates( + DynamicJobSpec spec, + Map allTaskTemplates, + Function, List> nodesRewriter, + Map allUsedTaskTemplates, + FlyteAdminClient flyteAdminClient, + Map allUsedSubWorkflows) { + + Map cache = new HashMap<>(); + + // collect directly used task templates + List rewrittenNodes = + collectTaskTemplates( + spec.nodes(), + nodesRewriter, + allUsedTaskTemplates, + allTaskTemplates, + flyteAdminClient, + cache); + + // collect task templates used by subworkflows + allUsedSubWorkflows + .values() + .forEach( + workflowTemplate -> + collectTaskTemplates( + workflowTemplate.nodes(), + nodesRewriter, + allUsedTaskTemplates, + allTaskTemplates, + flyteAdminClient, + cache)); + + return rewrittenNodes; + } + + private static Map collectAllUsedSubWorkflows( + List nodes, + Map workflowTemplates, + WorkflowNodeVisitor workflowNodeVisitor, + Function, List> nodesRewriter) { + + Map allUsedSubWorkflows = + ProjectClosure.collectSubWorkflows(nodes, workflowTemplates, nodesRewriter); + return mapValues(allUsedSubWorkflows, workflowNodeVisitor::visitWorkflowTemplate); + } + + private static List collectTaskTemplates( + List nodes, + Function, List> nodesRewriter, + Map allUsedTaskTemplates, + Map allTaskTemplates, + FlyteAdminClient flyteAdminClient, + Map cache) { + + List rewrittenNodes = nodesRewriter.apply(nodes); + + Map usedTaskTemplates = + ProjectClosure.collectDynamicWorkflowTasks( + rewrittenNodes, allTaskTemplates, id -> fetchTaskTemplate(flyteAdminClient, id, cache)); + allUsedTaskTemplates.putAll(usedTaskTemplates); + + return rewrittenNodes; + } + // note that there are cases we are making an unnecessary network call because we might have // already got the task template when resolving the latest task version, but since it is also // possible that user has provided a version for a remote task, and in that case we would not need @@ -255,18 +322,21 @@ static DynamicJobSpec rewrite( // we accept the additional cost because it should be rare to have remote tasks in a dynamic // workflow private static TaskTemplate fetchTaskTemplate( - FlyteAdminClient flyteAdminClient, TaskIdentifier id) { - LOG.info("fetching task template remotely for {}", id); - - TaskTemplate taskTemplate = - flyteAdminClient.fetchLatestTaskTemplate( - NamedEntityIdentifier.builder() - .domain(id.domain()) - .project(id.project()) - .name(id.name()) - .build()); - - return taskTemplate; + FlyteAdminClient flyteAdminClient, + TaskIdentifier id, + Map cache) { + return cache.computeIfAbsent( + id, + taskIdentifier -> { + LOG.info("fetching task template remotely for {}", id); + + return flyteAdminClient.fetchLatestTaskTemplate( + NamedEntityIdentifier.builder() + .domain(id.domain()) + .project(id.project()) + .name(id.name()) + .build()); + }); } private static DynamicWorkflowTask getDynamicWorkflowTask(String name) {