Skip to content

Commit c503deb

Browse files
authored
Remove usage of percent_decode_str for decoding query string key-value pairs (#1417)
* Use `form_urlencoding::parse` over `serde_urlencoded::from_str`, this removes our dependency on `serde_urlencoded`. * Remove `percent_encoding::percent_decode_str` where possible. * Contract `generateParsePercentEncodedStrFn` and its children into one function and make it optionally apply percent encoding. * Add `impl From<std::convert::Infallible> for RequestRejection` to remove codegen branch.
1 parent fe92efb commit c503deb

File tree

3 files changed

+93
-111
lines changed

3 files changed

+93
-111
lines changed

codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ import software.amazon.smithy.rust.codegen.smithy.RuntimeConfig
1616
*/
1717
object ServerCargoDependency {
1818
val AsyncTrait: CargoDependency = CargoDependency("async-trait", CratesIo("0.1"))
19+
val FormUrlEncoded: CargoDependency = CargoDependency("form_urlencoded", CratesIo("1"))
1920
val FuturesUtil: CargoDependency = CargoDependency("futures-util", CratesIo("0.3"))
2021
val Nom: CargoDependency = CargoDependency("nom", CratesIo("7"))
2122
val PinProjectLite: CargoDependency = CargoDependency("pin-project-lite", CratesIo("0.2"))
22-
val SerdeUrlEncoded: CargoDependency = CargoDependency("serde_urlencoded", CratesIo("0.7"))
2323
val Tower: CargoDependency = CargoDependency("tower", CratesIo("0.4"))
2424

2525
fun SmithyHttpServer(runtimeConfig: RuntimeConfig) = runtimeConfig.runtimeCrate("http-server")

codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/protocols/ServerHttpBoundProtocolGenerator.kt

Lines changed: 73 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@ import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
1212
import software.amazon.smithy.codegen.core.Symbol
1313
import software.amazon.smithy.model.knowledge.HttpBindingIndex
1414
import software.amazon.smithy.model.node.ExpectationNotMetException
15+
import software.amazon.smithy.model.shapes.BooleanShape
1516
import software.amazon.smithy.model.shapes.CollectionShape
17+
import software.amazon.smithy.model.shapes.NumberShape
1618
import software.amazon.smithy.model.shapes.OperationShape
1719
import software.amazon.smithy.model.shapes.Shape
1820
import software.amazon.smithy.model.shapes.StringShape
1921
import software.amazon.smithy.model.shapes.StructureShape
20-
import software.amazon.smithy.model.traits.EnumTrait
2122
import software.amazon.smithy.model.traits.ErrorTrait
2223
import software.amazon.smithy.model.traits.HttpErrorTrait
2324
import software.amazon.smithy.model.traits.HttpTrait
@@ -57,15 +58,13 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBoundProtocolPay
5758
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation
5859
import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol
5960
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator
60-
import software.amazon.smithy.rust.codegen.smithy.rustType
6161
import software.amazon.smithy.rust.codegen.smithy.toOptional
6262
import software.amazon.smithy.rust.codegen.smithy.wrapOptional
6363
import software.amazon.smithy.rust.codegen.util.dq
6464
import software.amazon.smithy.rust.codegen.util.expectTrait
6565
import software.amazon.smithy.rust.codegen.util.findStreamingMember
6666
import software.amazon.smithy.rust.codegen.util.getTrait
6767
import software.amazon.smithy.rust.codegen.util.hasStreamingMember
68-
import software.amazon.smithy.rust.codegen.util.hasTrait
6968
import software.amazon.smithy.rust.codegen.util.inputShape
7069
import software.amazon.smithy.rust.codegen.util.isStreaming
7170
import software.amazon.smithy.rust.codegen.util.outputShape
@@ -120,14 +119,14 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
120119
"AsyncTrait" to ServerCargoDependency.AsyncTrait.asType(),
121120
"Cow" to ServerRuntimeType.Cow,
122121
"DateTime" to RuntimeType.DateTime(runtimeConfig),
122+
"FormUrlEncoded" to ServerCargoDependency.FormUrlEncoded.asType(),
123123
"HttpBody" to CargoDependency.HttpBody.asType(),
124124
"header_util" to CargoDependency.SmithyHttp(runtimeConfig).asType().member("header"),
125125
"Hyper" to CargoDependency.Hyper.asType(),
126126
"LazyStatic" to CargoDependency.LazyStatic.asType(),
127127
"Nom" to ServerCargoDependency.Nom.asType(),
128128
"PercentEncoding" to CargoDependency.PercentEncoding.asType(),
129129
"Regex" to CargoDependency.Regex.asType(),
130-
"SerdeUrlEncoded" to ServerCargoDependency.SerdeUrlEncoded.asType(),
131130
"SmithyHttp" to CargoDependency.SmithyHttp(runtimeConfig).asType(),
132131
"SmithyHttpServer" to ServerCargoDependency.SmithyHttpServer(runtimeConfig).asType(),
133132
"RuntimeError" to ServerRuntimeType.RuntimeError(runtimeConfig),
@@ -775,7 +774,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
775774
.forEachIndexed { index, segment ->
776775
val binding = pathBindings.find { it.memberName == segment.content }
777776
if (binding != null && segment.isLabel) {
778-
val deserializer = generateParsePercentEncodedStrFn(binding)
777+
val deserializer = generateParseFn(binding, true)
779778
rustTemplate(
780779
"""
781780
input = input.${binding.member.setterName()}(
@@ -847,7 +846,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
847846
rustTemplate(
848847
"""
849848
let query_string = request.uri().query().unwrap_or("");
850-
let pairs = #{SerdeUrlEncoded}::from_str::<Vec<(#{Cow}<'_, str>, #{Cow}<'_, str>)>>(query_string)?;
849+
let pairs = #{FormUrlEncoded}::parse(query_string.as_bytes());
851850
""".trimIndent(),
852851
*codegenScope
853852
)
@@ -870,7 +869,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
870869

871870
rustBlock("for (k, v) in pairs") {
872871
queryBindingsTargettingSimple.forEach {
873-
val deserializer = generateParsePercentEncodedStrFn(it)
872+
val deserializer = generateParseFn(it, false)
874873
val memberName = symbolProvider.toMemberName(it.member)
875874
rustTemplate(
876875
"""
@@ -891,25 +890,15 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
891890

892891
when {
893892
memberShape.isStringShape -> {
894-
// `<_>::from()/try_from()` is necessary to convert the `&str` into:
895-
// * the Rust enum in case the `string` shape has the `enum` trait; or
896-
// * `String` in case it doesn't.
897-
if (memberShape.hasTrait<EnumTrait>()) {
898-
rustTemplate(
899-
"""
900-
let v = <#{memberShape}>::try_from(#{PercentEncoding}::percent_decode_str(&v).decode_utf8()?.as_ref())?;
901-
""",
902-
*codegenScope,
903-
"memberShape" to symbolProvider.toSymbol(memberShape),
904-
)
905-
} else {
906-
rustTemplate(
907-
"""
908-
let v = <_>::from(#{PercentEncoding}::percent_decode_str(&v).decode_utf8()?.as_ref());
909-
""".trimIndent(),
910-
*codegenScope
911-
)
912-
}
893+
// NOTE: This path is traversed with or without @enum applied. The `try_from` is used
894+
// as a common conversion.
895+
rustTemplate(
896+
"""
897+
let v = <#{memberShape}>::try_from(v.as_ref())?;
898+
""",
899+
*codegenScope,
900+
"memberShape" to symbolProvider.toSymbol(memberShape),
901+
)
913902
}
914903
memberShape.isTimestampShape -> {
915904
val index = HttpBindingIndex.of(model)
@@ -922,7 +911,6 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
922911
val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat)
923912
rustTemplate(
924913
"""
925-
let v = #{PercentEncoding}::percent_decode_str(&v).decode_utf8()?;
926914
let v = #{DateTime}::from_str(&v, #{format})?;
927915
""".trimIndent(),
928916
*codegenScope,
@@ -1013,21 +1001,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
10131001
)
10141002
}
10151003

1016-
// TODO(https://github.com/awslabs/smithy-rs/issues/1231): If this function was called to parse a query string
1017-
// key value pair, we don't need to percent-decode it _again_.
1018-
private fun generateParsePercentEncodedStrFn(binding: HttpBindingDescriptor): RuntimeType {
1019-
// HTTP bindings we support that contain percent-encoded data.
1020-
check(binding.location == HttpLocation.LABEL || binding.location == HttpLocation.QUERY)
1021-
1022-
val target = model.expectShape(binding.member.target)
1023-
return when {
1024-
target.isStringShape -> generateParsePercentEncodedStrAsStringFn(binding)
1025-
target.isTimestampShape -> generateParsePercentEncodedStrAsTimestampFn(binding)
1026-
else -> generateParseStrAsPrimitiveFn(binding)
1027-
}
1028-
}
1029-
1030-
private fun generateParsePercentEncodedStrAsStringFn(binding: HttpBindingDescriptor): RuntimeType {
1004+
private fun generateParseFn(binding: HttpBindingDescriptor, percentDecoding: Boolean): RuntimeType {
10311005
val output = symbolProvider.toSymbol(binding.member)
10321006
val fnName = generateParseStrFnName(binding)
10331007
val symbol = output.extractSymbolFromOption()
@@ -1037,85 +1011,74 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
10371011
*codegenScope,
10381012
"O" to output,
10391013
) {
1040-
// `<_>::from()` is necessary to convert the `&str` into:
1041-
// * the Rust enum in case the `string` shape has the `enum` trait; or
1042-
// * `String` in case it doesn't.
1043-
when (symbol.rustType()) {
1044-
RustType.String ->
1045-
rustTemplate(
1046-
"""
1047-
let value = <#{T}>::from(#{PercentEncoding}::percent_decode_str(value).decode_utf8()?.as_ref());
1048-
""",
1049-
*codegenScope,
1050-
"T" to symbol,
1051-
)
1052-
else -> { // RustType.Opaque, the Enum
1053-
check(symbol.rustType() is RustType.Opaque)
1014+
val target = model.expectShape(binding.member.target)
1015+
1016+
when {
1017+
target.isStringShape -> {
1018+
// NOTE: This path is traversed with or without @enum applied. The `try_from` is used as a
1019+
// common conversion.
1020+
if (percentDecoding) {
1021+
rustTemplate(
1022+
"""
1023+
let value = #{PercentEncoding}::percent_decode_str(value).decode_utf8()?;
1024+
let value = #{T}::try_from(value.as_ref())?;
1025+
""",
1026+
*codegenScope,
1027+
"T" to symbol,
1028+
)
1029+
} else {
1030+
rustTemplate(
1031+
"""
1032+
let value = #{T}::try_from(value)?;
1033+
""",
1034+
"T" to symbol,
1035+
)
1036+
}
1037+
}
1038+
target.isTimestampShape -> {
1039+
val index = HttpBindingIndex.of(model)
1040+
val timestampFormat =
1041+
index.determineTimestampFormat(
1042+
binding.member,
1043+
binding.location,
1044+
protocol.defaultTimestampFormat,
1045+
)
1046+
val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat)
1047+
1048+
if (percentDecoding) {
1049+
rustTemplate(
1050+
"""
1051+
let value = #{PercentEncoding}::percent_decode_str(value).decode_utf8()?;
1052+
let value = #{DateTime}::from_str(value.as_ref(), #{format})?;
1053+
""",
1054+
*codegenScope,
1055+
"format" to timestampFormatType,
1056+
)
1057+
} else {
1058+
rustTemplate(
1059+
"""
1060+
let value = #{DateTime}::from_str(value, #{format})?;
1061+
""",
1062+
*codegenScope,
1063+
"format" to timestampFormatType,
1064+
)
1065+
}
1066+
}
1067+
else -> {
1068+
check(target is NumberShape || target is BooleanShape)
10541069
rustTemplate(
10551070
"""
1056-
let value = <#{T}>::try_from(#{PercentEncoding}::percent_decode_str(value).decode_utf8()?.as_ref())?;
1071+
let value = std::str::FromStr::from_str(value)?;
10571072
""",
10581073
*codegenScope,
1059-
"T" to symbol,
10601074
)
10611075
}
10621076
}
1063-
writer.write(
1064-
"""
1065-
Ok(${symbolProvider.wrapOptional(binding.member, "value")})
1066-
"""
1067-
)
1068-
}
1069-
}
1070-
}
10711077

1072-
private fun generateParsePercentEncodedStrAsTimestampFn(binding: HttpBindingDescriptor): RuntimeType {
1073-
val output = symbolProvider.toSymbol(binding.member)
1074-
val fnName = generateParseStrFnName(binding)
1075-
val index = HttpBindingIndex.of(model)
1076-
val timestampFormat =
1077-
index.determineTimestampFormat(
1078-
binding.member,
1079-
binding.location,
1080-
protocol.defaultTimestampFormat,
1081-
)
1082-
val timestampFormatType = RuntimeType.TimestampFormat(runtimeConfig, timestampFormat)
1083-
return RuntimeType.forInlineFun(fnName, operationDeserModule) { writer ->
1084-
writer.rustBlockTemplate(
1085-
"pub fn $fnName(value: &str) -> std::result::Result<#{O}, #{RequestRejection}>",
1086-
*codegenScope,
1087-
"O" to output,
1088-
) {
1089-
rustTemplate(
1078+
writer.write(
10901079
"""
1091-
let value = #{PercentEncoding}::percent_decode_str(value).decode_utf8()?;
1092-
let value = #{DateTime}::from_str(&value, #{format})?;
10931080
Ok(${symbolProvider.wrapOptional(binding.member, "value")})
1094-
""".trimIndent(),
1095-
*codegenScope,
1096-
"format" to timestampFormatType,
1097-
)
1098-
}
1099-
}
1100-
}
1101-
1102-
// Function to parse a string as the data type generated for boolean, byte, short, integer, long, float, or double shapes.
1103-
// TODO(https://github.com/awslabs/smithy-rs/issues/1232): This function can be replaced by https://docs.rs/aws-smithy-types/latest/aws_smithy_types/primitive/trait.Parse.html
1104-
private fun generateParseStrAsPrimitiveFn(binding: HttpBindingDescriptor): RuntimeType {
1105-
val output = symbolProvider.toSymbol(binding.member)
1106-
val fnName = generateParseStrFnName(binding)
1107-
return RuntimeType.forInlineFun(fnName, operationDeserModule) { writer ->
1108-
writer.rustBlockTemplate(
1109-
"pub fn $fnName(value: &str) -> std::result::Result<#{O}, #{RequestRejection}>",
1110-
*codegenScope,
1111-
"O" to output,
1112-
) {
1113-
rustTemplate(
11141081
"""
1115-
let value = std::str::FromStr::from_str(value)?;
1116-
Ok(${symbolProvider.wrapOptional(binding.member, "value")})
1117-
""".trimIndent(),
1118-
*codegenScope,
11191082
)
11201083
}
11211084
}

rust-runtime/aws-smithy-http-server/src/rejection.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,25 @@ pub enum RequestRejection {
194194

195195
impl std::error::Error for RequestRejection {}
196196

197+
// Consider a conversion between `T` and `U` followed by a bubbling up of the conversion error
198+
// through `Result<_, RequestRejection>`. This [`From`] implementation accomodates the special case
199+
// where `T` and `U` are equal, in such cases `T`/`U` a enjoy `TryFrom<T>` with
200+
// `Err = std::convert::Infallible`.
201+
//
202+
// Note that when `!` stabilizes `std::convert::Infallible` will become an alias for `!` and there
203+
// will be a blanket `impl From<!> for T`. This will remove the need for this implementation.
204+
//
205+
// More details on this can be found in the following links:
206+
// - https://doc.rust-lang.org/std/primitive.never.html
207+
// - https://doc.rust-lang.org/std/convert/enum.Infallible.html#future-compatibility
208+
impl From<std::convert::Infallible> for RequestRejection {
209+
fn from(_err: std::convert::Infallible) -> Self {
210+
// We opt for this `match` here rather than [`unreachable`] to assure the reader that this
211+
// code path is dead.
212+
match _err {}
213+
}
214+
}
215+
197216
// These converters are solely to make code-generation simpler. They convert from a specific error
198217
// type (from a runtime/third-party crate or the standard library) into a variant of the
199218
// [`crate::rejection::RequestRejection`] enum holding the type-erased boxed [`crate::Error`]

0 commit comments

Comments
 (0)