@@ -12,12 +12,13 @@ import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
12
12
import software.amazon.smithy.codegen.core.Symbol
13
13
import software.amazon.smithy.model.knowledge.HttpBindingIndex
14
14
import software.amazon.smithy.model.node.ExpectationNotMetException
15
+ import software.amazon.smithy.model.shapes.BooleanShape
15
16
import software.amazon.smithy.model.shapes.CollectionShape
17
+ import software.amazon.smithy.model.shapes.NumberShape
16
18
import software.amazon.smithy.model.shapes.OperationShape
17
19
import software.amazon.smithy.model.shapes.Shape
18
20
import software.amazon.smithy.model.shapes.StringShape
19
21
import software.amazon.smithy.model.shapes.StructureShape
20
- import software.amazon.smithy.model.traits.EnumTrait
21
22
import software.amazon.smithy.model.traits.ErrorTrait
22
23
import software.amazon.smithy.model.traits.HttpErrorTrait
23
24
import software.amazon.smithy.model.traits.HttpTrait
@@ -57,15 +58,13 @@ import software.amazon.smithy.rust.codegen.smithy.protocols.HttpBoundProtocolPay
57
58
import software.amazon.smithy.rust.codegen.smithy.protocols.HttpLocation
58
59
import software.amazon.smithy.rust.codegen.smithy.protocols.Protocol
59
60
import software.amazon.smithy.rust.codegen.smithy.protocols.parse.StructuredDataParserGenerator
60
- import software.amazon.smithy.rust.codegen.smithy.rustType
61
61
import software.amazon.smithy.rust.codegen.smithy.toOptional
62
62
import software.amazon.smithy.rust.codegen.smithy.wrapOptional
63
63
import software.amazon.smithy.rust.codegen.util.dq
64
64
import software.amazon.smithy.rust.codegen.util.expectTrait
65
65
import software.amazon.smithy.rust.codegen.util.findStreamingMember
66
66
import software.amazon.smithy.rust.codegen.util.getTrait
67
67
import software.amazon.smithy.rust.codegen.util.hasStreamingMember
68
- import software.amazon.smithy.rust.codegen.util.hasTrait
69
68
import software.amazon.smithy.rust.codegen.util.inputShape
70
69
import software.amazon.smithy.rust.codegen.util.isStreaming
71
70
import software.amazon.smithy.rust.codegen.util.outputShape
@@ -120,14 +119,14 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
120
119
" AsyncTrait" to ServerCargoDependency .AsyncTrait .asType(),
121
120
" Cow" to ServerRuntimeType .Cow ,
122
121
" DateTime" to RuntimeType .DateTime (runtimeConfig),
122
+ " FormUrlEncoded" to ServerCargoDependency .FormUrlEncoded .asType(),
123
123
" HttpBody" to CargoDependency .HttpBody .asType(),
124
124
" header_util" to CargoDependency .SmithyHttp (runtimeConfig).asType().member(" header" ),
125
125
" Hyper" to CargoDependency .Hyper .asType(),
126
126
" LazyStatic" to CargoDependency .LazyStatic .asType(),
127
127
" Nom" to ServerCargoDependency .Nom .asType(),
128
128
" PercentEncoding" to CargoDependency .PercentEncoding .asType(),
129
129
" Regex" to CargoDependency .Regex .asType(),
130
- " SerdeUrlEncoded" to ServerCargoDependency .SerdeUrlEncoded .asType(),
131
130
" SmithyHttp" to CargoDependency .SmithyHttp (runtimeConfig).asType(),
132
131
" SmithyHttpServer" to ServerCargoDependency .SmithyHttpServer (runtimeConfig).asType(),
133
132
" RuntimeError" to ServerRuntimeType .RuntimeError (runtimeConfig),
@@ -775,7 +774,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
775
774
.forEachIndexed { index, segment ->
776
775
val binding = pathBindings.find { it.memberName == segment.content }
777
776
if (binding != null && segment.isLabel) {
778
- val deserializer = generateParsePercentEncodedStrFn (binding)
777
+ val deserializer = generateParseFn (binding, true )
779
778
rustTemplate(
780
779
"""
781
780
input = input.${binding.member.setterName()} (
@@ -847,7 +846,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
847
846
rustTemplate(
848
847
"""
849
848
let query_string = request.uri().query().unwrap_or("");
850
- let pairs = #{SerdeUrlEncoded }::from_str::<Vec<(#{Cow}<'_, str>, #{Cow}<'_, str>)>>( query_string)? ;
849
+ let pairs = #{FormUrlEncoded }::parse( query_string.as_bytes()) ;
851
850
""" .trimIndent(),
852
851
* codegenScope
853
852
)
@@ -870,7 +869,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
870
869
871
870
rustBlock(" for (k, v) in pairs" ) {
872
871
queryBindingsTargettingSimple.forEach {
873
- val deserializer = generateParsePercentEncodedStrFn (it)
872
+ val deserializer = generateParseFn (it, false )
874
873
val memberName = symbolProvider.toMemberName(it.member)
875
874
rustTemplate(
876
875
"""
@@ -891,25 +890,15 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
891
890
892
891
when {
893
892
memberShape.isStringShape -> {
894
- // `<_>::from()/try_from()` is necessary to convert the `&str` into:
895
- // * the Rust enum in case the `string` shape has the `enum` trait; or
896
- // * `String` in case it doesn't.
897
- if (memberShape.hasTrait<EnumTrait >()) {
898
- rustTemplate(
899
- """
900
- let v = <#{memberShape}>::try_from(#{PercentEncoding}::percent_decode_str(&v).decode_utf8()?.as_ref())?;
901
- """ ,
902
- * codegenScope,
903
- " memberShape" to symbolProvider.toSymbol(memberShape),
904
- )
905
- } else {
906
- rustTemplate(
907
- """
908
- let v = <_>::from(#{PercentEncoding}::percent_decode_str(&v).decode_utf8()?.as_ref());
909
- """ .trimIndent(),
910
- * codegenScope
911
- )
912
- }
893
+ // NOTE: This path is traversed with or without @enum applied. The `try_from` is used
894
+ // as a common conversion.
895
+ rustTemplate(
896
+ """
897
+ let v = <#{memberShape}>::try_from(v.as_ref())?;
898
+ """ ,
899
+ * codegenScope,
900
+ " memberShape" to symbolProvider.toSymbol(memberShape),
901
+ )
913
902
}
914
903
memberShape.isTimestampShape -> {
915
904
val index = HttpBindingIndex .of(model)
@@ -922,7 +911,6 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
922
911
val timestampFormatType = RuntimeType .TimestampFormat (runtimeConfig, timestampFormat)
923
912
rustTemplate(
924
913
"""
925
- let v = #{PercentEncoding}::percent_decode_str(&v).decode_utf8()?;
926
914
let v = #{DateTime}::from_str(&v, #{format})?;
927
915
""" .trimIndent(),
928
916
* codegenScope,
@@ -1013,21 +1001,7 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
1013
1001
)
1014
1002
}
1015
1003
1016
- // TODO(https://github.com/awslabs/smithy-rs/issues/1231): If this function was called to parse a query string
1017
- // key value pair, we don't need to percent-decode it _again_.
1018
- private fun generateParsePercentEncodedStrFn (binding : HttpBindingDescriptor ): RuntimeType {
1019
- // HTTP bindings we support that contain percent-encoded data.
1020
- check(binding.location == HttpLocation .LABEL || binding.location == HttpLocation .QUERY )
1021
-
1022
- val target = model.expectShape(binding.member.target)
1023
- return when {
1024
- target.isStringShape -> generateParsePercentEncodedStrAsStringFn(binding)
1025
- target.isTimestampShape -> generateParsePercentEncodedStrAsTimestampFn(binding)
1026
- else -> generateParseStrAsPrimitiveFn(binding)
1027
- }
1028
- }
1029
-
1030
- private fun generateParsePercentEncodedStrAsStringFn (binding : HttpBindingDescriptor ): RuntimeType {
1004
+ private fun generateParseFn (binding : HttpBindingDescriptor , percentDecoding : Boolean ): RuntimeType {
1031
1005
val output = symbolProvider.toSymbol(binding.member)
1032
1006
val fnName = generateParseStrFnName(binding)
1033
1007
val symbol = output.extractSymbolFromOption()
@@ -1037,85 +1011,74 @@ private class ServerHttpBoundProtocolTraitImplGenerator(
1037
1011
* codegenScope,
1038
1012
" O" to output,
1039
1013
) {
1040
- // `<_>::from()` is necessary to convert the `&str` into:
1041
- // * the Rust enum in case the `string` shape has the `enum` trait; or
1042
- // * `String` in case it doesn't.
1043
- when (symbol.rustType()) {
1044
- RustType .String ->
1045
- rustTemplate(
1046
- """
1047
- let value = <#{T}>::from(#{PercentEncoding}::percent_decode_str(value).decode_utf8()?.as_ref());
1048
- """ ,
1049
- * codegenScope,
1050
- " T" to symbol,
1051
- )
1052
- else -> { // RustType.Opaque, the Enum
1053
- check(symbol.rustType() is RustType .Opaque )
1014
+ val target = model.expectShape(binding.member.target)
1015
+
1016
+ when {
1017
+ target.isStringShape -> {
1018
+ // NOTE: This path is traversed with or without @enum applied. The `try_from` is used as a
1019
+ // common conversion.
1020
+ if (percentDecoding) {
1021
+ rustTemplate(
1022
+ """
1023
+ let value = #{PercentEncoding}::percent_decode_str(value).decode_utf8()?;
1024
+ let value = #{T}::try_from(value.as_ref())?;
1025
+ """ ,
1026
+ * codegenScope,
1027
+ " T" to symbol,
1028
+ )
1029
+ } else {
1030
+ rustTemplate(
1031
+ """
1032
+ let value = #{T}::try_from(value)?;
1033
+ """ ,
1034
+ " T" to symbol,
1035
+ )
1036
+ }
1037
+ }
1038
+ target.isTimestampShape -> {
1039
+ val index = HttpBindingIndex .of(model)
1040
+ val timestampFormat =
1041
+ index.determineTimestampFormat(
1042
+ binding.member,
1043
+ binding.location,
1044
+ protocol.defaultTimestampFormat,
1045
+ )
1046
+ val timestampFormatType = RuntimeType .TimestampFormat (runtimeConfig, timestampFormat)
1047
+
1048
+ if (percentDecoding) {
1049
+ rustTemplate(
1050
+ """
1051
+ let value = #{PercentEncoding}::percent_decode_str(value).decode_utf8()?;
1052
+ let value = #{DateTime}::from_str(value.as_ref(), #{format})?;
1053
+ """ ,
1054
+ * codegenScope,
1055
+ " format" to timestampFormatType,
1056
+ )
1057
+ } else {
1058
+ rustTemplate(
1059
+ """
1060
+ let value = #{DateTime}::from_str(value, #{format})?;
1061
+ """ ,
1062
+ * codegenScope,
1063
+ " format" to timestampFormatType,
1064
+ )
1065
+ }
1066
+ }
1067
+ else -> {
1068
+ check(target is NumberShape || target is BooleanShape )
1054
1069
rustTemplate(
1055
1070
"""
1056
- let value = <#{T}>::try_from(#{PercentEncoding}::percent_decode_str (value).decode_utf8()?.as_ref() )?;
1071
+ let value = std::str::FromStr::from_str (value)?;
1057
1072
""" ,
1058
1073
* codegenScope,
1059
- " T" to symbol,
1060
1074
)
1061
1075
}
1062
1076
}
1063
- writer.write(
1064
- """
1065
- Ok(${symbolProvider.wrapOptional(binding.member, " value" )} )
1066
- """
1067
- )
1068
- }
1069
- }
1070
- }
1071
1077
1072
- private fun generateParsePercentEncodedStrAsTimestampFn (binding : HttpBindingDescriptor ): RuntimeType {
1073
- val output = symbolProvider.toSymbol(binding.member)
1074
- val fnName = generateParseStrFnName(binding)
1075
- val index = HttpBindingIndex .of(model)
1076
- val timestampFormat =
1077
- index.determineTimestampFormat(
1078
- binding.member,
1079
- binding.location,
1080
- protocol.defaultTimestampFormat,
1081
- )
1082
- val timestampFormatType = RuntimeType .TimestampFormat (runtimeConfig, timestampFormat)
1083
- return RuntimeType .forInlineFun(fnName, operationDeserModule) { writer ->
1084
- writer.rustBlockTemplate(
1085
- " pub fn $fnName (value: &str) -> std::result::Result<#{O}, #{RequestRejection}>" ,
1086
- * codegenScope,
1087
- " O" to output,
1088
- ) {
1089
- rustTemplate(
1078
+ writer.write(
1090
1079
"""
1091
- let value = #{PercentEncoding}::percent_decode_str(value).decode_utf8()?;
1092
- let value = #{DateTime}::from_str(&value, #{format})?;
1093
1080
Ok(${symbolProvider.wrapOptional(binding.member, " value" )} )
1094
- """ .trimIndent(),
1095
- * codegenScope,
1096
- " format" to timestampFormatType,
1097
- )
1098
- }
1099
- }
1100
- }
1101
-
1102
- // Function to parse a string as the data type generated for boolean, byte, short, integer, long, float, or double shapes.
1103
- // TODO(https://github.com/awslabs/smithy-rs/issues/1232): This function can be replaced by https://docs.rs/aws-smithy-types/latest/aws_smithy_types/primitive/trait.Parse.html
1104
- private fun generateParseStrAsPrimitiveFn (binding : HttpBindingDescriptor ): RuntimeType {
1105
- val output = symbolProvider.toSymbol(binding.member)
1106
- val fnName = generateParseStrFnName(binding)
1107
- return RuntimeType .forInlineFun(fnName, operationDeserModule) { writer ->
1108
- writer.rustBlockTemplate(
1109
- " pub fn $fnName (value: &str) -> std::result::Result<#{O}, #{RequestRejection}>" ,
1110
- * codegenScope,
1111
- " O" to output,
1112
- ) {
1113
- rustTemplate(
1114
1081
"""
1115
- let value = std::str::FromStr::from_str(value)?;
1116
- Ok(${symbolProvider.wrapOptional(binding.member, " value" )} )
1117
- """ .trimIndent(),
1118
- * codegenScope,
1119
1082
)
1120
1083
}
1121
1084
}
0 commit comments