From 1b0153f0b53d9448f1b77fca0ef6b3b245d52841 Mon Sep 17 00:00:00 2001 From: Hongxin Liang Date: Sat, 7 Oct 2023 19:01:12 +0200 Subject: [PATCH] Support struct Signed-off-by: Hongxin Liang --- .scalafmt.conf | 2 +- .../org/flyte/examples/AllInputsTask.java | 23 ++++- .../org/flyte/examples/AllInputsWorkflow.java | 10 ++- .../jackson/JacksonSdkLiteralType.java | 4 +- .../jackson/SdkLiteralTypeModule.java | 6 +- .../flyte/flytekit/jackson/SdkTypeModule.java | 2 - .../flytekit/jackson/VariableMapVisitor.java | 9 +- .../LiteralMapDeserializers.java | 41 --------- .../SdkBindingDataDeserializer.java | 74 ++++++++++++--- ...erializer.java => StructDeserializer.java} | 48 ++++++---- .../serializers/GenericSerializer.java | 89 +------------------ .../flytekit/jackson/JacksonSdkTypeTest.java | 28 +++++- .../flyte/flytekit/SdkBindingDataFactory.java | 11 +++ .../flytekitscala/SdkScalaTypeTest.scala | 13 ++- .../flytekitscala/SdkBindingDataFactory.scala | 12 +++ .../integrationtests/structs/BQReference.java | 15 +++- .../structs/BuildBqReference.java | 17 ++-- .../structs/MockLookupBqTask.java | 16 ++-- .../structs/MockPipelineWorkflow.java | 6 +- .../src/test/java/org/flyte/AdditionalIT.java | 2 - 20 files changed, 238 insertions(+), 190 deletions(-) delete mode 100644 flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/LiteralMapDeserializers.java rename flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/{LiteralStructDeserializer.java => StructDeserializer.java} (70%) diff --git a/.scalafmt.conf b/.scalafmt.conf index 6d6fd4e2c..971a38a84 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -1,3 +1,3 @@ -version=2.5.2 +version=3.7.14 runner.dialect=scala212source3 diff --git a/flytekit-examples/src/main/java/org/flyte/examples/AllInputsTask.java b/flytekit-examples/src/main/java/org/flyte/examples/AllInputsTask.java index a37cd3045..5075ae8cb 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/AllInputsTask.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/AllInputsTask.java @@ -35,8 +35,20 @@ public AllInputsTask() { JacksonSdkType.of(AutoAllInputsInput.class), JacksonSdkType.of(AutoAllInputsOutput.class)); } + @AutoValue + public abstract static class Nested { + public abstract String hello(); + + public abstract String world(); + + public static Nested create(String hello, String world) { + return new AutoValue_AllInputsTask_Nested(hello, world); + } + } + @AutoValue public abstract static class AutoAllInputsInput { + public abstract SdkBindingData i(); public abstract SdkBindingData f(); @@ -51,6 +63,8 @@ public abstract static class AutoAllInputsInput { public abstract SdkBindingData blob(); + public abstract SdkBindingData generic(); + public abstract SdkBindingData> l(); public abstract SdkBindingData> m(); @@ -67,12 +81,13 @@ public static AutoAllInputsInput create( SdkBindingData t, SdkBindingData d, SdkBindingData blob, + SdkBindingData generic, SdkBindingData> l, SdkBindingData> m, SdkBindingData> emptyList, SdkBindingData> emptyMap) { return new AutoValue_AllInputsTask_AutoAllInputsInput( - i, f, s, b, t, d, blob, l, m, emptyList, emptyMap); + i, f, s, b, t, d, blob, generic, l, m, emptyList, emptyMap); } } @@ -93,6 +108,8 @@ public abstract static class AutoAllInputsOutput { public abstract SdkBindingData blob(); + public abstract SdkBindingData generic(); + public abstract SdkBindingData> l(); public abstract SdkBindingData> m(); @@ -109,12 +126,13 @@ public static AutoAllInputsOutput create( SdkBindingData t, SdkBindingData d, SdkBindingData blob, + SdkBindingData generic, SdkBindingData> l, SdkBindingData> m, SdkBindingData> emptyList, SdkBindingData> emptyMap) { return new AutoValue_AllInputsTask_AutoAllInputsOutput( - i, f, s, b, t, d, blob, l, m, emptyList, emptyMap); + i, f, s, b, t, d, blob, generic, l, m, emptyList, emptyMap); } } @@ -128,6 +146,7 @@ public AutoAllInputsOutput run(AutoAllInputsInput input) { input.t(), input.d(), input.blob(), + input.generic(), input.l(), input.m(), input.emptyList(), diff --git a/flytekit-examples/src/main/java/org/flyte/examples/AllInputsWorkflow.java b/flytekit-examples/src/main/java/org/flyte/examples/AllInputsWorkflow.java index 63394657c..8bd9acc31 100644 --- a/flytekit-examples/src/main/java/org/flyte/examples/AllInputsWorkflow.java +++ b/flytekit-examples/src/main/java/org/flyte/examples/AllInputsWorkflow.java @@ -29,12 +29,14 @@ import org.flyte.api.v1.BlobType; import org.flyte.api.v1.BlobType.BlobDimensionality; import org.flyte.examples.AllInputsTask.AutoAllInputsOutput; +import org.flyte.examples.AllInputsTask.Nested; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkNode; import org.flyte.flytekit.SdkTypes; import org.flyte.flytekit.SdkWorkflow; import org.flyte.flytekit.SdkWorkflowBuilder; +import org.flyte.flytekit.jackson.JacksonSdkLiteralType; import org.flyte.flytekit.jackson.JacksonSdkType; @AutoService(SdkWorkflow.class) @@ -73,6 +75,8 @@ public AllInputsWorkflowOutput expand(SdkWorkflowBuilder builder, Void noInput) .build()) .build()) .build()), + SdkBindingDataFactory.of( + JacksonSdkLiteralType.of(Nested.class), Nested.create("hello", "world")), SdkBindingDataFactory.ofStringCollection(Arrays.asList("foo", "bar")), SdkBindingDataFactory.ofStringMap(Map.of("test", "test")), SdkBindingDataFactory.ofStringCollection(Collections.emptyList()), @@ -88,6 +92,7 @@ public AllInputsWorkflowOutput expand(SdkWorkflowBuilder builder, Void noInput) outputs.t(), outputs.d(), outputs.blob(), + outputs.generic(), outputs.l(), outputs.m(), outputs.emptyList(), @@ -111,6 +116,8 @@ public abstract static class AllInputsWorkflowOutput { public abstract SdkBindingData blob(); + public abstract SdkBindingData generic(); + public abstract SdkBindingData> l(); public abstract SdkBindingData> m(); @@ -127,12 +134,13 @@ public static AllInputsWorkflow.AllInputsWorkflowOutput create( SdkBindingData t, SdkBindingData d, SdkBindingData blob, + SdkBindingData generic, SdkBindingData> l, SdkBindingData> m, SdkBindingData> emptyList, SdkBindingData> emptyMap) { return new AutoValue_AllInputsWorkflow_AllInputsWorkflowOutput( - i, f, s, b, t, d, blob, l, m, emptyList, emptyMap); + i, f, s, b, t, d, blob, generic, l, m, emptyList, emptyMap); } } } diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/JacksonSdkLiteralType.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/JacksonSdkLiteralType.java index 0be5ba34f..969fa64bd 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/JacksonSdkLiteralType.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/JacksonSdkLiteralType.java @@ -33,6 +33,7 @@ import org.flyte.api.v1.Scalar; import org.flyte.api.v1.SimpleType; import org.flyte.flytekit.SdkLiteralType; +import org.flyte.flytekit.jackson.deserializers.StructDeserializer.StructWrapper; /** * Implementation of {@link org.flyte.flytekit.SdkLiteralType} for {@link @@ -102,7 +103,8 @@ public Literal toLiteral(T value) { var tree = OBJECT_MAPPER.valueToTree(value); try { - return OBJECT_MAPPER.treeToValue(tree, Literal.class); + return Literal.ofScalar( + Scalar.ofGeneric(OBJECT_MAPPER.treeToValue(tree, StructWrapper.class).unwrap())); } catch (IOException e) { throw new UncheckedIOException("toLiteral failed for [" + clazz.getName() + "]: " + value, e); } diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/SdkLiteralTypeModule.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/SdkLiteralTypeModule.java index 861a1c640..4ec2d158d 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/SdkLiteralTypeModule.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/SdkLiteralTypeModule.java @@ -20,8 +20,8 @@ import com.fasterxml.jackson.databind.Module; import com.fasterxml.jackson.databind.module.SimpleDeserializers; import com.fasterxml.jackson.databind.module.SimpleSerializers; -import org.flyte.api.v1.Literal; -import org.flyte.flytekit.jackson.deserializers.LiteralStructDeserializer; +import org.flyte.flytekit.jackson.deserializers.StructDeserializer; +import org.flyte.flytekit.jackson.deserializers.StructDeserializer.StructWrapper; import org.flyte.flytekit.jackson.serializers.StructSerializer; class SdkLiteralTypeModule extends Module { @@ -43,7 +43,7 @@ public void setupModule(SetupContext context) { context.addSerializers(serializers); var deserializers = new SimpleDeserializers(); - deserializers.addDeserializer(Literal.class, new LiteralStructDeserializer()); + deserializers.addDeserializer(StructWrapper.class, new StructDeserializer()); context.addDeserializers(deserializers); // append with the lowest priority to use as fallback, if builtin annotations aren't present diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/SdkTypeModule.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/SdkTypeModule.java index 17f71c25a..aa25ff45e 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/SdkTypeModule.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/SdkTypeModule.java @@ -20,7 +20,6 @@ import com.fasterxml.jackson.databind.Module; import com.fasterxml.jackson.databind.deser.Deserializers; import com.fasterxml.jackson.databind.module.SimpleSerializers; -import org.flyte.flytekit.jackson.deserializers.LiteralMapDeserializers; import org.flyte.flytekit.jackson.deserializers.SdkBindingDataDeserializers; import org.flyte.flytekit.jackson.serializers.BindingMapSerializers; import org.flyte.flytekit.jackson.serializers.LiteralMapSerializers; @@ -60,7 +59,6 @@ public void setupModule(SetupContext context) { context.addSerializers(serializers); context.addSerializers(new LiteralMapSerializers()); - context.addDeserializers(new LiteralMapDeserializers()); context.addSerializers(new BindingMapSerializers()); context.addDeserializers(sdkbindingDeserializers); diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java index 0b9c5acaf..c565898be 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java @@ -175,9 +175,12 @@ private SdkLiteralType toLiteralType( return SdkLiteralTypes.blobs( BlobType.builder().format("").dimensionality(BlobDimensionality.SINGLE).build()); } - // TODO: Support structs - throw new UnsupportedOperationException( - String.format("Unsupported type: [%s]", type.getName())); + try { + return JacksonSdkLiteralType.of(type); + } catch (Exception e) { + throw new UnsupportedOperationException( + String.format("Unsupported type: [%s]", type.getName()), e); + } } private static boolean isPrimitiveAssignableFrom(Class fromClass, Class toClass) { diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/LiteralMapDeserializers.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/LiteralMapDeserializers.java deleted file mode 100644 index f3015c3da..000000000 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/LiteralMapDeserializers.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright 2020-2023 Flyte Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.flyte.flytekit.jackson.deserializers; - -import com.fasterxml.jackson.databind.BeanDescription; -import com.fasterxml.jackson.databind.DeserializationConfig; -import com.fasterxml.jackson.databind.JavaType; -import com.fasterxml.jackson.databind.JsonDeserializer; -import com.fasterxml.jackson.databind.deser.Deserializers; -import java.util.Map; -import org.flyte.api.v1.LiteralType; -import org.flyte.flytekit.jackson.JacksonLiteralMap; - -public class LiteralMapDeserializers extends Deserializers.Base { - - @Override - public JsonDeserializer findBeanDeserializer( - JavaType type, DeserializationConfig config, BeanDescription beanDesc) { - if (type.getRawClass().equals(JacksonLiteralMap.class)) { - Map literalTypeMap = type.getValueHandler(); - - return new LiteralMapDeserializer(literalTypeMap); - } - - return null; - } -} diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/SdkBindingDataDeserializer.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/SdkBindingDataDeserializer.java index 7df993b8f..8b5982a6d 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/SdkBindingDataDeserializer.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/SdkBindingDataDeserializer.java @@ -24,8 +24,12 @@ import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.VALUE; import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.BeanProperty; import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.JsonDeserializer; import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.deser.ContextualDeserializer; import com.fasterxml.jackson.databind.deser.std.StdDeserializer; import java.io.IOException; import java.time.Duration; @@ -46,36 +50,48 @@ import org.flyte.api.v1.LiteralType; import org.flyte.api.v1.Primitive; import org.flyte.api.v1.Scalar; +import org.flyte.api.v1.Scalar.Kind; import org.flyte.api.v1.SimpleType; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkLiteralType; import org.flyte.flytekit.SdkLiteralTypes; +import org.flyte.flytekit.jackson.JacksonSdkLiteralType; -class SdkBindingDataDeserializer extends StdDeserializer> { +class SdkBindingDataDeserializer extends StdDeserializer> + implements ContextualDeserializer { private static final long serialVersionUID = 0L; + private final JavaType type; + public SdkBindingDataDeserializer() { + this(null); + } + + private SdkBindingDataDeserializer(JavaType type) { super(SdkBindingData.class); + + this.type = type; } @Override public SdkBindingData deserialize( JsonParser jsonParser, DeserializationContext deserializationContext) throws IOException { JsonNode tree = jsonParser.readValueAsTree(); - return transform(tree); + return transform(tree, deserializationContext); } - private SdkBindingData transform(JsonNode tree) { + private SdkBindingData transform( + JsonNode tree, DeserializationContext deserializationContext) { Literal.Kind literalKind = Literal.Kind.valueOf(tree.get(LITERAL).asText()); switch (literalKind) { case SCALAR: - return transformScalar(tree); + return transformScalar(tree, deserializationContext); case COLLECTION: - return transformCollection(tree); + return transformCollection(tree, deserializationContext); case MAP: - return transformMap(tree); + return transformMap(tree, deserializationContext); default: throw new UnsupportedOperationException( @@ -83,7 +99,8 @@ private SdkBindingData transform(JsonNode tree) { } } - private static SdkBindingData transformScalar(JsonNode tree) { + private SdkBindingData transformScalar( + JsonNode tree, DeserializationContext deserializationContext) { Scalar.Kind scalarKind = Scalar.Kind.valueOf(tree.get(SCALAR).asText()); switch (scalarKind) { case PRIMITIVE: @@ -109,6 +126,8 @@ private static SdkBindingData transformScalar(JsonNode tree) { return transformBlob(tree); case GENERIC: + return transformGeneric(tree, deserializationContext, scalarKind); + default: throw new UnsupportedOperationException( "Type contains an unsupported scalar: " + scalarKind); @@ -132,8 +151,28 @@ private static SdkBindingData transformBlob(JsonNode tree) { .build()); } + private SdkBindingData transformGeneric( + JsonNode tree, DeserializationContext deserializationContext, Kind scalarKind) { + JsonParser jsonParser = tree.get(VALUE).traverse(); + try { + jsonParser.nextToken(); + Object object = + deserializationContext + .findNonContextualValueDeserializer(type) + .deserialize(jsonParser, deserializationContext); + @SuppressWarnings("unchecked") + SdkLiteralType jacksonSdkLiteralType = + (SdkLiteralType) JacksonSdkLiteralType.of(type.getRawClass()); + return SdkBindingData.literal(jacksonSdkLiteralType, object); + } catch (IOException e) { + throw new UnsupportedOperationException( + "Type contains an unsupported generic: " + scalarKind, e); + } + } + @SuppressWarnings("unchecked") - private SdkBindingData> transformCollection(JsonNode tree) { + private SdkBindingData> transformCollection( + JsonNode tree, DeserializationContext deserializationContext) { SdkLiteralType literalType = (SdkLiteralType) readLiteralType(tree.get(TYPE)); Iterator elements = tree.get(VALUE).elements(); @@ -144,7 +183,10 @@ private SdkBindingData> transformCollection(JsonNode tree) { case BLOB_TYPE: List collection = (List) - streamOf(elements).map(this::transform).map(SdkBindingData::get).collect(toList()); + streamOf(elements) + .map((JsonNode tree1) -> transform(tree1, deserializationContext)) + .map(SdkBindingData::get) + .collect(toList()); return SdkBindingDataFactory.of(literalType, collection); case SCHEMA_TYPE: @@ -155,7 +197,8 @@ private SdkBindingData> transformCollection(JsonNode tree) { } @SuppressWarnings("unchecked") - private SdkBindingData> transformMap(JsonNode tree) { + private SdkBindingData> transformMap( + JsonNode tree, DeserializationContext deserializationContext) { SdkLiteralType literalType = (SdkLiteralType) readLiteralType(tree.get(TYPE)); JsonNode valueNode = tree.get(VALUE); List> entries = @@ -169,7 +212,11 @@ private SdkBindingData> transformMap(JsonNode tree) { case BLOB_TYPE: Map bindingDataMap = entries.stream() - .map(entry -> Map.entry(entry.getKey(), (T) transform(entry.getValue()).get())) + .map( + entry -> + Map.entry( + entry.getKey(), + (T) transform(entry.getValue(), deserializationContext).get())) .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); return SdkBindingDataFactory.of(literalType, bindingDataMap); @@ -225,4 +272,9 @@ private Stream streamOf(Iterator nodes) { return StreamSupport.stream( Spliterators.spliteratorUnknownSize(nodes, Spliterator.ORDERED), false); } + + @Override + public JsonDeserializer createContextual(DeserializationContext ctxt, BeanProperty property) { + return new SdkBindingDataDeserializer(property.getType().containedType(0)); + } } diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/LiteralStructDeserializer.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/StructDeserializer.java similarity index 70% rename from flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/LiteralStructDeserializer.java rename to flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/StructDeserializer.java index 0c17f55d5..88f673f80 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/LiteralStructDeserializer.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/StructDeserializer.java @@ -29,23 +29,35 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import org.flyte.api.v1.Literal; -import org.flyte.api.v1.Scalar; import org.flyte.api.v1.Struct; import org.flyte.api.v1.Struct.Value; +import org.flyte.flytekit.jackson.deserializers.StructDeserializer.StructWrapper; -public class LiteralStructDeserializer extends StdDeserializer { +public class StructDeserializer extends StdDeserializer { private static final long serialVersionUID = -6835948754469626304L; - public LiteralStructDeserializer() { - super(Literal.class); + // we cannot use Struct directly because it is an auto-value class so this deserializer will not + // be used by Jackson + public static class StructWrapper { + + private final Struct struct; + + public StructWrapper(Struct struct) { + this.struct = struct; + } + + public Struct unwrap() { + return struct; + } } - @Override - public Literal deserialize(JsonParser p, DeserializationContext ctxt) throws IOException { + public StructDeserializer() { + super(StructWrapper.class); + } - Struct generic = readValueAsStruct(p); - return Literal.ofScalar(Scalar.ofGeneric(generic)); + @Override + public StructWrapper deserialize(JsonParser p, DeserializationContext ctxt) throws IOException { + return new StructWrapper(readValueAsStruct(p)); } private static Struct readValueAsStruct(JsonParser p) throws IOException { @@ -67,7 +79,7 @@ private static Struct readValueAsStruct(JsonParser p) throws IOException { return Struct.of(unmodifiableMap(fields)); } - private static Struct.Value readValueAsStructValue(JsonParser p) throws IOException { + private static Value readValueAsStructValue(JsonParser p) throws IOException { switch (p.currentToken()) { case START_ARRAY: p.nextToken(); @@ -75,38 +87,38 @@ private static Struct.Value readValueAsStructValue(JsonParser p) throws IOExcept List valuesList = new ArrayList<>(); while (p.currentToken() != JsonToken.END_ARRAY) { - Struct.Value value = readValueAsStructValue(p); + Value value = readValueAsStructValue(p); p.nextToken(); valuesList.add(value); } - return Struct.Value.ofListValue(unmodifiableList(valuesList)); + return Value.ofListValue(unmodifiableList(valuesList)); case START_OBJECT: Struct struct = readValueAsStruct(p); - return Struct.Value.ofStructValue(struct); + return Value.ofStructValue(struct); case VALUE_STRING: String stringValue = p.readValueAs(String.class); - return Struct.Value.ofStringValue(stringValue); + return Value.ofStringValue(stringValue); case VALUE_NUMBER_FLOAT: case VALUE_NUMBER_INT: Double doubleValue = p.readValueAs(Double.class); - return Struct.Value.ofNumberValue(doubleValue); + return Value.ofNumberValue(doubleValue); case VALUE_NULL: - return Struct.Value.ofNullValue(); + return Value.ofNullValue(); case VALUE_FALSE: - return Struct.Value.ofBoolValue(false); + return Value.ofBoolValue(false); case VALUE_TRUE: - return Struct.Value.ofBoolValue(true); + return Value.ofBoolValue(true); case FIELD_NAME: case NOT_AVAILABLE: diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/GenericSerializer.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/GenericSerializer.java index 12ec69e18..5c73535c7 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/GenericSerializer.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/GenericSerializer.java @@ -16,20 +16,15 @@ */ package org.flyte.flytekit.jackson.serializers; -import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.LITERAL; -import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.SCALAR; -import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.STRUCT_TYPE; -import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.STRUCT_VALUE; +import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.VALUE; import com.fasterxml.jackson.core.JsonGenerator; import com.fasterxml.jackson.databind.SerializerProvider; import java.io.IOException; -import java.util.Map; import org.flyte.api.v1.Literal; import org.flyte.api.v1.LiteralType; import org.flyte.api.v1.Scalar; import org.flyte.api.v1.SimpleType; -import org.flyte.api.v1.Struct; public class GenericSerializer extends ScalarSerializer { public GenericSerializer( @@ -48,85 +43,7 @@ public GenericSerializer( @Override public void serializeScalar() throws IOException { gen.writeObject(Scalar.Kind.GENERIC); - for (Map.Entry entry : value.scalar().generic().fields().entrySet()) { - gen.writeFieldName(entry.getKey()); - serializeStructValue(entry.getValue()); - } - } - - private void serializeStructValue(Struct.Value value) throws IOException { - if (!value.kind().equals(Struct.Value.Kind.LIST_VALUE) - && !value.kind().equals(Struct.Value.Kind.NULL_VALUE)) { - gen.writeStartObject(); - gen.writeFieldName(LITERAL); - gen.writeObject(Literal.Kind.SCALAR); - gen.writeFieldName(SCALAR); - gen.writeObject(Scalar.Kind.GENERIC); - } - - if (isSimpleType(value.kind())) { - gen.writeFieldName(STRUCT_TYPE); - } - switch (value.kind()) { - case BOOL_VALUE: - writeSimpleType( - Struct.Value.Kind.BOOL_VALUE, - value, - (generator, v) -> generator.writeBoolean(v.boolValue())); - return; - - case LIST_VALUE: - throw new RuntimeException("not supported list inside the struct"); - - case NUMBER_VALUE: - writeSimpleType( - Struct.Value.Kind.NUMBER_VALUE, - value, - (generator, v) -> generator.writeNumber(v.numberValue())); - return; - - case STRING_VALUE: - writeSimpleType( - Struct.Value.Kind.STRING_VALUE, - value, - (generator, v) -> generator.writeString(v.stringValue())); - return; - - case STRUCT_VALUE: - value.structValue().fields().forEach((k, v) -> writeStructValue(gen, k, v)); - gen.writeEndObject(); - return; - - case NULL_VALUE: - gen.writeNull(); - } - } - - private void writeStructValue(JsonGenerator gen, String k, Struct.Value v) { - try { - gen.writeFieldName(k); - serializeStructValue(v); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - private boolean isSimpleType(Struct.Value.Kind kind) { - return kind.equals(Struct.Value.Kind.BOOL_VALUE) - || kind.equals(Struct.Value.Kind.NUMBER_VALUE) - || kind.equals(Struct.Value.Kind.STRING_VALUE); - } - - private void writeSimpleType( - Struct.Value.Kind kind, Struct.Value structValue, WriteGenericFunction writeTypeFunction) - throws IOException { - gen.writeObject(kind); - gen.writeFieldName(STRUCT_VALUE); - writeTypeFunction.write(gen, structValue); - gen.writeEndObject(); - } - - interface WriteGenericFunction { - void write(JsonGenerator gen, Struct.Value value) throws IOException; + gen.writeFieldName(VALUE); + new StructSerializer().serialize(value.scalar().generic(), gen, serializerProvider); } } diff --git a/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java b/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java index 6b87b748d..66206b26d 100644 --- a/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java +++ b/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java @@ -46,6 +46,8 @@ import org.flyte.api.v1.Primitive; import org.flyte.api.v1.Scalar; import org.flyte.api.v1.SimpleType; +import org.flyte.api.v1.Struct; +import org.flyte.api.v1.Struct.Value; import org.flyte.api.v1.Variable; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkBindingDataFactory; @@ -72,6 +74,7 @@ public static AutoValueInput createAutoValueInput( Instant t, Duration d, Blob blob, + Nested generic, List l, List lb, Map m, @@ -88,6 +91,7 @@ public static AutoValueInput createAutoValueInput( SdkBindingDataFactory.of(t), SdkBindingDataFactory.of(d), SdkBindingDataFactory.of(blob), + SdkBindingDataFactory.of(JacksonSdkLiteralType.of(Nested.class), generic), SdkBindingDataFactory.ofStringCollection(l), SdkBindingDataFactory.of(SdkLiteralTypes.blobs(BLOB_TYPE), lb), SdkBindingDataFactory.ofStringMap(m), @@ -111,6 +115,7 @@ public void testVariableMap() { hasEntry("t", createVar(SimpleType.DATETIME)), hasEntry("d", createVar(SimpleType.DURATION)), hasEntry("blob", createVar(LiteralType.ofBlobType(BLOB_TYPE))), + hasEntry("generic", createVar(LiteralType.ofSimpleType(SimpleType.STRUCT))), hasEntry( "l", createVar(LiteralType.ofCollectionType(ofSimpleType(SimpleType.STRING)))), hasEntry( @@ -139,6 +144,7 @@ void testFromLiteralMap() { literalMap.put("t", literalOf(Primitive.ofDatetime(datetime))); literalMap.put("d", literalOf(Primitive.ofDuration(duration))); literalMap.put("blob", literalOf(BLOB)); + literalMap.put("generic", literalOf(Struct.of(Map.of("hello", Value.ofStringValue("hello"))))); literalMap.put("l", Literal.ofCollection(List.of(literalOf(Primitive.ofStringValue("123"))))); literalMap.put("lb", Literal.ofCollection(List.of(literalOf(BLOB)))); literalMap.put("m", Literal.ofMap(Map.of("marco", literalOf(Primitive.ofStringValue("polo"))))); @@ -185,6 +191,7 @@ void testFromLiteralMap() { /* t= */ datetime, /* d= */ duration, /* blob= */ BLOB, + Nested.create("hello"), /* l= */ List.of("123"), /* lb= */ List.of(BLOB), /* m= */ Map.of("marco", "polo"), @@ -216,6 +223,7 @@ void testToLiteralMap() { /* t= */ Instant.ofEpochSecond(42, 1), /* d= */ Duration.ofSeconds(1, 42), /* blob= */ BLOB, + Nested.create("hello"), /* l= */ List.of("foo"), /* lb= */ List.of(BLOB), /* m= */ Map.of("marco", "polo"), @@ -294,6 +302,7 @@ public void testToSdkBindingDataMap() { /* t= */ Instant.ofEpochSecond(42, 1), /* d= */ Duration.ofSeconds(1, 42), /* blob= */ BLOB, + Nested.create("hello"), /* l= */ List.of("foo"), /* lb= */ List.of(BLOB), /* m= */ Map.of("marco", "polo"), @@ -315,6 +324,7 @@ public void testToSdkBindingDataMap() { expected.put("t", input.t()); expected.put("d", input.d()); expected.put("blob", input.blob()); + expected.put("generic", input.generic()); expected.put("l", input.l()); expected.put("lb", input.lb()); expected.put("m", input.m()); @@ -532,6 +542,15 @@ public static AutoValueDeprecatedInput create(long i) { } } + @AutoValue + public abstract static class Nested { + public abstract String hello(); + + public static AutoValue_JacksonSdkTypeTest_Nested create(String hello) { + return new AutoValue_JacksonSdkTypeTest_Nested(hello); + } + } + @AutoValue public abstract static class AutoValueInput { @@ -550,6 +569,8 @@ public abstract static class AutoValueInput { public abstract SdkBindingData blob(); + public abstract SdkBindingData generic(); + public abstract SdkBindingData> l(); public abstract SdkBindingData> lb(); @@ -574,6 +595,7 @@ public static AutoValueInput create( SdkBindingData t, SdkBindingData d, SdkBindingData blob, + SdkBindingData generic, SdkBindingData> l, SdkBindingData> lb, SdkBindingData> m, @@ -583,7 +605,7 @@ public static AutoValueInput create( SdkBindingData>> ml, SdkBindingData>> mm) { return new AutoValue_JacksonSdkTypeTest_AutoValueInput( - i, f, s, b, t, d, blob, l, lb, m, mb, ll, lm, ml, mm); + i, f, s, b, t, d, blob, generic, l, lb, m, mb, ll, lm, ml, mm); } } @@ -722,4 +744,8 @@ private static Literal literalOf(Primitive primitive) { private static Literal literalOf(Blob blob) { return Literal.ofScalar(Scalar.ofBlob(blob)); } + + private static Literal literalOf(Struct generic) { + return Literal.ofScalar(Scalar.ofGeneric(generic)); + } } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingDataFactory.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingDataFactory.java index ac183ad8d..aa7217ec8 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingDataFactory.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingDataFactory.java @@ -124,6 +124,17 @@ public static SdkBindingData> of(SdkLiteralType elementType, List return SdkBindingData.literal(collections(elementType), collection); } + /** + * Creates a {@code SdkBindingData} for a flyte type with the given value. + * + * @param type the flyte type + * @param value the simple value for this data + * @return the new {@code SdkBindingData} + */ + public static SdkBindingData of(SdkLiteralType type, T value) { + return SdkBindingData.literal(type, value); + } + /** * Creates a {@code SdkBindingData} for a flyte Blob with the given value. * diff --git a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala index 1616a78a4..a3cbaf863 100644 --- a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala +++ b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala @@ -38,7 +38,8 @@ import org.flyte.flytekit.{ import org.flyte.flytekitscala.SdkBindingDataFactory import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows} import org.junit.jupiter.api.Test -import org.flyte.examples.AllInputsTask.AutoAllInputsInput +import org.flyte.examples.AllInputsTask.{AutoAllInputsInput, Nested} +import org.flyte.flytekit.jackson.JacksonSdkLiteralType import org.flyte.flytekitscala.SdkLiteralTypes.{collections, maps, strings} class SdkScalaTypeTest { @@ -409,6 +410,10 @@ class SdkScalaTypeTest { SdkJavaBindingDataFactory.of(Instant.parse("2023-01-01T00:00:00Z")), SdkJavaBindingDataFactory.of(Duration.ZERO), SdkJavaBindingDataFactory.of(blob), + SdkJavaBindingDataFactory.of( + JacksonSdkLiteralType.of(classOf[Nested]), + Nested.create("hello", "world") + ), SdkJavaBindingDataFactory.ofStringCollection(List("1", "2", "3").asJava), SdkJavaBindingDataFactory.ofStringMap(Map("a" -> "2", "b" -> "3").asJava), SdkJavaBindingDataFactory.ofStringCollection(List.empty[String].asJava), @@ -425,6 +430,7 @@ class SdkScalaTypeTest { instant: SdkBindingData[Instant], duration: SdkBindingData[Duration], blob: SdkBindingData[Blob], + generic: SdkBindingData[Nested], list: SdkBindingData[List[String]], map: SdkBindingData[Map[String, String]], emptyList: SdkBindingData[List[String]], @@ -439,6 +445,7 @@ class SdkScalaTypeTest { input.t(), input.d(), input.blob(), + input.generic(), toScalaList(input.l()), toScalaMap(input.m()), toScalaList(input.emptyList()), @@ -453,6 +460,10 @@ class SdkScalaTypeTest { SdkBindingDataFactory.of(Instant.parse("2023-01-01T00:00:00Z")), SdkBindingDataFactory.of(Duration.ZERO), SdkBindingDataFactory.of(blob), + SdkBindingDataFactory.of( + JacksonSdkLiteralType.of(classOf[Nested]), + Nested.create("hello", "world") + ), SdkBindingDataFactory.of(List("1", "2", "3")), SdkBindingDataFactory.of(Map("a" -> "2", "b" -> "3")), SdkBindingDataFactory.ofStringCollection(List.empty[String]), diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataFactory.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataFactory.scala index ad134a296..857238ee4 100644 --- a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataFactory.scala +++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataFactory.scala @@ -145,6 +145,18 @@ object SdkBindingDataFactory { def of(value: Blob): SdkBindingData[Blob] = SdkBindingData.literal(SdkLiteralTypes.blobs(value.metadata.`type`), value) + /** Creates a [[SdkBindingData]] for a flyte type with the given value. + * + * @param type + * the flyte type + * @param value + * the simple value for this data + * @return + * the new [[SdkBindingData]] + */ + def of[T](`type`: SdkLiteralType[T], value: T): SdkBindingData[T] = + SdkBindingData.literal(`type`, value) + /** Creates a [[SdkBindingDataFactory]] for a flyte string collection given a * scala [[List]]. * diff --git a/integration-tests/src/main/java/org/flyte/integrationtests/structs/BQReference.java b/integration-tests/src/main/java/org/flyte/integrationtests/structs/BQReference.java index 0da75a2c3..f5fac552f 100644 --- a/integration-tests/src/main/java/org/flyte/integrationtests/structs/BQReference.java +++ b/integration-tests/src/main/java/org/flyte/integrationtests/structs/BQReference.java @@ -20,13 +20,26 @@ @AutoValue public abstract class BQReference { + @AutoValue + public abstract static class Nested { + public abstract String project(); + + public abstract String dataset(); + + public abstract String tableName(); + } + public abstract String project(); public abstract String dataset(); public abstract String tableName(); + // this is only to test nested nested auto-value + public abstract Nested nested(); + public static BQReference create(String project, String dataset, String tableName) { - return new AutoValue_BQReference(project, dataset, tableName); + return new AutoValue_BQReference( + project, dataset, tableName, new AutoValue_BQReference_Nested(project, dataset, tableName)); } } diff --git a/integration-tests/src/main/java/org/flyte/integrationtests/structs/BuildBqReference.java b/integration-tests/src/main/java/org/flyte/integrationtests/structs/BuildBqReference.java index 0e78c9774..23fee27ad 100644 --- a/integration-tests/src/main/java/org/flyte/integrationtests/structs/BuildBqReference.java +++ b/integration-tests/src/main/java/org/flyte/integrationtests/structs/BuildBqReference.java @@ -16,12 +16,15 @@ */ package org.flyte.integrationtests.structs; +import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkRunnableTask; +import org.flyte.flytekit.jackson.JacksonSdkLiteralType; import org.flyte.flytekit.jackson.JacksonSdkType; -// @AutoService(SdkRunnableTask.class) +@AutoService(SdkRunnableTask.class) public class BuildBqReference extends SdkRunnableTask { private static final long serialVersionUID = -489898361071672070L; @@ -35,7 +38,10 @@ public BuildBqReference() { @Override public Output run(Input input) { return Output.create( - BQReference.create(input.project().get(), input.dataset().get(), input.tableName().get())); + SdkBindingDataFactory.of( + JacksonSdkLiteralType.of(BQReference.class), + BQReference.create( + input.project().get(), input.dataset().get(), input.tableName().get()))); } @AutoValue @@ -58,11 +64,8 @@ public static Input create( public abstract static class Output { abstract SdkBindingData ref(); - public static Output create(BQReference ref) { - // TODO We need a way to generate SdkBindings of generic autovalues like BQReference - // that would be mapped to sdkStructs. JacksonSdkType of nested autovalues are mapped as - // structs - return null; + public static Output create(SdkBindingData ref) { + return new AutoValue_BuildBqReference_Output(ref); } } } diff --git a/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockLookupBqTask.java b/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockLookupBqTask.java index 9e82df7ca..3b3fb8a8b 100644 --- a/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockLookupBqTask.java +++ b/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockLookupBqTask.java @@ -16,13 +16,15 @@ */ package org.flyte.integrationtests.structs; +import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkRunnableTask; +import org.flyte.flytekit.jackson.JacksonSdkLiteralType; import org.flyte.flytekit.jackson.JacksonSdkType; -// @AutoService(SdkRunnableTask.class) +@AutoService(SdkRunnableTask.class) public class MockLookupBqTask extends SdkRunnableTask { private static final long serialVersionUID = 604843235716487166L; @@ -39,21 +41,25 @@ public abstract static class Input { public static Input create( SdkBindingData ref, SdkBindingData checkIfExists) { - return null; // TODO + return new AutoValue_MockLookupBqTask_Input(ref, checkIfExists); } } @AutoValue public abstract static class Output { + public abstract SdkBindingData ref(); + public abstract SdkBindingData exists(); - public static Output create(boolean exists) { - return new AutoValue_MockLookupBqTask_Output(SdkBindingDataFactory.of(exists)); + public static Output create(BQReference ref, boolean exists) { + return new AutoValue_MockLookupBqTask_Output( + SdkBindingDataFactory.of(JacksonSdkLiteralType.of(BQReference.class), ref), + SdkBindingDataFactory.of(exists)); } } @Override public Output run(Input input) { - return Output.create(input.ref().get().tableName().contains("table-exists")); + return Output.create(input.ref().get(), input.ref().get().tableName().contains("table-exists")); } } diff --git a/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockPipelineWorkflow.java b/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockPipelineWorkflow.java index d1b565e77..5a5c6ccca 100644 --- a/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockPipelineWorkflow.java +++ b/integration-tests/src/main/java/org/flyte/integrationtests/structs/MockPipelineWorkflow.java @@ -16,6 +16,7 @@ */ package org.flyte.integrationtests.structs; +import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkBindingDataFactory; @@ -23,10 +24,7 @@ import org.flyte.flytekit.SdkWorkflowBuilder; import org.flyte.flytekit.jackson.JacksonSdkType; -// This workflow relays on SdkBinding that should be serialized -// as Struct. By going to typed inputs and outputs, we have de-scoped the support -// of structs. -// @AutoService(SdkWorkflow.class) +@AutoService(SdkWorkflow.class) public class MockPipelineWorkflow extends SdkWorkflow { public MockPipelineWorkflow() { diff --git a/integration-tests/src/test/java/org/flyte/AdditionalIT.java b/integration-tests/src/test/java/org/flyte/AdditionalIT.java index 3c9914312..00e50c27a 100644 --- a/integration-tests/src/test/java/org/flyte/AdditionalIT.java +++ b/integration-tests/src/test/java/org/flyte/AdditionalIT.java @@ -23,7 +23,6 @@ import flyteidl.core.Literals; import org.flyte.utils.Literal; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; @@ -67,7 +66,6 @@ public void testBranchNodeWorkflow(long a, long b, long c, long d, String expect "table-exists,true", "non-existent,false", }) - @Disabled("Not supporting struct with the strongly typed implementation.") public void testStructs(String name, boolean expected) { Literals.LiteralMap output = CLIENT.createExecution(