Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recursively handle subworkflow in dynamic #260

Merged
merged 3 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,19 @@ public Output run(SdkWorkflowBuilder builder, Input input) {
SdkTypes.nulls(),
SdkTypes.nulls())
.withUpstreamNode(hello));
// subworkflow that contains another subworkflow
SdkNode<WelcomeWorkflow.Output> greet =
builder.apply(
"greet",
new SubWorkflow().withUpstreamNode(world),
WelcomeWorkflow.Input.create(SdkBindingDataFactory.of("greet")));
@Var SdkBindingData<Long> prev = SdkBindingDataFactory.of(0);
@Var SdkBindingData<Long> value = SdkBindingDataFactory.of(1);
for (int i = 2; i <= input.n().get(); i++) {
SdkBindingData<Long> next =
builder
.apply(
"fib-" + i, new SumTask().withUpstreamNode(world), SumInput.create(value, prev))
"fib-" + i, new SumTask().withUpstreamNode(greet), SumInput.create(value, prev))
.getOutputs();
prev = value;
value = next;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,16 @@ && checkCycles(subWorkflowId, allWorkflows, beingVisited, visited))) {
}

@VisibleForTesting
static Map<WorkflowIdentifier, WorkflowTemplate> collectSubWorkflows(
List<Node> nodes, Map<WorkflowIdentifier, WorkflowTemplate> allWorkflows) {
return collectSubWorkflows(nodes, allWorkflows, Function.identity());
}

public static Map<WorkflowIdentifier, WorkflowTemplate> collectSubWorkflows(
List<Node> rewrittenNodes, Map<WorkflowIdentifier, WorkflowTemplate> allWorkflows) {
List<Node> nodes,
Map<WorkflowIdentifier, WorkflowTemplate> allWorkflows,
Function<List<Node>, List<Node>> nodesRewriter) {
List<Node> rewrittenNodes = nodesRewriter.apply(nodes);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bug fix because recursively collecting subworkflows doesn't work under dynamic task execution context.

return collectSubWorkflowIds(rewrittenNodes).stream()
// all identifiers should be rewritten at this point
.map(
Expand All @@ -366,7 +374,7 @@ public static Map<WorkflowIdentifier, WorkflowTemplate> collectSubWorkflows(
}

Map<WorkflowIdentifier, WorkflowTemplate> nestedSubWorkflows =
collectSubWorkflows(subWorkflow.nodes(), allWorkflows);
collectSubWorkflows(subWorkflow.nodes(), allWorkflows, nodesRewriter);

return Stream.concat(
Stream.of(Maps.immutableEntry(workflowId, subWorkflow)),
Expand All @@ -376,10 +384,10 @@ public static Map<WorkflowIdentifier, WorkflowTemplate> collectSubWorkflows(
}

public static Map<TaskIdentifier, TaskTemplate> collectDynamicWorkflowTasks(
List<Node> rewrittenNodes,
List<Node> nodes,
Map<TaskIdentifier, TaskTemplate> allTasks,
Function<TaskIdentifier, TaskTemplate> remoteTaskTemplateFetcher) {
return collectTaskIds(rewrittenNodes).stream()
return collectTaskIds(nodes).stream()
// all identifiers should be rewritten at this point
.map(
taskId ->
Expand Down
132 changes: 101 additions & 31 deletions jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Maps;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ForkJoinPool;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.flyte.api.v1.Binding;
import org.flyte.api.v1.BindingData;
Expand Down Expand Up @@ -197,12 +199,12 @@ private void execute() {
}
}

static DynamicJobSpec rewrite(
private static DynamicJobSpec rewrite(
Config config,
ExecutionConfig executionConfig,
DynamicJobSpec spec,
Map<TaskIdentifier, TaskTemplate> taskTemplates,
Map<WorkflowIdentifier, WorkflowTemplate> workflowTemplates) {
Map<TaskIdentifier, TaskTemplate> allTaskTemplates,
Map<WorkflowIdentifier, WorkflowTemplate> allWorkflowTemplates) {

try (FlyteAdminClient flyteAdminClient =
FlyteAdminClient.create(config.platformUrl(), config.platformInsecure(), null)) {
Expand All @@ -215,58 +217,126 @@ static DynamicJobSpec rewrite(
.adminClient(flyteAdminClient)
.build()
.visitor();
Function<List<Node>, List<Node>> nodesRewriter =
nodes -> nodes.stream().map(workflowNodeVisitor::visitNode).collect(toUnmodifiableList());

List<Node> rewrittenNodes =
spec.nodes().stream().map(workflowNodeVisitor::visitNode).collect(toUnmodifiableList());

Map<WorkflowIdentifier, WorkflowTemplate> usedSubWorkflows =
ProjectClosure.collectSubWorkflows(rewrittenNodes, workflowTemplates);

Map<TaskIdentifier, TaskTemplate> usedTaskTemplates =
ProjectClosure.collectDynamicWorkflowTasks(
rewrittenNodes, taskTemplates, id -> fetchTaskTemplate(flyteAdminClient, id));

// FIXME one sub-workflow can use more sub-workflows, we should recursively collect used tasks
// and workflows
Map<WorkflowIdentifier, WorkflowTemplate> allUsedSubWorkflows =
collectAllUsedSubWorkflows(
spec.nodes(), allWorkflowTemplates, workflowNodeVisitor, nodesRewriter);

Map<WorkflowIdentifier, WorkflowTemplate> rewrittenUsedSubWorkflows =
mapValues(usedSubWorkflows, workflowNodeVisitor::visitWorkflowTemplate);
Map<TaskIdentifier, TaskTemplate> allUsedTaskTemplates = new HashMap<>();
List<Node> rewrittenNodes =
collectAllUsedTaskTemplates(
spec,
allTaskTemplates,
nodesRewriter,
allUsedTaskTemplates,
flyteAdminClient,
allUsedSubWorkflows);

return spec.toBuilder()
.nodes(rewrittenNodes)
.subWorkflows(
ImmutableMap.<WorkflowIdentifier, WorkflowTemplate>builder()
.putAll(spec.subWorkflows())
.putAll(rewrittenUsedSubWorkflows)
.putAll(allUsedSubWorkflows)
.build())
.tasks(
ImmutableMap.<TaskIdentifier, TaskTemplate>builder()
.putAll(spec.tasks())
.putAll(usedTaskTemplates)
.putAll(allUsedTaskTemplates)
.build())
.build();
}
}

private static List<Node> collectAllUsedTaskTemplates(
DynamicJobSpec spec,
Map<TaskIdentifier, TaskTemplate> allTaskTemplates,
Function<List<Node>, List<Node>> nodesRewriter,
Map<TaskIdentifier, TaskTemplate> allUsedTaskTemplates,
FlyteAdminClient flyteAdminClient,
Map<WorkflowIdentifier, WorkflowTemplate> allUsedSubWorkflows) {

Map<TaskIdentifier, TaskTemplate> cache = new HashMap<>();

// collect directly used task templates
List<Node> rewrittenNodes =
collectTaskTemplates(
spec.nodes(),
nodesRewriter,
allUsedTaskTemplates,
allTaskTemplates,
flyteAdminClient,
cache);

// collect task templates used by subworkflows
allUsedSubWorkflows
.values()
.forEach(
workflowTemplate ->
collectTaskTemplates(
workflowTemplate.nodes(),
nodesRewriter,
allUsedTaskTemplates,
allTaskTemplates,
flyteAdminClient,
cache));

return rewrittenNodes;
}

private static Map<WorkflowIdentifier, WorkflowTemplate> collectAllUsedSubWorkflows(
List<Node> nodes,
Map<WorkflowIdentifier, WorkflowTemplate> workflowTemplates,
WorkflowNodeVisitor workflowNodeVisitor,
Function<List<Node>, List<Node>> nodesRewriter) {

Map<WorkflowIdentifier, WorkflowTemplate> allUsedSubWorkflows =
ProjectClosure.collectSubWorkflows(nodes, workflowTemplates, nodesRewriter);
return mapValues(allUsedSubWorkflows, workflowNodeVisitor::visitWorkflowTemplate);
}

private static List<Node> collectTaskTemplates(
List<Node> nodes,
Function<List<Node>, List<Node>> nodesRewriter,
Map<TaskIdentifier, TaskTemplate> allUsedTaskTemplates,
Map<TaskIdentifier, TaskTemplate> allTaskTemplates,
FlyteAdminClient flyteAdminClient,
Map<TaskIdentifier, TaskTemplate> cache) {

List<Node> rewrittenNodes = nodesRewriter.apply(nodes);

Map<TaskIdentifier, TaskTemplate> usedTaskTemplates =
ProjectClosure.collectDynamicWorkflowTasks(
rewrittenNodes, allTaskTemplates, id -> fetchTaskTemplate(flyteAdminClient, id, cache));
allUsedTaskTemplates.putAll(usedTaskTemplates);

return rewrittenNodes;
}

// note that there are cases we are making an unnecessary network call because we might have
// already got the task template when resolving the latest task version, but since it is also
// possible that user has provided a version for a remote task, and in that case we would not need
// to resolve the latest version, so we need to make this call;
// we accept the additional cost because it should be rare to have remote tasks in a dynamic
// workflow
private static TaskTemplate fetchTaskTemplate(
FlyteAdminClient flyteAdminClient, TaskIdentifier id) {
LOG.info("fetching task template remotely for {}", id);

TaskTemplate taskTemplate =
flyteAdminClient.fetchLatestTaskTemplate(
NamedEntityIdentifier.builder()
.domain(id.domain())
.project(id.project())
.name(id.name())
.build());

return taskTemplate;
FlyteAdminClient flyteAdminClient,
TaskIdentifier id,
Map<TaskIdentifier, TaskTemplate> cache) {
return cache.computeIfAbsent(
id,
taskIdentifier -> {
LOG.info("fetching task template remotely for {}", id);

return flyteAdminClient.fetchLatestTaskTemplate(
NamedEntityIdentifier.builder()
.domain(id.domain())
.project(id.project())
.name(id.name())
.build());
});
}

private static DynamicWorkflowTask getDynamicWorkflowTask(String name) {
Expand Down
Loading