diff --git a/flytekit-testing/src/main/java/org/flyte/flytekit/testing/SdkTestingExecutor.java b/flytekit-testing/src/main/java/org/flyte/flytekit/testing/SdkTestingExecutor.java index 17e5f8be2..981949152 100644 --- a/flytekit-testing/src/main/java/org/flyte/flytekit/testing/SdkTestingExecutor.java +++ b/flytekit-testing/src/main/java/org/flyte/flytekit/testing/SdkTestingExecutor.java @@ -401,15 +401,17 @@ public SdkTestingExecutor withWorkflowOutput( // fixed tasks TestingRunnableTask fixedTask = - getFixedTaskOrDefault(workflow.getName(), inputType, outputType); + getFixedTaskOrDefault(workflow.getName(), inputType, outputType) + .withFixedOutput(input, output); // replace workflow SdkWorkflow mockWorkflow = - new TestingWorkflow<>(inputType, outputType, output); + new TestingWorkflow<>(inputType, outputType, fixedTask.fixedOutputs); return toBuilder() .putWorkflowTemplate(workflow.getName(), mockWorkflow.toIdlTemplate()) - .putFixedTask(workflow.getName(), fixedTask.withFixedOutput(input, output)) + .putFixedTask(workflow.getName(), fixedTask) + .putFixedTask(TestingWorkflow.TestingSdkRunnableTask.class.getName(), fixedTask) .build(); } diff --git a/flytekit-testing/src/main/java/org/flyte/flytekit/testing/TestingWorkflow.java b/flytekit-testing/src/main/java/org/flyte/flytekit/testing/TestingWorkflow.java index 0339fb180..f5b2bd5fc 100644 --- a/flytekit-testing/src/main/java/org/flyte/flytekit/testing/TestingWorkflow.java +++ b/flytekit-testing/src/main/java/org/flyte/flytekit/testing/TestingWorkflow.java @@ -16,6 +16,8 @@ */ package org.flyte.flytekit.testing; +import java.util.Map; +import org.flyte.flytekit.SdkRunnableTask; import org.flyte.flytekit.SdkType; import org.flyte.flytekit.SdkWorkflow; import org.flyte.flytekit.SdkWorkflowBuilder; @@ -23,15 +25,36 @@ /** {@link SdkWorkflow} that can fix output for specific input. */ class TestingWorkflow extends SdkWorkflow { - private final OutputT output; + private final Map outputs; - TestingWorkflow(SdkType inputType, SdkType outputType, OutputT output) { + TestingWorkflow( + SdkType inputType, SdkType outputType, Map outputs) { super(inputType, outputType); - this.output = output; + this.outputs = outputs; } @Override public OutputT expand(SdkWorkflowBuilder builder, InputT input) { - return output; + return builder + .apply(new TestingSdkRunnableTask<>(getInputType(), getOutputType(), outputs), input) + .getOutputs(); + } + + public static class TestingSdkRunnableTask + extends SdkRunnableTask { + private static final long serialVersionUID = 6106269076155338045L; + + private final Map outputs; + + public TestingSdkRunnableTask( + SdkType inputType, SdkType outputType, Map outputs) { + super(inputType, outputType); + this.outputs = outputs; + } + + @Override + public OutputT run(InputT input) { + return outputs.get(input); + } } } diff --git a/flytekit-testing/src/test/java/org/flyte/flytekit/testing/MockSubWorkflowsTest.java b/flytekit-testing/src/test/java/org/flyte/flytekit/testing/MockSubWorkflowsTest.java new file mode 100644 index 000000000..6e96ea2a7 --- /dev/null +++ b/flytekit-testing/src/test/java/org/flyte/flytekit/testing/MockSubWorkflowsTest.java @@ -0,0 +1,107 @@ +/* + * Copyright 2024 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.flytekit.testing; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; + +import com.google.auto.value.AutoValue; +import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; +import org.flyte.flytekit.SdkWorkflow; +import org.flyte.flytekit.SdkWorkflowBuilder; +import org.flyte.flytekit.jackson.JacksonSdkType; +import org.junit.jupiter.api.Test; + +public class MockSubWorkflowsTest { + @Test + public void test() { + SdkTestingExecutor.Result result = + SdkTestingExecutor.of(new Workflow()) + .withFixedInputs( + JacksonSdkType.of(SubWorkflowInputs.class), + SubWorkflowInputs.create(SdkBindingDataFactory.of(1))) + .withWorkflowOutput( + new SubWorkflow(), + JacksonSdkType.of(SubWorkflowInputs.class), + SubWorkflowInputs.create(SdkBindingDataFactory.of(1)), + JacksonSdkType.of(SubWorkflowOutputs.class), + SubWorkflowOutputs.create(SdkBindingDataFactory.of(10))) + .withWorkflowOutput( + new SubWorkflow(), + JacksonSdkType.of(SubWorkflowInputs.class), + SubWorkflowInputs.create(SdkBindingDataFactory.of(2)), + JacksonSdkType.of(SubWorkflowOutputs.class), + SubWorkflowOutputs.create(SdkBindingDataFactory.of(20))) + .execute(); + + assertThat(result.getIntegerOutput("o"), equalTo(10L)); + } + + public static class Workflow extends SdkWorkflow { + public Workflow() { + super( + JacksonSdkType.of(SubWorkflowInputs.class), JacksonSdkType.of(SubWorkflowOutputs.class)); + } + + @Override + public SubWorkflowOutputs expand(SdkWorkflowBuilder builder, SubWorkflowInputs inputs) { + + var subOut1 = + builder + .apply( + "sub", new SubWorkflow(), SubWorkflowInputs.create(SdkBindingDataFactory.of(1))) + .getOutputs(); + builder + .apply("sub1", new SubWorkflow(), SubWorkflowInputs.create(SdkBindingDataFactory.of(2))) + .getOutputs(); + + return SubWorkflowOutputs.create(subOut1.o()); + } + } + + public static class SubWorkflow extends SdkWorkflow { + public SubWorkflow() { + super( + JacksonSdkType.of(SubWorkflowInputs.class), JacksonSdkType.of(SubWorkflowOutputs.class)); + } + + @Override + public SubWorkflowOutputs expand(SdkWorkflowBuilder builder, SubWorkflowInputs inputs) { + + return SubWorkflowOutputs.create(inputs.a()); + } + } + + @AutoValue + public abstract static class SubWorkflowInputs { + public abstract SdkBindingData a(); + + public static MockSubWorkflowsTest.SubWorkflowInputs create(SdkBindingData a) { + return new AutoValue_MockSubWorkflowsTest_SubWorkflowInputs(a); + } + } + + @AutoValue + public abstract static class SubWorkflowOutputs { + public abstract SdkBindingData o(); + + public static MockSubWorkflowsTest.SubWorkflowOutputs create(SdkBindingData o) { + return new AutoValue_MockSubWorkflowsTest_SubWorkflowOutputs(o); + } + } +}