diff --git a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala index 93014c55..92cafad9 100644 --- a/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala +++ b/flytekit-scala-tests/src/test/scala/org/flyte/flytekitscala/SdkScalaTypeTest.scala @@ -82,7 +82,8 @@ class SdkScalaTypeTest { floats: SdkBindingData[List[Double]], booleans: SdkBindingData[List[Boolean]], datetimes: SdkBindingData[List[Instant]], - durations: SdkBindingData[List[Duration]] + durations: SdkBindingData[List[Duration]], + generics: SdkBindingData[List[ScalarNested]] ) case class MapInput( @@ -91,7 +92,8 @@ class SdkScalaTypeTest { floatMap: SdkBindingData[Map[String, Double]], booleanMap: SdkBindingData[Map[String, Boolean]], datetimeMap: SdkBindingData[Map[String, Instant]], - durationMap: SdkBindingData[Map[String, Duration]] + durationMap: SdkBindingData[Map[String, Duration]], + genericMap: SdkBindingData[Map[String, ScalarNested]] ) case class ComplexInput( @@ -352,7 +354,8 @@ class SdkScalaTypeTest { "floats" -> createCollectionVar(SimpleType.FLOAT), "booleans" -> createCollectionVar(SimpleType.BOOLEAN), "datetimes" -> createCollectionVar(SimpleType.DATETIME), - "durations" -> createCollectionVar(SimpleType.DURATION) + "durations" -> createCollectionVar(SimpleType.DURATION), + "generics" -> createCollectionVar(SimpleType.STRUCT) ).asJava val output = SdkScalaType[CollectionInput].getVariableMap @@ -448,6 +451,25 @@ class SdkScalaTypeTest { ), durations = SdkBindingDataFactory.of( List(Duration.ofSeconds(123, 456), Duration.ofSeconds(543, 21)) + ), + generics = SdkBindingDataFactory.of( + SdkLiteralTypes.generics[ScalarNested](), + List( + ScalarNested( + "foo", + Some("bar"), + Some(ScalarNestedNested("foo", Some("bar"))), + List(ScalarNestedNested("foo", Some("bar"))), + Map("foo" -> ScalarNestedNested("foo", Some("bar"))) + ), + ScalarNested( + "foo2", + Some("bar2"), + Some(ScalarNestedNested("foo2", Some("bar2"))), + List(ScalarNestedNested("foo2", Some("bar2"))), + Map("foo2" -> ScalarNestedNested("foo2", Some("bar2"))) + ) + ) ) ) @@ -466,7 +488,8 @@ class SdkScalaTypeTest { "floatMap" -> createMapVar(SimpleType.FLOAT), "booleanMap" -> createMapVar(SimpleType.BOOLEAN), "datetimeMap" -> createMapVar(SimpleType.DATETIME), - "durationMap" -> createMapVar(SimpleType.DURATION) + "durationMap" -> createMapVar(SimpleType.DURATION), + "genericMap" -> createMapVar(SimpleType.STRUCT) ).asJava val output = SdkScalaType[MapInput].getVariableMap @@ -495,7 +518,26 @@ class SdkScalaTypeTest { datetimeMap = SdkBindingDataFactory.of(Map("k5" -> Instant.ofEpochMilli(321L))), durationMap = - SdkBindingDataFactory.of(Map("k6" -> Duration.ofSeconds(543, 21))) + SdkBindingDataFactory.of(Map("k6" -> Duration.ofSeconds(543, 21))), + genericMap = SdkBindingDataFactory.of( + SdkLiteralTypes.generics[ScalarNested](), + Map( + "a" -> ScalarNested( + "foo2", + Some("bar2"), + Some(ScalarNestedNested("foo2", Some("bar2"))), + List(ScalarNestedNested("foo2", Some("bar2"))), + Map("foo2" -> ScalarNestedNested("foo2", Some("bar2"))) + ), + "b" -> ScalarNested( + "foo2", + Some("bar2"), + Some(ScalarNestedNested("foo2", Some("bar2"))), + List(ScalarNestedNested("foo2", Some("bar2"))), + Map("foo2" -> ScalarNestedNested("foo2", Some("bar2"))) + ) + ) + ) ) val output = SdkScalaType[MapInput].fromLiteralMap(