From 2d1d14e4bebf7d48f730e8bec473993f5d0fa516 Mon Sep 17 00:00:00 2001 From: Hongxin Liang Date: Mon, 9 Oct 2023 14:13:54 +0200 Subject: [PATCH 1/3] Recursively handle subworkflow in dynamic Signed-off-by: Hongxin Liang --- .../flyte/jflyte/ExecuteDynamicWorkflow.java | 107 +++++++++++++----- 1 file changed, 76 insertions(+), 31 deletions(-) diff --git a/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java b/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java index c5de2cc88..5b72b18f5 100644 --- a/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java +++ b/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java @@ -25,6 +25,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; import java.util.Collection; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.Callable; @@ -197,7 +198,7 @@ private void execute() { } } - static DynamicJobSpec rewrite( + private static DynamicJobSpec rewrite( Config config, ExecutionConfig executionConfig, DynamicJobSpec spec, @@ -216,38 +217,79 @@ static DynamicJobSpec rewrite( .build() .visitor(); - List rewrittenNodes = - spec.nodes().stream().map(workflowNodeVisitor::visitNode).collect(toUnmodifiableList()); - - Map usedSubWorkflows = - ProjectClosure.collectSubWorkflows(rewrittenNodes, workflowTemplates); - - Map usedTaskTemplates = - ProjectClosure.collectDynamicWorkflowTasks( - rewrittenNodes, taskTemplates, id -> fetchTaskTemplate(flyteAdminClient, id)); - - // FIXME one sub-workflow can use more sub-workflows, we should recursively collect used tasks - // and workflows - - Map rewrittenUsedSubWorkflows = - mapValues(usedSubWorkflows, workflowNodeVisitor::visitWorkflowTemplate); + Map allUsedTaskTemplates = new HashMap<>(); + Map allUsedSubWorkflows = new HashMap<>(); + Map cache = new HashMap<>(); + + List nodes = + recursivelyCollect( + spec.nodes(), + allUsedTaskTemplates, + allUsedSubWorkflows, + taskTemplates, + workflowTemplates, + workflowNodeVisitor, + flyteAdminClient, + cache); return spec.toBuilder() - .nodes(rewrittenNodes) + .nodes(nodes) .subWorkflows( ImmutableMap.builder() .putAll(spec.subWorkflows()) - .putAll(rewrittenUsedSubWorkflows) + .putAll(allUsedSubWorkflows) .build()) .tasks( ImmutableMap.builder() .putAll(spec.tasks()) - .putAll(usedTaskTemplates) + .putAll(allUsedTaskTemplates) .build()) .build(); } } + private static List recursivelyCollect( + List startPoint, + Map allUsedTaskTemplates, + Map allUsedSubWorkflows, + Map allTaskTemplates, + Map allWorkflowTemplates, + WorkflowNodeVisitor workflowNodeVisitor, + FlyteAdminClient flyteAdminClient, + Map cache) { + + List rewrittenNodes = + startPoint.stream().map(workflowNodeVisitor::visitNode).collect(toUnmodifiableList()); + + Map usedTaskTemplates = + ProjectClosure.collectDynamicWorkflowTasks( + rewrittenNodes, allTaskTemplates, id -> fetchTaskTemplate(flyteAdminClient, id, cache)); + allUsedTaskTemplates.putAll(usedTaskTemplates); + + Map usedSubWorkflows = + ProjectClosure.collectSubWorkflows(rewrittenNodes, allWorkflowTemplates); + Map rewrittenUsedSubWorkflows = + mapValues(usedSubWorkflows, workflowNodeVisitor::visitWorkflowTemplate); + + rewrittenUsedSubWorkflows.forEach( + (key, value) -> { + if (!allUsedSubWorkflows.containsKey(key)) { + allUsedSubWorkflows.put(key, value); + recursivelyCollect( + value.nodes(), + allUsedTaskTemplates, + allUsedSubWorkflows, + allTaskTemplates, + allWorkflowTemplates, + workflowNodeVisitor, + flyteAdminClient, + cache); + } + }); + + return rewrittenNodes; + } + // note that there are cases we are making an unnecessary network call because we might have // already got the task template when resolving the latest task version, but since it is also // possible that user has provided a version for a remote task, and in that case we would not need @@ -255,18 +297,21 @@ static DynamicJobSpec rewrite( // we accept the additional cost because it should be rare to have remote tasks in a dynamic // workflow private static TaskTemplate fetchTaskTemplate( - FlyteAdminClient flyteAdminClient, TaskIdentifier id) { - LOG.info("fetching task template remotely for {}", id); - - TaskTemplate taskTemplate = - flyteAdminClient.fetchLatestTaskTemplate( - NamedEntityIdentifier.builder() - .domain(id.domain()) - .project(id.project()) - .name(id.name()) - .build()); - - return taskTemplate; + FlyteAdminClient flyteAdminClient, + TaskIdentifier id, + Map cache) { + return cache.computeIfAbsent( + id, + taskIdentifier -> { + LOG.info("fetching task template remotely for {}", id); + + return flyteAdminClient.fetchLatestTaskTemplate( + NamedEntityIdentifier.builder() + .domain(id.domain()) + .project(id.project()) + .name(id.name()) + .build()); + }); } private static DynamicWorkflowTask getDynamicWorkflowTask(String name) { From 27be6d49311843980ded9224b6792b1565f9d26a Mon Sep 17 00:00:00 2001 From: Hongxin Liang Date: Mon, 9 Oct 2023 14:40:26 +0200 Subject: [PATCH 2/3] Sub in Sub IT Signed-off-by: Hongxin Liang --- .../org/flyte/examples/DynamicFibonacciWorkflowTask.java | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/flytekit-examples/src/main/java/org/flyte/examples/DynamicFibonacciWorkflowTask.java b/flytekit-examples/src/main/java/org/flyte/examples/DynamicFibonacciWorkflowTask.java index abc2bb121..f1d57ebeb 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/DynamicFibonacciWorkflowTask.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/DynamicFibonacciWorkflowTask.java @@ -87,13 +87,19 @@ public Output run(SdkWorkflowBuilder builder, Input input) { SdkTypes.nulls(), SdkTypes.nulls()) .withUpstreamNode(hello)); + // subworkflow that contains another subworkflow + SdkNode greet = + builder.apply( + "greet", + new SubWorkflow().withUpstreamNode(world), + WelcomeWorkflow.Input.create(SdkBindingDataFactory.of("greet"))); @Var SdkBindingData prev = SdkBindingDataFactory.of(0); @Var SdkBindingData value = SdkBindingDataFactory.of(1); for (int i = 2; i <= input.n().get(); i++) { SdkBindingData next = builder .apply( - "fib-" + i, new SumTask().withUpstreamNode(world), SumInput.create(value, prev)) + "fib-" + i, new SumTask().withUpstreamNode(greet), SumInput.create(value, prev)) .getOutputs(); prev = value; value = next; From 158ed81a737da8b993f92f83d7687d2191f50c07 Mon Sep 17 00:00:00 2001 From: Hongxin Liang Date: Mon, 9 Oct 2023 16:05:29 +0200 Subject: [PATCH 3/3] Fix subworkflow collecting Signed-off-by: Hongxin Liang --- .../flyte/jflyte/utils/ProjectClosure.java | 16 ++- .../flyte/jflyte/ExecuteDynamicWorkflow.java | 109 +++++++++++------- 2 files changed, 79 insertions(+), 46 deletions(-) diff --git a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java index 7ef4b9de8..fba0e4154 100644 --- a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java +++ b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java @@ -343,8 +343,16 @@ && checkCycles(subWorkflowId, allWorkflows, beingVisited, visited))) { } @VisibleForTesting + static Map collectSubWorkflows( + List nodes, Map allWorkflows) { + return collectSubWorkflows(nodes, allWorkflows, Function.identity()); + } + public static Map collectSubWorkflows( - List rewrittenNodes, Map allWorkflows) { + List nodes, + Map allWorkflows, + Function, List> nodesRewriter) { + List rewrittenNodes = nodesRewriter.apply(nodes); return collectSubWorkflowIds(rewrittenNodes).stream() // all identifiers should be rewritten at this point .map( @@ -366,7 +374,7 @@ public static Map collectSubWorkflows( } Map nestedSubWorkflows = - collectSubWorkflows(subWorkflow.nodes(), allWorkflows); + collectSubWorkflows(subWorkflow.nodes(), allWorkflows, nodesRewriter); return Stream.concat( Stream.of(Maps.immutableEntry(workflowId, subWorkflow)), @@ -376,10 +384,10 @@ public static Map collectSubWorkflows( } public static Map collectDynamicWorkflowTasks( - List rewrittenNodes, + List nodes, Map allTasks, Function remoteTaskTemplateFetcher) { - return collectTaskIds(rewrittenNodes).stream() + return collectTaskIds(nodes).stream() // all identifiers should be rewritten at this point .map( taskId -> diff --git a/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java b/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java index 5b72b18f5..c011611dc 100644 --- a/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java +++ b/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java @@ -31,6 +31,7 @@ import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; import java.util.concurrent.ForkJoinPool; +import java.util.function.Function; import java.util.stream.Collectors; import org.flyte.api.v1.Binding; import org.flyte.api.v1.BindingData; @@ -202,8 +203,8 @@ private static DynamicJobSpec rewrite( Config config, ExecutionConfig executionConfig, DynamicJobSpec spec, - Map taskTemplates, - Map workflowTemplates) { + Map allTaskTemplates, + Map allWorkflowTemplates) { try (FlyteAdminClient flyteAdminClient = FlyteAdminClient.create(config.platformUrl(), config.platformInsecure(), null)) { @@ -216,24 +217,25 @@ private static DynamicJobSpec rewrite( .adminClient(flyteAdminClient) .build() .visitor(); + Function, List> nodesRewriter = + nodes -> nodes.stream().map(workflowNodeVisitor::visitNode).collect(toUnmodifiableList()); - Map allUsedTaskTemplates = new HashMap<>(); - Map allUsedSubWorkflows = new HashMap<>(); - Map cache = new HashMap<>(); + Map allUsedSubWorkflows = + collectAllUsedSubWorkflows( + spec.nodes(), allWorkflowTemplates, workflowNodeVisitor, nodesRewriter); - List nodes = - recursivelyCollect( - spec.nodes(), + Map allUsedTaskTemplates = new HashMap<>(); + List rewrittenNodes = + collectAllUsedTaskTemplates( + spec, + allTaskTemplates, + nodesRewriter, allUsedTaskTemplates, - allUsedSubWorkflows, - taskTemplates, - workflowTemplates, - workflowNodeVisitor, flyteAdminClient, - cache); + allUsedSubWorkflows); return spec.toBuilder() - .nodes(nodes) + .nodes(rewrittenNodes) .subWorkflows( ImmutableMap.builder() .putAll(spec.subWorkflows()) @@ -248,45 +250,68 @@ private static DynamicJobSpec rewrite( } } - private static List recursivelyCollect( - List startPoint, - Map allUsedTaskTemplates, - Map allUsedSubWorkflows, + private static List collectAllUsedTaskTemplates( + DynamicJobSpec spec, Map allTaskTemplates, - Map allWorkflowTemplates, + Function, List> nodesRewriter, + Map allUsedTaskTemplates, + FlyteAdminClient flyteAdminClient, + Map allUsedSubWorkflows) { + + Map cache = new HashMap<>(); + + // collect directly used task templates + List rewrittenNodes = + collectTaskTemplates( + spec.nodes(), + nodesRewriter, + allUsedTaskTemplates, + allTaskTemplates, + flyteAdminClient, + cache); + + // collect task templates used by subworkflows + allUsedSubWorkflows + .values() + .forEach( + workflowTemplate -> + collectTaskTemplates( + workflowTemplate.nodes(), + nodesRewriter, + allUsedTaskTemplates, + allTaskTemplates, + flyteAdminClient, + cache)); + + return rewrittenNodes; + } + + private static Map collectAllUsedSubWorkflows( + List nodes, + Map workflowTemplates, WorkflowNodeVisitor workflowNodeVisitor, + Function, List> nodesRewriter) { + + Map allUsedSubWorkflows = + ProjectClosure.collectSubWorkflows(nodes, workflowTemplates, nodesRewriter); + return mapValues(allUsedSubWorkflows, workflowNodeVisitor::visitWorkflowTemplate); + } + + private static List collectTaskTemplates( + List nodes, + Function, List> nodesRewriter, + Map allUsedTaskTemplates, + Map allTaskTemplates, FlyteAdminClient flyteAdminClient, Map cache) { - List rewrittenNodes = - startPoint.stream().map(workflowNodeVisitor::visitNode).collect(toUnmodifiableList()); + List rewrittenNodes = nodesRewriter.apply(nodes); Map usedTaskTemplates = ProjectClosure.collectDynamicWorkflowTasks( rewrittenNodes, allTaskTemplates, id -> fetchTaskTemplate(flyteAdminClient, id, cache)); allUsedTaskTemplates.putAll(usedTaskTemplates); - Map usedSubWorkflows = - ProjectClosure.collectSubWorkflows(rewrittenNodes, allWorkflowTemplates); - Map rewrittenUsedSubWorkflows = - mapValues(usedSubWorkflows, workflowNodeVisitor::visitWorkflowTemplate); - - rewrittenUsedSubWorkflows.forEach( - (key, value) -> { - if (!allUsedSubWorkflows.containsKey(key)) { - allUsedSubWorkflows.put(key, value); - recursivelyCollect( - value.nodes(), - allUsedTaskTemplates, - allUsedSubWorkflows, - allTaskTemplates, - allWorkflowTemplates, - workflowNodeVisitor, - flyteAdminClient, - cache); - } - }); - return rewrittenNodes; }