Skip to content

Commit

Permalink
refactor: decrease generated artifact size (#1057)
Browse files Browse the repository at this point in the history
  • Loading branch information
aajtodd authored Apr 2, 2024
1 parent 281d415 commit 5acf7ef
Show file tree
Hide file tree
Showing 34 changed files with 380 additions and 160 deletions.
5 changes: 5 additions & 0 deletions .changes/1332be89-09d8-4b30-9e42-6f7a353c4c72.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"id": "1332be89-09d8-4b30-9e42-6f7a353c4c72",
"type": "misc",
"description": "Decrease generated client artifact sizes by reducing the number of suspension points for operations and inlining commonly used HTTP builders"
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class AwsQuery : QueryHttpBindingProtocolGenerator() {
writer: KotlinWriter,
) {
writer.write("""checkNotNull(payload){ "unable to parse error from empty response" }""")
writer.write("#T(payload)", RuntimeTypes.AwsXmlProtocols.parseRestXmlErrorResponse)
writer.write("#T(payload)", RuntimeTypes.AwsXmlProtocols.parseRestXmlErrorResponseNoSuspend)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class Ec2Query : QueryHttpBindingProtocolGenerator() {
writer: KotlinWriter,
) {
writer.write("""checkNotNull(payload){ "unable to parse error from empty response" }""")
writer.write("#T(payload)", RuntimeTypes.AwsXmlProtocols.parseEc2QueryErrorResponse)
writer.write("#T(payload)", RuntimeTypes.AwsXmlProtocols.parseEc2QueryErrorResponseNoSuspend)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ open class RestXml : AwsHttpBindingProtocolGenerator() {
writer: KotlinWriter,
) {
writer.write("""checkNotNull(payload){ "unable to parse error from empty response" }""")
writer.write("#T(payload)", RuntimeTypes.AwsXmlProtocols.parseRestXmlErrorResponse)
writer.write("#T(payload)", RuntimeTypes.AwsXmlProtocols.parseRestXmlErrorResponseNoSuspend)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,11 @@ abstract class AwsHttpBindingProtocolGenerator : HttpBindingProtocolGenerator()
override fun operationErrorHandler(ctx: ProtocolGenerator.GenerationContext, op: OperationShape): Symbol =
op.errorHandler(ctx.settings) { writer ->
writer.withBlock(
"private suspend fun ${op.errorHandlerName()}(context: #T, call: #T): #Q {",
"private fun ${op.errorHandlerName()}(context: #T, call: #T, payload: #T?): #Q {",
"}",
RuntimeTypes.Core.ExecutionContext,
RuntimeTypes.Http.HttpCall,
KotlinTypes.ByteArray,
KotlinTypes.Nothing,
) {
renderThrowOperationError(ctx, op, writer)
Expand All @@ -107,8 +108,7 @@ abstract class AwsHttpBindingProtocolGenerator : HttpBindingProtocolGenerator()
),
) {
val exceptionBaseSymbol = ExceptionBaseClassGenerator.baseExceptionSymbol(ctx.settings)
writer.write("val payload = call.response.body.#T()", RuntimeTypes.Http.readAll)
.write("val wrappedResponse = call.response.#T(payload)", RuntimeTypes.AwsProtocolCore.withPayload)
writer.write("val wrappedResponse = call.response.#T(payload)", RuntimeTypes.AwsProtocolCore.withPayload)
.write("val wrappedCall = call.copy(response = wrappedResponse)")
.write("")
.declareSection(
Expand Down Expand Up @@ -151,7 +151,7 @@ abstract class AwsHttpBindingProtocolGenerator : HttpBindingProtocolGenerator()
name = "${errSymbol.name}Deserializer"
namespace = ctx.settings.pkg.serde
}
writer.write("#S -> #T().deserialize(context, wrappedCall)", getErrorCode(ctx, err), errDeserializerSymbol)
writer.write("#S -> #T().deserialize(context, wrappedCall, payload)", getErrorCode(ctx, err), errDeserializerSymbol)
}
write("else -> #T(errorDetails.message)", exceptionBaseSymbol)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ object RuntimeTypes {
val EndpointResolver = symbol("EndpointResolver")
val ResolveEndpointRequest = symbol("ResolveEndpointRequest")
val execute = symbol("execute")
val HttpDeserialize = symbol("HttpDeserialize")
val HttpDeserializer = symbol("HttpDeserializer")
val HttpOperationContext = symbol("HttpOperationContext")
val HttpSerialize = symbol("HttpSerialize")
val HttpSerializer = symbol("HttpSerializer")
val OperationAuthConfig = symbol("OperationAuthConfig")
val OperationMetrics = symbol("OperationMetrics")
val OperationRequest = symbol("OperationRequest")
Expand Down Expand Up @@ -407,8 +407,8 @@ object RuntimeTypes {
val RestJsonErrorDeserializer = symbol("RestJsonErrorDeserializer")
}
object AwsXmlProtocols : RuntimeTypePackage(KotlinDependency.AWS_XML_PROTOCOLS) {
val parseRestXmlErrorResponse = symbol("parseRestXmlErrorResponse")
val parseEc2QueryErrorResponse = symbol("parseEc2QueryErrorResponse")
val parseRestXmlErrorResponseNoSuspend = symbol("parseRestXmlErrorResponseNoSuspend")
val parseEc2QueryErrorResponseNoSuspend = symbol("parseEc2QueryErrorResponseNoSuspend")
}

object AwsEventStream : RuntimeTypePackage(KotlinDependency.AWS_EVENT_STREAM) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
* The function should have the following signature:
*
* ```
* suspend fun throwFooOperationError(context: ExecutionContext, call: HttpCall): Nothing {
* fun throwFooOperationError(context: ExecutionContext, call: HttpCall, payload: ByteArray?): Nothing {
* <-- CURRENT WRITER CONTEXT -->
* }
* ```
Expand Down Expand Up @@ -169,20 +169,25 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
val operationSerializerSymbols = setOf(
RuntimeTypes.Http.HttpBody,
RuntimeTypes.Http.HttpMethod,
RuntimeTypes.HttpClient.Operation.HttpSerialize,
RuntimeTypes.Http.Request.HttpRequestBuilder,
RuntimeTypes.Http.Request.url,
)

val serdeMeta = HttpSerdeMeta(op.isInputEventStream(ctx.model))

ctx.delegator.useSymbolWriter(serializerSymbol) { writer ->
// import all of http, http.request, and serde packages. All serializers requires one or more of the symbols
// and most require quite a few. Rather than try and figure out which specific ones are used just take them
// all to ensure all the various DSL builders are available, etc
writer
.addImport(operationSerializerSymbols)
.write("")
.openBlock("internal class #T: #T<#T> {", serializerSymbol, RuntimeTypes.HttpClient.Operation.HttpSerialize, inputSymbol)
.openBlock("internal class #T: #T.#L<#T> {", serializerSymbol, RuntimeTypes.HttpClient.Operation.HttpSerializer, serdeMeta.variantName, inputSymbol)
.call {
writer.openBlock("override suspend fun serialize(context: #T, input: #T): #T {", RuntimeTypes.Core.ExecutionContext, inputSymbol, RuntimeTypes.Http.Request.HttpRequestBuilder)
val modifier = if (serdeMeta.isStreaming) "suspend " else ""
writer.openBlock(
"override #Lfun serialize(context: #T, input: #T): #T {",
modifier,
RuntimeTypes.Core.ExecutionContext,
inputSymbol,
RuntimeTypes.Http.Request.HttpRequestBuilder,
)
.write("val builder = #T()", RuntimeTypes.Http.Request.HttpRequestBuilder)
.call {
renderHttpSerialize(ctx, op, writer)
Expand Down Expand Up @@ -546,18 +551,22 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {

val resolver = getProtocolHttpBindingResolver(ctx.model, ctx.service)
val responseBindings = resolver.responseBindings(op)

val serdeMeta = httpDeserializerInfo(ctx, op)

ctx.delegator.useSymbolWriter(deserializerSymbol) { writer ->
writer
.write("")
.openBlock(
"internal class #T: #T<#T> {",
"internal class #T: #T.#L<#T> {",
deserializerSymbol,
RuntimeTypes.HttpClient.Operation.HttpDeserialize,
RuntimeTypes.HttpClient.Operation.HttpDeserializer,
serdeMeta.variantName,
outputSymbol,
)
.write("")
.call {
renderHttpDeserialize(ctx, outputSymbol, responseBindings, op, writer)
renderHttpDeserialize(ctx, outputSymbol, responseBindings, serdeMeta, op, writer)
}
.closeBlock("}")
}
Expand All @@ -569,8 +578,12 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
protected open fun renderIsHttpError(ctx: ProtocolGenerator.GenerationContext, op: OperationShape, writer: KotlinWriter) {
writer.addImport(RuntimeTypes.Http.isSuccess)
writer.withBlock("if (!response.status.#T()) {", "}", RuntimeTypes.Http.isSuccess) {
val serdeMeta = httpDeserializerInfo(ctx, op)
if (serdeMeta.isStreaming) {
writer.write("val payload = response.body.#T()", RuntimeTypes.Http.readAll)
}
val errorHandlerFn = operationErrorHandler(ctx, op)
write("#T(context, call)", errorHandlerFn)
write("#T(context, call, payload)", errorHandlerFn)
}
}

Expand All @@ -587,7 +600,6 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
RuntimeTypes.Serde.SerialKind,
RuntimeTypes.Serde.deserializeStruct,
RuntimeTypes.Http.Response.HttpResponse,
RuntimeTypes.HttpClient.Operation.HttpDeserialize,
)

val deserializerSymbol = buildSymbol {
Expand All @@ -598,16 +610,19 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
reference(outputSymbol, SymbolReference.ContextOption.DECLARE)
}

// exception deserializers are never streaming
val serdeMeta = HttpSerdeMeta(false)

ctx.delegator.useSymbolWriter(deserializerSymbol) { writer ->
val resolver = getProtocolHttpBindingResolver(ctx.model, ctx.service)
val responseBindings = resolver.responseBindings(shape)
writer
.addImport(exceptionDeserializerSymbols)
.write("")
.openBlock("internal class #T: #T<#T> {", deserializerSymbol, RuntimeTypes.HttpClient.Operation.HttpDeserialize, outputSymbol)
.openBlock("internal class #T: #T.NonStreaming<#T> {", deserializerSymbol, RuntimeTypes.HttpClient.Operation.HttpDeserializer, outputSymbol)
.write("")
.call {
renderHttpDeserialize(ctx, outputSymbol, responseBindings, null, writer)
renderHttpDeserialize(ctx, outputSymbol, responseBindings, serdeMeta, null, writer)
}
.closeBlock("}")
}
Expand All @@ -617,18 +632,31 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
ctx: ProtocolGenerator.GenerationContext,
outputSymbol: Symbol,
responseBindings: List<HttpBindingDescriptor>,
serdeMeta: HttpSerdeMeta,
// this method is shared between operation and exception deserialization. In the case of operations this MUST be set
op: OperationShape?,
writer: KotlinWriter,
) {
writer
.openBlock(
"override suspend fun deserialize(context: #T, call: #T): #T {",
RuntimeTypes.Core.ExecutionContext,
RuntimeTypes.Http.HttpCall,
outputSymbol,
)
.write("val response = call.response")
if (serdeMeta.isStreaming) {
writer
.openBlock(
"override suspend fun deserialize(context: #T, call: #T): #T {",
RuntimeTypes.Core.ExecutionContext,
RuntimeTypes.Http.HttpCall,
outputSymbol,
)
} else {
writer
.openBlock(
"override fun deserialize(context: #T, call: #T, payload: #T?): #T {",
RuntimeTypes.Core.ExecutionContext,
RuntimeTypes.Http.HttpCall,
KotlinTypes.ByteArray,
outputSymbol,
)
}

writer.write("val response = call.response")
.call {
if (outputSymbol.shape?.isError == false && op != null) {
// handle operation errors
Expand Down Expand Up @@ -657,7 +685,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
if (op != null && op.isOutputEventStream(ctx.model)) {
deserializeViaEventStream(ctx, op, writer)
} else {
deserializeViaPayload(ctx, outputSymbol, responseBindings, op, writer)
deserializeViaPayload(ctx, outputSymbol, responseBindings, serdeMeta, op, writer)
}
}
.call {
Expand All @@ -681,6 +709,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
ctx: ProtocolGenerator.GenerationContext,
outputSymbol: Symbol,
responseBindings: List<HttpBindingDescriptor>,
serdeMeta: HttpSerdeMeta,
// this method is shared between operation and exception deserialization. In the case of operations this MUST be set
op: OperationShape?,
writer: KotlinWriter,
Expand All @@ -707,10 +736,11 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
sdg.errorDeserializer(ctx, outputSymbol.shape as StructureShape, documentMembers)
}

writer.write("val payload = response.body.#T()", RuntimeTypes.Http.readAll)
.withBlock("if (payload != null) {", "}") {
if (!serdeMeta.isStreaming) {
writer.withBlock("if (payload != null) {", "}") {
write("#T(builder, payload)", bodyDeserializerFn)
}
}
}
}
}
Expand Down Expand Up @@ -872,7 +902,6 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
""
}

// writer.addImport("${KotlinDependency.CLIENT_RT_HTTP.namespace}.util", splitFn)
writer
.addImport(splitFn, KotlinDependency.HTTP, subpackage = "util")
.write("builder.#L = response.headers.getAll(#S)?.flatMap(::$splitFn)$mapFn", memberName, headerName)
Expand Down Expand Up @@ -940,9 +969,12 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
val memberName = binding.member.defaultName()
val target = ctx.model.expectShape(binding.member.target)
val targetSymbol = ctx.symbolProvider.toSymbol(target)

// NOTE: we don't need serde metadata to know what to do here. Everything is non-streaming except streaming
// blob payloads.
when (target.type) {
ShapeType.STRING -> {
writer.write("val contents = response.body.#T()?.decodeToString()", RuntimeTypes.Http.readAll)
writer.write("val contents = payload?.decodeToString()")
if (target.isEnum) {
writer.write("builder.$memberName = contents?.let { #T.fromValue(it) }", targetSymbol)
} else {
Expand All @@ -951,36 +983,32 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
}

ShapeType.ENUM -> {
writer.write("val contents = response.body.#T()?.decodeToString()", RuntimeTypes.Http.readAll)
writer.write("val contents = payload?.decodeToString()")
writer.write("builder.#L = contents?.let { #T.fromValue(it) }", memberName, targetSymbol)
}

ShapeType.INT_ENUM -> {
writer.write("val contents = response.body.#T()?.decodeToString()", RuntimeTypes.Http.readAll)
writer.write("val contents = payload?.decodeToString()")
writer.write("builder.#L = contents?.let { #T.fromValue(it.toInt()) }", memberName, targetSymbol)
}

ShapeType.BLOB -> {
val isBinaryStream = target.hasTrait<StreamingTrait>()
val conversion = if (isBinaryStream) {
writer.addImport(RuntimeTypes.Http.toByteStream)
"toByteStream()"
if (isBinaryStream) {
writer.write("builder.#L = response.body.#T()", memberName, RuntimeTypes.Http.toByteStream)
} else {
writer.addImport(RuntimeTypes.Http.readAll)
"readAll()"
writer.write("builder.#L = payload", memberName)
}
writer.write("builder.$memberName = response.body.$conversion")
}

ShapeType.STRUCTURE, ShapeType.UNION, ShapeType.DOCUMENT -> {
// delegate to the payload deserializer
val sdg = structuredDataParser(ctx)
val payloadDeserializerFn = sdg.payloadDeserializer(ctx, binding.member)

writer.write("val payload = response.body.#T()", RuntimeTypes.Http.readAll)
.withBlock("if (payload != null) {", "}") {
write("builder.#L = #T(payload)", memberName, payloadDeserializerFn)
}
writer.withBlock("if (payload != null) {", "}") {
write("builder.#L = #T(payload)", memberName, payloadDeserializerFn)
}
}

else ->
Expand Down Expand Up @@ -1061,3 +1089,18 @@ private fun renderNonBlankGuard(ctx: ProtocolGenerator.GenerationContext, member
private fun MemberShape.isNonBlankInStruct(ctx: ProtocolGenerator.GenerationContext): Boolean =
ctx.model.expectShape(target).isStringShape &&
getTrait<LengthTrait>()?.min?.getOrNull()?.takeIf { it > 0 } != null

private data class HttpSerdeMeta(val isStreaming: Boolean) {
/**
* The name of the HttpSerializer<T>/HttpDeserializer<T> variant
*/
val variantName: String
get() = if (isStreaming) "Streaming" else "NonStreaming"
}

private fun httpDeserializerInfo(ctx: ProtocolGenerator.GenerationContext, op: OperationShape): HttpSerdeMeta {
val isStreaming = ctx.model.expectShape<StructureShape>(op.output.get()).hasStreamingMember(ctx.model) ||
op.isOutputEventStream(ctx.model)

return HttpSerdeMeta(isStreaming)
}
Loading

0 comments on commit 5acf7ef

Please sign in to comment.