Skip to content

Commit

Permalink
Bring back Blob support
Browse files Browse the repository at this point in the history
Signed-off-by: Hongxin Liang <[email protected]>
  • Loading branch information
honnix committed Oct 7, 2023
1 parent c892ceb commit 72a7781
Show file tree
Hide file tree
Showing 10 changed files with 196 additions and 45 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright 2023 Flyte Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.flyte.flytekit.jackson;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/** Applied to a blob property to annotate its type. */
@Target({ElementType.FIELD, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface BlobTypeDescription {
/**
* Describes the blob's format.
*
* @return format, not {@code null}
*/
String format();

/**
* Describes the blob's dimensionality.
*
* @return dimensionality, not {@code null}
*/
String dimensionality();
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.flyte.api.v1.Blob;
import org.flyte.api.v1.BlobType;
import org.flyte.api.v1.BlobType.BlobDimensionality;
import org.flyte.api.v1.Variable;
import org.flyte.flytekit.SdkBindingData;
import org.flyte.flytekit.SdkLiteralType;
Expand Down Expand Up @@ -63,11 +66,7 @@ public void property(BeanProperty prop) {
String propName = prop.getName();
AnnotatedMember member = prop.getMember();
SdkLiteralType<?> literalType =
toLiteralType(
handledType,
/*rootLevel=*/ true,
propName,
member.getMember().getDeclaringClass().getName());
toLiteralType(handledType, /* rootLevel= */ true, propName, member);

String description = getDescription(member);

Expand Down Expand Up @@ -132,18 +131,17 @@ private String getDescription(AnnotatedMember member) {

@SuppressWarnings("AlreadyChecked")
private SdkLiteralType<?> toLiteralType(
JavaType javaType, boolean rootLevel, String propName, String declaringClassName) {
JavaType javaType, boolean rootLevel, String propName, AnnotatedMember member) {
Class<?> type = javaType.getRawClass();

if (SdkBindingData.class.isAssignableFrom(type)) {
return toLiteralType(
javaType.getBindings().getBoundType(0), false, propName, declaringClassName);
return toLiteralType(javaType.getBindings().getBoundType(0), false, propName, member);
} else if (rootLevel) {
throw new UnsupportedOperationException(
String.format(
"Field '%s' from class '%s' is declared as '%s' and it is not matching any of the supported types. "
+ "Please make sure your variable declared type is wrapped in 'SdkBindingData<>'.",
propName, declaringClassName, type));
propName, member.getMember().getDeclaringClass().getName(), type));
} else if (isPrimitiveAssignableFrom(Long.class, type)) {
return SdkLiteralTypes.integers();
} else if (isPrimitiveAssignableFrom(Double.class, type)) {
Expand All @@ -159,8 +157,7 @@ private SdkLiteralType<?> toLiteralType(
} else if (List.class.isAssignableFrom(type)) {
JavaType elementType = javaType.getBindings().getBoundType(0);

return SdkLiteralTypes.collections(
toLiteralType(elementType, false, propName, declaringClassName));
return SdkLiteralTypes.collections(toLiteralType(elementType, false, propName, member));
} else if (Map.class.isAssignableFrom(type)) {
JavaType keyType = javaType.getBindings().getBoundType(0);
JavaType valueType = javaType.getBindings().getBoundType(1);
Expand All @@ -170,9 +167,22 @@ private SdkLiteralType<?> toLiteralType(
"Only Map<String, ?> is supported, got [" + javaType.getGenericSignature() + "]");
}

return SdkLiteralTypes.maps(toLiteralType(valueType, false, propName, declaringClassName));
return SdkLiteralTypes.maps(toLiteralType(valueType, false, propName, member));
} else if (Blob.class.isAssignableFrom(type)) {
BlobTypeDescription annotation = member.getAnnotation(BlobTypeDescription.class);
if (annotation == null) {
throw new UnsupportedOperationException(
String.format(
"Field '%s' from class '%s' is declared as '%s' and it must be annotated",
propName, member.getMember().getDeclaringClass().getName(), type));
}
return SdkLiteralTypes.blobs(
BlobType.builder()
.format(annotation.format())
.dimensionality(BlobDimensionality.valueOf(annotation.dimensionality()))
.build());
}
// TODO: Support blobs and structs
// TODO: Support structs
throw new UnsupportedOperationException(
String.format("Unsupported type: [%s]", type.getName()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.deser.std.StdDeserializer;
import java.io.IOException;
import java.io.Serializable;
import java.time.Duration;
import java.time.Instant;
import java.util.Iterator;
Expand All @@ -39,6 +38,10 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.flyte.api.v1.Blob;
import org.flyte.api.v1.BlobMetadata;
import org.flyte.api.v1.BlobType;
import org.flyte.api.v1.BlobType.BlobDimensionality;
import org.flyte.api.v1.Literal;
import org.flyte.api.v1.LiteralType;
import org.flyte.api.v1.Primitive;
Expand Down Expand Up @@ -80,7 +83,7 @@ private SdkBindingData<?> transform(JsonNode tree) {
}
}

private static SdkBindingData<? extends Serializable> transformScalar(JsonNode tree) {
private static SdkBindingData<?> transformScalar(JsonNode tree) {
Scalar.Kind scalarKind = Scalar.Kind.valueOf(tree.get(SCALAR).asText());
switch (scalarKind) {
case PRIMITIVE:
Expand All @@ -102,14 +105,33 @@ private static SdkBindingData<? extends Serializable> transformScalar(JsonNode t
throw new UnsupportedOperationException(
"Type contains an unsupported primitive: " + primitiveKind);

case GENERIC:
case BLOB:
return transformBlob(tree);

case GENERIC:
default:
throw new UnsupportedOperationException(
"Type contains an unsupported scalar: " + scalarKind);
}
}

private static SdkBindingData<Blob> transformBlob(JsonNode tree) {
JsonNode value = tree.get(VALUE);
String uri = value.get("uri").asText();
JsonNode type = value.get("metadata").get("type");
String format = type.get("format").asText();
BlobDimensionality dimensionality =
BlobDimensionality.valueOf(type.get("dimensionality").asText());
return SdkBindingDataFactory.of(
Blob.builder()
.uri(uri)
.metadata(
BlobMetadata.builder()
.type(BlobType.builder().format(format).dimensionality(dimensionality).build())
.build())
.build());
}

@SuppressWarnings("unchecked")
private <T> SdkBindingData<List<T>> transformCollection(JsonNode tree) {
SdkLiteralType<T> literalType = (SdkLiteralType<T>) readLiteralType(tree.get(TYPE));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
*/
package org.flyte.flytekit.jackson.serializers;

import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.SCALAR;
import static org.flyte.flytekit.jackson.serializers.SdkBindingDataSerializationProtocol.VALUE;

import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.SerializerProvider;
import java.io.IOException;
import org.flyte.api.v1.Blob;
import org.flyte.api.v1.Literal;
import org.flyte.api.v1.LiteralType;
import org.flyte.api.v1.Scalar;
import org.flyte.api.v1.Scalar.Kind;

public class BlobSerializer extends ScalarSerializer {
public BlobSerializer(
Expand All @@ -38,8 +38,8 @@ public BlobSerializer(

@Override
void serializeScalar() throws IOException {
gen.writeFieldName(SCALAR);
gen.writeObject(Scalar.Kind.BLOB);
gen.writeObject(Kind.BLOB);
gen.writeFieldName(VALUE);
serializerProvider
.findValueSerializer(Blob.class)
.serialize(value.scalar().blob(), gen, serializerProvider);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ public class SdkBindingDataSerializationProtocol {
public static final String TYPE = "type";
public static final String KIND = "kind";
public static final String PRIMITIVE = "primitive";
public static final String BLOB = "blob";
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
import java.util.Map;
import java.util.Objects;
import javax.annotation.Nullable;
import org.flyte.api.v1.Blob;
import org.flyte.api.v1.BlobMetadata;
import org.flyte.api.v1.BlobType;
import org.flyte.api.v1.Literal;
import org.flyte.api.v1.LiteralType;
Expand All @@ -64,7 +66,7 @@ public static AutoValueInput createAutoValueInput(
boolean b,
Instant t,
Duration d,
// Blob blob,
Blob blob,
List<String> l,
Map<String, String> m,
List<List<String>> ll,
Expand All @@ -78,6 +80,7 @@ public static AutoValueInput createAutoValueInput(
SdkBindingDataFactory.of(b),
SdkBindingDataFactory.of(t),
SdkBindingDataFactory.of(d),
SdkBindingDataFactory.of(blob),
SdkBindingDataFactory.ofStringCollection(l),
SdkBindingDataFactory.ofStringMap(m),
SdkBindingDataFactory.of(SdkLiteralTypes.collections(SdkLiteralTypes.strings()), ll),
Expand Down Expand Up @@ -119,19 +122,19 @@ public void testVariableMap() {
void testFromLiteralMap() {
Instant datetime = Instant.ofEpochSecond(12, 34);
Duration duration = Duration.ofSeconds(56, 78);
// Blob blob =
// Blob.builder()
// .metadata(BlobMetadata.builder().type(BLOB_TYPE).build())
// .uri("file://test")
// .build();
Blob blob =
Blob.builder()
.metadata(BlobMetadata.builder().type(BLOB_TYPE).build())
.uri("file://test")
.build();
Map<String, Literal> literalMap = new HashMap<>();
literalMap.put("i", literalOf(Primitive.ofIntegerValue(123L)));
literalMap.put("f", literalOf(Primitive.ofFloatValue(123.0)));
literalMap.put("s", literalOf(Primitive.ofStringValue("123")));
literalMap.put("b", literalOf(Primitive.ofBooleanValue(true)));
literalMap.put("t", literalOf(Primitive.ofDatetime(datetime)));
literalMap.put("d", literalOf(Primitive.ofDuration(duration)));
// literalMap.put("blob", literalOf(blob));
literalMap.put("blob", literalOf(blob));
literalMap.put("l", Literal.ofCollection(List.of(literalOf(Primitive.ofStringValue("123")))));
literalMap.put("m", Literal.ofMap(Map.of("marco", literalOf(Primitive.ofStringValue("polo")))));
literalMap.put(
Expand Down Expand Up @@ -159,9 +162,9 @@ void testFromLiteralMap() {
Literal.ofMap(
Map.of(
"math",
Literal.ofMap(
Map.of("pi", stringLiteralOf("3.14"), "e", stringLiteralOf("2.72"))),
"pokemon", Literal.ofMap(Map.of("ash", stringLiteralOf("pikachu"))))));
Literal.ofMap(Map.of("pi", stringLiteralOf("3.14"), "e", stringLiteralOf("2.72"))),
"pokemon",
Literal.ofMap(Map.of("ash", stringLiteralOf("pikachu"))))));

AutoValueInput input = JacksonSdkType.of(AutoValueInput.class).fromLiteralMap(literalMap);

Expand All @@ -175,7 +178,7 @@ void testFromLiteralMap() {
/* b= */ true,
/* t= */ datetime,
/* d= */ duration,
/// * blob= */ blob,
/* blob= */ blob,
/* l= */ List.of("123"),
/* m= */ Map.of("marco", "polo"),
/* ll= */ List.of(List.of("foo", "bar"), List.of("a", "b", "c")),
Expand All @@ -194,11 +197,11 @@ private static Literal stringLiteralOf(String string) {

@Test
void testToLiteralMap() {
// Blob blob =
// Blob.builder()
// .metadata(BlobMetadata.builder().type(BLOB_TYPE).build())
// .uri("file://test")
// .build();
Blob blob =
Blob.builder()
.metadata(BlobMetadata.builder().type(BLOB_TYPE).build())
.uri("file://test")
.build();
Map<String, Literal> literalMap =
JacksonSdkType.of(AutoValueInput.class)
.toLiteralMap(
Expand All @@ -209,7 +212,7 @@ void testToLiteralMap() {
/* b= */ false,
/* t= */ Instant.ofEpochSecond(42, 1),
/* d= */ Duration.ofSeconds(1, 42),
/// * blob= */ blob,
/* blob= */ blob,
/* l= */ List.of("foo"),
/* m= */ Map.of("marco", "polo"),
/* ll= */ List.of(List.of("foo", "bar"), List.of("a", "b", "c")),
Expand Down Expand Up @@ -271,13 +274,17 @@ void testToLiteralMap() {
Map.of(
"pi", stringLiteralOf("3.14"), "e", stringLiteralOf("2.72"))),
"pokemon",
Literal.ofMap(Map.of("ash", stringLiteralOf("pikachu"))))))
// hasEntry("blob", literalOf(blob))
)));
Literal.ofMap(Map.of("ash", stringLiteralOf("pikachu")))))),
hasEntry("blob", literalOf(blob)))));
}

@Test
public void testToSdkBindingDataMap() {
Blob blob =
Blob.builder()
.metadata(BlobMetadata.builder().type(BLOB_TYPE).build())
.uri("file://test")
.build();
AutoValueInput input =
createAutoValueInput(
/* i= */ 42L,
Expand All @@ -286,7 +293,7 @@ public void testToSdkBindingDataMap() {
/* b= */ false,
/* t= */ Instant.ofEpochSecond(42, 1),
/* d= */ Duration.ofSeconds(1, 42),
/// * blob= */ blob,
/* blob= */ blob,
/* l= */ List.of("foo"),
/* m= */ Map.of("marco", "polo"),
/* ll= */ List.of(List.of("foo", "bar"), List.of("a", "b", "c")),
Expand All @@ -305,6 +312,7 @@ public void testToSdkBindingDataMap() {
expected.put("b", input.b());
expected.put("t", input.t());
expected.put("d", input.d());
expected.put("blob", input.blob());
expected.put("l", input.l());
expected.put("m", input.m());
expected.put("ll", input.ll());
Expand Down Expand Up @@ -536,8 +544,8 @@ public abstract static class AutoValueInput {

public abstract SdkBindingData<Duration> d();

// TODO add blobs to sdkbinding data
// public abstract SdkBindingData<Blob> blob();
@BlobTypeDescription(format = "", dimensionality = "SINGLE")
public abstract SdkBindingData<Blob> blob();

public abstract SdkBindingData<List<String>> l();

Expand All @@ -558,15 +566,15 @@ public static AutoValueInput create(
SdkBindingData<Boolean> b,
SdkBindingData<Instant> t,
SdkBindingData<Duration> d,
// Blob blob,
SdkBindingData<Blob> blob,
SdkBindingData<List<String>> l,
SdkBindingData<Map<String, String>> m,
SdkBindingData<List<List<String>>> ll,
SdkBindingData<List<Map<String, String>>> lm,
SdkBindingData<Map<String, List<String>>> ml,
SdkBindingData<Map<String, Map<String, String>>> mm) {
return new AutoValue_JacksonSdkTypeTest_AutoValueInput(
i, f, s, b, t, d, l, m, ll, lm, ml, mm);
i, f, s, b, t, d, blob, l, m, ll, lm, ml, mm);
}
}

Expand Down Expand Up @@ -701,4 +709,8 @@ private static Variable createVar(LiteralType literalType, String description) {
private static Literal literalOf(Primitive primitive) {
return Literal.ofScalar(Scalar.ofPrimitive(primitive));
}

private static Literal literalOf(Blob blob) {
return Literal.ofScalar(Scalar.ofBlob(blob));
}
}
Loading

0 comments on commit 72a7781

Please sign in to comment.