From d773bb2184fbd17bd5e1b3aa14b3d5c0b4655955 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Fri, 27 Dec 2024 22:05:10 -0100 Subject: [PATCH] StepFunctions: describe_state_machine() now takes Version ARN's (#8441) --- moto/stepfunctions/models.py | 85 ++++++++++++++++--- moto/stepfunctions/responses.py | 14 +-- .../test_stepfunctions_versions.py | 53 ++++++++++++ 3 files changed, 130 insertions(+), 22 deletions(-) create mode 100644 tests/test_stepfunctions/test_stepfunctions_versions.py diff --git a/moto/stepfunctions/models.py b/moto/stepfunctions/models.py index 5a0ee9a1a4de..c432b9bd0b8f 100644 --- a/moto/stepfunctions/models.py +++ b/moto/stepfunctions/models.py @@ -26,14 +26,13 @@ from .utils import PAGINATION_MODEL, api_to_cfn_tags, cfn_to_api_tags -class StateMachine(CloudFormationModel): +class StateMachineInstance: def __init__( self, arn: str, name: str, definition: str, roleArn: str, - tags: Optional[List[Dict[str, str]]] = None, encryptionConfiguration: Optional[Dict[str, Any]] = None, loggingConfiguration: Optional[Dict[str, Any]] = None, tracingConfiguration: Optional[Dict[str, Any]] = None, @@ -45,10 +44,6 @@ def __init__( self.definition = definition self.roleArn = roleArn self.executions: List[Execution] = [] - self.tags: List[Dict[str, str]] = [] - if tags: - self.add_tags(tags) - self.version = 0 self.type = "STANDARD" self.encryptionConfiguration = encryptionConfiguration or { "type": "AWS_OWNED_KEY" @@ -57,6 +52,61 @@ def __init__( self.tracingConfiguration = tracingConfiguration or {"enabled": False} self.sm_type = "STANDARD" # or express + +class StateMachineVersion(StateMachineInstance, CloudFormationModel): + def __init__(self, source: StateMachineInstance, version: int): + version_arn = f"{source.arn}:{version}" + StateMachineInstance.__init__( + self, + arn=version_arn, + name=source.name, + definition=source.definition, + roleArn=source.roleArn, + encryptionConfiguration=source.encryptionConfiguration, + loggingConfiguration=source.loggingConfiguration, + tracingConfiguration=source.tracingConfiguration, + ) + self.source_arn = source.arn + self.version = version + + +class StateMachine(StateMachineInstance, CloudFormationModel): + def __init__( + self, + arn: str, + name: str, + definition: str, + roleArn: str, + tags: Optional[List[Dict[str, str]]] = None, + encryptionConfiguration: Optional[Dict[str, Any]] = None, + loggingConfiguration: Optional[Dict[str, Any]] = None, + tracingConfiguration: Optional[Dict[str, Any]] = None, + ): + StateMachineInstance.__init__( + self, + arn=arn, + name=name, + definition=definition, + roleArn=roleArn, + encryptionConfiguration=encryptionConfiguration, + loggingConfiguration=loggingConfiguration, + tracingConfiguration=tracingConfiguration, + ) + self.tags: List[Dict[str, str]] = [] + if tags: + self.add_tags(tags) + + self.latest_version_number = 0 + self.versions: Dict[int, StateMachineVersion] = {} + self.latest_version: Optional[StateMachineVersion] = None + + def publish(self) -> None: + new_version_number = self.latest_version_number + 1 + new_version = StateMachineVersion(source=self, version=new_version_number) + self.versions[new_version_number] = new_version + self.latest_version = new_version + self.latest_version_number = new_version_number + def start_execution( self, region_name: str, @@ -554,7 +604,7 @@ def create_state_machine( tracingConfiguration, ) if publish: - state_machine.version += 1 + state_machine.publish() self.state_machines.append(state_machine) return state_machine @@ -566,10 +616,21 @@ def describe_state_machine(self, arn: str) -> StateMachine: self._validate_machine_arn(arn) sm = next((x for x in self.state_machines if x.arn == arn), None) if not sm: - raise StateMachineDoesNotExist( - "State Machine Does Not Exist: '" + arn + "'" - ) - return sm + if ( + (arn_parts := arn.split(":")) + and len(arn_parts) > 7 + and arn_parts[-1].isnumeric() + ): + # we might have a versioned arn, ending in :stateMachine:name:version_nr + source_arn = ":".join(arn_parts[:-1]) + source_sm = next( + (x for x in self.state_machines if x.arn == source_arn), None + ) + if source_sm: + sm = source_sm.versions.get(int(arn_parts[-1])) # type: ignore[assignment] + if not sm: + raise StateMachineDoesNotExist(f"State Machine Does Not Exist: '{arn}'") + return sm # type: ignore[return-value] def delete_state_machine(self, arn: str) -> None: self._validate_machine_arn(arn) @@ -600,7 +661,7 @@ def update_state_machine( updates["tracingConfiguration"] = tracing_configuration sm.update(**updates) if publish: - sm.version += 1 + sm.publish() return sm def start_execution( diff --git a/moto/stepfunctions/responses.py b/moto/stepfunctions/responses.py index 1b26327e4fed..1b136026de6b 100644 --- a/moto/stepfunctions/responses.py +++ b/moto/stepfunctions/responses.py @@ -47,10 +47,8 @@ def create_state_machine(self) -> TYPE_RESPONSE: "creationDate": state_machine.creation_date, "stateMachineArn": state_machine.arn, } - if publish: - response["stateMachineVersionArn"] = ( - f"{state_machine.arn}:{state_machine.version}" - ) + if state_machine.latest_version: + response["stateMachineVersionArn"] = state_machine.latest_version.arn return 200, {}, json.dumps(response) def list_state_machines(self) -> TYPE_RESPONSE: @@ -116,13 +114,9 @@ def update_state_machine(self) -> TYPE_RESPONSE: logging_configuration=logging_config, publish=publish, ) - response = { - "updateDate": state_machine.update_date, - } + response = {"updateDate": state_machine.update_date} if publish: - response["stateMachineVersionArn"] = ( - f"{state_machine.arn}:{state_machine.version}" - ) + response["stateMachineVersionArn"] = state_machine.latest_version.arn # type: ignore return 200, {}, json.dumps(response) def list_tags_for_resource(self) -> TYPE_RESPONSE: diff --git a/tests/test_stepfunctions/test_stepfunctions_versions.py b/tests/test_stepfunctions/test_stepfunctions_versions.py new file mode 100644 index 000000000000..715033cb0463 --- /dev/null +++ b/tests/test_stepfunctions/test_stepfunctions_versions.py @@ -0,0 +1,53 @@ +import json +from uuid import uuid4 + +import boto3 +import pytest + +from moto import mock_aws +from tests.test_stepfunctions.parser import sfn_role_policy +from tests.test_stepfunctions.test_stepfunctions import simple_definition + + +@pytest.mark.parametrize("use_parser", [True, False], ids=["use_parser", "use_mock"]) +def test_describe_state_machine_using_version_arn(use_parser): + with mock_aws(config={"stepfunctions": {"execute_state_machine": use_parser}}): + iam = boto3.client("iam", region_name="us-east-1") + role_name = f"sfn_role_{str(uuid4())[0:6]}" + sfn_role = iam.create_role( + RoleName=role_name, + AssumeRolePolicyDocument=json.dumps(sfn_role_policy), + Path="/", + )["Role"]["Arn"] + + client = boto3.client("stepfunctions", region_name="us-east-1") + + name1 = f"sfn_name_{str(uuid4())[0:6]}" + response = client.create_state_machine( + name=name1, definition=simple_definition, roleArn=sfn_role, publish=True + ) + arn = response["stateMachineArn"] + version_arn1 = response["stateMachineVersionArn"] + + # Use the initial version to describe the state machine + version1 = client.describe_state_machine(stateMachineArn=version_arn1) + assert version1["loggingConfiguration"] == {"level": "OFF"} + + # Update the state machine + update = client.update_state_machine( + stateMachineArn=arn, + loggingConfiguration={"level": "ALL"}, + publish=True, + ) + version_arn2 = update["stateMachineVersionArn"] + assert version_arn1 != version_arn2 + + # Assert that we can retrieve the latest configuration, either by the regular ARN or by the version ARN + latest = client.describe_state_machine(stateMachineArn=arn) + assert latest["loggingConfiguration"] == {"level": "ALL"} + version2 = client.describe_state_machine(stateMachineArn=version_arn2) + assert version2["loggingConfiguration"] == {"level": "ALL"} + + # Assert that we can still describe the first version of the state machine + version1 = client.describe_state_machine(stateMachineArn=version_arn1) + assert version1["loggingConfiguration"] == {"level": "OFF"}