From 6a1a8df7ccd172f67308b99f6ccf7b1d2d4d1f6d Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Sat, 7 Sep 2019 16:37:55 +0100 Subject: [PATCH] Step Functions - Simplify tests --- moto/stepfunctions/exceptions.py | 21 ++-- moto/stepfunctions/models.py | 22 ++-- .../test_stepfunctions/test_stepfunctions.py | 116 +++++++----------- 3 files changed, 59 insertions(+), 100 deletions(-) diff --git a/moto/stepfunctions/exceptions.py b/moto/stepfunctions/exceptions.py index 133d0cc83..8af4686c7 100644 --- a/moto/stepfunctions/exceptions.py +++ b/moto/stepfunctions/exceptions.py @@ -3,38 +3,33 @@ import json class AWSError(Exception): - CODE = None + TYPE = None STATUS = 400 - def __init__(self, message, code=None, status=None): + def __init__(self, message, type=None, status=None): self.message = message - self.code = code if code is not None else self.CODE + self.type = type if type is not None else self.TYPE self.status = status if status is not None else self.STATUS def response(self): - return json.dumps({'__type': self.code, 'message': self.message}), dict(status=self.status) - - -class AccessDeniedException(AWSError): - CODE = 'AccessDeniedException' - STATUS = 400 + return json.dumps({'__type': self.type, 'message': self.message}), dict(status=self.status) class ExecutionDoesNotExist(AWSError): - CODE = 'ExecutionDoesNotExist' + TYPE = 'ExecutionDoesNotExist' STATUS = 400 class InvalidArn(AWSError): - CODE = 'InvalidArn' + TYPE = 'InvalidArn' STATUS = 400 class InvalidName(AWSError): - CODE = 'InvalidName' + TYPE = 'InvalidName' STATUS = 400 class StateMachineDoesNotExist(AWSError): - CODE = 'StateMachineDoesNotExist' + TYPE = 'StateMachineDoesNotExist' STATUS = 400 diff --git a/moto/stepfunctions/models.py b/moto/stepfunctions/models.py index fd272624f..8db9db1a1 100644 --- a/moto/stepfunctions/models.py +++ b/moto/stepfunctions/models.py @@ -5,7 +5,7 @@ from datetime import datetime from moto.core import BaseBackend from moto.core.utils import iso_8601_datetime_without_milliseconds from uuid import uuid4 -from .exceptions import AccessDeniedException, ExecutionDoesNotExist, InvalidArn, InvalidName, StateMachineDoesNotExist +from .exceptions import ExecutionDoesNotExist, InvalidArn, InvalidName, StateMachineDoesNotExist class StateMachine(): @@ -98,7 +98,11 @@ class StepFunctionBackend(BaseBackend): def start_execution(self, state_machine_arn): state_machine_name = self.describe_state_machine(state_machine_arn).name - execution = Execution(region_name=self.region_name, account_id=self._get_account_id(), state_machine_name=state_machine_name, execution_name=str(uuid4()), state_machine_arn=state_machine_arn) + execution = Execution(region_name=self.region_name, + account_id=self._get_account_id(), + state_machine_name=state_machine_name, + execution_name=str(uuid4()), + state_machine_arn=state_machine_arn) self.executions.append(execution) return execution @@ -134,29 +138,23 @@ class StepFunctionBackend(BaseBackend): def _validate_role_arn(self, role_arn): self._validate_arn(arn=role_arn, regex=self.accepted_role_arn_format, - invalid_msg="Invalid Role Arn: '" + role_arn + "'", - access_denied_msg='Cross-account pass role is not allowed.') + invalid_msg="Invalid Role Arn: '" + role_arn + "'") def _validate_machine_arn(self, machine_arn): self._validate_arn(arn=machine_arn, regex=self.accepted_mchn_arn_format, - invalid_msg="Invalid Role Arn: '" + machine_arn + "'", - access_denied_msg='User moto is not authorized to access this resource') + invalid_msg="Invalid Role Arn: '" + machine_arn + "'") def _validate_execution_arn(self, execution_arn): self._validate_arn(arn=execution_arn, regex=self.accepted_exec_arn_format, - invalid_msg="Execution Does Not Exist: '" + execution_arn + "'", - access_denied_msg='User moto is not authorized to access this resource') + invalid_msg="Execution Does Not Exist: '" + execution_arn + "'") - def _validate_arn(self, arn, regex, invalid_msg, access_denied_msg): + def _validate_arn(self, arn, regex, invalid_msg): match = regex.match(arn) if not arn or not match: raise InvalidArn(invalid_msg) - if self._get_account_id() != match.group('account_id'): - raise AccessDeniedException(access_denied_msg) - def _get_account_id(self): if self._account_id: return self._account_id diff --git a/tests/test_stepfunctions/test_stepfunctions.py b/tests/test_stepfunctions/test_stepfunctions.py index bf5c92570..10953ce2d 100644 --- a/tests/test_stepfunctions/test_stepfunctions.py +++ b/tests/test_stepfunctions/test_stepfunctions.py @@ -6,7 +6,6 @@ import datetime from datetime import datetime from botocore.exceptions import ClientError -from moto.config.models import DEFAULT_ACCOUNT_ID from nose.tools import assert_raises from moto import mock_sts, mock_stepfunctions @@ -17,7 +16,7 @@ simple_definition = '{"Comment": "An example of the Amazon States Language using '"StartAt": "DefaultState",' \ '"States": ' \ '{"DefaultState": {"Type": "Fail","Error": "DefaultStateError","Cause": "No Matches!"}}}' -default_stepfunction_role = 'arn:aws:iam:' + str(DEFAULT_ACCOUNT_ID) + ':role/unknown_sf_role' +account_id = None @mock_stepfunctions @@ -28,7 +27,7 @@ def test_state_machine_creation_succeeds(): # response = client.create_state_machine(name=name, definition=str(simple_definition), - roleArn=default_stepfunction_role) + roleArn=_get_default_role()) # response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) response['creationDate'].should.be.a(datetime) @@ -68,10 +67,7 @@ def test_state_machine_creation_fails_with_invalid_names(): with assert_raises(ClientError) as exc: client.create_state_machine(name=invalid_name, definition=str(simple_definition), - roleArn=default_stepfunction_role) - exc.exception.response['Error']['Code'].should.equal('InvalidName') - exc.exception.response['Error']['Message'].should.equal("Invalid Name: '" + invalid_name + "'") - exc.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) + roleArn=_get_default_role()) @mock_stepfunctions @@ -83,24 +79,6 @@ def test_state_machine_creation_requires_valid_role_arn(): client.create_state_machine(name=name, definition=str(simple_definition), roleArn='arn:aws:iam:1234:role/unknown_role') - exc.exception.response['Error']['Code'].should.equal('InvalidArn') - exc.exception.response['Error']['Message'].should.equal("Invalid Role Arn: 'arn:aws:iam:1234:role/unknown_role'") - exc.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) - - -@mock_stepfunctions -@mock_sts -def test_state_machine_creation_requires_role_in_same_account(): - client = boto3.client('stepfunctions', region_name=region) - name = 'example_step_function' - # - with assert_raises(ClientError) as exc: - client.create_state_machine(name=name, - definition=str(simple_definition), - roleArn='arn:aws:iam:000000000000:role/unknown_role') - exc.exception.response['Error']['Code'].should.equal('AccessDeniedException') - exc.exception.response['Error']['Message'].should.equal('Cross-account pass role is not allowed.') - exc.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) @mock_stepfunctions @@ -118,10 +96,10 @@ def test_state_machine_list_returns_created_state_machines(): # machine2 = client.create_state_machine(name='name2', definition=str(simple_definition), - roleArn=default_stepfunction_role) + roleArn=_get_default_role()) machine1 = client.create_state_machine(name='name1', definition=str(simple_definition), - roleArn=default_stepfunction_role, + roleArn=_get_default_role(), tags=[{'key': 'tag_key', 'value': 'tag_value'}]) list = client.list_state_machines() # @@ -142,15 +120,15 @@ def test_state_machine_list_returns_created_state_machines(): def test_state_machine_creation_is_idempotent_by_name(): client = boto3.client('stepfunctions', region_name=region) # - client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) + client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) sm_list = client.list_state_machines() sm_list['stateMachines'].should.have.length_of(1) # - client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) + client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) sm_list = client.list_state_machines() sm_list['stateMachines'].should.have.length_of(1) # - client.create_state_machine(name='diff_name', definition=str(simple_definition), roleArn=default_stepfunction_role) + client.create_state_machine(name='diff_name', definition=str(simple_definition), roleArn=_get_default_role()) sm_list = client.list_state_machines() sm_list['stateMachines'].should.have.length_of(2) @@ -160,13 +138,13 @@ def test_state_machine_creation_is_idempotent_by_name(): def test_state_machine_creation_can_be_described(): client = boto3.client('stepfunctions', region_name=region) # - sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) desc = client.describe_state_machine(stateMachineArn=sm['stateMachineArn']) desc['ResponseMetadata']['HTTPStatusCode'].should.equal(200) desc['creationDate'].should.equal(sm['creationDate']) desc['definition'].should.equal(str(simple_definition)) desc['name'].should.equal('name') - desc['roleArn'].should.equal(default_stepfunction_role) + desc['roleArn'].should.equal(_get_default_role()) desc['stateMachineArn'].should.equal(sm['stateMachineArn']) desc['status'].should.equal('ACTIVE') @@ -177,12 +155,8 @@ def test_state_machine_throws_error_when_describing_unknown_machine(): client = boto3.client('stepfunctions', region_name=region) # with assert_raises(ClientError) as exc: - unknown_state_machine = 'arn:aws:states:' + region + ':' + str(DEFAULT_ACCOUNT_ID) + ':stateMachine:unknown' + unknown_state_machine = 'arn:aws:states:' + region + ':' + _get_account_id() + ':stateMachine:unknown' client.describe_state_machine(stateMachineArn=unknown_state_machine) - exc.exception.response['Error']['Code'].should.equal('StateMachineDoesNotExist') - exc.exception.response['Error']['Message'].\ - should.equal("State Machine Does Not Exist: '" + unknown_state_machine + "'") - exc.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) @mock_stepfunctions @@ -193,16 +167,13 @@ def test_state_machine_throws_error_when_describing_machine_in_different_account with assert_raises(ClientError) as exc: unknown_state_machine = 'arn:aws:states:' + region + ':000000000000:stateMachine:unknown' client.describe_state_machine(stateMachineArn=unknown_state_machine) - exc.exception.response['Error']['Code'].should.equal('AccessDeniedException') - exc.exception.response['Error']['Message'].should.contain('is not authorized to access this resource') - exc.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) @mock_stepfunctions @mock_sts def test_state_machine_can_be_deleted(): client = boto3.client('stepfunctions', region_name=region) - sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) # response = client.delete_state_machine(stateMachineArn=sm['stateMachineArn']) response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) @@ -224,19 +195,6 @@ def test_state_machine_can_deleted_nonexisting_machine(): sm_list['stateMachines'].should.have.length_of(0) -@mock_stepfunctions -@mock_sts -def test_state_machine_deletion_validates_arn(): - client = boto3.client('stepfunctions', region_name=region) - # - with assert_raises(ClientError) as exc: - unknown_account_id = 'arn:aws:states:' + region + ':000000000000:stateMachine:unknown' - client.delete_state_machine(stateMachineArn=unknown_account_id) - exc.exception.response['Error']['Code'].should.equal('AccessDeniedException') - exc.exception.response['Error']['Message'].should.contain('is not authorized to access this resource') - exc.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) - - @mock_stepfunctions @mock_sts def test_state_machine_list_tags_for_created_machine(): @@ -244,7 +202,7 @@ def test_state_machine_list_tags_for_created_machine(): # machine = client.create_state_machine(name='name1', definition=str(simple_definition), - roleArn=default_stepfunction_role, + roleArn=_get_default_role(), tags=[{'key': 'tag_key', 'value': 'tag_value'}]) response = client.list_tags_for_resource(resourceArn=machine['stateMachineArn']) tags = response['tags'] @@ -259,7 +217,7 @@ def test_state_machine_list_tags_for_machine_without_tags(): # machine = client.create_state_machine(name='name1', definition=str(simple_definition), - roleArn=default_stepfunction_role) + roleArn=_get_default_role()) response = client.list_tags_for_resource(resourceArn=machine['stateMachineArn']) tags = response['tags'] tags.should.have.length_of(0) @@ -270,7 +228,7 @@ def test_state_machine_list_tags_for_machine_without_tags(): def test_state_machine_list_tags_for_nonexisting_machine(): client = boto3.client('stepfunctions', region_name=region) # - non_existing_state_machine = 'arn:aws:states:' + region + ':' + str(DEFAULT_ACCOUNT_ID) + ':stateMachine:unknown' + non_existing_state_machine = 'arn:aws:states:' + region + ':' + _get_account_id() + ':stateMachine:unknown' response = client.list_tags_for_resource(resourceArn=non_existing_state_machine) tags = response['tags'] tags.should.have.length_of(0) @@ -281,11 +239,11 @@ def test_state_machine_list_tags_for_nonexisting_machine(): def test_state_machine_start_execution(): client = boto3.client('stepfunctions', region_name=region) # - sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) execution = client.start_execution(stateMachineArn=sm['stateMachineArn']) # execution['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - expected_exec_name = 'arn:aws:states:' + region + ':' + str(DEFAULT_ACCOUNT_ID) + ':execution:name:[a-zA-Z0-9-]+' + expected_exec_name = 'arn:aws:states:' + region + ':' + _get_account_id() + ':execution:name:[a-zA-Z0-9-]+' execution['executionArn'].should.match(expected_exec_name) execution['startDate'].should.be.a(datetime) @@ -295,7 +253,7 @@ def test_state_machine_start_execution(): def test_state_machine_list_executions(): client = boto3.client('stepfunctions', region_name=region) # - sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) execution = client.start_execution(stateMachineArn=sm['stateMachineArn']) execution_arn = execution['executionArn'] execution_name = execution_arn[execution_arn.rindex(':')+1:] @@ -316,7 +274,7 @@ def test_state_machine_list_executions(): def test_state_machine_list_executions_when_none_exist(): client = boto3.client('stepfunctions', region_name=region) # - sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) executions = client.list_executions(stateMachineArn=sm['stateMachineArn']) # executions['ResponseMetadata']['HTTPStatusCode'].should.equal(200) @@ -328,7 +286,7 @@ def test_state_machine_list_executions_when_none_exist(): def test_state_machine_describe_execution(): client = boto3.client('stepfunctions', region_name=region) # - sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) execution = client.start_execution(stateMachineArn=sm['stateMachineArn']) description = client.describe_execution(executionArn=execution['executionArn']) # @@ -348,11 +306,8 @@ def test_state_machine_throws_error_when_describing_unknown_machine(): client = boto3.client('stepfunctions', region_name=region) # with assert_raises(ClientError) as exc: - unknown_execution = 'arn:aws:states:' + region + ':' + str(DEFAULT_ACCOUNT_ID) + ':execution:unknown' + unknown_execution = 'arn:aws:states:' + region + ':' + _get_account_id() + ':execution:unknown' client.describe_execution(executionArn=unknown_execution) - exc.exception.response['Error']['Code'].should.equal('ExecutionDoesNotExist') - exc.exception.response['Error']['Message'].should.equal("Execution Does Not Exist: '" + unknown_execution + "'") - exc.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) @mock_stepfunctions @@ -360,13 +315,13 @@ def test_state_machine_throws_error_when_describing_unknown_machine(): def test_state_machine_can_be_described_by_execution(): client = boto3.client('stepfunctions', region_name=region) # - sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) execution = client.start_execution(stateMachineArn=sm['stateMachineArn']) desc = client.describe_state_machine_for_execution(executionArn=execution['executionArn']) desc['ResponseMetadata']['HTTPStatusCode'].should.equal(200) desc['definition'].should.equal(str(simple_definition)) desc['name'].should.equal('name') - desc['roleArn'].should.equal(default_stepfunction_role) + desc['roleArn'].should.equal(_get_default_role()) desc['stateMachineArn'].should.equal(sm['stateMachineArn']) @@ -376,11 +331,8 @@ def test_state_machine_throws_error_when_describing_unknown_execution(): client = boto3.client('stepfunctions', region_name=region) # with assert_raises(ClientError) as exc: - unknown_execution = 'arn:aws:states:' + region + ':' + str(DEFAULT_ACCOUNT_ID) + ':execution:unknown' + unknown_execution = 'arn:aws:states:' + region + ':' + _get_account_id() + ':execution:unknown' client.describe_state_machine_for_execution(executionArn=unknown_execution) - exc.exception.response['Error']['Code'].should.equal('ExecutionDoesNotExist') - exc.exception.response['Error']['Message'].should.equal("Execution Does Not Exist: '" + unknown_execution + "'") - exc.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) @mock_stepfunctions @@ -388,10 +340,9 @@ def test_state_machine_throws_error_when_describing_unknown_execution(): def test_state_machine_stop_execution(): client = boto3.client('stepfunctions', region_name=region) # - sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) start = client.start_execution(stateMachineArn=sm['stateMachineArn']) stop = client.stop_execution(executionArn=start['executionArn']) - print(stop) # stop['ResponseMetadata']['HTTPStatusCode'].should.equal(200) stop['stopDate'].should.be.a(datetime) @@ -400,9 +351,10 @@ def test_state_machine_stop_execution(): @mock_stepfunctions @mock_sts def test_state_machine_describe_execution_after_stoppage(): + account_id client = boto3.client('stepfunctions', region_name=region) # - sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) execution = client.start_execution(stateMachineArn=sm['stateMachineArn']) client.stop_execution(executionArn=execution['executionArn']) description = client.describe_execution(executionArn=execution['executionArn']) @@ -410,3 +362,17 @@ def test_state_machine_describe_execution_after_stoppage(): description['ResponseMetadata']['HTTPStatusCode'].should.equal(200) description['status'].should.equal('SUCCEEDED') description['stopDate'].should.be.a(datetime) + + +def _get_account_id(): + global account_id + if account_id: + return account_id + sts = boto3.client("sts") + identity = sts.get_caller_identity() + account_id = identity['Account'] + return account_id + + +def _get_default_role(): + return 'arn:aws:iam:' + _get_account_id() + ':role/unknown_sf_role'