Skip to content

Commit

Permalink
StepFunctions: Support Version Descriptions (#8466)
Browse files Browse the repository at this point in the history
  • Loading branch information
bblommers authored Jan 4, 2025
1 parent 4886ac6 commit 87214ab
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 9 deletions.
18 changes: 13 additions & 5 deletions moto/stepfunctions/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,13 @@ def __init__(
self.loggingConfiguration = loggingConfiguration or {"level": "OFF"}
self.tracingConfiguration = tracingConfiguration or {"enabled": False}
self.sm_type = "STANDARD" # or express
self.description: Optional[str] = None


class StateMachineVersion(StateMachineInstance, CloudFormationModel):
def __init__(self, source: StateMachineInstance, version: int):
def __init__(
self, source: StateMachineInstance, version: int, description: Optional[str]
):
version_arn = f"{source.arn}:{version}"
StateMachineInstance.__init__(
self,
Expand All @@ -68,6 +71,7 @@ def __init__(self, source: StateMachineInstance, version: int):
)
self.source_arn = source.arn
self.version = version
self.description = description


class StateMachine(StateMachineInstance, CloudFormationModel):
Expand Down Expand Up @@ -100,9 +104,11 @@ def __init__(
self.versions: Dict[int, StateMachineVersion] = {}
self.latest_version: Optional[StateMachineVersion] = None

def publish(self) -> None:
def publish(self, description: Optional[str]) -> None:
new_version_number = self.latest_version_number + 1
new_version = StateMachineVersion(source=self, version=new_version_number)
new_version = StateMachineVersion(
source=self, version=new_version_number, description=description
)
self.versions[new_version_number] = new_version
self.latest_version = new_version
self.latest_version_number = new_version_number
Expand Down Expand Up @@ -586,6 +592,7 @@ def create_state_machine(
loggingConfiguration: Optional[Dict[str, Any]] = None,
tracingConfiguration: Optional[Dict[str, Any]] = None,
encryptionConfiguration: Optional[Dict[str, Any]] = None,
version_description: Optional[str] = None,
) -> StateMachine:
self._validate_name(name)
self._validate_role_arn(roleArn)
Expand All @@ -604,7 +611,7 @@ def create_state_machine(
tracingConfiguration,
)
if publish:
state_machine.publish()
state_machine.publish(description=version_description)
self.state_machines.append(state_machine)
return state_machine

Expand Down Expand Up @@ -647,6 +654,7 @@ def update_state_machine(
tracing_configuration: Optional[Dict[str, bool]] = None,
encryption_configuration: Optional[Dict[str, Any]] = None,
publish: Optional[bool] = None,
version_description: Optional[str] = None,
) -> StateMachine:
sm = self.describe_state_machine(arn)
updates: Dict[str, Any] = {
Expand All @@ -661,7 +669,7 @@ def update_state_machine(
updates["tracingConfiguration"] = tracing_configuration
sm.update(**updates)
if publish:
sm.publish()
sm.publish(version_description)
return sm

def start_execution(
Expand Down
6 changes: 4 additions & 2 deletions moto/stepfunctions/parser/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
TaskToken,
TraceHeader,
TracingConfiguration,
VersionDescription,
)
from moto.stepfunctions.parser.asl.component.state.exec.state_map.iteration.itemprocessor.map_run_record import (
MapRunRecord,
Expand Down Expand Up @@ -86,6 +85,7 @@ def create_state_machine(
loggingConfiguration: Optional[LoggingConfiguration] = None,
tracingConfiguration: Optional[TracingConfiguration] = None,
encryptionConfiguration: Optional[EncryptionConfiguration] = None,
version_description: Optional[str] = None,
) -> StateMachine:
StepFunctionsParserBackend._validate_definition(definition=definition)

Expand All @@ -98,6 +98,7 @@ def create_state_machine(
loggingConfiguration=loggingConfiguration,
tracingConfiguration=tracingConfiguration,
encryptionConfiguration=encryptionConfiguration,
version_description=version_description,
)

def send_task_heartbeat(self, task_token: TaskToken) -> SendTaskHeartbeatOutput:
Expand Down Expand Up @@ -216,7 +217,7 @@ def update_state_machine(
tracing_configuration: TracingConfiguration = None,
encryption_configuration: EncryptionConfiguration = None,
publish: Optional[bool] = None,
version_description: VersionDescription = None,
version_description: str = None,
) -> StateMachine:
if not any(
[
Expand All @@ -242,6 +243,7 @@ def update_state_machine(
tracing_configuration=tracing_configuration,
encryption_configuration=encryption_configuration,
publish=publish,
version_description=version_description,
)

def describe_map_run(self, map_run_arn: str) -> Dict[str, Any]:
Expand Down
19 changes: 19 additions & 0 deletions moto/stepfunctions/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from moto.core.responses import BaseResponse
from moto.core.utils import iso_8601_datetime_with_milliseconds

from .exceptions import ValidationException
from .models import StepFunctionBackend, stepfunctions_backends
from .parser.api import ExecutionStatus

Expand Down Expand Up @@ -33,6 +34,13 @@ def create_state_machine(self) -> TYPE_RESPONSE:
encryptionConfiguration = self._get_param("encryptionConfiguration")
loggingConfiguration = self._get_param("loggingConfiguration")
tracingConfiguration = self._get_param("tracingConfiguration")
version_description = self._get_param("versionDescription")

if version_description and not publish:
raise ValidationException(
"Version description can only be set when publish is true"
)

state_machine = self.stepfunction_backend.create_state_machine(
name=name,
definition=definition,
Expand All @@ -42,6 +50,7 @@ def create_state_machine(self) -> TYPE_RESPONSE:
loggingConfiguration=loggingConfiguration,
tracingConfiguration=tracingConfiguration,
encryptionConfiguration=encryptionConfiguration,
version_description=version_description,
)
response = {
"creationDate": state_machine.creation_date,
Expand Down Expand Up @@ -90,6 +99,8 @@ def _describe_state_machine(self, state_machine_arn: str) -> TYPE_RESPONSE:
"tracingConfiguration": state_machine.tracingConfiguration,
"loggingConfiguration": state_machine.loggingConfiguration,
}
if state_machine.description:
response["description"] = state_machine.description
return 200, {}, json.dumps(response)

def delete_state_machine(self) -> TYPE_RESPONSE:
Expand All @@ -105,6 +116,13 @@ def update_state_machine(self) -> TYPE_RESPONSE:
encryption_config = self._get_param("encryptionConfiguration")
logging_config = self._get_param("loggingConfiguration")
publish = self._get_param("publish")
version_description = self._get_param("versionDescription")

if version_description and not publish:
raise ValidationException(
"Version description can only be set when publish is true"
)

state_machine = self.stepfunction_backend.update_state_machine(
arn=arn,
definition=definition,
Expand All @@ -113,6 +131,7 @@ def update_state_machine(self) -> TYPE_RESPONSE:
encryption_configuration=encryption_config,
logging_configuration=logging_config,
publish=publish,
version_description=version_description,
)
response = {"updateDate": state_machine.update_date}
if publish:
Expand Down
87 changes: 85 additions & 2 deletions tests/test_stepfunctions/test_stepfunctions_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

import boto3
import pytest
from botocore.exceptions import ClientError

from moto import mock_aws
from tests import aws_verified
from tests.test_stepfunctions.parser import sfn_role_policy
from tests.test_stepfunctions.test_stepfunctions import simple_definition

Expand All @@ -22,9 +24,9 @@ def test_describe_state_machine_using_version_arn(use_parser):

client = boto3.client("stepfunctions", region_name="us-east-1")

name1 = f"sfn_name_{str(uuid4())[0:6]}"
name = f"sfn_name_{str(uuid4())[0:6]}"
response = client.create_state_machine(
name=name1, definition=simple_definition, roleArn=sfn_role, publish=True
name=name, definition=simple_definition, roleArn=sfn_role, publish=True
)
arn = response["stateMachineArn"]
version_arn1 = response["stateMachineVersionArn"]
Expand All @@ -51,3 +53,84 @@ def test_describe_state_machine_using_version_arn(use_parser):
# Assert that we can still describe the first version of the state machine
version1 = client.describe_state_machine(stateMachineArn=version_arn1)
assert version1["loggingConfiguration"] == {"level": "OFF"}


@aws_verified
@pytest.mark.aws_verified
def test_create_state_machine_with_version_description():
iam = boto3.client("iam", region_name="us-east-1")
role_name = f"sfn_role_{str(uuid4())[0:6]}"
sfn_role = iam.create_role(
RoleName=role_name,
AssumeRolePolicyDocument=json.dumps(sfn_role_policy),
Path="/",
)["Role"]["Arn"]

client = boto3.client("stepfunctions", region_name="us-east-1")

try:
name = f"sfn_name_{str(uuid4())[0:6]}"

response = client.create_state_machine(
name=name,
definition=simple_definition,
roleArn=sfn_role,
versionDescription="first version",
publish=True,
)
arn = response["stateMachineArn"]
version_arn1 = response["stateMachineVersionArn"]

# Use the initial version to describe the state machine
sm = client.describe_state_machine(stateMachineArn=arn)
assert "description" not in sm

version = client.describe_state_machine(stateMachineArn=version_arn1)
assert version["description"] == "first version"

update = client.update_state_machine(
stateMachineArn=arn,
definition=simple_definition,
versionDescription="second version",
publish=True,
)
version_arn2 = update["stateMachineVersionArn"]

version = client.describe_state_machine(stateMachineArn=version_arn2)
assert version["description"] == "second version"
finally:
client.delete_state_machine(stateMachineArn=arn)
iam.delete_role(RoleName=role_name)


@aws_verified
@pytest.mark.aws_verified
def test_create_unpublished_state_machine_with_version_description():
iam = boto3.client("iam", region_name="us-east-1")
role_name = f"sfn_role_{str(uuid4())[0:6]}"
sfn_role = iam.create_role(
RoleName=role_name,
AssumeRolePolicyDocument=json.dumps(sfn_role_policy),
Path="/",
)["Role"]["Arn"]

client = boto3.client("stepfunctions", region_name="us-east-1")

try:
name = f"sfn_name_{str(uuid4())[0:6]}"

with pytest.raises(ClientError) as exc:
client.create_state_machine(
name=name,
definition=simple_definition,
roleArn=sfn_role,
versionDescription="first version of statemachine",
)
err = exc.value.response["Error"]
assert err["Code"] == "ValidationException"
assert (
err["Message"] == "Version description can only be set when publish is true"
)

finally:
iam.delete_role(RoleName=role_name)

0 comments on commit 87214ab

Please sign in to comment.