Skip to content

Commit

Permalink
Support struct in Scala layer
Browse files Browse the repository at this point in the history
Signed-off-by: Hongxin Liang <[email protected]>
  • Loading branch information
honnix committed Oct 14, 2023
1 parent 1b0153f commit 3b2775a
Show file tree
Hide file tree
Showing 15 changed files with 410 additions and 39 deletions.
Original file line number Diff line number Diff line change
@@ -1 +1 @@
org.flyte.examples.flytekitscala.FibonacciLaunchPlan
org.flyte.examples.flytekitscala.LaunchPlanRegistry
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ org.flyte.examples.flytekitscala.SumTask
org.flyte.examples.flytekitscala.GreetTask
org.flyte.examples.flytekitscala.AddQuestionTask
org.flyte.examples.flytekitscala.NoInputsTask
org.flyte.examples.flytekitscala.NestedIOTask
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
org.flyte.examples.flytekitscala.FibonacciWorkflow
org.flyte.examples.flytekitscala.WelcomeWorkflow
org.flyte.examples.flytekitscala.NestedIOWorkflow
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ import org.flyte.flytekit.{SdkLaunchPlan, SimpleSdkLaunchPlanRegistry}
import org.flyte.flytekitscala.SdkScalaType

case class FibonacciLaunchPlanInput(fib0: Long, fib1: Long)
case class NestedIOLaunchPlanInput(name: String, generic: Nested)

class FibonacciLaunchPlan extends SimpleSdkLaunchPlanRegistry {
class LaunchPlanRegistry extends SimpleSdkLaunchPlanRegistry {
// Register default launch plans for all workflows
registerDefaultLaunchPlans()

Expand Down Expand Up @@ -53,4 +54,33 @@ class FibonacciLaunchPlan extends SimpleSdkLaunchPlanRegistry {
.withDefaultInput("fib0", 0L)
.withDefaultInput("fib1", 1L)
)

registerLaunchPlan(
SdkLaunchPlan
.of(new NestedIOWorkflow)
.withName("NestedIOWorkflowLaunchPlan")
.withDefaultInput(
SdkScalaType[NestedIOLaunchPlanInput],
NestedIOLaunchPlanInput(
"yo",
Nested(
boolean = true,
1.toByte,
2.toShort,
3,
4L,
5.toFloat,
6.toDouble,
"hello",
List("1", "2"),
Map("1" -> "1", "2" -> "2"),
Some(false),
None,
Some(List("3", "4")),
Some(Map("3" -> "3", "4" -> "4")),
NestedNested(7.toDouble, NestedNestedNested("world"))
)
)
)
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright 2023 Flyte Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.flyte.examples.flytekitscala

import org.flyte.flytekit.{SdkBindingData, SdkRunnableTask, SdkTransform}
import org.flyte.flytekitscala.{
Description,
SdkBindingDataFactory,
SdkScalaType
}

case class NestedNestedNested(string: String)
case class NestedNested(double: Double, nested: NestedNestedNested)
case class Nested(
boolean: Boolean,
byte: Byte,
short: Short,
int: Int,
long: Long,
float: Float,
double: Double,
string: String,
list: List[String],
map: Map[String, String],
optBoolean: Option[Boolean],
optByte: Option[Byte],
optList: Option[List[String]],
optMap: Option[Map[String, String]],
nested: NestedNested
)
case class NestedIOTaskInput(
@Description("the name of the person to be greeted")
name: SdkBindingData[String],
@Description("a nested input")
generic: SdkBindingData[Nested]
)
case class NestedIOTaskOutput(
@Description("the name of the person to be greeted")
name: SdkBindingData[String],
@Description("a nested input")
generic: SdkBindingData[Nested]
)

/** Example Flyte task that takes a name as the input and outputs a simple
* greeting message.
*/
class NestedIOTask
extends SdkRunnableTask[
NestedIOTaskInput,
NestedIOTaskOutput
](
SdkScalaType[NestedIOTaskInput],
SdkScalaType[NestedIOTaskOutput]
) {

/** Defines task behavior. This task takes a name as the input, wraps it in a
* welcome message, and outputs the message.
*
* @param input
* the name of the person to be greeted
* @return
* the welcome message
*/
override def run(input: NestedIOTaskInput): NestedIOTaskOutput =
NestedIOTaskOutput(
input.name,
input.generic
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright 2020-2023 Flyte Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.flyte.examples.flytekitscala

import org.flyte.flytekitscala.{
SdkScalaType,
SdkScalaWorkflow,
SdkScalaWorkflowBuilder
}

class NestedIOWorkflow
extends SdkScalaWorkflow[NestedIOTaskInput, Unit](
SdkScalaType[NestedIOTaskInput],
SdkScalaType.unit
) {

override def expand(
builder: SdkScalaWorkflowBuilder,
input: NestedIOTaskInput
): Unit = {
builder.apply(new NestedIOTask(), input)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,18 @@ import org.flyte.flytekit.{

import java.time.{Duration, Instant}
import scala.collection.JavaConverters._
import scala.reflect.runtime.universe.{TypeTag, typeOf}
import scala.reflect.api.{Mirror, TypeCreator, Universe}
import scala.reflect.runtime.universe
import scala.reflect.{ClassTag, classTag}
import scala.reflect.runtime.universe.{
NoPrefix,
Symbol,
Type,
TypeTag,
runtimeMirror,
termNames,
typeOf
}

object SdkLiteralTypes {

Expand Down Expand Up @@ -202,6 +213,185 @@ object SdkLiteralTypes {
*/
def durations(): SdkLiteralType[Duration] = SdkJavaLiteralTypes.durations()

/** Returns a [[SdkLiteralType]] for products.
* @return
* the [[SdkLiteralType]]
*/
def generics[T <: Product: TypeTag: ClassTag](): SdkLiteralType[T] = {
ScalaLiteralType[T](
LiteralType.ofSimpleType(SimpleType.STRUCT),
(value: T) => Literal.ofScalar(Scalar.ofGeneric(toStruct(value))),
(x: Literal) => toProduct(x.scalar().generic()),
(v: T) => BindingData.ofScalar(Scalar.ofGeneric(toStruct(v))),
"generics"
)
}

private def toStruct(product: Product): Struct = {
def productToMap(product: Product): Map[String, Any] = {
// by spec getDeclaredFields is not ordered but in practice it works fine
// it's a lot better since Scala 2.13 because productElementNames was introduced
// (product.productElementNames zip product.productIterator).toMap
product.getClass.getDeclaredFields
.map(_.getName)
.zip(product.productIterator.toList)
.toMap
}

def mapToStruct(map: Map[String, Any]): Struct = {
val fields = map.map({ case (key, value) =>
(key, anyToStructValue(value))
})
Struct.of(fields.asJava)
}

def anyToStructValue(value: Any): Struct.Value = {
def anyToStructureValue0(value: Any): Struct.Value = {
value match {
case s: String => Struct.Value.ofStringValue(s)
case n @ (_: Byte | _: Short | _: Int | _: Long | _: Float |
_: Double) =>
Struct.Value.ofNumberValue(n.toString.toDouble)
case b: Boolean => Struct.Value.ofBoolValue(b)
case l: List[Any] =>
Struct.Value.ofListValue(l.map(anyToStructValue).asJava)
case m: Map[_, _] =>
Struct.Value.ofStructValue(
mapToStruct(m.asInstanceOf[Map[String, Any]])
)
case null => Struct.Value.ofNullValue()
case p: Product =>
Struct.Value.ofStructValue(mapToStruct(productToMap(p)))
case _ =>
throw new IllegalArgumentException(
s"Unsupported type: ${value.getClass}"
)
}
}

value match {
case Some(v) => anyToStructureValue0(v)
case None => Struct.Value.ofNullValue()
case _ => anyToStructureValue0(value)
}
}

mapToStruct(productToMap(product))
}

private def toProduct[T <: Product: TypeTag: ClassTag](
struct: Struct
): T = {
def structToMap(struct: Struct): Map[String, Any] = {
struct
.fields()
.asScala
.map({ case (key, value) =>
(key, structValueToAny(value))
})
.toMap
}

def mapToProduct[S <: Product: TypeTag: ClassTag](
map: Map[String, Any]
): S = {
val mirror = runtimeMirror(classTag[S].runtimeClass.getClassLoader)

def valueToParamValue(value: Any, param: Symbol): Any = {
def valueToParamValue0(value: Any, param: Symbol): Any = {
if (param.typeSignature =:= typeOf[Byte]) {
value.asInstanceOf[Double].toByte
} else if (param.typeSignature =:= typeOf[Short]) {
value.asInstanceOf[Double].toShort
} else if (param.typeSignature =:= typeOf[Int]) {
value.asInstanceOf[Double].toInt
} else if (param.typeSignature =:= typeOf[Long]) {
value.asInstanceOf[Double].toLong
} else if (param.typeSignature =:= typeOf[Float]) {
value.asInstanceOf[Double].toFloat
} else if (param.typeSignature <:< typeOf[Product]) {
val typeTag = createTypeTag(param.typeSignature)
val classTag = ClassTag(
typeTag.mirror.runtimeClass(param.typeSignature)
)
mapToProduct(value.asInstanceOf[Map[String, Any]])(
typeTag,
classTag
)
} else {
value
}
}

if (param.typeSignature <:< typeOf[Option[Any]]) {
Some(
valueToParamValue0(
value,
param.typeSignature.dealias.typeArgs.head.typeSymbol
)
)
} else {
valueToParamValue0(value, param)
}
}

def createTypeTag[U <: Product](tpe: Type): TypeTag[U] = {
val typSym = mirror.staticClass(tpe.typeSymbol.fullName)
// note: this uses internal API, otherwise we will need to depend on scala-compiler at runtime
val typeRef =
universe.internal.typeRef(NoPrefix, typSym, List.empty)

TypeTag(
mirror,
new TypeCreator {
override def apply[V <: Universe with Singleton](
m: Mirror[V]
): V#Type = {
assert(
m == mirror,
s"TypeTag[$typeRef] defined in $mirror cannot be migrated to $m."
)
typeRef.asInstanceOf[V#Type]
}
}
)
}

val clazz = typeOf[S].typeSymbol.asClass
val classMirror = mirror.reflectClass(clazz)
val constructor = typeOf[S].decl(termNames.CONSTRUCTOR).asMethod
val constructorMirror = classMirror.reflectConstructor(constructor)

val constructorArgs =
constructor.paramLists.flatten.map((param: Symbol) => {
val paramName = param.name.toString
val value = map.getOrElse(
paramName,
throw new IllegalArgumentException(
s"Map is missing required parameter named $paramName"
)
)
valueToParamValue(value, param)
})

constructorMirror(constructorArgs: _*).asInstanceOf[S]
}

def structValueToAny(value: Struct.Value): Any = {
value.kind() match {
case Struct.Value.Kind.STRING_VALUE => value.stringValue()
case Struct.Value.Kind.NUMBER_VALUE => value.numberValue()
case Struct.Value.Kind.BOOL_VALUE => value.boolValue()
case Struct.Value.Kind.LIST_VALUE =>
value.listValue().asScala.map(structValueToAny).toList
case Struct.Value.Kind.STRUCT_VALUE => structToMap(value.structValue())
case Struct.Value.Kind.NULL_VALUE => None
}
}

mapToProduct[T](structToMap(struct))
}

/** Returns a [[SdkLiteralType]] for blob.
*
* @return
Expand Down
Loading

0 comments on commit 3b2775a

Please sign in to comment.