Skip to content

Commit

Permalink
feat: add support for smithy.protocols#rpcv2Cbor protocol (#1103)
Browse files Browse the repository at this point in the history
  • Loading branch information
lauzadis authored Jul 12, 2024
1 parent 823372c commit 14cf870
Show file tree
Hide file tree
Showing 82 changed files with 6,704 additions and 134 deletions.
8 changes: 8 additions & 0 deletions .changes/4f6cf597-a267-4bca-a8b9-98aa055a9a72.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "4f6cf597-a267-4bca-a8b9-98aa055a9a72",
"type": "feature",
"description": "Add support for `smithy.protocols#rpcv2Cbor` protocol",
"issues": [
"https://github.com/awslabs/aws-sdk-kotlin/issues/1302"
]
}
8 changes: 8 additions & 0 deletions .changes/d079d678-08d3-437a-b638-2ef11e339938.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "d079d678-08d3-437a-b638-2ef11e339938",
"type": "feature",
"description": "Add support for prioritized protocol resolution",
"issues": [
"https://github.com/smithy-lang/smithy-kotlin/issues/843"
]
}
2 changes: 2 additions & 0 deletions codegen/protocol-tests/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ val enabledProtocols = listOf(
ProtocolTest("aws-restxml", "aws.protocoltests.restxml#RestXml"),
ProtocolTest("aws-restxml-xmlns", "aws.protocoltests.restxml.xmlns#RestXmlWithNamespace"),
ProtocolTest("aws-query", "aws.protocoltests.query#AwsQuery"),
ProtocolTest("smithy-rpcv2-cbor", "smithy.protocoltests.rpcv2Cbor#RpcV2Protocol"),

// Custom hand written tests
ProtocolTest("error-correction-json", "aws.protocoltests.errorcorrection#RequiredValueJson"),
Expand Down Expand Up @@ -82,6 +83,7 @@ dependencies {
// the aws-protocol-tests dependency is found when generating code such that the `includeServices` transform
// actually works
codegen(libs.smithy.aws.protocol.tests)
codegen(libs.smithy.protocol.tests)
}

tasks.generateSmithyProjections {
Expand Down
1 change: 1 addition & 0 deletions codegen/smithy-aws-kotlin-codegen/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies {
api(libs.smithy.aws.iam.traits)
api(libs.smithy.aws.cloudformation.traits)
api(libs.smithy.protocol.test.traits)
api(libs.smithy.protocol.traits)
implementation(libs.smithy.aws.endpoints)

testImplementation(libs.junit.jupiter)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@ class SdkProtocolGeneratorSupplier : KotlinIntegration {
RestXml(),
AwsQuery(),
Ec2Query(),
RpcV2Cbor(),
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.kotlin.codegen.aws.protocols

import software.amazon.smithy.kotlin.codegen.aws.protocols.core.AwsHttpBindingProtocolGenerator
import software.amazon.smithy.kotlin.codegen.aws.protocols.core.StaticHttpBindingResolver
import software.amazon.smithy.kotlin.codegen.core.KotlinWriter
import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes
import software.amazon.smithy.kotlin.codegen.model.*
import software.amazon.smithy.kotlin.codegen.model.traits.SyntheticClone
import software.amazon.smithy.kotlin.codegen.rendering.protocol.*
import software.amazon.smithy.kotlin.codegen.rendering.serde.CborParserGenerator
import software.amazon.smithy.kotlin.codegen.rendering.serde.CborSerializerGenerator
import software.amazon.smithy.kotlin.codegen.rendering.serde.StructuredDataParserGenerator
import software.amazon.smithy.kotlin.codegen.rendering.serde.StructuredDataSerializerGenerator
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.knowledge.HttpBinding
import software.amazon.smithy.model.pattern.UriPattern
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.traits.HttpTrait
import software.amazon.smithy.model.traits.TimestampFormatTrait
import software.amazon.smithy.model.traits.UnitTypeTrait
import software.amazon.smithy.protocol.traits.Rpcv2CborTrait

class RpcV2Cbor : AwsHttpBindingProtocolGenerator() {
override val protocol: ShapeId = Rpcv2CborTrait.ID

// TODO Timestamp format is not used in RpcV2Cbor since it's a binary protocol. We seem to be missing an abstraction
// between text-based and binary-based protocols
override val defaultTimestampFormat = TimestampFormatTrait.Format.UNKNOWN

override fun getProtocolHttpBindingResolver(model: Model, serviceShape: ServiceShape): HttpBindingResolver =
RpcV2CborHttpBindingResolver(model, serviceShape)

override fun structuredDataSerializer(ctx: ProtocolGenerator.GenerationContext): StructuredDataSerializerGenerator =
CborSerializerGenerator(this)

override fun structuredDataParser(ctx: ProtocolGenerator.GenerationContext): StructuredDataParserGenerator =
CborParserGenerator(this)

override fun renderDeserializeErrorDetails(ctx: ProtocolGenerator.GenerationContext, op: OperationShape, writer: KotlinWriter) {
writer.write("#T.deserialize(payload)", RuntimeTypes.SmithyRpcV2Protocols.Cbor.RpcV2CborErrorDeserializer)
}

override fun getDefaultHttpMiddleware(ctx: ProtocolGenerator.GenerationContext): List<ProtocolMiddleware> {
// Every request MUST contain a `smithy-protocol` header with the value of `rpc-v2-cbor`
val smithyProtocolHeaderMiddleware = MutateHeadersMiddleware(overrideHeaders = mapOf("smithy-protocol" to "rpc-v2-cbor"))

// Every response MUST contain the same `smithy-protocol` header, otherwise it's considered invalid
val validateSmithyProtocolHeaderMiddleware = object : ProtocolMiddleware {
override val name: String = "RpcV2CborValidateSmithyProtocolResponseHeader"
override fun render(ctx: ProtocolGenerator.GenerationContext, op: OperationShape, writer: KotlinWriter) {
val interceptorSymbol = RuntimeTypes.SmithyRpcV2Protocols.Cbor.RpcV2CborSmithyProtocolResponseHeaderInterceptor
writer.write("op.interceptors.add(#T)", interceptorSymbol)
}
}

// Requests with event stream responses MUST include an `Accept` header set to the value `application/vnd.amazon.eventstream`
val eventStreamsAcceptHeaderMiddleware = object : ProtocolMiddleware {
private val mutateHeadersMiddleware = MutateHeadersMiddleware(extraHeaders = mapOf("Accept" to "application/vnd.amazon.eventstream"))

override fun isEnabledFor(ctx: ProtocolGenerator.GenerationContext, op: OperationShape): Boolean = op.isOutputEventStream(ctx.model)
override val name: String = "RpcV2CborEventStreamsAcceptHeaderMiddleware"
override fun render(ctx: ProtocolGenerator.GenerationContext, op: OperationShape, writer: KotlinWriter) = mutateHeadersMiddleware.render(ctx, op, writer)
}

// Emit a metric to track usage of RpcV2Cbor
val businessMetricsMiddleware = object : ProtocolMiddleware {
override val name: String = "RpcV2CborBusinessMetricsMiddleware"
override fun render(ctx: ProtocolGenerator.GenerationContext, op: OperationShape, writer: KotlinWriter) {
writer.write("op.context.#T(#T.PROTOCOL_RPC_V2_CBOR)", RuntimeTypes.Core.BusinessMetrics.emitBusinessMetric, RuntimeTypes.Core.BusinessMetrics.SmithyBusinessMetric)
}
}

return super.getDefaultHttpMiddleware(ctx) + listOf(
smithyProtocolHeaderMiddleware,
validateSmithyProtocolHeaderMiddleware,
eventStreamsAcceptHeaderMiddleware,
businessMetricsMiddleware,
)
}

/**
* Exact copy of [HttpBindingProtocolGenerator.renderSerializeHttpBody] but with a custom
* [OperationShape.hasHttpBody] function to handle protocol-specific serialization rules.
*/
override fun renderSerializeHttpBody(
ctx: ProtocolGenerator.GenerationContext,
op: OperationShape,
writer: KotlinWriter,
) {
val resolver = getProtocolHttpBindingResolver(ctx.model, ctx.service)
if (!op.hasHttpBody(ctx)) return

// payload member(s)
val requestBindings = resolver.requestBindings(op)
val httpPayload = requestBindings.firstOrNull { it.location == HttpBinding.Location.PAYLOAD }
if (httpPayload != null) {
renderExplicitHttpPayloadSerializer(ctx, httpPayload, writer)
} else {
val documentMembers = requestBindings.filterDocumentBoundMembers()
// Unbound document members that should be serialized into the document format for the protocol.
// delegate to the generate operation body serializer function
val sdg = structuredDataSerializer(ctx)
val opBodySerializerFn = sdg.operationSerializer(ctx, op, documentMembers)
writer.write("builder.body = #T(context, input)", opBodySerializerFn)
}
renderContentTypeHeader(ctx, op, writer, resolver)
}

/**
* @return whether the operation input does _not_ target the unit shape ([UnitTypeTrait.UNIT])
*/
private fun OperationShape.hasHttpBody(ctx: ProtocolGenerator.GenerationContext): Boolean {
val input = ctx.model.expectShape(inputShape).targetOrSelf(ctx.model).let {
// If the input has been synthetically cloned from the original (most likely),
// pull the archetype and check _that_
it.getTrait<SyntheticClone>()?.let { clone ->
ctx.model.expectShape(clone.archetype).targetOrSelf(ctx.model)
} ?: it
}

return input.id != UnitTypeTrait.UNIT
}

override fun renderContentTypeHeader(
ctx: ProtocolGenerator.GenerationContext,
op: OperationShape,
writer: KotlinWriter,
resolver: HttpBindingResolver,
) {
writer.write("builder.headers.setMissing(\"Content-Type\", #S)", resolver.determineRequestContentType(op))
}

class RpcV2CborHttpBindingResolver(
model: Model,
val serviceShape: ServiceShape,
) : StaticHttpBindingResolver(
model,
serviceShape,
HttpTrait.builder().code(200).method("POST").uri(UriPattern.parse("/")).build(),
"application/cbor",
TimestampFormatTrait.Format.UNKNOWN,
) {

override fun httpTrait(operationShape: OperationShape): HttpTrait = HttpTrait
.builder()
.code(200)
.method("POST")
.uri(UriPattern.parse("/service/${serviceShape.id.name}/operation/${operationShape.id.name}"))
.build()

override fun determineRequestContentType(operationShape: OperationShape): String = when {
operationShape.isInputEventStream(model) -> "application/vnd.amazon.eventstream"
else -> "application/cbor"
}
}
}
1 change: 1 addition & 0 deletions codegen/smithy-kotlin-codegen-testutils/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ version = codegenVersion
dependencies {
implementation(kotlin("stdlib-jdk8"))
implementation(libs.smithy.aws.traits)
implementation(libs.smithy.protocol.traits)
api(project(":codegen:smithy-kotlin-codegen"))

// Test dependencies
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import software.amazon.smithy.model.knowledge.HttpBindingIndex
import software.amazon.smithy.model.shapes.*
import software.amazon.smithy.model.traits.TimestampFormatTrait
import software.amazon.smithy.model.traits.Trait
import software.amazon.smithy.protocol.traits.Rpcv2CborTrait
import software.amazon.smithy.utils.StringUtils

// This file houses test classes and functions relating to the code generator (protocols, serializers, etc)
Expand Down Expand Up @@ -150,6 +151,7 @@ private val allProtocols = setOf(
Ec2QueryTrait.ID,
RestJson1Trait.ID,
RestXmlTrait.ID,
Rpcv2CborTrait.ID,
)

/** An HttpBindingProtocolGenerator for testing (nothing is rendered for serializing/deserializing payload bodies) */
Expand Down
1 change: 1 addition & 0 deletions codegen/smithy-kotlin-codegen/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies {
api(libs.smithy.waiters)
implementation(libs.smithy.rules.engine)
implementation(libs.smithy.aws.traits)
implementation(libs.smithy.protocol.traits)
implementation(libs.smithy.protocol.test.traits)
implementation(libs.jsoup)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@

package software.amazon.smithy.kotlin.codegen

import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait
import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait
import software.amazon.smithy.aws.traits.protocols.AwsQueryTrait
import software.amazon.smithy.aws.traits.protocols.Ec2QueryTrait
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.kotlin.codegen.lang.isValidPackageName
import software.amazon.smithy.kotlin.codegen.utils.getOrNull
Expand All @@ -17,6 +23,7 @@ import software.amazon.smithy.model.node.StringNode
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.protocol.traits.Rpcv2CborTrait
import java.util.Optional
import java.util.logging.Logger
import kotlin.IllegalArgumentException
Expand All @@ -34,6 +41,17 @@ private const val API_SETTINGS = "api"
// Optional specification of sdkId for models that provide them, otherwise Service's shape id name is used
private const val SDK_ID = "sdkId"

// Prioritized list of protocols supported for code generation
private val DEFAULT_PROTOCOL_RESOLUTION_PRIORITY = setOf<ShapeId>(
Rpcv2CborTrait.ID,
AwsJson1_0Trait.ID,
AwsJson1_1Trait.ID,
RestJson1Trait.ID,
RestXmlTrait.ID,
AwsQueryTrait.ID,
Ec2QueryTrait.ID,
)

/**
* Settings used by [KotlinCodegenPlugin]
*/
Expand Down Expand Up @@ -133,9 +151,10 @@ data class KotlinSettings(
supportedProtocolTraits: Set<ShapeId>,
): ShapeId {
val resolvedProtocols: Set<ShapeId> = serviceIndex.getProtocols(service).keys
val protocol = resolvedProtocols.firstOrNull(supportedProtocolTraits::contains)
val protocol = api.protocolResolutionPriority.firstOrNull { it in resolvedProtocols && supportedProtocolTraits.contains(it) }
return protocol ?: throw UnresolvableProtocolException(
"The ${service.id} service supports the following unsupported protocols $resolvedProtocols. " +
"They were evaluated using the prioritized list: ${api.protocolResolutionPriority.joinToString()}. " +
"The following protocol generators were found on the class path: $supportedProtocolTraits",
)
}
Expand Down Expand Up @@ -195,7 +214,6 @@ data class BuildSettings(
}.orNull()
}
}.orNull()

BuildSettings(generateFullProject, generateBuildFiles, annotations, generateMultiplatformProject)
}.orElse(Default)

Expand Down Expand Up @@ -275,12 +293,14 @@ data class ApiSettings(
val nullabilityCheckMode: CheckMode = CheckMode.CLIENT_CAREFUL,
val defaultValueSerializationMode: DefaultValueSerializationMode = DefaultValueSerializationMode.WHEN_DIFFERENT,
val enableEndpointAuthProvider: Boolean = false,
val protocolResolutionPriority: Set<ShapeId> = DEFAULT_PROTOCOL_RESOLUTION_PRIORITY,
) {
companion object {
const val VISIBILITY = "visibility"
const val NULLABILITY_CHECK_MODE = "nullabilityCheckMode"
const val DEFAULT_VALUE_SERIALIZATION_MODE = "defaultValueSerializationMode"
const val ENABLE_ENDPOINT_AUTH_PROVIDER = "enableEndpointAuthProvider"
const val PROTOCOL_RESOLUTION_PRIORITY = "protocolResolutionPriority"

fun fromNode(node: Optional<ObjectNode>): ApiSettings = node.map {
val visibility = node.get()
Expand All @@ -299,7 +319,14 @@ data class ApiSettings(
),
)
val enableEndpointAuthProvider = node.get().getBooleanMemberOrDefault(ENABLE_ENDPOINT_AUTH_PROVIDER, false)
ApiSettings(visibility, checkMode, defaultValueSerializationMode, enableEndpointAuthProvider)

val protocolResolutionPriority = node.get()
.getArrayMember(PROTOCOL_RESOLUTION_PRIORITY).getOrNull()
?.map { ShapeId.from(it.asStringNode().get().value) }?.toSet() ?: run {
DEFAULT_PROTOCOL_RESOLUTION_PRIORITY
}

ApiSettings(visibility, checkMode, defaultValueSerializationMode, enableEndpointAuthProvider, protocolResolutionPriority)
}.orElse(Default)

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ data class KotlinDependency(
val SERDE_JSON = KotlinDependency(GradleConfiguration.Implementation, "$RUNTIME_ROOT_NS.serde.json", RUNTIME_GROUP, "serde-json", RUNTIME_VERSION)
val SERDE_XML = KotlinDependency(GradleConfiguration.Implementation, "$RUNTIME_ROOT_NS.serde.xml", RUNTIME_GROUP, "serde-xml", RUNTIME_VERSION)
val SERDE_FORM_URL = KotlinDependency(GradleConfiguration.Implementation, "$RUNTIME_ROOT_NS.serde.formurl", RUNTIME_GROUP, "serde-form-url", RUNTIME_VERSION)
val SERDE_CBOR = KotlinDependency(GradleConfiguration.Implementation, "$RUNTIME_ROOT_NS.serde.cbor", RUNTIME_GROUP, "serde-cbor", RUNTIME_VERSION)
val SMITHY_CLIENT = KotlinDependency(GradleConfiguration.Api, "$RUNTIME_ROOT_NS.client", RUNTIME_GROUP, "smithy-client", RUNTIME_VERSION)
val SMITHY_TEST = KotlinDependency(GradleConfiguration.TestImplementation, "$RUNTIME_ROOT_NS.smithy.test", RUNTIME_GROUP, "smithy-test", RUNTIME_VERSION)
val DEFAULT_HTTP_ENGINE = KotlinDependency(GradleConfiguration.Implementation, "$RUNTIME_ROOT_NS.http.engine", RUNTIME_GROUP, "http-client-engine-default", RUNTIME_VERSION)
Expand All @@ -125,6 +126,8 @@ data class KotlinDependency(
val HTTP_AUTH = KotlinDependency(GradleConfiguration.Implementation, "$RUNTIME_ROOT_NS.http.auth", RUNTIME_GROUP, "http-auth", RUNTIME_VERSION)
val HTTP_AUTH_AWS = KotlinDependency(GradleConfiguration.Implementation, "$RUNTIME_ROOT_NS.http.auth", RUNTIME_GROUP, "http-auth-aws", RUNTIME_VERSION)
val IDENTITY_API = KotlinDependency(GradleConfiguration.Implementation, "$RUNTIME_ROOT_NS", RUNTIME_GROUP, "identity-api", RUNTIME_VERSION)
val SMITHY_RPCV2_PROTOCOLS = KotlinDependency(GradleConfiguration.Implementation, "$RUNTIME_ROOT_NS.awsprotocol.rpcv2", RUNTIME_GROUP, "smithy-rpcv2-protocols", RUNTIME_VERSION)
val SMITHY_RPCV2_PROTOCOLS_CBOR = KotlinDependency(GradleConfiguration.Implementation, "$RUNTIME_ROOT_NS.awsprotocol.rpcv2.cbor", RUNTIME_GROUP, "smithy-rpcv2-protocols", RUNTIME_VERSION)

// External third-party dependencies
val KOTLIN_STDLIB = KotlinDependency(GradleConfiguration.Implementation, "kotlin", "org.jetbrains.kotlin", "kotlin-stdlib", KOTLIN_COMPILER_VERSION)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,12 @@ object RuntimeTypes {
val QueryLiteral = symbol("QueryLiteral")
val FormUrlSerializer = symbol("FormUrlSerializer")
}

object SerdeCbor : RuntimeTypePackage(KotlinDependency.SERDE_CBOR) {
val CborSerializer = symbol("CborSerializer")
val CborDeserializer = symbol("CborDeserializer")
val CborSerialName = symbol("CborSerialName")
}
}

object Auth {
Expand Down Expand Up @@ -422,6 +428,13 @@ object RuntimeTypes {
val parseEc2QueryErrorResponseNoSuspend = symbol("parseEc2QueryErrorResponseNoSuspend")
}

object SmithyRpcV2Protocols : RuntimeTypePackage(KotlinDependency.SMITHY_RPCV2_PROTOCOLS) {
object Cbor : RuntimeTypePackage(KotlinDependency.SMITHY_RPCV2_PROTOCOLS_CBOR) {
val RpcV2CborErrorDeserializer = symbol("RpcV2CborErrorDeserializer")
val RpcV2CborSmithyProtocolResponseHeaderInterceptor = symbol("RpcV2CborSmithyProtocolResponseHeaderInterceptor")
}
}

object AwsEventStream : RuntimeTypePackage(KotlinDependency.AWS_EVENT_STREAM) {
val HeaderValue = symbol("HeaderValue")
val Message = symbol("Message")
Expand Down
Loading

0 comments on commit 14cf870

Please sign in to comment.