diff --git a/IMPLEMENTATION_COVERAGE.md b/IMPLEMENTATION_COVERAGE.md index a108361d3..110cd7b6b 100644 --- a/IMPLEMENTATION_COVERAGE.md +++ b/IMPLEMENTATION_COVERAGE.md @@ -8125,9 +8125,9 @@ - [ ] send_task_success - [X] start_execution - [X] stop_execution -- [ ] tag_resource -- [ ] untag_resource -- [ ] update_state_machine +- [X] tag_resource +- [X] untag_resource +- [X] update_state_machine ## storagegateway diff --git a/moto/stepfunctions/exceptions.py b/moto/stepfunctions/exceptions.py index b5fd2ddb9..c184e2cc7 100644 --- a/moto/stepfunctions/exceptions.py +++ b/moto/stepfunctions/exceptions.py @@ -54,3 +54,11 @@ class InvalidToken(AWSError): def __init__(self, message="Invalid token"): super(InvalidToken, self).__init__("Invalid Token: {}".format(message)) + + +class ResourceNotFound(AWSError): + TYPE = "ResourceNotFound" + STATUS = 400 + + def __init__(self, arn): + super(ResourceNotFound, self).__init__("Resource not found: '{}'".format(arn)) diff --git a/moto/stepfunctions/models.py b/moto/stepfunctions/models.py index 9dfa33ba8..86c76c98a 100644 --- a/moto/stepfunctions/models.py +++ b/moto/stepfunctions/models.py @@ -13,6 +13,7 @@ from .exceptions import ( InvalidArn, InvalidExecutionInput, InvalidName, + ResourceNotFound, StateMachineDoesNotExist, ) from .utils import paginate @@ -21,11 +22,41 @@ from .utils import paginate class StateMachine(CloudFormationModel): def __init__(self, arn, name, definition, roleArn, tags=None): self.creation_date = iso_8601_datetime_with_milliseconds(datetime.now()) + self.update_date = self.creation_date self.arn = arn self.name = name self.definition = definition self.roleArn = roleArn - self.tags = tags + self.tags = [] + if tags: + self.add_tags(tags) + + def update(self, **kwargs): + for key, value in kwargs.items(): + if value is not None: + setattr(self, key, value) + self.update_date = iso_8601_datetime_with_milliseconds(datetime.now()) + + def add_tags(self, tags): + merged_tags = [] + for tag in self.tags: + replacement_index = next( + (index for (index, d) in enumerate(tags) if d["key"] == tag["key"]), + None, + ) + if replacement_index is not None: + replacement = tags.pop(replacement_index) + merged_tags.append(replacement) + else: + merged_tags.append(tag) + for tag in tags: + merged_tags.append(tag) + self.tags = merged_tags + return self.tags + + def remove_tags(self, tag_keys): + self.tags = [tag_set for tag_set in self.tags if tag_set["key"] not in tag_keys] + return self.tags @property def physical_resource_id(self): @@ -249,6 +280,15 @@ class StepFunctionBackend(BaseBackend): if sm: self.state_machines.remove(sm) + def update_state_machine(self, arn, definition=None, role_arn=None): + sm = self.describe_state_machine(arn) + updates = { + "definition": definition, + "roleArn": role_arn, + } + sm.update(**updates) + return sm + def start_execution(self, state_machine_arn, name=None, execution_input=None): state_machine_name = self.describe_state_machine(state_machine_arn).name self._ensure_execution_name_doesnt_exist(name) @@ -296,6 +336,20 @@ class StepFunctionBackend(BaseBackend): raise ExecutionDoesNotExist("Execution Does Not Exist: '" + arn + "'") return exctn + def tag_resource(self, resource_arn, tags): + try: + state_machine = self.describe_state_machine(resource_arn) + state_machine.add_tags(tags) + except StateMachineDoesNotExist: + raise ResourceNotFound(resource_arn) + + def untag_resource(self, resource_arn, tag_keys): + try: + state_machine = self.describe_state_machine(resource_arn) + state_machine.remove_tags(tag_keys) + except StateMachineDoesNotExist: + raise ResourceNotFound(resource_arn) + def reset(self): region_name = self.region_name self.__dict__ = {} diff --git a/moto/stepfunctions/responses.py b/moto/stepfunctions/responses.py index 7106d81d0..7eae8091b 100644 --- a/moto/stepfunctions/responses.py +++ b/moto/stepfunctions/responses.py @@ -83,6 +83,22 @@ class StepFunctionResponse(BaseResponse): except AWSError as err: return err.response() + @amzn_request_id + def update_state_machine(self): + arn = self._get_param("stateMachineArn") + definition = self._get_param("definition") + role_arn = self._get_param("roleArn") + try: + state_machine = self.stepfunction_backend.update_state_machine( + arn=arn, definition=definition, role_arn=role_arn + ) + response = { + "updateDate": state_machine.update_date, + } + return 200, {}, json.dumps(response) + except AWSError as err: + return err.response() + @amzn_request_id def list_tags_for_resource(self): arn = self._get_param("resourceArn") @@ -94,6 +110,26 @@ class StepFunctionResponse(BaseResponse): response = {"tags": tags} return 200, {}, json.dumps(response) + @amzn_request_id + def tag_resource(self): + arn = self._get_param("resourceArn") + tags = self._get_param("tags", []) + try: + self.stepfunction_backend.tag_resource(arn, tags) + except AWSError as err: + return err.response() + return 200, {}, json.dumps({}) + + @amzn_request_id + def untag_resource(self): + arn = self._get_param("resourceArn") + tag_keys = self._get_param("tagKeys", []) + try: + self.stepfunction_backend.untag_resource(arn, tag_keys) + except AWSError as err: + return err.response() + return 200, {}, json.dumps({}) + @amzn_request_id def start_execution(self): arn = self._get_param("stateMachineArn") diff --git a/tests/test_stepfunctions/test_stepfunctions.py b/tests/test_stepfunctions/test_stepfunctions.py index 1c961b882..0bea43084 100644 --- a/tests/test_stepfunctions/test_stepfunctions.py +++ b/tests/test_stepfunctions/test_stepfunctions.py @@ -155,6 +155,33 @@ def test_state_machine_creation_requires_valid_role_arn(): ) +@mock_stepfunctions +@mock_sts +def test_update_state_machine(): + client = boto3.client("stepfunctions", region_name=region) + + resp = client.create_state_machine( + name="test", definition=str(simple_definition), roleArn=_get_default_role() + ) + state_machine_arn = resp["stateMachineArn"] + + updated_role = _get_default_role() + "-updated" + updated_definition = str(simple_definition).replace( + "DefaultState", "DefaultStateUpdated" + ) + resp = client.update_state_machine( + stateMachineArn=state_machine_arn, + definition=updated_definition, + roleArn=updated_role, + ) + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + resp["updateDate"].should.be.a(datetime) + + desc = client.describe_state_machine(stateMachineArn=state_machine_arn) + desc["definition"].should.equal(updated_definition) + desc["roleArn"].should.equal(updated_role) + + @mock_stepfunctions def test_state_machine_list_returns_empty_list_by_default(): client = boto3.client("stepfunctions", region_name=region) @@ -326,6 +353,85 @@ def test_state_machine_can_deleted_nonexisting_machine(): sm_list["stateMachines"].should.have.length_of(0) +@mock_stepfunctions +def test_state_machine_tagging_non_existent_resource_fails(): + client = boto3.client("stepfunctions", region_name=region) + non_existent_arn = "arn:aws:states:{region}:{account}:stateMachine:non-existent".format( + region=region, account=ACCOUNT_ID + ) + with assert_raises(ClientError) as ex: + client.tag_resource(resourceArn=non_existent_arn, tags=[]) + ex.exception.response["Error"]["Code"].should.equal("ResourceNotFound") + ex.exception.response["Error"]["Message"].should.contain(non_existent_arn) + + +@mock_stepfunctions +def test_state_machine_untagging_non_existent_resource_fails(): + client = boto3.client("stepfunctions", region_name=region) + non_existent_arn = "arn:aws:states:{region}:{account}:stateMachine:non-existent".format( + region=region, account=ACCOUNT_ID + ) + with assert_raises(ClientError) as ex: + client.untag_resource(resourceArn=non_existent_arn, tagKeys=[]) + ex.exception.response["Error"]["Code"].should.equal("ResourceNotFound") + ex.exception.response["Error"]["Message"].should.contain(non_existent_arn) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_tagging(): + client = boto3.client("stepfunctions", region_name=region) + tags = [ + {"key": "tag_key1", "value": "tag_value1"}, + {"key": "tag_key2", "value": "tag_value2"}, + ] + machine = client.create_state_machine( + name="test", definition=str(simple_definition), roleArn=_get_default_role(), + ) + client.tag_resource(resourceArn=machine["stateMachineArn"], tags=tags) + resp = client.list_tags_for_resource(resourceArn=machine["stateMachineArn"]) + resp["tags"].should.equal(tags) + + tags_update = [ + {"key": "tag_key1", "value": "tag_value1_new"}, + {"key": "tag_key3", "value": "tag_value3"}, + ] + client.tag_resource(resourceArn=machine["stateMachineArn"], tags=tags_update) + resp = client.list_tags_for_resource(resourceArn=machine["stateMachineArn"]) + tags_expected = [ + tags_update[0], + tags[1], + tags_update[1], + ] + resp["tags"].should.equal(tags_expected) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_untagging(): + client = boto3.client("stepfunctions", region_name=region) + tags = [ + {"key": "tag_key1", "value": "tag_value1"}, + {"key": "tag_key2", "value": "tag_value2"}, + {"key": "tag_key3", "value": "tag_value3"}, + ] + machine = client.create_state_machine( + name="test", + definition=str(simple_definition), + roleArn=_get_default_role(), + tags=tags, + ) + resp = client.list_tags_for_resource(resourceArn=machine["stateMachineArn"]) + resp["tags"].should.equal(tags) + tags_to_delete = ["tag_key1", "tag_key2"] + client.untag_resource( + resourceArn=machine["stateMachineArn"], tagKeys=tags_to_delete + ) + resp = client.list_tags_for_resource(resourceArn=machine["stateMachineArn"]) + expected_tags = [tag for tag in tags if tag["key"] not in tags_to_delete] + resp["tags"].should.equal(expected_tags) + + @mock_stepfunctions @mock_sts def test_state_machine_list_tags_for_created_machine():