Skip to content

Commit

Permalink
Add workflow on failure policy support (#237)
Browse files Browse the repository at this point in the history
* Add workflow on failure policy support

Signed-off-by: Andres Gomez Ferrer <[email protected]>

* Fix unit test

Signed-off-by: Andres Gomez Ferrer <[email protected]>

* Add SdkWorkflowMetadata

Signed-off-by: Andres Gomez Ferrer <[email protected]>

* Configure FAIL_IMMEDIATELY default in the builder

Signed-off-by: Andres Gomez Ferrer <[email protected]>

---------

Signed-off-by: Andres Gomez Ferrer <[email protected]>
  • Loading branch information
andresgomezfrr authored Aug 10, 2023
1 parent 0c9ba51 commit 62ef63a
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 12 deletions.
41 changes: 41 additions & 0 deletions flytekit-api/src/main/java/org/flyte/api/v1/OnFailurePolicy.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright 2021 Flyte Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.flyte.api.v1;

import com.google.auto.value.AutoValue;

/** Failure Handling Strategy. */
@AutoValue
public abstract class OnFailurePolicy {
public enum Kind {
FAIL_IMMEDIATELY,
FAIL_AFTER_EXECUTABLE_NODES_COMPLETE
}

public abstract Kind getKind();

public static OnFailurePolicy.Builder builder() {
return new AutoValue_OnFailurePolicy.Builder();
}

@AutoValue.Builder
public abstract static class Builder {
public abstract Builder kind(Kind kind);

public abstract OnFailurePolicy build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,24 @@
package org.flyte.api.v1;

import com.google.auto.value.AutoValue;
import org.flyte.api.v1.OnFailurePolicy.Kind;

/** Metadata for the entire workflow. */
@AutoValue
public class WorkflowMetadata {
public abstract class WorkflowMetadata {

public abstract OnFailurePolicy onFailure();

public static Builder builder() {
return new AutoValue_WorkflowMetadata.Builder();
return new AutoValue_WorkflowMetadata.Builder()
.onFailure(OnFailurePolicy.builder().kind(Kind.FAIL_IMMEDIATELY).build());
}

@AutoValue.Builder
public abstract static class Builder {

public abstract Builder onFailure(OnFailurePolicy onFailure);

public abstract WorkflowMetadata build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import java.util.Map;
import java.util.Objects;
import javax.annotation.Nullable;
import org.flyte.api.v1.OnFailurePolicy;
import org.flyte.api.v1.OnFailurePolicy.Kind;
import org.flyte.api.v1.WorkflowTemplate;

/** Builder used during {@link SdkWorkflow#expand(SdkWorkflowBuilder)}. */
Expand All @@ -37,6 +39,11 @@ public class SdkWorkflowBuilder {
private final Map<String, String> outputDescriptions;
private final SdkNodeNamePolicy sdkNodeNamePolicy;

private SdkWorkflowMetadata workflowMetadata =
SdkWorkflowMetadata.builder()
.onFailure(OnFailurePolicy.builder().kind(Kind.FAIL_IMMEDIATELY).build())
.build();

/** Creates a new builder. */
public SdkWorkflowBuilder() {
this(new SdkNodeNamePolicy());
Expand All @@ -54,6 +61,15 @@ public SdkWorkflowBuilder() {

this.sdkNodeNamePolicy = sdkNodeNamePolicy;
}

public void setWorkflowMetadata(SdkWorkflowMetadata workflowMetadata) {
this.workflowMetadata = workflowMetadata;
}

public SdkWorkflowMetadata getWorkflowMetadata() {
return this.workflowMetadata;
}

/**
* Applies the given transformation and returns a new node with a given node id.
*
Expand Down Expand Up @@ -99,7 +115,7 @@ public <OutputT> SdkNode<OutputT> applyWithInputMap(
* @return the new {@link SdkNode}
*/
public <OutputT> SdkNode<OutputT> apply(SdkTransform<Void, OutputT> transformWithoutInputs) {
return apply(/*nodeId=*/ (String) null, transformWithoutInputs);
return apply(/* nodeId= */ (String) null, transformWithoutInputs);
}

/**
Expand All @@ -112,7 +128,7 @@ public <OutputT> SdkNode<OutputT> apply(SdkTransform<Void, OutputT> transformWit
*/
public <InputT, OutputT> SdkNode<OutputT> apply(
SdkTransform<InputT, OutputT> transform, InputT inputs) {
return apply(/*nodeId=*/ null, transform, inputs);
return apply(/* nodeId= */ null, transform, inputs);
}

/**
Expand All @@ -125,7 +141,7 @@ public <InputT, OutputT> SdkNode<OutputT> apply(
*/
public <OutputT> SdkNode<OutputT> applyWithInputMap(
SdkTransform<?, OutputT> transform, Map<String, SdkBindingData<?>> inputs) {
return applyWithInputMap(/*nodeId=*/ null, transform, inputs);
return applyWithInputMap(/* nodeId= */ null, transform, inputs);
}

protected <InputT, OutputT> SdkNode<OutputT> applyInternal(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright 2021 Flyte Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.flyte.flytekit;

import com.google.auto.value.AutoValue;
import org.flyte.api.v1.OnFailurePolicy;

/** Metadata for the entire workflow. */
@AutoValue
public abstract class SdkWorkflowMetadata {

public abstract OnFailurePolicy onFailure();

public static Builder builder() {
return new AutoValue_SdkWorkflowMetadata.Builder()
.onFailure(OnFailurePolicy.builder().kind(OnFailurePolicy.Kind.FAIL_IMMEDIATELY).build());
}

@AutoValue.Builder
public abstract static class Builder {

public abstract Builder onFailure(OnFailurePolicy onFailure);

public abstract SdkWorkflowMetadata build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,14 @@
class WorkflowTemplateIdl {

static WorkflowTemplate ofBuilder(SdkWorkflowBuilder builder) {
WorkflowMetadata metadata = WorkflowMetadata.builder().build();

List<Node> nodes =
builder.getNodes().values().stream().map(SdkNode::toIdl).collect(toUnmodifiableList());

List<Binding> outputs = getOutputBindings(builder);

return WorkflowTemplate.builder()
.metadata(metadata)
.metadata(
WorkflowMetadata.builder().onFailure(builder.getWorkflowMetadata().onFailure()).build())
.interface_(
TypedInterface.builder()
.inputs(getInputVariableMap(builder))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
*/
package org.flyte.flytekitscala

import org.flyte.api.v1.WorkflowTemplate
import org.flyte.api.v1.{WorkflowMetadata, WorkflowTemplate}
import org.flyte.flytekit.{
SdkBindingData => SdkJavaBindingData,
SdkNode,
SdkTransform,
SdkType,
SdkWorkflow,
SdkWorkflowBuilder
SdkWorkflowBuilder,
SdkWorkflowMetadata,
SdkBindingData => SdkJavaBindingData
}

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -66,6 +67,10 @@ abstract class SdkScalaWorkflow[InputT, OutputT](

class SdkScalaWorkflowBuilder(builder: SdkWorkflowBuilder) {

def setWorkflowMetadata(workflowMetadata: SdkWorkflowMetadata): Unit =
builder.setWorkflowMetadata(workflowMetadata)
def getWorkflowMetadata(): SdkWorkflowMetadata = builder.getWorkflowMetadata

/** Get the nodes applied on the DAG.
* @return
* The workflows node by name.
Expand Down
5 changes: 4 additions & 1 deletion jflyte/src/main/java/org/flyte/jflyte/ProtoUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import flyteidl.core.Types;
import flyteidl.core.Types.SchemaType.SchemaColumn.SchemaColumnType;
import flyteidl.core.Workflow;
import flyteidl.core.Workflow.WorkflowMetadata.OnFailurePolicy;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.time.Duration;
Expand Down Expand Up @@ -741,7 +742,9 @@ static Workflow.WorkflowTemplate serialize(WorkflowTemplate template) {

private static Workflow.WorkflowMetadata serialize(
@SuppressWarnings("UnusedVariable") WorkflowMetadata metadata) {
return Workflow.WorkflowMetadata.newBuilder().build();
return Workflow.WorkflowMetadata.newBuilder()
.setOnFailure(OnFailurePolicy.valueOf(metadata.onFailure().getKind().name()))
.build();
}

@VisibleForTesting
Expand Down

0 comments on commit 62ef63a

Please sign in to comment.