Skip to content

Commit

Permalink
StepFunctions: describe_state_machine() now takes Version ARN's (#8441)
Browse files Browse the repository at this point in the history
  • Loading branch information
bblommers authored Dec 27, 2024
1 parent 249d8e2 commit d773bb2
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 22 deletions.
85 changes: 73 additions & 12 deletions moto/stepfunctions/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,13 @@
from .utils import PAGINATION_MODEL, api_to_cfn_tags, cfn_to_api_tags


class StateMachine(CloudFormationModel):
class StateMachineInstance:
def __init__(
self,
arn: str,
name: str,
definition: str,
roleArn: str,
tags: Optional[List[Dict[str, str]]] = None,
encryptionConfiguration: Optional[Dict[str, Any]] = None,
loggingConfiguration: Optional[Dict[str, Any]] = None,
tracingConfiguration: Optional[Dict[str, Any]] = None,
Expand All @@ -45,10 +44,6 @@ def __init__(
self.definition = definition
self.roleArn = roleArn
self.executions: List[Execution] = []
self.tags: List[Dict[str, str]] = []
if tags:
self.add_tags(tags)
self.version = 0
self.type = "STANDARD"
self.encryptionConfiguration = encryptionConfiguration or {
"type": "AWS_OWNED_KEY"
Expand All @@ -57,6 +52,61 @@ def __init__(
self.tracingConfiguration = tracingConfiguration or {"enabled": False}
self.sm_type = "STANDARD" # or express


class StateMachineVersion(StateMachineInstance, CloudFormationModel):
def __init__(self, source: StateMachineInstance, version: int):
version_arn = f"{source.arn}:{version}"
StateMachineInstance.__init__(
self,
arn=version_arn,
name=source.name,
definition=source.definition,
roleArn=source.roleArn,
encryptionConfiguration=source.encryptionConfiguration,
loggingConfiguration=source.loggingConfiguration,
tracingConfiguration=source.tracingConfiguration,
)
self.source_arn = source.arn
self.version = version


class StateMachine(StateMachineInstance, CloudFormationModel):
def __init__(
self,
arn: str,
name: str,
definition: str,
roleArn: str,
tags: Optional[List[Dict[str, str]]] = None,
encryptionConfiguration: Optional[Dict[str, Any]] = None,
loggingConfiguration: Optional[Dict[str, Any]] = None,
tracingConfiguration: Optional[Dict[str, Any]] = None,
):
StateMachineInstance.__init__(
self,
arn=arn,
name=name,
definition=definition,
roleArn=roleArn,
encryptionConfiguration=encryptionConfiguration,
loggingConfiguration=loggingConfiguration,
tracingConfiguration=tracingConfiguration,
)
self.tags: List[Dict[str, str]] = []
if tags:
self.add_tags(tags)

self.latest_version_number = 0
self.versions: Dict[int, StateMachineVersion] = {}
self.latest_version: Optional[StateMachineVersion] = None

def publish(self) -> None:
new_version_number = self.latest_version_number + 1
new_version = StateMachineVersion(source=self, version=new_version_number)
self.versions[new_version_number] = new_version
self.latest_version = new_version
self.latest_version_number = new_version_number

def start_execution(
self,
region_name: str,
Expand Down Expand Up @@ -554,7 +604,7 @@ def create_state_machine(
tracingConfiguration,
)
if publish:
state_machine.version += 1
state_machine.publish()
self.state_machines.append(state_machine)
return state_machine

Expand All @@ -566,10 +616,21 @@ def describe_state_machine(self, arn: str) -> StateMachine:
self._validate_machine_arn(arn)
sm = next((x for x in self.state_machines if x.arn == arn), None)
if not sm:
raise StateMachineDoesNotExist(
"State Machine Does Not Exist: '" + arn + "'"
)
return sm
if (
(arn_parts := arn.split(":"))
and len(arn_parts) > 7
and arn_parts[-1].isnumeric()
):
# we might have a versioned arn, ending in :stateMachine:name:version_nr
source_arn = ":".join(arn_parts[:-1])
source_sm = next(
(x for x in self.state_machines if x.arn == source_arn), None
)
if source_sm:
sm = source_sm.versions.get(int(arn_parts[-1])) # type: ignore[assignment]
if not sm:
raise StateMachineDoesNotExist(f"State Machine Does Not Exist: '{arn}'")
return sm # type: ignore[return-value]

def delete_state_machine(self, arn: str) -> None:
self._validate_machine_arn(arn)
Expand Down Expand Up @@ -600,7 +661,7 @@ def update_state_machine(
updates["tracingConfiguration"] = tracing_configuration
sm.update(**updates)
if publish:
sm.version += 1
sm.publish()
return sm

def start_execution(
Expand Down
14 changes: 4 additions & 10 deletions moto/stepfunctions/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,8 @@ def create_state_machine(self) -> TYPE_RESPONSE:
"creationDate": state_machine.creation_date,
"stateMachineArn": state_machine.arn,
}
if publish:
response["stateMachineVersionArn"] = (
f"{state_machine.arn}:{state_machine.version}"
)
if state_machine.latest_version:
response["stateMachineVersionArn"] = state_machine.latest_version.arn
return 200, {}, json.dumps(response)

def list_state_machines(self) -> TYPE_RESPONSE:
Expand Down Expand Up @@ -116,13 +114,9 @@ def update_state_machine(self) -> TYPE_RESPONSE:
logging_configuration=logging_config,
publish=publish,
)
response = {
"updateDate": state_machine.update_date,
}
response = {"updateDate": state_machine.update_date}
if publish:
response["stateMachineVersionArn"] = (
f"{state_machine.arn}:{state_machine.version}"
)
response["stateMachineVersionArn"] = state_machine.latest_version.arn # type: ignore
return 200, {}, json.dumps(response)

def list_tags_for_resource(self) -> TYPE_RESPONSE:
Expand Down
53 changes: 53 additions & 0 deletions tests/test_stepfunctions/test_stepfunctions_versions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import json
from uuid import uuid4

import boto3
import pytest

from moto import mock_aws
from tests.test_stepfunctions.parser import sfn_role_policy
from tests.test_stepfunctions.test_stepfunctions import simple_definition


@pytest.mark.parametrize("use_parser", [True, False], ids=["use_parser", "use_mock"])
def test_describe_state_machine_using_version_arn(use_parser):
with mock_aws(config={"stepfunctions": {"execute_state_machine": use_parser}}):
iam = boto3.client("iam", region_name="us-east-1")
role_name = f"sfn_role_{str(uuid4())[0:6]}"
sfn_role = iam.create_role(
RoleName=role_name,
AssumeRolePolicyDocument=json.dumps(sfn_role_policy),
Path="/",
)["Role"]["Arn"]

client = boto3.client("stepfunctions", region_name="us-east-1")

name1 = f"sfn_name_{str(uuid4())[0:6]}"
response = client.create_state_machine(
name=name1, definition=simple_definition, roleArn=sfn_role, publish=True
)
arn = response["stateMachineArn"]
version_arn1 = response["stateMachineVersionArn"]

# Use the initial version to describe the state machine
version1 = client.describe_state_machine(stateMachineArn=version_arn1)
assert version1["loggingConfiguration"] == {"level": "OFF"}

# Update the state machine
update = client.update_state_machine(
stateMachineArn=arn,
loggingConfiguration={"level": "ALL"},
publish=True,
)
version_arn2 = update["stateMachineVersionArn"]
assert version_arn1 != version_arn2

# Assert that we can retrieve the latest configuration, either by the regular ARN or by the version ARN
latest = client.describe_state_machine(stateMachineArn=arn)
assert latest["loggingConfiguration"] == {"level": "ALL"}
version2 = client.describe_state_machine(stateMachineArn=version_arn2)
assert version2["loggingConfiguration"] == {"level": "ALL"}

# Assert that we can still describe the first version of the state machine
version1 = client.describe_state_machine(stateMachineArn=version_arn1)
assert version1["loggingConfiguration"] == {"level": "OFF"}

0 comments on commit d773bb2

Please sign in to comment.