Skip to content

Commit

Permalink
Remove usage of percent_decode_str for decoding query string key-valu…
Browse files Browse the repository at this point in the history
…e pairs (#1417)

* Use `form_urlencoding::parse` over `serde_urlencoded::from_str`, this removes our dependency on `serde_urlencoded`.

* Remove `percent_encoding::percent_decode_str` where possible.

* Contract `generateParsePercentEncodedStrFn` and its children into one function and make it optionally apply percent encoding.

* Add `impl From<std::convert::Infallible> for RequestRejection` to remove codegen branch.
  • Loading branch information
hlbarber authored May 30, 2022
1 parent fe92efb commit c503deb
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 111 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig
*/
object ServerCargoDependency {
val AsyncTrait: CargoDependency = CargoDependency("async-trait", CratesIo("0.1"))
val FormUrlEncoded: CargoDependency = CargoDependency("form_urlencoded", CratesIo("1"))
val FuturesUtil: CargoDependency = CargoDependency("futures-util", CratesIo("0.3"))
val Nom: CargoDependency = CargoDependency("nom", CratesIo("7"))
val PinProjectLite: CargoDependency = CargoDependency("pin-project-lite", CratesIo("0.2"))
val SerdeUrlEncoded: CargoDependency = CargoDependency("serde_urlencoded", CratesIo("0.7"))
val Tower: CargoDependency = CargoDependency("tower", CratesIo("0.4"))

fun SmithyHttpServer(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("http-server")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@ import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.knowledge.HttpBindingIndex
import software.amazon.smithy.model.node.ExpectationNotMetException
import software.amazon.smithy.model.shapes.BooleanShape
import software.amazon.smithy.model.shapes.CollectionShape
import software.amazon.smithy.model.shapes.NumberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.model.traits.HttpErrorTrait
import software.amazon.smithy.model.traits.HttpTrait
Expand Down Expand Up @@ -57,15 +58,13 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBoundProtocolPay
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator
import software.amazon.smithy.rust.codegen.smithy.rustType
import software.amazon.smithy.rust.codegen.smithy.toOptional
import software.amazon.smithy.rust.codegen.smithy.wrapOptional
import software.amazon.smithy.rust.codegen.util.dq
import software.amazon.smithy.rust.codegen.util.expectTrait
import software.amazon.smithy.rust.codegen.util.findStreamingMember
import software.amazon.smithy.rust.codegen.util.getTrait
import software.amazon.smithy.rust.codegen.util.hasStreamingMember
import software.amazon.smithy.rust.codegen.util.hasTrait
import software.amazon.smithy.rust.codegen.util.inputShape
import software.amazon.smithy.rust.codegen.util.isStreaming
import software.amazon.smithy.rust.codegen.util.outputShape
Expand Down Expand Up @@ -120,14 +119,14 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
"AsyncTrait" to ServerCargoDependency.AsyncTrait.asType(),
"Cow" to ServerRuntimeType.Cow,
"DateTime" to RuntimeType.DateTime(runtimeConfig),
"FormUrlEncoded" to ServerCargoDependency.FormUrlEncoded.asType(),
"HttpBody" to CargoDependency.HttpBody.asType(),
"header_util" to CargoDependency.SmithyHttp(runtimeConfig).asType().member("header"),
"Hyper" to CargoDependency.Hyper.asType(),
"LazyStatic" to CargoDependency.LazyStatic.asType(),
"Nom" to ServerCargoDependency.Nom.asType(),
"PercentEncoding" to CargoDependency.PercentEncoding.asType(),
"Regex" to CargoDependency.Regex.asType(),
"SerdeUrlEncoded" to ServerCargoDependency.SerdeUrlEncoded.asType(),
"SmithyHttp" to CargoDependency.SmithyHttp(runtimeConfig).asType(),
"SmithyHttpServer" to ServerCargoDependency.SmithyHttpServer(runtimeConfig).asType(),
"RuntimeError" to ServerRuntimeType.RuntimeError(runtimeConfig),
Expand Down Expand Up @@ -775,7 +774,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
.forEachIndexed { index, segment ->
val binding = pathBindings.find { it.memberName == segment.content }
if (binding != null && segment.isLabel) {
val deserializer = generateParsePercentEncodedStrFn(binding)
val deserializer = generateParseFn(binding, true)
rustTemplate(
"""
input = input.${binding.member.setterName()}(
Expand Down Expand Up @@ -847,7 +846,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
rustTemplate(
"""
let query_string = request.uri().query().unwrap_or("");
let pairs = #{SerdeUrlEncoded}::from_str::<Vec<(#{Cow}<'_, str>, #{Cow}<'_, str>)>>(query_string)?;
let pairs = #{FormUrlEncoded}::parse(query_string.as_bytes());
""".trimIndent(),
*codegenScope
)
Expand All @@ -870,7 +869,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(

rustBlock("for (k, v) in pairs") {
queryBindingsTargettingSimple.forEach {
val deserializer = generateParsePercentEncodedStrFn(it)
val deserializer = generateParseFn(it, false)
val memberName = symbolProvider.toMemberName(it.member)
rustTemplate(
"""
Expand All @@ -891,25 +890,15 @@ private class ServerHttpBoundProtocolTraitImplGenerator(

when {
memberShape.isStringShape -> {
// `<_>::from()/try_from()` is necessary to convert the `&str` into:
// * the Rust enum in case the `string` shape has the `enum` trait; or
// * `String` in case it doesn't.
if (memberShape.hasTrait<EnumTrait>()) {
rustTemplate(
"""
let v = <#{memberShape}>::try_from(#{PercentEncoding}::percent_decode_str(&v).decode_utf8()?.as_ref())?;
""",
*codegenScope,
"memberShape" to symbolProvider.toSymbol(memberShape),
)
} else {
rustTemplate(
"""
let v = <_>::from(#{PercentEncoding}::percent_decode_str(&v).decode_utf8()?.as_ref());
""".trimIndent(),
*codegenScope
)
}
// NOTE: This path is traversed with or without @enum applied. The `try_from` is used
// as a common conversion.
rustTemplate(
"""
let v = <#{memberShape}>::try_from(v.as_ref())?;
""",
*codegenScope,
"memberShape" to symbolProvider.toSymbol(memberShape),
)
}
memberShape.isTimestampShape -> {
val index = HttpBindingIndex.of(model)
Expand All @@ -922,7 +911,6 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat)
rustTemplate(
"""
let v = #{PercentEncoding}::percent_decode_str(&v).decode_utf8()?;
let v = #{DateTime}::from_str(&v, #{format})?;
""".trimIndent(),
*codegenScope,
Expand Down Expand Up @@ -1013,21 +1001,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
)
}

// TODO(https://github.com/awslabs/smithy-rs/issues/1231): If this function was called to parse a query string
// key value pair, we don't need to percent-decode it _again_.
private fun generateParsePercentEncodedStrFn(binding: HttpBindingDescriptor): RuntimeType {
// HTTP bindings we support that contain percent-encoded data.
check(binding.location == HttpLocation.LABEL || binding.location == HttpLocation.QUERY)

val target = model.expectShape(binding.member.target)
return when {
target.isStringShape -> generateParsePercentEncodedStrAsStringFn(binding)
target.isTimestampShape -> generateParsePercentEncodedStrAsTimestampFn(binding)
else -> generateParseStrAsPrimitiveFn(binding)
}
}

private fun generateParsePercentEncodedStrAsStringFn(binding: HttpBindingDescriptor): RuntimeType {
private fun generateParseFn(binding: HttpBindingDescriptor, percentDecoding: Boolean): RuntimeType {
val output = symbolProvider.toSymbol(binding.member)
val fnName = generateParseStrFnName(binding)
val symbol = output.extractSymbolFromOption()
Expand All @@ -1037,85 +1011,74 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
*codegenScope,
"O" to output,
) {
// `<_>::from()` is necessary to convert the `&str` into:
// * the Rust enum in case the `string` shape has the `enum` trait; or
// * `String` in case it doesn't.
when (symbol.rustType()) {
RustType.String ->
rustTemplate(
"""
let value = <#{T}>::from(#{PercentEncoding}::percent_decode_str(value).decode_utf8()?.as_ref());
""",
*codegenScope,
"T" to symbol,
)
else -> { // RustType.Opaque, the Enum
check(symbol.rustType() is RustType.Opaque)
val target = model.expectShape(binding.member.target)

when {
target.isStringShape -> {
// NOTE: This path is traversed with or without @enum applied. The `try_from` is used as a
// common conversion.
if (percentDecoding) {
rustTemplate(
"""
let value = #{PercentEncoding}::percent_decode_str(value).decode_utf8()?;
let value = #{T}::try_from(value.as_ref())?;
""",
*codegenScope,
"T" to symbol,
)
} else {
rustTemplate(
"""
let value = #{T}::try_from(value)?;
""",
"T" to symbol,
)
}
}
target.isTimestampShape -> {
val index = HttpBindingIndex.of(model)
val timestampFormat =
index.determineTimestampFormat(
binding.member,
binding.location,
protocol.defaultTimestampFormat,
)
val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat)

if (percentDecoding) {
rustTemplate(
"""
let value = #{PercentEncoding}::percent_decode_str(value).decode_utf8()?;
let value = #{DateTime}::from_str(value.as_ref(), #{format})?;
""",
*codegenScope,
"format" to timestampFormatType,
)
} else {
rustTemplate(
"""
let value = #{DateTime}::from_str(value, #{format})?;
""",
*codegenScope,
"format" to timestampFormatType,
)
}
}
else -> {
check(target is NumberShape || target is BooleanShape)
rustTemplate(
"""
let value = <#{T}>::try_from(#{PercentEncoding}::percent_decode_str(value).decode_utf8()?.as_ref())?;
let value = std::str::FromStr::from_str(value)?;
""",
*codegenScope,
"T" to symbol,
)
}
}
writer.write(
"""
Ok(${symbolProvider.wrapOptional(binding.member, "value")})
"""
)
}
}
}

private fun generateParsePercentEncodedStrAsTimestampFn(binding: HttpBindingDescriptor): RuntimeType {
val output = symbolProvider.toSymbol(binding.member)
val fnName = generateParseStrFnName(binding)
val index = HttpBindingIndex.of(model)
val timestampFormat =
index.determineTimestampFormat(
binding.member,
binding.location,
protocol.defaultTimestampFormat,
)
val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat)
return RuntimeType.forInlineFun(fnName, operationDeserModule) { writer ->
writer.rustBlockTemplate(
"pub fn $fnName(value: &str) -> std::result::Result<#{O}, #{RequestRejection}>",
*codegenScope,
"O" to output,
) {
rustTemplate(
writer.write(
"""
let value = #{PercentEncoding}::percent_decode_str(value).decode_utf8()?;
let value = #{DateTime}::from_str(&value, #{format})?;
Ok(${symbolProvider.wrapOptional(binding.member, "value")})
""".trimIndent(),
*codegenScope,
"format" to timestampFormatType,
)
}
}
}

// Function to parse a string as the data type generated for boolean, byte, short, integer, long, float, or double shapes.
// TODO(https://github.com/awslabs/smithy-rs/issues/1232): This function can be replaced by https://docs.rs/aws-smithy-types/latest/aws_smithy_types/primitive/trait.Parse.html
private fun generateParseStrAsPrimitiveFn(binding: HttpBindingDescriptor): RuntimeType {
val output = symbolProvider.toSymbol(binding.member)
val fnName = generateParseStrFnName(binding)
return RuntimeType.forInlineFun(fnName, operationDeserModule) { writer ->
writer.rustBlockTemplate(
"pub fn $fnName(value: &str) -> std::result::Result<#{O}, #{RequestRejection}>",
*codegenScope,
"O" to output,
) {
rustTemplate(
"""
let value = std::str::FromStr::from_str(value)?;
Ok(${symbolProvider.wrapOptional(binding.member, "value")})
""".trimIndent(),
*codegenScope,
)
}
}
Expand Down
19 changes: 19 additions & 0 deletions rust-runtime/aws-smithy-http-server/src/rejection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,25 @@ pub enum RequestRejection {

impl std::error::Error for RequestRejection {}

// Consider a conversion between `T` and `U` followed by a bubbling up of the conversion error
// through `Result<_, RequestRejection>`. This [`From`] implementation accomodates the special case
// where `T` and `U` are equal, in such cases `T`/`U` a enjoy `TryFrom<T>` with
// `Err = std::convert::Infallible`.
//
// Note that when `!` stabilizes `std::convert::Infallible` will become an alias for `!` and there
// will be a blanket `impl From<!> for T`. This will remove the need for this implementation.
//
// More details on this can be found in the following links:
// - https://doc.rust-lang.org/std/primitive.never.html
// - https://doc.rust-lang.org/std/convert/enum.Infallible.html#future-compatibility
impl From<std::convert::Infallible> for RequestRejection {
fn from(_err: std::convert::Infallible) -> Self {
// We opt for this `match` here rather than [`unreachable`] to assure the reader that this
// code path is dead.
match _err {}
}
}

// These converters are solely to make code-generation simpler. They convert from a specific error
// type (from a runtime/third-party crate or the standard library) into a variant of the
// [`crate::rejection::RequestRejection`] enum holding the type-erased boxed [`crate::Error`]
Expand Down

0 comments on commit c503deb

Please sign in to comment.