From 6c2a1a9e3d7dff80923c877c5ea6dc38582afb44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20G=C3=B3mez?= Date: Thu, 18 Apr 2024 09:24:26 +0200 Subject: [PATCH] Add support to Binary input/output type (#291) Signed-off-by: Andres Gomez Ferrer Co-authored-by: Andres Gomez Ferrer --- .../main/java/org/flyte/api/v1/Binary.java | 47 +++++++++++++++++ .../main/java/org/flyte/api/v1/Scalar.java | 9 +++- .../java/org/flyte/api/v1/SimpleType.java | 3 +- .../flytekit/jackson/VariableMapVisitor.java | 3 ++ .../deserializers/LiteralMapDeserializer.java | 47 +++++++++++++++++ .../SdkBindingDataDeserializer.java | 16 ++++++ .../jackson/serializers/BinarySerializer.java | 51 +++++++++++++++++++ .../serializers/LiteralSerializerFactory.java | 2 + .../flytekit/jackson/JacksonSdkTypeTest.java | 48 ++++++++++++++++- .../flyte/flytekit/SdkBindingDataFactory.java | 11 ++++ .../org/flyte/flytekit/SdkLiteralTypes.java | 35 +++++++++++++ .../flytekitscala/SdkLiteralTypesTest.scala | 4 +- .../flytekitscala/SdkScalaTypeTest.scala | 2 - .../SdkBindingDataConverters.scala | 5 ++ .../flytekitscala/SdkBindingDataFactory.scala | 12 ++++- .../flyte/flytekitscala/SdkLiteralTypes.scala | 10 ++++ .../flyte/flytekitscala/SdkScalaType.scala | 5 ++ .../org/flyte/jflyte/utils/ProtoUtil.java | 26 +++++++++- .../org/flyte/jflyte/utils/ProtoUtilTest.java | 32 ++++++++++++ 19 files changed, 358 insertions(+), 10 deletions(-) create mode 100644 flytekit-api/src/main/java/org/flyte/api/v1/Binary.java create mode 100644 flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/BinarySerializer.java diff --git a/flytekit-api/src/main/java/org/flyte/api/v1/Binary.java b/flytekit-api/src/main/java/org/flyte/api/v1/Binary.java new file mode 100644 index 000000000..aa37d51c4 --- /dev/null +++ b/flytekit-api/src/main/java/org/flyte/api/v1/Binary.java @@ -0,0 +1,47 @@ +/* + * Copyright 2020-2021 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.api.v1; + +import com.google.auto.value.AutoValue; + +/** + * A simple byte array with a tag to help different parts of the system communicate about what is in + * the byte array. It's strongly advisable that consumers of this type define a unique tag and + * validate the tag before parsing the data. + */ +@AutoValue +public abstract class Binary { + public static final String TAG_FIELD = "tag"; + public static final String VALUE_FIELD = "value"; + + public abstract byte[] value(); + + public abstract String tag(); + + public static Builder builder() { + return new AutoValue_Binary.Builder(); + } + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder value(byte[] value); + + public abstract Builder tag(String tag); + + public abstract Binary build(); + } +} diff --git a/flytekit-api/src/main/java/org/flyte/api/v1/Scalar.java b/flytekit-api/src/main/java/org/flyte/api/v1/Scalar.java index 99df63658..2eb3a757d 100644 --- a/flytekit-api/src/main/java/org/flyte/api/v1/Scalar.java +++ b/flytekit-api/src/main/java/org/flyte/api/v1/Scalar.java @@ -25,7 +25,8 @@ public abstract class Scalar { public enum Kind { PRIMITIVE, GENERIC, - BLOB + BLOB, + BINARY } public abstract Kind kind(); @@ -36,6 +37,8 @@ public enum Kind { public abstract Blob blob(); + public abstract Binary binary(); + // TODO: add the rest of the cases from src/main/proto/flyteidl/core/literals.proto public static Scalar ofPrimitive(Primitive primitive) { @@ -49,4 +52,8 @@ public static Scalar ofGeneric(Struct generic) { public static Scalar ofBlob(Blob blob) { return AutoOneOf_Scalar.blob(blob); } + + public static Scalar ofBinary(Binary binary) { + return AutoOneOf_Scalar.binary(binary); + } } diff --git a/flytekit-api/src/main/java/org/flyte/api/v1/SimpleType.java b/flytekit-api/src/main/java/org/flyte/api/v1/SimpleType.java index 6b0f551b9..84f8db618 100644 --- a/flytekit-api/src/main/java/org/flyte/api/v1/SimpleType.java +++ b/flytekit-api/src/main/java/org/flyte/api/v1/SimpleType.java @@ -24,5 +24,6 @@ public enum SimpleType { BOOLEAN, DATETIME, DURATION, - STRUCT + STRUCT, + BINARY } diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java index 6b78841cc..535e65bfd 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/VariableMapVisitor.java @@ -30,6 +30,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import org.flyte.api.v1.Binary; import org.flyte.api.v1.Blob; import org.flyte.api.v1.BlobType; import org.flyte.api.v1.Variable; @@ -172,6 +173,8 @@ private SdkLiteralType toLiteralType( // feature // https://docs.flyte.org/projects/flytekit/en/latest/generated/flytekit.BlobType.html#flytekit-blobtype return SdkLiteralTypes.blobs(BlobType.DEFAULT); + } else if (Binary.class.isAssignableFrom(type)) { + return SdkLiteralTypes.binary(); } try { return JacksonSdkLiteralType.of(type); diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/LiteralMapDeserializer.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/LiteralMapDeserializer.java index 0e855c188..e4d9613ac 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/LiteralMapDeserializer.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/LiteralMapDeserializer.java @@ -25,6 +25,7 @@ import com.fasterxml.jackson.databind.DeserializationContext; import com.fasterxml.jackson.databind.JavaType; import com.fasterxml.jackson.databind.deser.std.StdDeserializer; +import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.NotSerializableException; import java.io.ObjectInputStream; @@ -34,6 +35,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import org.flyte.api.v1.Binary; import org.flyte.api.v1.Blob; import org.flyte.api.v1.Literal; import org.flyte.api.v1.LiteralType; @@ -167,11 +169,56 @@ private static Literal deserialize(JsonParser p, SimpleType simpleType) throws I Struct generic = readValueAsStruct(p); return Literal.ofScalar(Scalar.ofGeneric(generic)); + + case BINARY: + Binary binary = readValueAsBinary(p); + + return Literal.ofScalar(Scalar.ofBinary(binary)); } throw new AssertionError(String.format("Unexpected SimpleType: [%s]", simpleType)); } + private static Binary readValueAsBinary(JsonParser p) throws IOException { + verifyToken(p, JsonToken.START_OBJECT); + p.nextToken(); + + Binary.Builder binaryBuilder = Binary.builder(); + + while (p.currentToken() != JsonToken.END_OBJECT) { + verifyToken(p, JsonToken.FIELD_NAME); + String fieldName = p.currentName(); + p.nextToken(); + + switch (fieldName) { + case Binary.TAG_FIELD: + binaryBuilder.tag(p.readValueAs(String.class)); + break; + case Binary.VALUE_FIELD: + ByteArrayOutputStream value = new ByteArrayOutputStream(); + p.readBinaryValue(value); + binaryBuilder.value(value.toByteArray()); + break; + default: + throw new IllegalStateException("Unexpected field [" + fieldName + "]"); + } + + p.nextToken(); + } + + Binary binary = binaryBuilder.build(); + + if (binary.tag() == null) { + throw new IllegalStateException("Missing field [" + Binary.TAG_FIELD + "]"); + } + + if (binary.value().length == 0) { + throw new IllegalStateException("Missing field [" + Binary.VALUE_FIELD + "]"); + } + + return binary; + } + private static Struct readValueAsStruct(JsonParser p) throws IOException { verifyToken(p, JsonToken.START_OBJECT); p.nextToken(); diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/SdkBindingDataDeserializer.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/SdkBindingDataDeserializer.java index b860ef053..c0ff7b306 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/SdkBindingDataDeserializer.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/deserializers/SdkBindingDataDeserializer.java @@ -34,6 +34,7 @@ import java.io.IOException; import java.time.Duration; import java.time.Instant; +import java.util.Base64; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -42,6 +43,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; import java.util.stream.StreamSupport; +import org.flyte.api.v1.Binary; import org.flyte.api.v1.Blob; import org.flyte.api.v1.BlobMetadata; import org.flyte.api.v1.BlobType; @@ -128,12 +130,24 @@ private SdkBindingData transformScalar( case GENERIC: return transformGeneric(tree, deserializationContext, scalarKind, type); + case BINARY: + return transformBinary(tree); + default: throw new UnsupportedOperationException( "Type contains an unsupported scalar: " + scalarKind); } } + private static SdkBindingData transformBinary(JsonNode tree) { + JsonNode value = tree.get(VALUE); + String tag = value.get(Binary.TAG_FIELD).asText(); + String base64Value = value.get(Binary.VALUE_FIELD).asText(); + + return SdkBindingDataFactory.of( + Binary.builder().tag(tag).value(Base64.getDecoder().decode(base64Value)).build()); + } + private static SdkBindingData transformBlob(JsonNode tree) { JsonNode value = tree.get(VALUE); String uri = value.get("uri").asText(); @@ -256,6 +270,8 @@ private SdkLiteralType readLiteralType(JsonNode typeNode) { return SdkLiteralTypes.durations(); case STRUCT: return JacksonSdkLiteralType.of(type.getContentType().getRawClass()); + case BINARY: + return SdkLiteralTypes.binary(); } throw new UnsupportedOperationException( "Type contains a collection/map of an supported literal type: " + kind); diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/BinarySerializer.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/BinarySerializer.java new file mode 100644 index 000000000..70ba7270d --- /dev/null +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/BinarySerializer.java @@ -0,0 +1,51 @@ +/* + * Copyright 2020-2023 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.flytekit.jackson.serializers; + +import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.VALUE; + +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.databind.SerializerProvider; +import java.io.IOException; +import java.util.Base64; +import org.flyte.api.v1.Binary; +import org.flyte.api.v1.Literal; +import org.flyte.api.v1.LiteralType; +import org.flyte.api.v1.Scalar.Kind; + +public class BinarySerializer extends ScalarSerializer { + public BinarySerializer( + JsonGenerator gen, + String key, + Literal value, + SerializerProvider serializerProvider, + LiteralType literalType) { + super(gen, key, value, serializerProvider, literalType); + } + + @Override + void serializeScalar() throws IOException { + gen.writeObject(Kind.BINARY); + gen.writeFieldName(VALUE); + gen.writeStartObject(); + gen.writeFieldName(Binary.TAG_FIELD); + gen.writeString(value.scalar().binary().tag()); + gen.writeFieldName(Binary.VALUE_FIELD); + gen.writeString(Base64.getEncoder().encodeToString(value.scalar().binary().value())); + gen.writeEndObject(); + } +} diff --git a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/LiteralSerializerFactory.java b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/LiteralSerializerFactory.java index c5f8710b2..286cfa0a7 100644 --- a/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/LiteralSerializerFactory.java +++ b/flytekit-jackson/src/main/java/org/flyte/flytekit/jackson/serializers/LiteralSerializerFactory.java @@ -52,6 +52,8 @@ private static ScalarSerializer createScalarSerializer( return new GenericSerializer(gen, key, value, serializerProvider, literalType); case BLOB: return new BlobSerializer(gen, key, value, serializerProvider, literalType); + case BINARY: + return new BinarySerializer(gen, key, value, serializerProvider, literalType); } throw new AssertionError("Unexpected Literal.Kind: [" + value.scalar().kind() + "]"); } diff --git a/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java b/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java index 4ae10036e..9578fb879 100644 --- a/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java +++ b/flytekit-jackson/src/test/java/org/flyte/flytekit/jackson/JacksonSdkTypeTest.java @@ -30,6 +30,7 @@ import com.fasterxml.jackson.databind.annotation.JsonSerialize; import com.fasterxml.jackson.databind.util.StdConverter; import com.google.auto.value.AutoValue; +import java.nio.charset.StandardCharsets; import java.time.Duration; import java.time.Instant; import java.util.HashMap; @@ -37,6 +38,7 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; +import org.flyte.api.v1.Binary; import org.flyte.api.v1.Blob; import org.flyte.api.v1.BlobMetadata; import org.flyte.api.v1.BlobType; @@ -64,6 +66,12 @@ public class JacksonSdkTypeTest { .uri("file://test") .build(); + private static final Binary BINARY = + Binary.builder() + .tag("this-is-a-custom-tag") + .value("file://test".getBytes(StandardCharsets.UTF_8)) + .build(); + public static AutoValueInput createAutoValueInput( long i, double f, @@ -72,12 +80,15 @@ public static AutoValueInput createAutoValueInput( Instant t, Duration d, Blob blob, + Binary binary, Nested generic, List l, List lb, + List lbinary, List lg, Map m, Map mb, + Map mbinary, Map mg, List> ll, List> lm, @@ -91,12 +102,15 @@ public static AutoValueInput createAutoValueInput( SdkBindingDataFactory.of(t), SdkBindingDataFactory.of(d), SdkBindingDataFactory.of(blob), + SdkBindingDataFactory.of(binary), SdkBindingDataFactory.of(JacksonSdkLiteralType.of(Nested.class), generic), SdkBindingDataFactory.ofStringCollection(l), SdkBindingDataFactory.of(SdkLiteralTypes.blobs(BLOB_TYPE), lb), + SdkBindingDataFactory.of(SdkLiteralTypes.binary(), lbinary), SdkBindingDataFactory.of(JacksonSdkLiteralType.of(Nested.class), lg), SdkBindingDataFactory.ofStringMap(m), SdkBindingDataFactory.of(SdkLiteralTypes.blobs(BLOB_TYPE), mb), + SdkBindingDataFactory.of(SdkLiteralTypes.binary(), mbinary), SdkBindingDataFactory.of(JacksonSdkLiteralType.of(Nested.class), mg), SdkBindingDataFactory.of(SdkLiteralTypes.collections(SdkLiteralTypes.strings()), ll), SdkBindingDataFactory.of(SdkLiteralTypes.maps(SdkLiteralTypes.strings()), lm), @@ -146,6 +160,7 @@ void testFromLiteralMap() { literalMap.put("t", literalOf(Primitive.ofDatetime(datetime))); literalMap.put("d", literalOf(Primitive.ofDuration(duration))); literalMap.put("blob", literalOf(BLOB)); + literalMap.put("binary", literalOf(BINARY)); literalMap.put( "generic", literalOf( @@ -157,12 +172,14 @@ void testFromLiteralMap() { Value.ofStringValue("world"))))); literalMap.put("l", Literal.ofCollection(List.of(literalOf(Primitive.ofStringValue("123"))))); literalMap.put("lb", Literal.ofCollection(List.of(literalOf(BLOB)))); + literalMap.put("lbinary", Literal.ofCollection(List.of(literalOf(BINARY)))); literalMap.put( "lg", Literal.ofCollection( List.of(literalOf(Struct.of(Map.of("hello", Value.ofStringValue("hello"))))))); literalMap.put("m", Literal.ofMap(Map.of("marco", literalOf(Primitive.ofStringValue("polo"))))); literalMap.put("mb", Literal.ofMap(Map.of("blob", literalOf(BLOB)))); + literalMap.put("mbinary", Literal.ofMap(Map.of("binary", literalOf(BINARY)))); literalMap.put( "mg", Literal.ofMap( @@ -210,12 +227,15 @@ void testFromLiteralMap() { /* t= */ datetime, /* d= */ duration, /* blob= */ BLOB, + /* binary= */ BINARY, /* generic= */ Nested.create("hello", "world"), /* l= */ List.of("123"), /* lb= */ List.of(BLOB), + /* lbinary= */ List.of(BINARY), /* lg= */ List.of(Nested.create("hello")), /* m= */ Map.of("marco", "polo"), /* mb= */ Map.of("blob", BLOB), + /* mbinary= */ Map.of("binary", BINARY), /* mg= */ Map.of("generic", Nested.create("hello")), /* ll= */ List.of(List.of("foo", "bar"), List.of("a", "b", "c")), /* lm= */ List.of(Map.of("A", "a", "B", "b"), Map.of("a", "A", "b", "B")), @@ -244,12 +264,15 @@ void testToLiteralMap() { /* t= */ Instant.ofEpochSecond(42, 1), /* d= */ Duration.ofSeconds(1, 42), /* blob= */ BLOB, + /* binary= */ BINARY, /* generic= */ Nested.create("hello"), /* l= */ List.of("foo"), /* lb= */ List.of(BLOB), + /* lbinary= */ List.of(BINARY), /* lg= */ List.of(Nested.create("hello")), /* m= */ Map.of("marco", "polo"), /* mb= */ Map.of("blob", BLOB), + /* mbinary= */ Map.of("binary", BINARY), /* mg= */ Map.of("generic", Nested.create("hello")), /* ll= */ List.of(List.of("foo", "bar"), List.of("a", "b", "c")), /* lm= */ List.of(Map.of("A", "a", "B", "b"), Map.of("a", "A", "b", "B")), @@ -311,7 +334,8 @@ void testToLiteralMap() { "pi", stringLiteralOf("3.14"), "e", stringLiteralOf("2.72"))), "pokemon", Literal.ofMap(Map.of("ash", stringLiteralOf("pikachu")))))), - hasEntry("blob", literalOf(BLOB))))); + hasEntry("blob", literalOf(BLOB)), + hasEntry("binary", literalOf(BINARY))))); } @Test @@ -325,12 +349,15 @@ public void testToSdkBindingDataMap() { /* t= */ Instant.ofEpochSecond(42, 1), /* d= */ Duration.ofSeconds(1, 42), /* blob= */ BLOB, + /* binary= */ BINARY, /* generic= */ Nested.create("hello"), /* l= */ List.of("foo"), /* lb= */ List.of(BLOB), + /* lbinary= */ List.of(BINARY), /* lg= */ List.of(Nested.create("hello")), /* m= */ Map.of("marco", "polo"), /* mb= */ Map.of("blob", BLOB), + /* mbinary= */ Map.of("binary", BINARY), /* mg= */ Map.of("generic", Nested.create("hello")), /* ll= */ List.of(List.of("foo", "bar"), List.of("a", "b", "c")), /* lm= */ List.of(Map.of("A", "a", "B", "b"), Map.of("a", "A", "b", "B")), @@ -349,12 +376,15 @@ public void testToSdkBindingDataMap() { expected.put("t", input.t()); expected.put("d", input.d()); expected.put("blob", input.blob()); + expected.put("binary", input.binary()); expected.put("generic", input.generic()); expected.put("l", input.l()); expected.put("lb", input.lb()); + expected.put("lbinary", input.lbinary()); expected.put("lg", input.lg()); expected.put("m", input.m()); expected.put("mb", input.mb()); + expected.put("mbinary", input.mbinary()); expected.put("mg", input.mg()); expected.put("ll", input.ll()); expected.put("lm", input.lm()); @@ -578,18 +608,24 @@ public abstract static class AutoValueInput { public abstract SdkBindingData blob(); + public abstract SdkBindingData binary(); + public abstract SdkBindingData generic(); public abstract SdkBindingData> l(); public abstract SdkBindingData> lb(); + public abstract SdkBindingData> lbinary(); + public abstract SdkBindingData> lg(); public abstract SdkBindingData> m(); public abstract SdkBindingData> mb(); + public abstract SdkBindingData> mbinary(); + public abstract SdkBindingData> mg(); public abstract SdkBindingData>> ll(); @@ -608,19 +644,23 @@ public static AutoValueInput create( SdkBindingData t, SdkBindingData d, SdkBindingData blob, + SdkBindingData binary, SdkBindingData generic, SdkBindingData> l, SdkBindingData> lb, + SdkBindingData> lbinary, SdkBindingData> lg, SdkBindingData> m, SdkBindingData> mb, + SdkBindingData> mbinary, SdkBindingData> mg, SdkBindingData>> ll, SdkBindingData>> lm, SdkBindingData>> ml, SdkBindingData>> mm) { return new AutoValue_JacksonSdkTypeTest_AutoValueInput( - i, f, s, b, t, d, blob, generic, l, lb, lg, m, mb, mg, ll, lm, ml, mm); + i, f, s, b, t, d, blob, binary, generic, l, lb, lbinary, lg, m, mb, mbinary, mg, ll, lm, + ml, mm); } } @@ -727,6 +767,10 @@ private static Literal literalOf(Blob blob) { return Literal.ofScalar(Scalar.ofBlob(blob)); } + private static Literal literalOf(Binary binary) { + return Literal.ofScalar(Scalar.ofBinary(binary)); + } + private static Literal literalOf(Struct generic) { return Literal.ofScalar(Scalar.ofGeneric(generic)); } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingDataFactory.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingDataFactory.java index aa7217ec8..a55475591 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingDataFactory.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingDataFactory.java @@ -26,6 +26,7 @@ import java.time.ZoneOffset; import java.util.List; import java.util.Map; +import org.flyte.api.v1.Binary; import org.flyte.api.v1.Blob; /** A utility class for creating {@link SdkBindingData} objects for different types. */ @@ -135,6 +136,16 @@ public static SdkBindingData of(SdkLiteralType type, T value) { return SdkBindingData.literal(type, value); } + /** + * Creates a {@code SdkBindingData} for a flyte Binary with the given value. + * + * @param value the simple value for this data + * @return the new {@code SdkBindingData} + */ + public static SdkBindingData of(Binary value) { + return SdkBindingData.literal(SdkLiteralTypes.binary(), value); + } + /** * Creates a {@code SdkBindingData} for a flyte Blob with the given value. * diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkLiteralTypes.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkLiteralTypes.java index af1b77292..016702a5b 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkLiteralTypes.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkLiteralTypes.java @@ -24,6 +24,7 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; +import org.flyte.api.v1.Binary; import org.flyte.api.v1.BindingData; import org.flyte.api.v1.Blob; import org.flyte.api.v1.BlobType; @@ -31,6 +32,7 @@ import org.flyte.api.v1.LiteralType; import org.flyte.api.v1.Primitive; import org.flyte.api.v1.Scalar; +import org.flyte.api.v1.SimpleType; /** A utility class for creating {@link SdkLiteralType} objects for different types. */ public class SdkLiteralTypes { @@ -183,6 +185,15 @@ public static SdkLiteralType> maps(SdkLiteralType mapValue return new MapSdkLiteralType<>(mapValueType); } + /** + * Returns a {@link SdkLiteralType} for binary. + * + * @return the {@link SdkLiteralType} + */ + public static SdkLiteralType binary() { + return BinarySdkLiteralType.INSTANCE; + } + /** * Returns a {@link SdkLiteralType} for blobs. * @@ -216,6 +227,30 @@ public String toString() { } } + private static class BinarySdkLiteralType extends SdkLiteralType { + private static final BinarySdkLiteralType INSTANCE = new BinarySdkLiteralType(); + + @Override + public LiteralType getLiteralType() { + return LiteralType.ofSimpleType(SimpleType.BINARY); + } + + @Override + public final Literal toLiteral(Binary value) { + return Literal.ofScalar(Scalar.ofBinary(value)); + } + + @Override + public final Binary fromLiteral(Literal literal) { + return literal.scalar().binary(); + } + + @Override + public final BindingData toBindingData(Binary value) { + return BindingData.ofScalar(Scalar.ofBinary(value)); + } + } + private static class BlobSdkLiteralType extends SdkLiteralType { private final BlobType blobType; diff --git a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkLiteralTypesTest.scala b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkLiteralTypesTest.scala index 5d169ad54..2b70765c7 100644 --- a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkLiteralTypesTest.scala +++ b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkLiteralTypesTest.scala @@ -16,8 +16,7 @@ */ package org.flyte.flytekitscala -import org.flyte.api.v1.{Blob, BlobType} -import org.flyte.api.v1.BlobType.BlobDimensionality +import org.flyte.api.v1.{Binary, Blob, BlobType} import org.flyte.flytekit.SdkLiteralType import org.flyte.flytekitscala.SdkLiteralTypes.{of, _} import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows} @@ -67,6 +66,7 @@ class TestOfReturnsProperTypeProvider extends ArgumentsProvider { Arguments.of(datetimes(), of[Instant]()), Arguments.of(durations(), of[Duration]()), Arguments.of(blobs(BlobType.DEFAULT), of[Blob]()), + Arguments.of(binary(), of[Binary]()), Arguments.of(generics(), of[ScalarNested]()), Arguments.of(collections(integers()), of[List[Long]]()), Arguments.of(collections(floats()), of[List[Double]]()), diff --git a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala index 109f07fc0..a36a6a3c2 100644 --- a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala +++ b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala @@ -36,14 +36,12 @@ import org.flyte.flytekit.{ SdkBindingData, SdkBindingDataFactory => SdkJavaBindingDataFactory } -import org.flyte.flytekitscala.SdkBindingDataFactory import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows} import org.junit.jupiter.api.Test import org.flyte.examples.AllInputsTask.{AutoAllInputsInput, Nested} import org.flyte.flytekit.jackson.JacksonSdkLiteralType import org.flyte.flytekitscala.SdkLiteralTypes.{ __TYPE, - blobs, collections, maps, strings diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataConverters.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataConverters.scala index 368d09ba4..b21c9a08c 100644 --- a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataConverters.scala +++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataConverters.scala @@ -174,6 +174,11 @@ object SdkBindingDataConverters { SdkScalaLiteralTypes.durations(), jf.Function.identity() ) + case SimpleType.BINARY => + TypeCastingResult( + SdkScalaLiteralTypes.binary(), + jf.Function.identity() + ) } case LiteralType.Kind.BLOB_TYPE => TypeCastingResult( diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataFactory.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataFactory.scala index c75ef8c65..2441eaae6 100644 --- a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataFactory.scala +++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkBindingDataFactory.scala @@ -16,7 +16,7 @@ */ package org.flyte.flytekitscala -import org.flyte.api.v1.Blob +import org.flyte.api.v1.{Binary, Blob} import org.flyte.flytekit.{ BindingCollection, BindingMap, @@ -145,6 +145,16 @@ object SdkBindingDataFactory { def of(value: Blob): SdkBindingData[Blob] = SdkBindingData.literal(SdkLiteralTypes.blobs(value.metadata.`type`), value) + /** Creates a [[SdkBindingData]] for a flyte Binary with the given value. + * + * @param value + * the simple value for this data + * @return + * the new [[SdkBindingData]] + */ + def of(value: Binary): SdkBindingData[Binary] = + SdkBindingData.literal(SdkLiteralTypes.binary(), value) + /** Creates a [[SdkBindingData]] for a flyte type with the given value. * * @param type diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala index 4ecfda6ce..fb128dd50 100644 --- a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala +++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkLiteralTypes.scala @@ -70,6 +70,8 @@ object SdkLiteralTypes { durations().asInstanceOf[SdkLiteralType[T]] case t if t =:= typeOf[Blob] => blobs(BlobType.DEFAULT).asInstanceOf[SdkLiteralType[T]] + case t if t =:= typeOf[Binary] => + binary().asInstanceOf[SdkLiteralType[T]] case t if t <:< typeOf[Product] && !(t =:= typeOf[Option[_]]) => generics().asInstanceOf[SdkLiteralType[T]] @@ -424,6 +426,14 @@ object SdkLiteralTypes { mapToProduct[T](structToMap(struct)) } + /** Returns a [[SdkLiteralType]] for binary. + * + * @return + * the [[SdkLiteralType]] + */ + def binary(): SdkLiteralType[Binary] = + SdkJavaLiteralTypes.binary() + /** Returns a [[SdkLiteralType]] for blob. * * @return diff --git a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala index a0a7fb107..00cbdea56 100644 --- a/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala +++ b/flytekit-scala_2.13/src/main/scala/org/flyte/flytekitscala/SdkScalaType.scala @@ -248,6 +248,11 @@ object SdkScalaType { SdkLiteralTypes.blobs(BlobType.DEFAULT) ) + implicit def binaryLiteralType: SdkScalaLiteralType[Binary] = + DelegateLiteralType( + SdkLiteralTypes.binary() + ) + // TODO we are forced to do this because SdkDataBinding.ofInteger returns a SdkBindingData // This makes Scala dev mad when they are forced to use the java types instead of scala types // We need to think what to do, maybe move the factory methods out of SdkDataBinding into their own class diff --git a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProtoUtil.java b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProtoUtil.java index 1d542b0b9..bddca2a35 100644 --- a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProtoUtil.java +++ b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProtoUtil.java @@ -28,6 +28,7 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.protobuf.ByteString; import com.google.protobuf.ListValue; import com.google.protobuf.NullValue; import com.google.protobuf.Timestamp; @@ -60,6 +61,7 @@ import java.util.Map; import java.util.function.Consumer; import java.util.regex.Pattern; +import org.flyte.api.v1.Binary; import org.flyte.api.v1.Binding; import org.flyte.api.v1.BindingData; import org.flyte.api.v1.Blob; @@ -163,6 +165,8 @@ private static Scalar deserialize(Literals.Scalar scalar) { return Scalar.ofBlob(deserialize(scalar.getBlob())); case BINARY: + return Scalar.ofBinary(deserialize(scalar.getBinary())); + case ERROR: case NONE_TYPE: case SCHEMA: @@ -203,6 +207,10 @@ static Primitive deserialize(Literals.Primitive primitive) { throw new UnsupportedOperationException(String.format("Unsupported Primitive [%s]", primitive)); } + static Binary deserialize(Literals.Binary binary) { + return Binary.builder().tag(binary.getTag()).value(binary.getValue().toByteArray()).build(); + } + @VisibleForTesting static Blob deserialize(Literals.Blob blob) { BlobType type = @@ -479,6 +487,8 @@ private static Types.SimpleType serialize(SimpleType simpleType) { return Types.SimpleType.DURATION; case STRUCT: return Types.SimpleType.STRUCT; + case BINARY: + return Types.SimpleType.BINARY; } return Types.SimpleType.UNRECOGNIZED; @@ -500,8 +510,9 @@ private static SimpleType deserialize(Types.SimpleType proto) { return SimpleType.DURATION; case STRUCT: return SimpleType.STRUCT; - case ERROR: case BINARY: + return SimpleType.BINARY; + case ERROR: case NONE: throw new IllegalArgumentException("Unsupported SimpleType: " + proto); @@ -987,6 +998,11 @@ private static Literals.Scalar serialize(Scalar scalar) { Blob blob = scalar.blob(); return Literals.Scalar.newBuilder().setBlob(serialize(blob)).build(); + + case BINARY: + Binary binary = scalar.binary(); + + return Literals.Scalar.newBuilder().setBinary(serialize(binary)).build(); } throw new AssertionError("Unexpected Scalar.Kind: " + scalar.kind()); @@ -1112,6 +1128,14 @@ private static void dateTimeRangeCheck(Instant datetime) { } } + @VisibleForTesting + static Literals.Binary serialize(Binary binary) { + return Literals.Binary.newBuilder() + .setTag(binary.tag()) + .setValue(ByteString.copyFrom(binary.value())) + .build(); + } + @VisibleForTesting static Literals.Blob serialize(Blob blob) { Literals.BlobMetadata metadata = diff --git a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProtoUtilTest.java b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProtoUtilTest.java index 9993b7e8b..0bc3b76ca 100644 --- a/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProtoUtilTest.java +++ b/jflyte-utils/src/test/java/org/flyte/jflyte/utils/ProtoUtilTest.java @@ -34,6 +34,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.protobuf.ByteString; import com.google.protobuf.ListValue; import com.google.protobuf.NullValue; import com.google.protobuf.Value; @@ -49,12 +50,14 @@ import flyteidl.core.Types; import flyteidl.core.Types.SchemaType.SchemaColumn.SchemaColumnType; import flyteidl.core.Workflow; +import java.nio.charset.StandardCharsets; import java.time.Duration; import java.time.Instant; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.stream.Stream; +import org.flyte.api.v1.Binary; import org.flyte.api.v1.Binding; import org.flyte.api.v1.BindingData; import org.flyte.api.v1.Blob; @@ -900,6 +903,9 @@ void shouldDeserializeComplexLiteralTypes(LiteralType expected, Types.LiteralTyp static Stream createSerializeComplexLiteralArguments() { return Stream.of( + Arguments.of( + LiteralType.ofSimpleType(SimpleType.BINARY), + Types.LiteralType.newBuilder().setSimple(Types.SimpleType.BINARY).build()), Arguments.of( LiteralType.ofBlobType( BlobType.builder() @@ -1035,6 +1041,21 @@ void shouldSerializeContainerError() { .build())); } + @Test + void shouldSerializeBinary() { + String tag = "tag"; + byte[] data = "data".getBytes(StandardCharsets.UTF_8); + + Binary binary = Binary.builder().tag(tag).value(data).build(); + + Literals.Binary proto = ProtoUtil.serialize(binary); + + assertThat( + proto, + equalTo( + Literals.Binary.newBuilder().setTag(tag).setValue(ByteString.copyFrom(data)).build())); + } + @Test void shouldSerializeBlob() { BlobType type = @@ -1125,6 +1146,17 @@ void shouldDeserializeWorkflowId() { .build())); } + @Test + void shouldDeserializeBinary() { + String tag = "tag"; + byte[] data = "data".getBytes(StandardCharsets.UTF_8); + Literals.Binary binary = + Literals.Binary.newBuilder().setTag(tag).setValue(ByteString.copyFrom(data)).build(); + + assertThat( + ProtoUtil.deserialize(binary), equalTo(Binary.builder().tag(tag).value(data).build())); + } + @Test void shouldDeserializeBlob() { Types.BlobType type =