From 1a85668be1a21aff0aaaa64b705a97832169905f Mon Sep 17 00:00:00 2001 From: Hongxin Liang Date: Sun, 8 Oct 2023 18:12:27 +0200 Subject: [PATCH] Support struct Signed-off-by: Hongxin Liang --- .../jackson/JacksonSdkLiteralType.java | 4 +- .../jackson/SdkLiteralTypeModule.java | 6 +- .../flyte/flytekit/jackson/SdkTypeModule.java | 2 - .../flytekit/jackson/VariableMapVisitor.java | 9 +- .../LiteralMapDeserializers.java | 41 --------- .../SdkBindingDataDeserializer.java | 74 ++++++++++++--- ...erializer.java => StructDeserializer.java} | 48 ++++++---- .../serializers/GenericSerializer.java | 89 +------------------ .../flytekit/jackson/JacksonSdkTypeTest.java | 28 +++++- .../flyte/flytekit/SdkBindingDataFactory.java | 6 ++ 10 files changed, 141 insertions(+), 166 deletions(-) delete mode 100644 flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/LiteralMapDeserializers.java rename flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/{LiteralStructDeserializer.java => StructDeserializer.java} (70%) diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/JacksonSdkLiteralType.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/JacksonSdkLiteralType.java index 0be5ba34f..969fa64bd 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/JacksonSdkLiteralType.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/JacksonSdkLiteralType.java @@ -33,6 +33,7 @@ import org.flyte.api.v1.Scalar; import org.flyte.api.v1.SimpleType; import org.flyte.flytekit.SdkLiteralType; +import org.flyte.flytekit.jackson.deserializers.StructDeserializer.StructWrapper; /** * Implementation of {@link org.flyte.flytekit.SdkLiteralType} for {@link @@ -102,7 +103,8 @@ public Literal toLiteral(T value) { var tree = OBJECT_MAPPER.valueToTree(value); try { - return OBJECT_MAPPER.treeToValue(tree, Literal.class); + return Literal.ofScalar( + Scalar.ofGeneric(OBJECT_MAPPER.treeToValue(tree, StructWrapper.class).unwrap())); } catch (IOException e) { throw new UncheckedIOException("toLiteral failed for [" + clazz.getName() + "]: " + value, e); } diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/SdkLiteralTypeModule.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/SdkLiteralTypeModule.java index 861a1c640..4ec2d158d 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/SdkLiteralTypeModule.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/SdkLiteralTypeModule.java @@ -20,8 +20,8 @@ import com.fasterxml.jackson.databind.Module; import com.fasterxml.jackson.databind.module.SimpleDeserializers; import com.fasterxml.jackson.databind.module.SimpleSerializers; -import org.flyte.api.v1.Literal; -import org.flyte.flytekit.jackson.deserializers.LiteralStructDeserializer; +import org.flyte.flytekit.jackson.deserializers.StructDeserializer; +import org.flyte.flytekit.jackson.deserializers.StructDeserializer.StructWrapper; import org.flyte.flytekit.jackson.serializers.StructSerializer; class SdkLiteralTypeModule extends Module { @@ -43,7 +43,7 @@ public void setupModule(SetupContext context) { context.addSerializers(serializers); var deserializers = new SimpleDeserializers(); - deserializers.addDeserializer(Literal.class, new LiteralStructDeserializer()); + deserializers.addDeserializer(StructWrapper.class, new StructDeserializer()); context.addDeserializers(deserializers); // append with the lowest priority to use as fallback, if builtin annotations aren't present diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/SdkTypeModule.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/SdkTypeModule.java index 17f71c25a..aa25ff45e 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/SdkTypeModule.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/SdkTypeModule.java @@ -20,7 +20,6 @@ import com.fasterxml.jackson.databind.Module; import com.fasterxml.jackson.databind.deser.Deserializers; import com.fasterxml.jackson.databind.module.SimpleSerializers; -import org.flyte.flytekit.jackson.deserializers.LiteralMapDeserializers; import org.flyte.flytekit.jackson.deserializers.SdkBindingDataDeserializers; import org.flyte.flytekit.jackson.serializers.BindingMapSerializers; import org.flyte.flytekit.jackson.serializers.LiteralMapSerializers; @@ -60,7 +59,6 @@ public void setupModule(SetupContext context) { context.addSerializers(serializers); context.addSerializers(new LiteralMapSerializers()); - context.addDeserializers(new LiteralMapDeserializers()); context.addSerializers(new BindingMapSerializers()); context.addDeserializers(sdkbindingDeserializers); diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java index c100b81c3..b97366859 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java @@ -181,9 +181,12 @@ private SdkLiteralType toLiteralType( .dimensionality(annotation.dimensionality()) .build()); } - // TODO: Support structs - throw new UnsupportedOperationException( - String.format("Unsupported type: [%s]", type.getName())); + try { + return JacksonSdkLiteralType.of(type); + } catch (Exception e) { + throw new UnsupportedOperationException( + String.format("Unsupported type: [%s]", type.getName()), e); + } } private static boolean isPrimitiveAssignableFrom(Class fromClass, Class toClass) { diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/LiteralMapDeserializers.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/LiteralMapDeserializers.java deleted file mode 100644 index f3015c3da..000000000 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/LiteralMapDeserializers.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright 2020-2023 Flyte Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.flyte.flytekit.jackson.deserializers; - -import com.fasterxml.jackson.databind.BeanDescription; -import com.fasterxml.jackson.databind.DeserializationConfig; -import com.fasterxml.jackson.databind.JavaType; -import com.fasterxml.jackson.databind.JsonDeserializer; -import com.fasterxml.jackson.databind.deser.Deserializers; -import java.util.Map; -import org.flyte.api.v1.LiteralType; -import org.flyte.flytekit.jackson.JacksonLiteralMap; - -public class LiteralMapDeserializers extends Deserializers.Base { - - @Override - public JsonDeserializer findBeanDeserializer( - JavaType type, DeserializationConfig config, BeanDescription beanDesc) { - if (type.getRawClass().equals(JacksonLiteralMap.class)) { - Map literalTypeMap = type.getValueHandler(); - - return new LiteralMapDeserializer(literalTypeMap); - } - - return null; - } -} diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/SdkBindingDataDeserializer.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/SdkBindingDataDeserializer.java index 39f7e1034..34311eb8e 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/SdkBindingDataDeserializer.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/SdkBindingDataDeserializer.java @@ -24,8 +24,12 @@ import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.VALUE; import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.BeanProperty; import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.JsonDeserializer; import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.deser.ContextualDeserializer; import com.fasterxml.jackson.databind.deser.std.StdDeserializer; import java.io.IOException; import java.time.Duration; @@ -46,36 +50,47 @@ import org.flyte.api.v1.LiteralType; import org.flyte.api.v1.Primitive; import org.flyte.api.v1.Scalar; +import org.flyte.api.v1.Scalar.Kind; import org.flyte.api.v1.SimpleType; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkBindingDataFactory; import org.flyte.flytekit.SdkLiteralType; import org.flyte.flytekit.SdkLiteralTypes; +import org.flyte.flytekit.jackson.JacksonSdkLiteralType; -class SdkBindingDataDeserializer extends StdDeserializer> { +class SdkBindingDataDeserializer extends StdDeserializer> + implements ContextualDeserializer { private static final long serialVersionUID = 0L; + private JavaType type; + public SdkBindingDataDeserializer() { + this(null); + } + + private SdkBindingDataDeserializer(JavaType type) { super(SdkBindingData.class); + this.type = type; } @Override public SdkBindingData deserialize( JsonParser jsonParser, DeserializationContext deserializationContext) throws IOException { JsonNode tree = jsonParser.readValueAsTree(); - return transform(tree); + return transform(tree, deserializationContext); } - private SdkBindingData transform(JsonNode tree) { + private SdkBindingData transform( + JsonNode tree, DeserializationContext deserializationContext) { Literal.Kind literalKind = Literal.Kind.valueOf(tree.get(LITERAL).asText()); switch (literalKind) { case SCALAR: - return transformScalar(tree); + return transformScalar(tree, deserializationContext); case COLLECTION: - return transformCollection(tree); + return transformCollection(tree, deserializationContext); case MAP: - return transformMap(tree); + return transformMap(tree, deserializationContext); default: throw new UnsupportedOperationException( @@ -83,7 +98,8 @@ private SdkBindingData transform(JsonNode tree) { } } - private static SdkBindingData transformScalar(JsonNode tree) { + private SdkBindingData transformScalar( + JsonNode tree, DeserializationContext deserializationContext) { Scalar.Kind scalarKind = Scalar.Kind.valueOf(tree.get(SCALAR).asText()); switch (scalarKind) { case PRIMITIVE: @@ -109,6 +125,8 @@ private static SdkBindingData transformScalar(JsonNode tree) { return transformBlob(tree); case GENERIC: + return transformGeneric(tree, deserializationContext, scalarKind); + default: throw new UnsupportedOperationException( "Type contains an unsupported scalar: " + scalarKind); @@ -132,8 +150,28 @@ private static SdkBindingData transformBlob(JsonNode tree) { .build()); } + private SdkBindingData transformGeneric( + JsonNode tree, DeserializationContext deserializationContext, Kind scalarKind) { + JsonParser jsonParser = tree.get(VALUE).traverse(); + try { + jsonParser.nextToken(); + Object object = + deserializationContext + .findNonContextualValueDeserializer(type) + .deserialize(jsonParser, deserializationContext); + @SuppressWarnings("unchecked") + SdkLiteralType jacksonSdkLiteralType = + (SdkLiteralType) JacksonSdkLiteralType.of(type.getRawClass()); + return SdkBindingData.literal(jacksonSdkLiteralType, object); + } catch (IOException e) { + throw new UnsupportedOperationException( + "Type contains an unsupported generic: " + scalarKind, e); + } + } + @SuppressWarnings("unchecked") - private SdkBindingData> transformCollection(JsonNode tree) { + private SdkBindingData> transformCollection( + JsonNode tree, DeserializationContext deserializationContext) { SdkLiteralType literalType = (SdkLiteralType) readLiteralType(tree.get(TYPE)); Iterator elements = tree.get(VALUE).elements(); @@ -143,7 +181,10 @@ private SdkBindingData> transformCollection(JsonNode tree) { case COLLECTION_TYPE: List collection = (List) - streamOf(elements).map(this::transform).map(SdkBindingData::get).collect(toList()); + streamOf(elements) + .map((JsonNode tree1) -> transform(tree1, deserializationContext)) + .map(SdkBindingData::get) + .collect(toList()); return SdkBindingDataFactory.of(literalType, collection); case SCHEMA_TYPE: @@ -155,7 +196,8 @@ private SdkBindingData> transformCollection(JsonNode tree) { } @SuppressWarnings("unchecked") - private SdkBindingData> transformMap(JsonNode tree) { + private SdkBindingData> transformMap( + JsonNode tree, DeserializationContext deserializationContext) { SdkLiteralType literalType = (SdkLiteralType) readLiteralType(tree.get(TYPE)); JsonNode valueNode = tree.get(VALUE); List> entries = @@ -168,7 +210,11 @@ private SdkBindingData> transformMap(JsonNode tree) { case COLLECTION_TYPE: Map bindingDataMap = entries.stream() - .map(entry -> Map.entry(entry.getKey(), (T) transform(entry.getValue()).get())) + .map( + entry -> + Map.entry( + entry.getKey(), + (T) transform(entry.getValue(), deserializationContext).get())) .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); return SdkBindingDataFactory.of(literalType, bindingDataMap); @@ -220,4 +266,10 @@ private Stream streamOf(Iterator nodes) { return StreamSupport.stream( Spliterators.spliteratorUnknownSize(nodes, Spliterator.ORDERED), false); } + + @Override + public JsonDeserializer createContextual(DeserializationContext ctxt, BeanProperty property) { + type = property.getType().containedType(0); + return new SdkBindingDataDeserializer(type); + } } diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/LiteralStructDeserializer.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/StructDeserializer.java similarity index 70% rename from flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/LiteralStructDeserializer.java rename to flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/StructDeserializer.java index 0c17f55d5..88f673f80 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/LiteralStructDeserializer.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/StructDeserializer.java @@ -29,23 +29,35 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import org.flyte.api.v1.Literal; -import org.flyte.api.v1.Scalar; import org.flyte.api.v1.Struct; import org.flyte.api.v1.Struct.Value; +import org.flyte.flytekit.jackson.deserializers.StructDeserializer.StructWrapper; -public class LiteralStructDeserializer extends StdDeserializer { +public class StructDeserializer extends StdDeserializer { private static final long serialVersionUID = -6835948754469626304L; - public LiteralStructDeserializer() { - super(Literal.class); + // we cannot use Struct directly because it is an auto-value class so this deserializer will not + // be used by Jackson + public static class StructWrapper { + + private final Struct struct; + + public StructWrapper(Struct struct) { + this.struct = struct; + } + + public Struct unwrap() { + return struct; + } } - @Override - public Literal deserialize(JsonParser p, DeserializationContext ctxt) throws IOException { + public StructDeserializer() { + super(StructWrapper.class); + } - Struct generic = readValueAsStruct(p); - return Literal.ofScalar(Scalar.ofGeneric(generic)); + @Override + public StructWrapper deserialize(JsonParser p, DeserializationContext ctxt) throws IOException { + return new StructWrapper(readValueAsStruct(p)); } private static Struct readValueAsStruct(JsonParser p) throws IOException { @@ -67,7 +79,7 @@ private static Struct readValueAsStruct(JsonParser p) throws IOException { return Struct.of(unmodifiableMap(fields)); } - private static Struct.Value readValueAsStructValue(JsonParser p) throws IOException { + private static Value readValueAsStructValue(JsonParser p) throws IOException { switch (p.currentToken()) { case START_ARRAY: p.nextToken(); @@ -75,38 +87,38 @@ private static Struct.Value readValueAsStructValue(JsonParser p) throws IOExcept List valuesList = new ArrayList<>(); while (p.currentToken() != JsonToken.END_ARRAY) { - Struct.Value value = readValueAsStructValue(p); + Value value = readValueAsStructValue(p); p.nextToken(); valuesList.add(value); } - return Struct.Value.ofListValue(unmodifiableList(valuesList)); + return Value.ofListValue(unmodifiableList(valuesList)); case START_OBJECT: Struct struct = readValueAsStruct(p); - return Struct.Value.ofStructValue(struct); + return Value.ofStructValue(struct); case VALUE_STRING: String stringValue = p.readValueAs(String.class); - return Struct.Value.ofStringValue(stringValue); + return Value.ofStringValue(stringValue); case VALUE_NUMBER_FLOAT: case VALUE_NUMBER_INT: Double doubleValue = p.readValueAs(Double.class); - return Struct.Value.ofNumberValue(doubleValue); + return Value.ofNumberValue(doubleValue); case VALUE_NULL: - return Struct.Value.ofNullValue(); + return Value.ofNullValue(); case VALUE_FALSE: - return Struct.Value.ofBoolValue(false); + return Value.ofBoolValue(false); case VALUE_TRUE: - return Struct.Value.ofBoolValue(true); + return Value.ofBoolValue(true); case FIELD_NAME: case NOT_AVAILABLE: diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/GenericSerializer.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/GenericSerializer.java index 12ec69e18..5c73535c7 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/GenericSerializer.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/GenericSerializer.java @@ -16,20 +16,15 @@ */ package org.flyte.flytekit.jackson.serializers; -import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.LITERAL; -import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.SCALAR; -import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.STRUCT_TYPE; -import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.STRUCT_VALUE; +import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.VALUE; import com.fasterxml.jackson.core.JsonGenerator; import com.fasterxml.jackson.databind.SerializerProvider; import java.io.IOException; -import java.util.Map; import org.flyte.api.v1.Literal; import org.flyte.api.v1.LiteralType; import org.flyte.api.v1.Scalar; import org.flyte.api.v1.SimpleType; -import org.flyte.api.v1.Struct; public class GenericSerializer extends ScalarSerializer { public GenericSerializer( @@ -48,85 +43,7 @@ public GenericSerializer( @Override public void serializeScalar() throws IOException { gen.writeObject(Scalar.Kind.GENERIC); - for (Map.Entry entry : value.scalar().generic().fields().entrySet()) { - gen.writeFieldName(entry.getKey()); - serializeStructValue(entry.getValue()); - } - } - - private void serializeStructValue(Struct.Value value) throws IOException { - if (!value.kind().equals(Struct.Value.Kind.LIST_VALUE) - && !value.kind().equals(Struct.Value.Kind.NULL_VALUE)) { - gen.writeStartObject(); - gen.writeFieldName(LITERAL); - gen.writeObject(Literal.Kind.SCALAR); - gen.writeFieldName(SCALAR); - gen.writeObject(Scalar.Kind.GENERIC); - } - - if (isSimpleType(value.kind())) { - gen.writeFieldName(STRUCT_TYPE); - } - switch (value.kind()) { - case BOOL_VALUE: - writeSimpleType( - Struct.Value.Kind.BOOL_VALUE, - value, - (generator, v) -> generator.writeBoolean(v.boolValue())); - return; - - case LIST_VALUE: - throw new RuntimeException("not supported list inside the struct"); - - case NUMBER_VALUE: - writeSimpleType( - Struct.Value.Kind.NUMBER_VALUE, - value, - (generator, v) -> generator.writeNumber(v.numberValue())); - return; - - case STRING_VALUE: - writeSimpleType( - Struct.Value.Kind.STRING_VALUE, - value, - (generator, v) -> generator.writeString(v.stringValue())); - return; - - case STRUCT_VALUE: - value.structValue().fields().forEach((k, v) -> writeStructValue(gen, k, v)); - gen.writeEndObject(); - return; - - case NULL_VALUE: - gen.writeNull(); - } - } - - private void writeStructValue(JsonGenerator gen, String k, Struct.Value v) { - try { - gen.writeFieldName(k); - serializeStructValue(v); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - private boolean isSimpleType(Struct.Value.Kind kind) { - return kind.equals(Struct.Value.Kind.BOOL_VALUE) - || kind.equals(Struct.Value.Kind.NUMBER_VALUE) - || kind.equals(Struct.Value.Kind.STRING_VALUE); - } - - private void writeSimpleType( - Struct.Value.Kind kind, Struct.Value structValue, WriteGenericFunction writeTypeFunction) - throws IOException { - gen.writeObject(kind); - gen.writeFieldName(STRUCT_VALUE); - writeTypeFunction.write(gen, structValue); - gen.writeEndObject(); - } - - interface WriteGenericFunction { - void write(JsonGenerator gen, Struct.Value value) throws IOException; + gen.writeFieldName(VALUE); + new StructSerializer().serialize(value.scalar().generic(), gen, serializerProvider); } } diff --git a/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java b/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java index 38fd83639..91eace169 100644 --- a/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java +++ b/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java @@ -47,6 +47,8 @@ import org.flyte.api.v1.Primitive; import org.flyte.api.v1.Scalar; import org.flyte.api.v1.SimpleType; +import org.flyte.api.v1.Struct; +import org.flyte.api.v1.Struct.Value; import org.flyte.api.v1.Variable; import org.flyte.flytekit.SdkBindingData; import org.flyte.flytekit.SdkBindingDataFactory; @@ -68,6 +70,7 @@ public static AutoValueInput createAutoValueInput( Instant t, Duration d, Blob blob, + Nested generic, List l, Map m, List> ll, @@ -82,6 +85,7 @@ public static AutoValueInput createAutoValueInput( SdkBindingDataFactory.of(t), SdkBindingDataFactory.of(d), SdkBindingDataFactory.of(blob), + SdkBindingDataFactory.of(JacksonSdkLiteralType.of(Nested.class), generic), SdkBindingDataFactory.ofStringCollection(l), SdkBindingDataFactory.ofStringMap(m), SdkBindingDataFactory.of(SdkLiteralTypes.collections(SdkLiteralTypes.strings()), ll), @@ -103,6 +107,7 @@ public void testVariableMap() { hasEntry("t", createVar(SimpleType.DATETIME)), hasEntry("d", createVar(SimpleType.DURATION)), hasEntry("blob", createVar(LiteralType.ofBlobType(BLOB_TYPE))), + hasEntry("generic", createVar(LiteralType.ofSimpleType(SimpleType.STRUCT))), hasEntry( "l", createVar(LiteralType.ofCollectionType(ofSimpleType(SimpleType.STRING)))), hasEntry( @@ -136,6 +141,7 @@ void testFromLiteralMap() { literalMap.put("t", literalOf(Primitive.ofDatetime(datetime))); literalMap.put("d", literalOf(Primitive.ofDuration(duration))); literalMap.put("blob", literalOf(blob)); + literalMap.put("generic", literalOf(Struct.of(Map.of("hello", Value.ofStringValue("hello"))))); literalMap.put("l", Literal.ofCollection(List.of(literalOf(Primitive.ofStringValue("123"))))); literalMap.put("m", Literal.ofMap(Map.of("marco", literalOf(Primitive.ofStringValue("polo"))))); literalMap.put( @@ -180,6 +186,7 @@ void testFromLiteralMap() { /* t= */ datetime, /* d= */ duration, /* blob= */ blob, + Nested.create("hello"), /* l= */ List.of("123"), /* m= */ Map.of("marco", "polo"), /* ll= */ List.of(List.of("foo", "bar"), List.of("a", "b", "c")), @@ -214,6 +221,7 @@ void testToLiteralMap() { /* t= */ Instant.ofEpochSecond(42, 1), /* d= */ Duration.ofSeconds(1, 42), /* blob= */ blob, + Nested.create("hello"), /* l= */ List.of("foo"), /* m= */ Map.of("marco", "polo"), /* ll= */ List.of(List.of("foo", "bar"), List.of("a", "b", "c")), @@ -295,6 +303,7 @@ public void testToSdkBindingDataMap() { /* t= */ Instant.ofEpochSecond(42, 1), /* d= */ Duration.ofSeconds(1, 42), /* blob= */ blob, + Nested.create("hello"), /* l= */ List.of("foo"), /* m= */ Map.of("marco", "polo"), /* ll= */ List.of(List.of("foo", "bar"), List.of("a", "b", "c")), @@ -314,6 +323,7 @@ public void testToSdkBindingDataMap() { expected.put("t", input.t()); expected.put("d", input.d()); expected.put("blob", input.blob()); + expected.put("generic", input.generic()); expected.put("l", input.l()); expected.put("m", input.m()); expected.put("ll", input.ll()); @@ -529,6 +539,15 @@ public static AutoValueDeprecatedInput create(long i) { } } + @AutoValue + public abstract static class Nested { + public abstract String hello(); + + public static AutoValue_JacksonSdkTypeTest_Nested create(String hello) { + return new AutoValue_JacksonSdkTypeTest_Nested(hello); + } + } + @AutoValue public abstract static class AutoValueInput { @@ -548,6 +567,8 @@ public abstract static class AutoValueInput { @BlobTypeDescription(format = "", dimensionality = BlobDimensionality.SINGLE) public abstract SdkBindingData blob(); + public abstract SdkBindingData generic(); + public abstract SdkBindingData> l(); public abstract SdkBindingData> m(); @@ -568,6 +589,7 @@ public static AutoValueInput create( SdkBindingData t, SdkBindingData d, SdkBindingData blob, + SdkBindingData generic, SdkBindingData> l, SdkBindingData> m, SdkBindingData>> ll, @@ -575,7 +597,7 @@ public static AutoValueInput create( SdkBindingData>> ml, SdkBindingData>> mm) { return new AutoValue_JacksonSdkTypeTest_AutoValueInput( - i, f, s, b, t, d, blob, l, m, ll, lm, ml, mm); + i, f, s, b, t, d, blob, generic, l, m, ll, lm, ml, mm); } } @@ -714,4 +736,8 @@ private static Literal literalOf(Primitive primitive) { private static Literal literalOf(Blob blob) { return Literal.ofScalar(Scalar.ofBlob(blob)); } + + private static Literal literalOf(Struct generic) { + return Literal.ofScalar(Scalar.ofGeneric(generic)); + } } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingDataFactory.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingDataFactory.java index f8c9d8d47..c736006c8 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingDataFactory.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingDataFactory.java @@ -124,6 +124,12 @@ public static SdkBindingData> of(SdkLiteralType elementType, List return SdkBindingData.literal(collections(elementType), collection); } + /** + * Creates a {@code SdkBindingData} for a flyte struct with the given value. + * + * @param value the simple value for this data + * @return the new {@code SdkBindingData} + */ public static SdkBindingData of(SdkLiteralType type, T value) { return SdkBindingData.literal(type, value); }