Skip to content

Commit

Permalink
fix: correct deserialization of nested map/list/types in unions (#1144)
Browse files Browse the repository at this point in the history
  • Loading branch information
ianbotsf authored Aug 29, 2024
1 parent 932b579 commit 277a77c
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 14 deletions.
8 changes: 8 additions & 0 deletions .changes/b786851c-9427-40cd-b3fa-ee375011d931.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "b786851c-9427-40cd-b3fa-ee375011d931",
"type": "bugfix",
"description": "Correct deserialization of nested map/list types in unions",
"issues": [
"awslabs/smithy-kotlin#1126"
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@ open class DeserializeStructGenerator(
/**
* Enables overriding the codegen output of the final value resulting
* from the deserialization of a non-primitive type.
* @param memberShape [MemberShape] associated with entry
* @param forMemberShape [MemberShape] associated with entry, if any
* @param defaultCollectionName the default value produced by this class.
*/
open fun collectionReturnExpression(memberShape: MemberShape, defaultCollectionName: String): String = defaultCollectionName
open fun collectionReturnExpression(forMemberShape: MemberShape?, defaultCollectionName: String): String =
defaultCollectionName

/**
* Enables overriding of the lhs expression into which a deserialization operation's
Expand Down Expand Up @@ -292,7 +293,7 @@ open class DeserializeStructGenerator(
val descriptorName = rootMemberShape.descriptorName(nestingLevel.nestedDescriptorName())
val nextNestingLevel = nestingLevel + 1
val memberName = nextNestingLevel.variableNameFor(NestedIdentifierType.MAP)
val collectionReturnExpression = collectionReturnExpression(rootMemberShape, memberName)
val collectionReturnExpression = collectionReturnExpression(null, memberName)

writeKeyVal(keyShape, keySymbol, keyName)
writer.withBlock("val $valueName =", "") {
Expand Down Expand Up @@ -346,7 +347,7 @@ open class DeserializeStructGenerator(
val descriptorName = rootMemberShape.descriptorName(nestingLevel.nestedDescriptorName())
val nextNestingLevel = nestingLevel + 1
val memberName = nextNestingLevel.variableNameFor(NestedIdentifierType.COLLECTION)
val collectionReturnExpression = collectionReturnExpression(rootMemberShape, memberName)
val collectionReturnExpression = collectionReturnExpression(null, memberName)

writeKeyVal(keyShape, keySymbol, keyName)
writer.withBlock("val $valueName =", "") {
Expand Down Expand Up @@ -516,7 +517,7 @@ open class DeserializeStructGenerator(
val elementName = nestingLevel.variableNameFor(NestedIdentifierType.ELEMENT)
val nextNestingLevel = nestingLevel + 1
val mapName = nextNestingLevel.variableNameFor(NestedIdentifierType.MAP)
val collectionReturnExpression = collectionReturnExpression(rootMemberShape, mapName)
val collectionReturnExpression = collectionReturnExpression(null, mapName)

writer.withBlock("val $elementName = deserializer.#T($descriptorName) {", "}", RuntimeTypes.Serde.deserializeMap) {
write(
Expand Down Expand Up @@ -555,7 +556,7 @@ open class DeserializeStructGenerator(
val elementName = nestingLevel.variableNameFor(NestedIdentifierType.ELEMENT)
val nextNestingLevel = nestingLevel + 1
val listName = nextNestingLevel.variableNameFor(NestedIdentifierType.COLLECTION)
val collectionReturnExpression = collectionReturnExpression(rootMemberShape, listName)
val collectionReturnExpression = collectionReturnExpression(null, listName)

writer.withBlock("val $elementName = deserializer.#T($descriptorName) {", "}", RuntimeTypes.Serde.deserializeList) {
write(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,13 @@ class DeserializeUnionGenerator(
override fun deserializationResultName(defaultName: String): String = "value"

// Return the type that deserializes the incoming value. Example: `MyAggregateUnion.IntList`
override fun collectionReturnExpression(memberShape: MemberShape, defaultCollectionName: String): String {
val unionTypeName = memberShape.unionTypeName(ctx)
return "$unionTypeName($defaultCollectionName)"
}
override fun collectionReturnExpression(forMemberShape: MemberShape?, defaultCollectionName: String) =
if (forMemberShape != null && forMemberShape in members) {
// We're returning a top-level collection for a member value—nest it inside a union variant
val unionTypeName = forMemberShape.unionTypeName(ctx)
"$unionTypeName($defaultCollectionName)"
} else {
// We're returning a nested collection type—don't nest it inside a union variant
super.collectionReturnExpression(null, defaultCollectionName)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,11 @@ class DeserializeUnionGeneratorTest {
val v2 = if (nextHasValue()) { deserializeBarUnionDocument(deserializer) } else { deserializeNull(); continue }
map2[k2] = v2
}
FooUnion.StrMapVal(map2)
map2
}
col1.add(el1)
}
FooUnion.StrMapVal(col1)
col1
}
} else { deserializeNull(); continue }
Expand Down Expand Up @@ -269,7 +269,7 @@ class DeserializeUnionGeneratorTest {
val el1 = if (nextHasValue()) { deserializeInt() } else { deserializeNull(); continue }
col1.add(el1)
}
MyAggregateUnion.ListOfIntList(col1)
col1
}
col0.add(el0)
}
Expand All @@ -288,7 +288,7 @@ class DeserializeUnionGeneratorTest {
val el1 = if (nextHasValue()) { deserializeInt() } else { deserializeNull(); continue }
col1.add(el1)
}
MyAggregateUnion.MapOfLists(col1)
col1
}
} else { deserializeNull(); continue }
Expand Down

0 comments on commit 277a77c

Please sign in to comment.