Skip to content

Commit

Permalink
Fix only apply last subworkflow mock (#280)
Browse files Browse the repository at this point in the history
* Fix only apply last subworkflow mock

Signed-off-by: Andres Gomez Ferrer <[email protected]>

* Fix spotbugs

Signed-off-by: Andres Gomez Ferrer <[email protected]>

---------

Signed-off-by: Andres Gomez Ferrer <[email protected]>
Co-authored-by: Andres Gomez Ferrer <[email protected]>
  • Loading branch information
andresgomezfrr and andresgomezfrr authored Jan 22, 2024
1 parent e37bcad commit d33cd10
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -401,15 +401,17 @@ public <InputT, OutputT> SdkTestingExecutor withWorkflowOutput(

// fixed tasks
TestingRunnableTask<InputT, OutputT> fixedTask =
getFixedTaskOrDefault(workflow.getName(), inputType, outputType);
getFixedTaskOrDefault(workflow.getName(), inputType, outputType)
.withFixedOutput(input, output);

// replace workflow
SdkWorkflow<InputT, OutputT> mockWorkflow =
new TestingWorkflow<>(inputType, outputType, output);
new TestingWorkflow<>(inputType, outputType, fixedTask.fixedOutputs);

return toBuilder()
.putWorkflowTemplate(workflow.getName(), mockWorkflow.toIdlTemplate())
.putFixedTask(workflow.getName(), fixedTask.withFixedOutput(input, output))
.putFixedTask(workflow.getName(), fixedTask)
.putFixedTask(TestingWorkflow.TestingSdkRunnableTask.class.getName(), fixedTask)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,45 @@
*/
package org.flyte.flytekit.testing;

import java.util.Map;
import org.flyte.flytekit.SdkRunnableTask;
import org.flyte.flytekit.SdkType;
import org.flyte.flytekit.SdkWorkflow;
import org.flyte.flytekit.SdkWorkflowBuilder;

/** {@link SdkWorkflow} that can fix output for specific input. */
class TestingWorkflow<InputT, OutputT> extends SdkWorkflow<InputT, OutputT> {

private final OutputT output;
private final Map<InputT, OutputT> outputs;

TestingWorkflow(SdkType<InputT> inputType, SdkType<OutputT> outputType, OutputT output) {
TestingWorkflow(
SdkType<InputT> inputType, SdkType<OutputT> outputType, Map<InputT, OutputT> outputs) {
super(inputType, outputType);
this.output = output;
this.outputs = outputs;
}

@Override
public OutputT expand(SdkWorkflowBuilder builder, InputT input) {
return output;
return builder
.apply(new TestingSdkRunnableTask<>(getInputType(), getOutputType(), outputs), input)
.getOutputs();
}

public static class TestingSdkRunnableTask<InputT, OutputT>
extends SdkRunnableTask<InputT, OutputT> {
private static final long serialVersionUID = 6106269076155338045L;

private final Map<InputT, OutputT> outputs;

public TestingSdkRunnableTask(
SdkType<InputT> inputType, SdkType<OutputT> outputType, Map<InputT, OutputT> outputs) {
super(inputType, outputType);
this.outputs = outputs;
}

@Override
public OutputT run(InputT input) {
return outputs.get(input);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright 2024 Flyte Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.flyte.flytekit.testing;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;

import com.google.auto.value.AutoValue;
import org.flyte.flytekit.SdkBindingData;
import org.flyte.flytekit.SdkBindingDataFactory;
import org.flyte.flytekit.SdkWorkflow;
import org.flyte.flytekit.SdkWorkflowBuilder;
import org.flyte.flytekit.jackson.JacksonSdkType;
import org.junit.jupiter.api.Test;

public class MockSubWorkflowsTest {
@Test
public void test() {
SdkTestingExecutor.Result result =
SdkTestingExecutor.of(new Workflow())
.withFixedInputs(
JacksonSdkType.of(SubWorkflowInputs.class),
SubWorkflowInputs.create(SdkBindingDataFactory.of(1)))
.withWorkflowOutput(
new SubWorkflow(),
JacksonSdkType.of(SubWorkflowInputs.class),
SubWorkflowInputs.create(SdkBindingDataFactory.of(1)),
JacksonSdkType.of(SubWorkflowOutputs.class),
SubWorkflowOutputs.create(SdkBindingDataFactory.of(10)))
.withWorkflowOutput(
new SubWorkflow(),
JacksonSdkType.of(SubWorkflowInputs.class),
SubWorkflowInputs.create(SdkBindingDataFactory.of(2)),
JacksonSdkType.of(SubWorkflowOutputs.class),
SubWorkflowOutputs.create(SdkBindingDataFactory.of(20)))
.execute();

assertThat(result.getIntegerOutput("o"), equalTo(10L));
}

public static class Workflow extends SdkWorkflow<SubWorkflowInputs, SubWorkflowOutputs> {
public Workflow() {
super(
JacksonSdkType.of(SubWorkflowInputs.class), JacksonSdkType.of(SubWorkflowOutputs.class));
}

@Override
public SubWorkflowOutputs expand(SdkWorkflowBuilder builder, SubWorkflowInputs inputs) {

var subOut1 =
builder
.apply(
"sub", new SubWorkflow(), SubWorkflowInputs.create(SdkBindingDataFactory.of(1)))
.getOutputs();
builder
.apply("sub1", new SubWorkflow(), SubWorkflowInputs.create(SdkBindingDataFactory.of(2)))
.getOutputs();

return SubWorkflowOutputs.create(subOut1.o());
}
}

public static class SubWorkflow extends SdkWorkflow<SubWorkflowInputs, SubWorkflowOutputs> {
public SubWorkflow() {
super(
JacksonSdkType.of(SubWorkflowInputs.class), JacksonSdkType.of(SubWorkflowOutputs.class));
}

@Override
public SubWorkflowOutputs expand(SdkWorkflowBuilder builder, SubWorkflowInputs inputs) {

return SubWorkflowOutputs.create(inputs.a());
}
}

@AutoValue
public abstract static class SubWorkflowInputs {
public abstract SdkBindingData<Long> a();

public static MockSubWorkflowsTest.SubWorkflowInputs create(SdkBindingData<Long> a) {
return new AutoValue_MockSubWorkflowsTest_SubWorkflowInputs(a);
}
}

@AutoValue
public abstract static class SubWorkflowOutputs {
public abstract SdkBindingData<Long> o();

public static MockSubWorkflowsTest.SubWorkflowOutputs create(SdkBindingData<Long> o) {
return new AutoValue_MockSubWorkflowsTest_SubWorkflowOutputs(o);
}
}
}

0 comments on commit d33cd10

Please sign in to comment.