Skip to content

Commit

Permalink
fix: remove global booleans previously used for emitting business met…
Browse files Browse the repository at this point in the history
…rics (#1104)

Co-authored-by: 0marperez <[email protected]>
  • Loading branch information
lauzadis and 0marperez authored Jun 21, 2024
1 parent 72bc814 commit 97709d7
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule
import software.amazon.smithy.rulesengine.language.syntax.rule.ErrorRule
import software.amazon.smithy.rulesengine.language.syntax.rule.Rule
import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule
import java.util.stream.Collectors

/**
* The core set of standard library functions available to the rules language.
Expand All @@ -49,24 +50,9 @@ typealias EndpointPropertyRenderer = (KotlinWriter, Expression, ExpressionRender
* An expression renderer generates code for an endpoint expression construct.
*/
fun interface ExpressionRenderer {
fun renderExpression(expr: Expression)
fun renderExpression(expr: Expression): EndpointInfo
}

/**
* Will be toggled to true if it is determined an endpoint is account ID based then to false again
*/
private var hasAccountIdBasedEndpoint = false

/**
* Will be toggled to true if determined an endpoint comes from a service endpoint override then to false again
*/
private var hasServiceEndpointOverride = false

/**
* Will be toggled to true when rendering an endpoint URL then to false again
*/
private var renderingEndpointUrl = false

/**
* Renders the default endpoint provider based on the provided rule set.
*/
Expand Down Expand Up @@ -121,9 +107,7 @@ class DefaultEndpointProviderGenerator(
}
}

override fun renderExpression(expr: Expression) {
expr.accept(expressionGenerator)
}
override fun renderExpression(expr: Expression): EndpointInfo = expr.accept(expressionGenerator) ?: EndpointInfo.Empty

private fun renderDocumentation() {
writer.dokka {
Expand Down Expand Up @@ -185,11 +169,11 @@ class DefaultEndpointProviderGenerator(
withConditions(rule.conditions) {
writer.withBlock("return #T(", ")", RuntimeTypes.SmithyClient.Endpoints.Endpoint) {
writeInline("#T.parse(", RuntimeTypes.Core.Net.Url.Url)
renderingEndpointUrl = true
renderExpression(rule.endpoint.url)
renderingEndpointUrl = false
val endpointInfo = renderExpression(rule.endpoint.url)
write("),")

val hasAccountIdBasedEndpoint = "accountId" in endpointInfo.params
val hasServiceEndpointOverride = "endpoint" in endpointInfo.params
val needAdditionalEndpointProperties = hasAccountIdBasedEndpoint || hasServiceEndpointOverride

if (rule.endpoint.headers.isNotEmpty()) {
Expand Down Expand Up @@ -226,11 +210,9 @@ class DefaultEndpointProviderGenerator(

if (hasAccountIdBasedEndpoint) {
writer.write("#T to params.accountId", RuntimeTypes.Core.BusinessMetrics.AccountIdBasedEndpointAccountId)
hasAccountIdBasedEndpoint = false
}
if (hasServiceEndpointOverride) {
writer.write("#T to true", RuntimeTypes.Core.BusinessMetrics.ServiceEndpointOverride)
hasServiceEndpointOverride = false
}
}
}
Expand All @@ -253,16 +235,24 @@ class DefaultEndpointProviderGenerator(
}
}

data class EndpointInfo(val params: MutableSet<String>) {
companion object {
val Empty = EndpointInfo(params = mutableSetOf())
}

operator fun plus(other: EndpointInfo) = EndpointInfo(
params = (this.params + other.params).toMutableSet(),
)
}

class ExpressionGenerator(
private val writer: KotlinWriter,
private val rules: EndpointRuleSet,
private val functions: Map<String, Symbol>,
) : ExpressionVisitor<Unit>, LiteralVisitor<Unit>, TemplateVisitor<Unit> {
override fun visitLiteral(literal: Literal) {
literal.accept(this as LiteralVisitor<Unit>)
}
) : ExpressionVisitor<EndpointInfo?>, LiteralVisitor<EndpointInfo?>, TemplateVisitor<EndpointInfo?> {
override fun visitLiteral(literal: Literal): EndpointInfo? = literal.accept(this as LiteralVisitor<EndpointInfo?>)

override fun visitRef(reference: Reference) {
override fun visitRef(reference: Reference): EndpointInfo {
val referenceName = reference.name.defaultName()
val isParamReference = isParamRef(reference)

Expand All @@ -271,90 +261,112 @@ class ExpressionGenerator(
}
writer.writeInline(referenceName)

if (renderingEndpointUrl) {
if (isParamReference && referenceName == "accountId") hasAccountIdBasedEndpoint = true
if (isParamReference && referenceName == "endpoint") hasServiceEndpointOverride = true
return if (isParamReference) {
EndpointInfo(params = mutableSetOf(referenceName))
} else {
EndpointInfo.Empty
}
}

override fun visitGetAttr(getAttr: GetAttr) {
getAttr.target.accept(this)
override fun visitGetAttr(getAttr: GetAttr): EndpointInfo? {
val endpointInfo = getAttr.target.accept(this)
getAttr.path.forEach {
when (it) {
is GetAttr.Part.Key -> writer.writeInline("?.#L", it.key().toString())
is GetAttr.Part.Index -> writer.writeInline("?.getOrNull(#L)", it.index())
else -> throw CodegenException("unexpected path")
}
}
return endpointInfo
}

override fun visitIsSet(target: Expression) {
target.accept(this)
override fun visitIsSet(target: Expression): EndpointInfo? {
val endpointInfo = target.accept(this)
writer.writeInline(" != null")
return endpointInfo
}

override fun visitNot(target: Expression) {
override fun visitNot(target: Expression): EndpointInfo? {
writer.writeInline("!(")
target.accept(this)
val endpointInfo = target.accept(this)
writer.writeInline(")")
return endpointInfo
}

override fun visitBoolEquals(left: Expression, right: Expression) {
visitEquals(left, right)
}
override fun visitBoolEquals(left: Expression, right: Expression): EndpointInfo? = visitEquals(left, right)

override fun visitStringEquals(left: Expression, right: Expression) {
visitEquals(left, right)
}
override fun visitStringEquals(left: Expression, right: Expression): EndpointInfo? = visitEquals(left, right)

private fun visitEquals(left: Expression, right: Expression) {
left.accept(this)
private fun visitEquals(left: Expression, right: Expression): EndpointInfo? {
val leftEndpointInfo = left.accept(this)
writer.writeInline(" == ")
right.accept(this)
val rightEndpointInfo = right.accept(this)

return when {
leftEndpointInfo != null && rightEndpointInfo != null -> leftEndpointInfo + rightEndpointInfo
leftEndpointInfo != null -> leftEndpointInfo
else -> rightEndpointInfo
}
}

override fun visitLibraryFunction(fn: FunctionDefinition, args: MutableList<Expression>) {
override fun visitLibraryFunction(fn: FunctionDefinition, args: MutableList<Expression>): EndpointInfo? {
writer.writeInline("#T(", functions.getValue(fn.id))
args.forEachIndexed { index, it ->
it.accept(this)

val endpointInfo = args.foldIndexed(EndpointInfo.Empty) { index, acc, curr ->
val currEndpointInfo = curr.accept(this)
if (index < args.lastIndex) {
writer.writeInline(", ")
}
currEndpointInfo?.let { acc + it } ?: acc
}
writer.writeInline(")")
return endpointInfo
}

override fun visitInteger(value: Int) {
override fun visitInteger(value: Int): EndpointInfo? {
writer.writeInline("#L", value)
return null
}

override fun visitString(value: Template) {
override fun visitString(value: Template): EndpointInfo? {
writer.writeInline("\"")
value.accept(this).forEach {} // must "consume" the stream to actually generate everything
val endpointInfo = value.accept(this)
.collect(Collectors.toList())
.fold(EndpointInfo.Empty) { acc, curr ->
curr?.let { acc + it } ?: acc
}
writer.writeInline("\"")
return endpointInfo
}

override fun visitBoolean(value: Boolean) {
override fun visitBoolean(value: Boolean): EndpointInfo? {
writer.writeInline("#L", value)
return null
}

override fun visitRecord(value: MutableMap<Identifier, Literal>) {
override fun visitRecord(value: MutableMap<Identifier, Literal>): EndpointInfo? {
var endpointInfo: EndpointInfo? = null
writer.withInlineBlock("#T {", "}", RuntimeTypes.Core.Content.buildDocument) {
value.entries.forEachIndexed { index, (k, v) ->
endpointInfo = value.entries.foldIndexed(EndpointInfo.Empty) { index, acc, (k, v) ->
writeInline("#S to ", k.toString())
v.accept(this@ExpressionGenerator as LiteralVisitor<Unit>)
val currInfo = v.accept(this@ExpressionGenerator as LiteralVisitor<EndpointInfo?>)
if (index < value.size - 1) write("")
currInfo?.let { acc + it } ?: acc
}
}
return endpointInfo
}

override fun visitTuple(value: MutableList<Literal>) {
override fun visitTuple(value: MutableList<Literal>): EndpointInfo? {
var endpointInfo: EndpointInfo? = null
writer.withInlineBlock("listOf(", ")") {
value.forEachIndexed { index, it ->
it.accept(this@ExpressionGenerator as LiteralVisitor<Unit>)
endpointInfo = value.foldIndexed(EndpointInfo.Empty) { index, acc, curr ->
val localInfo = curr.accept(this@ExpressionGenerator as LiteralVisitor<EndpointInfo?>)
if (index < value.size - 1) write(",") else writeInline(",")
localInfo?.let { acc + it } ?: acc
}
}
return endpointInfo
}

override fun visitStaticTemplate(value: String) = writeTemplateString(value)
Expand All @@ -363,17 +375,19 @@ class ExpressionGenerator(
override fun visitDynamicElement(value: Expression) = writeTemplateExpression(value)

// no-ops for kotlin codegen
override fun startMultipartTemplate() {}
override fun finishMultipartTemplate() {}
override fun startMultipartTemplate(): EndpointInfo? = null
override fun finishMultipartTemplate(): EndpointInfo? = null

private fun writeTemplateString(value: String) {
private fun writeTemplateString(value: String): EndpointInfo? {
writer.writeInline(value.replace("\"", "\\\""))
return null
}

private fun writeTemplateExpression(expr: Expression) {
private fun writeTemplateExpression(expr: Expression): EndpointInfo? {
writer.writeInline("\${")
expr.accept(this)
val endpointInfo = expr.accept(this)
writer.writeInline("}")
return endpointInfo
}

private fun isParamRef(ref: Reference): Boolean = rules.parameters.toList().any { it.name == ref.name }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,7 @@ class DefaultEndpointProviderTestGenerator(
}
}

override fun renderExpression(expr: Expression) {
expr.accept(expressionGenerator)
}
override fun renderExpression(expr: Expression): EndpointInfo = expr.accept(expressionGenerator) ?: EndpointInfo.Empty

private fun renderTestCase(index: Int, case: EndpointTestCase) {
case.documentation.ifPresent {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ class DefaultEndpointProviderGeneratorTest {
"QuxName": {
"type": "string",
"required": false
},
"accountId": {
"type": "string",
"required": false
},
"endpoint": {
"type": "string",
"required": false
}
},
"rules": [
Expand Down Expand Up @@ -118,6 +126,57 @@ class DefaultEndpointProviderGeneratorTest {
"fooheader": ["barheader"]
}
}
},
{
"documentation": "account id based endpoint and service endpoint override",
"type": "endpoint",
"conditions": [
{
"fn": "isSet",
"argv": [
{"ref": "accountId"}
]
},
{
"fn": "isSet",
"argv": [
{"ref": "endpoint"}
]
}
],
"endpoint": {
"url": "https://{accountId}.{endpoint}"
}
},
{
"documentation": "service endpoint override",
"type": "endpoint",
"conditions": [
{
"fn": "isSet",
"argv": [
{"ref": "endpoint"}
]
}
],
"endpoint": {
"url": "https://{endpoint}"
}
},
{
"documentation": "account id based endpoint",
"type": "endpoint",
"conditions": [
{
"fn": "isSet",
"argv": [
{"ref": "accountId"}
]
}
],
"endpoint": {
"url": "https://{accountId}"
}
}
]
}
Expand Down Expand Up @@ -216,4 +275,41 @@ class DefaultEndpointProviderGeneratorTest {
""".formatForTest(indent = " ")
generatedClass.shouldContainOnlyOnceWithDiff(expected)
}

@Test
fun testBusinessMetrics() {
val moneySign = "$"

val accountIdAndEndpoint = """
return Endpoint(
Url.parse("https://$moneySign{params.accountId}.$moneySign{params.endpoint}"),
attributes = attributesOf {
AccountIdBasedEndpointAccountId to params.accountId
ServiceEndpointOverride to true
},
)
"""

val accountId = """
return Endpoint(
Url.parse("https://$moneySign{params.accountId}"),
attributes = attributesOf {
AccountIdBasedEndpointAccountId to params.accountId
},
)
"""

val endpoint = """
return Endpoint(
Url.parse("https://$moneySign{params.endpoint}"),
attributes = attributesOf {
ServiceEndpointOverride to true
},
)
"""

generatedClass.shouldContainOnlyOnceWithDiff(accountIdAndEndpoint)
generatedClass.shouldContainOnlyOnceWithDiff(accountId)
generatedClass.shouldContainOnlyOnceWithDiff(endpoint)
}
}

0 comments on commit 97709d7

Please sign in to comment.