From c7147b06b119f68483a532daa6fe8a248bc3feb0 Mon Sep 17 00:00:00 2001 From: Steve Pulec Date: Wed, 10 Jul 2019 21:59:25 -0500 Subject: [PATCH 01/67] Cleanup responses mocking. Closes #1567 This unblocks requests to other websites with requests while Moto is activated. It also adds a wildcard for AWS services to still ensure no accidental requests are made for unmocked services --- moto/apigateway/models.py | 14 +++-- moto/core/models.py | 75 +++++++++---------------- tests/test_core/test_request_mocking.py | 21 +++++++ 3 files changed, 59 insertions(+), 51 deletions(-) create mode 100644 tests/test_core/test_request_mocking.py diff --git a/moto/apigateway/models.py b/moto/apigateway/models.py index 41a49e361..d8c926811 100644 --- a/moto/apigateway/models.py +++ b/moto/apigateway/models.py @@ -388,10 +388,16 @@ class RestAPI(BaseModel): stage_url_upper = STAGE_URL.format(api_id=self.id.upper(), region_name=self.region_name, stage_name=stage_name) - responses.add_callback(responses.GET, stage_url_lower, - callback=self.resource_callback) - responses.add_callback(responses.GET, stage_url_upper, - callback=self.resource_callback) + for url in [stage_url_lower, stage_url_upper]: + responses._default_mock._matches.insert(0, + responses.CallbackResponse( + url=url, + method=responses.GET, + callback=self.resource_callback, + content_type="text/plain", + match_querystring=False, + ) + ) def create_stage(self, name, deployment_id, variables=None, description='', cacheClusterEnabled=None, cacheClusterSize=None): if variables is None: diff --git a/moto/core/models.py b/moto/core/models.py index 9fe1e96bd..94f75dafb 100644 --- a/moto/core/models.py +++ b/moto/core/models.py @@ -193,53 +193,8 @@ class CallbackResponse(responses.CallbackResponse): botocore_mock = responses.RequestsMock(assert_all_requests_are_fired=False, target='botocore.vendored.requests.adapters.HTTPAdapter.send') responses_mock = responses._default_mock - - -class ResponsesMockAWS(BaseMockAWS): - def reset(self): - botocore_mock.reset() - responses_mock.reset() - - def enable_patching(self): - if not hasattr(botocore_mock, '_patcher') or not hasattr(botocore_mock._patcher, 'target'): - # Check for unactivated patcher - botocore_mock.start() - - if not hasattr(responses_mock, '_patcher') or not hasattr(responses_mock._patcher, 'target'): - responses_mock.start() - - for method in RESPONSES_METHODS: - for backend in self.backends_for_urls.values(): - for key, value in backend.urls.items(): - responses_mock.add( - CallbackResponse( - method=method, - url=re.compile(key), - callback=convert_flask_to_responses_response(value), - stream=True, - match_querystring=False, - ) - ) - botocore_mock.add( - CallbackResponse( - method=method, - url=re.compile(key), - callback=convert_flask_to_responses_response(value), - stream=True, - match_querystring=False, - ) - ) - - def disable_patching(self): - try: - botocore_mock.stop() - except RuntimeError: - pass - - try: - responses_mock.stop() - except RuntimeError: - pass +# Add passthrough to allow any other requests to work +responses_mock.add_passthru("http") BOTOCORE_HTTP_METHODS = [ @@ -306,6 +261,14 @@ botocore_stubber = BotocoreStubber() BUILTIN_HANDLERS.append(('before-send', botocore_stubber)) +def not_implemented_callback(request): + status = 400 + headers = {} + response = "The method is not implemented" + + return status, headers, response + + class BotocoreEventMockAWS(BaseMockAWS): def reset(self): botocore_stubber.reset() @@ -335,6 +298,24 @@ class BotocoreEventMockAWS(BaseMockAWS): match_querystring=False, ) ) + responses_mock.add( + CallbackResponse( + method=method, + url=re.compile("https?://.+.amazonaws.com/.*"), + callback=not_implemented_callback, + stream=True, + match_querystring=False, + ) + ) + botocore_mock.add( + CallbackResponse( + method=method, + url=re.compile("https?://.+.amazonaws.com/.*"), + callback=not_implemented_callback, + stream=True, + match_querystring=False, + ) + ) def disable_patching(self): botocore_stubber.enabled = False diff --git a/tests/test_core/test_request_mocking.py b/tests/test_core/test_request_mocking.py new file mode 100644 index 000000000..fd9f85ab6 --- /dev/null +++ b/tests/test_core/test_request_mocking.py @@ -0,0 +1,21 @@ +import requests +import sure # noqa + +import boto3 +from moto import mock_sqs + + +@mock_sqs +def test_passthrough_requests(): + conn = boto3.client("sqs", region_name='us-west-1') + conn.create_queue(QueueName="queue1") + + res = requests.get("https://httpbin.org/ip") + assert res.status_code == 200 + + +@mock_sqs +def test_requests_to_amazon_subdomains_dont_work(): + res = requests.get("https://fakeservice.amazonaws.com/foo/bar") + assert res.content == b"The method is not implemented" + assert res.status_code == 400 From 11506e21d583f996415ee29d0a78f432e06663dd Mon Sep 17 00:00:00 2001 From: Steve Pulec Date: Wed, 10 Jul 2019 22:45:26 -0500 Subject: [PATCH 02/67] Only test passthrough exception in non-server mode. --- tests/test_core/test_request_mocking.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/test_core/test_request_mocking.py b/tests/test_core/test_request_mocking.py index fd9f85ab6..e07c25123 100644 --- a/tests/test_core/test_request_mocking.py +++ b/tests/test_core/test_request_mocking.py @@ -2,7 +2,7 @@ import requests import sure # noqa import boto3 -from moto import mock_sqs +from moto import mock_sqs, settings @mock_sqs @@ -14,8 +14,9 @@ def test_passthrough_requests(): assert res.status_code == 200 -@mock_sqs -def test_requests_to_amazon_subdomains_dont_work(): - res = requests.get("https://fakeservice.amazonaws.com/foo/bar") - assert res.content == b"The method is not implemented" - assert res.status_code == 400 +if settings.TEST_SERVER_MODE: + @mock_sqs + def test_requests_to_amazon_subdomains_dont_work(): + res = requests.get("https://fakeservice.amazonaws.com/foo/bar") + assert res.content == b"The method is not implemented" + assert res.status_code == 400 From 2a0df1e1a87c3290a4cb7b24bcc2cac6e40f4e33 Mon Sep 17 00:00:00 2001 From: Steve Pulec Date: Thu, 11 Jul 2019 10:09:01 -0500 Subject: [PATCH 03/67] Flip when we test passthru. --- tests/test_core/test_request_mocking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_core/test_request_mocking.py b/tests/test_core/test_request_mocking.py index e07c25123..ee3ec5f88 100644 --- a/tests/test_core/test_request_mocking.py +++ b/tests/test_core/test_request_mocking.py @@ -14,7 +14,7 @@ def test_passthrough_requests(): assert res.status_code == 200 -if settings.TEST_SERVER_MODE: +if not settings.TEST_SERVER_MODE: @mock_sqs def test_requests_to_amazon_subdomains_dont_work(): res = requests.get("https://fakeservice.amazonaws.com/foo/bar") From 7091be8eaec81a6044649f811341af4691de7184 Mon Sep 17 00:00:00 2001 From: Daniel Guerrero Date: Mon, 29 Jul 2019 21:13:58 -0500 Subject: [PATCH 04/67] Adding support for AT_SEQUENCE_NUMBER and AFTER_SEQUENCE_NUMBER Adding support on DynamoDB Streams for AT_SEQUENCE_NUMBER and AFTER_SEQUENCE_NUMBER ShardIteratorType Change SequenceNumber type to string instead of int to match documentation --- moto/dynamodb2/models.py | 2 +- moto/dynamodbstreams/models.py | 2 +- moto/dynamodbstreams/responses.py | 7 ++- .../test_dynamodbstreams.py | 56 +++++++++++++++++++ 4 files changed, 64 insertions(+), 3 deletions(-) diff --git a/moto/dynamodb2/models.py b/moto/dynamodb2/models.py index e868caaa8..4ef4461cd 100644 --- a/moto/dynamodb2/models.py +++ b/moto/dynamodb2/models.py @@ -363,7 +363,7 @@ class StreamRecord(BaseModel): 'dynamodb': { 'StreamViewType': stream_type, 'ApproximateCreationDateTime': datetime.datetime.utcnow().isoformat(), - 'SequenceNumber': seq, + 'SequenceNumber': str(seq), 'SizeBytes': 1, 'Keys': keys } diff --git a/moto/dynamodbstreams/models.py b/moto/dynamodbstreams/models.py index 41cc6e280..3e20ae13f 100644 --- a/moto/dynamodbstreams/models.py +++ b/moto/dynamodbstreams/models.py @@ -39,7 +39,7 @@ class ShardIterator(BaseModel): def get(self, limit=1000): items = self.stream_shard.get(self.sequence_number, limit) try: - last_sequence_number = max(i['dynamodb']['SequenceNumber'] for i in items) + last_sequence_number = max(int(i['dynamodb']['SequenceNumber']) for i in items) new_shard_iterator = ShardIterator(self.streams_backend, self.stream_shard, 'AFTER_SEQUENCE_NUMBER', diff --git a/moto/dynamodbstreams/responses.py b/moto/dynamodbstreams/responses.py index c9c113615..0e2800f55 100644 --- a/moto/dynamodbstreams/responses.py +++ b/moto/dynamodbstreams/responses.py @@ -23,8 +23,13 @@ class DynamoDBStreamsHandler(BaseResponse): arn = self._get_param('StreamArn') shard_id = self._get_param('ShardId') shard_iterator_type = self._get_param('ShardIteratorType') + sequence_number = self._get_param('SequenceNumber') + #according to documentation sequence_number param should be string + if isinstance(sequence_number, str): + sequence_number = int(sequence_number) + return self.backend.get_shard_iterator(arn, shard_id, - shard_iterator_type) + shard_iterator_type, sequence_number) def get_records(self): arn = self._get_param('ShardIterator') diff --git a/tests/test_dynamodbstreams/test_dynamodbstreams.py b/tests/test_dynamodbstreams/test_dynamodbstreams.py index b60c21053..f1c59fa29 100644 --- a/tests/test_dynamodbstreams/test_dynamodbstreams.py +++ b/tests/test_dynamodbstreams/test_dynamodbstreams.py @@ -76,6 +76,34 @@ class TestCore(): ShardIteratorType='TRIM_HORIZON' ) assert 'ShardIterator' in resp + + def test_get_shard_iterator_at_sequence_number(self): + conn = boto3.client('dynamodbstreams', region_name='us-east-1') + + resp = conn.describe_stream(StreamArn=self.stream_arn) + shard_id = resp['StreamDescription']['Shards'][0]['ShardId'] + + resp = conn.get_shard_iterator( + StreamArn=self.stream_arn, + ShardId=shard_id, + ShardIteratorType='AT_SEQUENCE_NUMBER', + SequenceNumber=resp['StreamDescription']['Shards'][0]['SequenceNumberRange']['StartingSequenceNumber'] + ) + assert 'ShardIterator' in resp + + def test_get_shard_iterator_after_sequence_number(self): + conn = boto3.client('dynamodbstreams', region_name='us-east-1') + + resp = conn.describe_stream(StreamArn=self.stream_arn) + shard_id = resp['StreamDescription']['Shards'][0]['ShardId'] + + resp = conn.get_shard_iterator( + StreamArn=self.stream_arn, + ShardId=shard_id, + ShardIteratorType='AFTER_SEQUENCE_NUMBER', + SequenceNumber=resp['StreamDescription']['Shards'][0]['SequenceNumberRange']['StartingSequenceNumber'] + ) + assert 'ShardIterator' in resp def test_get_records_empty(self): conn = boto3.client('dynamodbstreams', region_name='us-east-1') @@ -135,11 +163,39 @@ class TestCore(): assert resp['Records'][1]['eventName'] == 'MODIFY' assert resp['Records'][2]['eventName'] == 'DELETE' + sequence_number_modify = resp['Records'][1]['dynamodb']['SequenceNumber'] + # now try fetching from the next shard iterator, it should be # empty resp = conn.get_records(ShardIterator=resp['NextShardIterator']) assert len(resp['Records']) == 0 + #check that if we get the shard iterator AT_SEQUENCE_NUMBER will get the MODIFY event + resp = conn.get_shard_iterator( + StreamArn=self.stream_arn, + ShardId=shard_id, + ShardIteratorType='AT_SEQUENCE_NUMBER', + SequenceNumber=sequence_number_modify + ) + iterator_id = resp['ShardIterator'] + resp = conn.get_records(ShardIterator=iterator_id) + assert len(resp['Records']) == 2 + assert resp['Records'][0]['eventName'] == 'MODIFY' + assert resp['Records'][1]['eventName'] == 'DELETE' + + #check that if we get the shard iterator AFTER_SEQUENCE_NUMBER will get the DELETE event + resp = conn.get_shard_iterator( + StreamArn=self.stream_arn, + ShardId=shard_id, + ShardIteratorType='AFTER_SEQUENCE_NUMBER', + SequenceNumber=sequence_number_modify + ) + iterator_id = resp['ShardIterator'] + resp = conn.get_records(ShardIterator=iterator_id) + assert len(resp['Records']) == 1 + assert resp['Records'][0]['eventName'] == 'DELETE' + + class TestEdges(): mocks = [] From bfc401c520b38f60adb29d5aba09bcead2da13c0 Mon Sep 17 00:00:00 2001 From: Daniel Guerrero Date: Mon, 29 Jul 2019 21:21:02 -0500 Subject: [PATCH 05/67] Fixing comments conventions --- moto/dynamodbstreams/responses.py | 2 +- tests/test_dynamodbstreams/test_dynamodbstreams.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/moto/dynamodbstreams/responses.py b/moto/dynamodbstreams/responses.py index 0e2800f55..c4e61a750 100644 --- a/moto/dynamodbstreams/responses.py +++ b/moto/dynamodbstreams/responses.py @@ -24,7 +24,7 @@ class DynamoDBStreamsHandler(BaseResponse): shard_id = self._get_param('ShardId') shard_iterator_type = self._get_param('ShardIteratorType') sequence_number = self._get_param('SequenceNumber') - #according to documentation sequence_number param should be string + # according to documentation sequence_number param should be string if isinstance(sequence_number, str): sequence_number = int(sequence_number) diff --git a/tests/test_dynamodbstreams/test_dynamodbstreams.py b/tests/test_dynamodbstreams/test_dynamodbstreams.py index f1c59fa29..deb9f9283 100644 --- a/tests/test_dynamodbstreams/test_dynamodbstreams.py +++ b/tests/test_dynamodbstreams/test_dynamodbstreams.py @@ -170,7 +170,7 @@ class TestCore(): resp = conn.get_records(ShardIterator=resp['NextShardIterator']) assert len(resp['Records']) == 0 - #check that if we get the shard iterator AT_SEQUENCE_NUMBER will get the MODIFY event + # check that if we get the shard iterator AT_SEQUENCE_NUMBER will get the MODIFY event resp = conn.get_shard_iterator( StreamArn=self.stream_arn, ShardId=shard_id, @@ -183,7 +183,7 @@ class TestCore(): assert resp['Records'][0]['eventName'] == 'MODIFY' assert resp['Records'][1]['eventName'] == 'DELETE' - #check that if we get the shard iterator AFTER_SEQUENCE_NUMBER will get the DELETE event + # check that if we get the shard iterator AFTER_SEQUENCE_NUMBER will get the DELETE event resp = conn.get_shard_iterator( StreamArn=self.stream_arn, ShardId=shard_id, From 364bd0720d6e3f8b93bc3416927ad37f48911691 Mon Sep 17 00:00:00 2001 From: Daniel Guerrero Date: Tue, 30 Jul 2019 13:54:42 -0500 Subject: [PATCH 06/67] Adding support for python 2.7 Python 2.7 sends unicode type instead string type --- moto/dynamodbstreams/responses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/moto/dynamodbstreams/responses.py b/moto/dynamodbstreams/responses.py index c4e61a750..6ff6ba2f4 100644 --- a/moto/dynamodbstreams/responses.py +++ b/moto/dynamodbstreams/responses.py @@ -25,7 +25,7 @@ class DynamoDBStreamsHandler(BaseResponse): shard_iterator_type = self._get_param('ShardIteratorType') sequence_number = self._get_param('SequenceNumber') # according to documentation sequence_number param should be string - if isinstance(sequence_number, str): + if isinstance(sequence_number, str) or isinstance(sequence_number, unicode): sequence_number = int(sequence_number) return self.backend.get_shard_iterator(arn, shard_id, From 1ce162f0561f10e20c924dbcf0ae2cd27ec78285 Mon Sep 17 00:00:00 2001 From: Daniel Guerrero Date: Tue, 30 Jul 2019 14:15:47 -0500 Subject: [PATCH 07/67] Using string class to detect type Using string class instead unicode that has been removed from python 3 --- moto/dynamodbstreams/responses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/moto/dynamodbstreams/responses.py b/moto/dynamodbstreams/responses.py index 6ff6ba2f4..c570483c5 100644 --- a/moto/dynamodbstreams/responses.py +++ b/moto/dynamodbstreams/responses.py @@ -25,7 +25,7 @@ class DynamoDBStreamsHandler(BaseResponse): shard_iterator_type = self._get_param('ShardIteratorType') sequence_number = self._get_param('SequenceNumber') # according to documentation sequence_number param should be string - if isinstance(sequence_number, str) or isinstance(sequence_number, unicode): + if isinstance(sequence_number, "".__class__): sequence_number = int(sequence_number) return self.backend.get_shard_iterator(arn, shard_id, From e8d60435fe5157953b8a18e5edc9b5867c4b60dd Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Fri, 23 Aug 2019 10:57:15 +0100 Subject: [PATCH 08/67] #2366 - SecretsManager - put_secret_value should support binary values --- moto/secretsmanager/models.py | 4 +- moto/secretsmanager/responses.py | 7 ++- .../test_secretsmanager.py | 59 ++++++++++++++++++- 3 files changed, 66 insertions(+), 4 deletions(-) diff --git a/moto/secretsmanager/models.py b/moto/secretsmanager/models.py index 3e0424b6b..63d847c49 100644 --- a/moto/secretsmanager/models.py +++ b/moto/secretsmanager/models.py @@ -154,9 +154,9 @@ class SecretsManagerBackend(BaseBackend): return version_id - def put_secret_value(self, secret_id, secret_string, version_stages): + def put_secret_value(self, secret_id, secret_string, secret_binary, version_stages): - version_id = self._add_secret(secret_id, secret_string, version_stages=version_stages) + version_id = self._add_secret(secret_id, secret_string, secret_binary, version_stages=version_stages) response = json.dumps({ 'ARN': secret_arn(self.region, secret_id), diff --git a/moto/secretsmanager/responses.py b/moto/secretsmanager/responses.py index 090688351..4995c4bc7 100644 --- a/moto/secretsmanager/responses.py +++ b/moto/secretsmanager/responses.py @@ -1,6 +1,7 @@ from __future__ import unicode_literals from moto.core.responses import BaseResponse +from moto.secretsmanager.exceptions import InvalidRequestException from .models import secretsmanager_backends @@ -71,10 +72,14 @@ class SecretsManagerResponse(BaseResponse): def put_secret_value(self): secret_id = self._get_param('SecretId', if_none='') - secret_string = self._get_param('SecretString', if_none='') + secret_string = self._get_param('SecretString') + secret_binary = self._get_param('SecretBinary') + if not secret_binary and not secret_string: + raise InvalidRequestException('You must provide either SecretString or SecretBinary.') version_stages = self._get_param('VersionStages', if_none=['AWSCURRENT']) return secretsmanager_backends[self.region].put_secret_value( secret_id=secret_id, + secret_binary=secret_binary, secret_string=secret_string, version_stages=version_stages, ) diff --git a/tests/test_secretsmanager/test_secretsmanager.py b/tests/test_secretsmanager/test_secretsmanager.py index 78b95ee6a..62de93bab 100644 --- a/tests/test_secretsmanager/test_secretsmanager.py +++ b/tests/test_secretsmanager/test_secretsmanager.py @@ -5,9 +5,9 @@ import boto3 from moto import mock_secretsmanager from botocore.exceptions import ClientError import string -import unittest import pytz from datetime import datetime +import sure # noqa from nose.tools import assert_raises from six import b @@ -23,6 +23,7 @@ def test_get_secret_value(): result = conn.get_secret_value(SecretId='java-util-test-password') assert result['SecretString'] == 'foosecret' + @mock_secretsmanager def test_get_secret_value_binary(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -32,6 +33,7 @@ def test_get_secret_value_binary(): result = conn.get_secret_value(SecretId='java-util-test-password') assert result['SecretBinary'] == b('foosecret') + @mock_secretsmanager def test_get_secret_that_does_not_exist(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -39,6 +41,7 @@ def test_get_secret_that_does_not_exist(): with assert_raises(ClientError): result = conn.get_secret_value(SecretId='i-dont-exist') + @mock_secretsmanager def test_get_secret_that_does_not_match(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -72,6 +75,7 @@ def test_create_secret(): secret = conn.get_secret_value(SecretId='test-secret') assert secret['SecretString'] == 'foosecret' + @mock_secretsmanager def test_create_secret_with_tags(): conn = boto3.client('secretsmanager', region_name='us-east-1') @@ -216,6 +220,7 @@ def test_get_random_exclude_lowercase(): ExcludeLowercase=True) assert any(c.islower() for c in random_password['RandomPassword']) == False + @mock_secretsmanager def test_get_random_exclude_uppercase(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -224,6 +229,7 @@ def test_get_random_exclude_uppercase(): ExcludeUppercase=True) assert any(c.isupper() for c in random_password['RandomPassword']) == False + @mock_secretsmanager def test_get_random_exclude_characters_and_symbols(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -232,6 +238,7 @@ def test_get_random_exclude_characters_and_symbols(): ExcludeCharacters='xyzDje@?!.') assert any(c in 'xyzDje@?!.' for c in random_password['RandomPassword']) == False + @mock_secretsmanager def test_get_random_exclude_numbers(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -240,6 +247,7 @@ def test_get_random_exclude_numbers(): ExcludeNumbers=True) assert any(c.isdigit() for c in random_password['RandomPassword']) == False + @mock_secretsmanager def test_get_random_exclude_punctuation(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -249,6 +257,7 @@ def test_get_random_exclude_punctuation(): assert any(c in string.punctuation for c in random_password['RandomPassword']) == False + @mock_secretsmanager def test_get_random_include_space_false(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -256,6 +265,7 @@ def test_get_random_include_space_false(): random_password = conn.get_random_password(PasswordLength=300) assert any(c.isspace() for c in random_password['RandomPassword']) == False + @mock_secretsmanager def test_get_random_include_space_true(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -264,6 +274,7 @@ def test_get_random_include_space_true(): IncludeSpace=True) assert any(c.isspace() for c in random_password['RandomPassword']) == True + @mock_secretsmanager def test_get_random_require_each_included_type(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -275,6 +286,7 @@ def test_get_random_require_each_included_type(): assert any(c in string.ascii_uppercase for c in random_password['RandomPassword']) == True assert any(c in string.digits for c in random_password['RandomPassword']) == True + @mock_secretsmanager def test_get_random_too_short_password(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -282,6 +294,7 @@ def test_get_random_too_short_password(): with assert_raises(ClientError): random_password = conn.get_random_password(PasswordLength=3) + @mock_secretsmanager def test_get_random_too_long_password(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -289,6 +302,7 @@ def test_get_random_too_long_password(): with assert_raises(Exception): random_password = conn.get_random_password(PasswordLength=5555) + @mock_secretsmanager def test_describe_secret(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -307,6 +321,7 @@ def test_describe_secret(): assert secret_description_2['Name'] == ('test-secret-2') assert secret_description_2['ARN'] != '' # Test arn not empty + @mock_secretsmanager def test_describe_secret_that_does_not_exist(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -314,6 +329,7 @@ def test_describe_secret_that_does_not_exist(): with assert_raises(ClientError): result = conn.get_secret_value(SecretId='i-dont-exist') + @mock_secretsmanager def test_describe_secret_that_does_not_match(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -500,6 +516,7 @@ def test_rotate_secret_rotation_period_zero(): # test_server actually handles this error. assert True + @mock_secretsmanager def test_rotate_secret_rotation_period_too_long(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -511,6 +528,7 @@ def test_rotate_secret_rotation_period_too_long(): result = conn.rotate_secret(SecretId=DEFAULT_SECRET_NAME, RotationRules=rotation_rules) + @mock_secretsmanager def test_put_secret_value_puts_new_secret(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -526,6 +544,45 @@ def test_put_secret_value_puts_new_secret(): assert get_secret_value_dict assert get_secret_value_dict['SecretString'] == 'foosecret' + +@mock_secretsmanager +def test_put_secret_binary_value_puts_new_secret(): + conn = boto3.client('secretsmanager', region_name='us-west-2') + put_secret_value_dict = conn.put_secret_value(SecretId=DEFAULT_SECRET_NAME, + SecretBinary=b('foosecret'), + VersionStages=['AWSCURRENT']) + version_id = put_secret_value_dict['VersionId'] + + get_secret_value_dict = conn.get_secret_value(SecretId=DEFAULT_SECRET_NAME, + VersionId=version_id, + VersionStage='AWSCURRENT') + + assert get_secret_value_dict + assert get_secret_value_dict['SecretBinary'] == b('foosecret') + + +@mock_secretsmanager +def test_create_and_put_secret_binary_value_puts_new_secret(): + conn = boto3.client('secretsmanager', region_name='us-west-2') + conn.create_secret(Name=DEFAULT_SECRET_NAME, SecretBinary=b("foosecret")) + conn.put_secret_value(SecretId=DEFAULT_SECRET_NAME, SecretBinary=b('foosecret_update')) + + latest_secret = conn.get_secret_value(SecretId=DEFAULT_SECRET_NAME) + + assert latest_secret + assert latest_secret['SecretBinary'] == b('foosecret_update') + + +@mock_secretsmanager +def test_put_secret_binary_requires_either_string_or_binary(): + conn = boto3.client('secretsmanager', region_name='us-west-2') + with assert_raises(ClientError) as ire: + conn.put_secret_value(SecretId=DEFAULT_SECRET_NAME) + + ire.exception.response['Error']['Code'].should.equal('InvalidRequestException') + ire.exception.response['Error']['Message'].should.equal('You must provide either SecretString or SecretBinary.') + + @mock_secretsmanager def test_put_secret_value_can_get_first_version_if_put_twice(): conn = boto3.client('secretsmanager', region_name='us-west-2') From 8b90a75aa011ab890cb0951c817844975682f424 Mon Sep 17 00:00:00 2001 From: Vladimir Date: Fri, 23 Aug 2019 17:17:10 +0300 Subject: [PATCH 09/67] issues-2386 make comparing exists and new queues only by static attrs --- moto/sqs/models.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/moto/sqs/models.py b/moto/sqs/models.py index e774e261c..c7386b52f 100644 --- a/moto/sqs/models.py +++ b/moto/sqs/models.py @@ -424,13 +424,26 @@ class SQSBackend(BaseBackend): queue_attributes = queue.attributes new_queue_attributes = new_queue.attributes + static_attributes = ( + 'DelaySeconds', + 'MaximumMessageSize', + 'MessageRetentionPeriod', + 'Policy', + 'QueueArn', + 'ReceiveMessageWaitTimeSeconds', + 'RedrivePolicy', + 'VisibilityTimeout', + 'KmsMasterKeyId', + 'KmsDataKeyReusePeriodSeconds', + 'FifoQueue', + 'ContentBasedDeduplication', + ) - for key in ['CreatedTimestamp', 'LastModifiedTimestamp']: - queue_attributes.pop(key) - new_queue_attributes.pop(key) - - if queue_attributes != new_queue_attributes: - raise QueueAlreadyExists("The specified queue already exists.") + for key in static_attributes: + if queue_attributes.get(key) != new_queue_attributes.get(key): + raise QueueAlreadyExists( + "The specified queue already exists.", + ) else: try: kwargs.pop('region') From 59852eb13aef9869ae8b6b2f80b76ddb98a46239 Mon Sep 17 00:00:00 2001 From: Giulio Date: Sat, 24 Aug 2019 11:19:50 +0100 Subject: [PATCH 10/67] Add tag support to API Gateway keys --- moto/apigateway/models.py | 3 ++- tests/test_apigateway/test_apigateway.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/moto/apigateway/models.py b/moto/apigateway/models.py index 6be062d7f..25aa23721 100644 --- a/moto/apigateway/models.py +++ b/moto/apigateway/models.py @@ -298,7 +298,7 @@ class Stage(BaseModel, dict): class ApiKey(BaseModel, dict): def __init__(self, name=None, description=None, enabled=True, - generateDistinctId=False, value=None, stageKeys=None, customerId=None): + generateDistinctId=False, value=None, stageKeys=None, tags=None, customerId=None): super(ApiKey, self).__init__() self['id'] = create_id() self['value'] = value if value else ''.join(random.sample(string.ascii_letters + string.digits, 40)) @@ -308,6 +308,7 @@ class ApiKey(BaseModel, dict): self['enabled'] = enabled self['createdDate'] = self['lastUpdatedDate'] = int(time.time()) self['stageKeys'] = stageKeys + self['tags'] = tags def update_operations(self, patch_operations): for op in patch_operations: diff --git a/tests/test_apigateway/test_apigateway.py b/tests/test_apigateway/test_apigateway.py index 0a33f2f9f..20cc078b8 100644 --- a/tests/test_apigateway/test_apigateway.py +++ b/tests/test_apigateway/test_apigateway.py @@ -981,11 +981,13 @@ def test_api_keys(): apikey['value'].should.equal(apikey_value) apikey_name = 'TESTKEY2' - payload = {'name': apikey_name } + payload = {'name': apikey_name, 'tags': {'tag1': 'test_tag1', 'tag2': '1'}} response = client.create_api_key(**payload) apikey_id = response['id'] apikey = client.get_api_key(apiKey=apikey_id) apikey['name'].should.equal(apikey_name) + apikey['tags']['tag1'].should.equal('test_tag1') + apikey['tags']['tag2'].should.equal('1') len(apikey['value']).should.equal(40) apikey_name = 'TESTKEY3' From d8a922811cd518b4620dd7c0802674a197b0bfef Mon Sep 17 00:00:00 2001 From: gruebel Date: Sun, 25 Aug 2019 16:48:14 +0200 Subject: [PATCH 11/67] Add exact Number, exact String.Array and attribute key matching to SNS subscription filter policy and validate filter policy --- moto/sns/exceptions.py | 8 + moto/sns/models.py | 86 ++++- moto/sns/responses.py | 11 +- tests/test_sns/test_publishing_boto3.py | 388 +++++++++++++++++++++ tests/test_sns/test_subscriptions_boto3.py | 79 ++++- 5 files changed, 564 insertions(+), 8 deletions(-) diff --git a/moto/sns/exceptions.py b/moto/sns/exceptions.py index 0e7a0bdcf..706b3b5cc 100644 --- a/moto/sns/exceptions.py +++ b/moto/sns/exceptions.py @@ -40,3 +40,11 @@ class InvalidParameterValue(RESTError): def __init__(self, message): super(InvalidParameterValue, self).__init__( "InvalidParameterValue", message) + + +class InternalError(RESTError): + code = 500 + + def __init__(self, message): + super(InternalError, self).__init__( + "InternalFailure", message) diff --git a/moto/sns/models.py b/moto/sns/models.py index 18b86cb93..f152046e9 100644 --- a/moto/sns/models.py +++ b/moto/sns/models.py @@ -18,7 +18,7 @@ from moto.awslambda import lambda_backends from .exceptions import ( SNSNotFoundError, DuplicateSnsEndpointError, SnsEndpointDisabled, SNSInvalidParameter, - InvalidParameterValue + InvalidParameterValue, InternalError ) from .utils import make_arn_for_topic, make_arn_for_subscription @@ -131,13 +131,47 @@ class Subscription(BaseModel): message_attributes = {} def _field_match(field, rules, message_attributes): - if field not in message_attributes: - return False for rule in rules: + # TODO: boolean value matching is not supported, SNS behavior unknown if isinstance(rule, six.string_types): - # only string value matching is supported + if field not in message_attributes: + return False if message_attributes[field]['Value'] == rule: return True + try: + json_data = json.loads(message_attributes[field]['Value']) + if rule in json_data: + return True + except (ValueError, TypeError): + pass + if isinstance(rule, (six.integer_types, float)): + if field not in message_attributes: + return False + if message_attributes[field]['Type'] == 'Number': + attribute_values = [message_attributes[field]['Value']] + elif message_attributes[field]['Type'] == 'String.Array': + try: + attribute_values = json.loads(message_attributes[field]['Value']) + if not isinstance(attribute_values, list): + attribute_values = [attribute_values] + except (ValueError, TypeError): + return False + else: + return False + + for attribute_values in attribute_values: + # Even the offical documentation states a 5 digits of accuracy after the decimal point for numerics, in reality it is 6 + # https://docs.aws.amazon.com/sns/latest/dg/sns-subscription-filter-policies.html#subscription-filter-policy-constraints + if int(attribute_values * 1000000) == int(rule * 1000000): + return True + if isinstance(rule, dict): + keyword = list(rule.keys())[0] + attributes = list(rule.values())[0] + if keyword == 'exists': + if attributes and field in message_attributes: + return True + elif not attributes and field not in message_attributes: + return True return False return all(_field_match(field, rules, message_attributes) @@ -421,7 +455,49 @@ class SNSBackend(BaseBackend): subscription.attributes[name] = value if name == 'FilterPolicy': - subscription._filter_policy = json.loads(value) + filter_policy = json.loads(value) + self._validate_filter_policy(filter_policy) + subscription._filter_policy = filter_policy + + def _validate_filter_policy(self, value): + # TODO: extend validation checks + combinations = 1 + for rules in six.itervalues(value): + combinations *= len(rules) + # Even the offical documentation states the total combination of values must not exceed 100, in reality it is 150 + # https://docs.aws.amazon.com/sns/latest/dg/sns-subscription-filter-policies.html#subscription-filter-policy-constraints + if combinations > 150: + raise SNSInvalidParameter("Invalid parameter: FilterPolicy: Filter policy is too complex") + + for field, rules in six.iteritems(value): + for rule in rules: + if rule is None: + continue + if isinstance(rule, six.string_types): + continue + if isinstance(rule, bool): + continue + if isinstance(rule, (six.integer_types, float)): + if rule <= -1000000000 or rule >= 1000000000: + raise InternalError("Unknown") + continue + if isinstance(rule, dict): + keyword = list(rule.keys())[0] + attributes = list(rule.values())[0] + if keyword == 'anything-but': + continue + elif keyword == 'exists': + if not isinstance(attributes, bool): + raise SNSInvalidParameter("Invalid parameter: FilterPolicy: exists match pattern must be either true or false.") + continue + elif keyword == 'numeric': + continue + elif keyword == 'prefix': + continue + else: + raise SNSInvalidParameter("Invalid parameter: FilterPolicy: Unrecognized match type {type}".format(type=keyword)) + + raise SNSInvalidParameter("Invalid parameter: FilterPolicy: Match value must be String, number, true, false, or null") sns_backends = {} diff --git a/moto/sns/responses.py b/moto/sns/responses.py index 440115429..578c5ea65 100644 --- a/moto/sns/responses.py +++ b/moto/sns/responses.py @@ -57,7 +57,16 @@ class SNSResponse(BaseResponse): transform_value = None if 'StringValue' in value: - transform_value = value['StringValue'] + if data_type == 'Number': + try: + transform_value = float(value['StringValue']) + except ValueError: + raise InvalidParameterValue( + "An error occurred (ParameterValueInvalid) " + "when calling the Publish operation: " + "Could not cast message attribute '{0}' value to number.".format(name)) + else: + transform_value = value['StringValue'] elif 'BinaryValue' in value: transform_value = value['BinaryValue'] if not transform_value: diff --git a/tests/test_sns/test_publishing_boto3.py b/tests/test_sns/test_publishing_boto3.py index 3d598d406..d7bf32e51 100644 --- a/tests/test_sns/test_publishing_boto3.py +++ b/tests/test_sns/test_publishing_boto3.py @@ -109,6 +109,17 @@ def test_publish_to_sqs_bad(): }}) except ClientError as err: err.response['Error']['Code'].should.equal('InvalidParameterValue') + try: + # Test Number DataType, with a non numeric value + conn.publish( + TopicArn=topic_arn, Message=message, + MessageAttributes={'price': { + 'DataType': 'Number', + 'StringValue': 'error' + }}) + except ClientError as err: + err.response['Error']['Code'].should.equal('InvalidParameterValue') + err.response['Error']['Message'].should.equal("An error occurred (ParameterValueInvalid) when calling the Publish operation: Could not cast message attribute 'price' value to number.") @mock_sqs @@ -487,3 +498,380 @@ def test_filtering_exact_string_no_attributes_no_match(): message_attributes = [ json.loads(m.body)['MessageAttributes'] for m in messages] message_attributes.should.equal([]) + + +@mock_sqs +@mock_sns +def test_filtering_exact_number_int(): + topic, subscription, queue = _setup_filter_policy_test( + {'price': [100]}) + + topic.publish( + Message='match', + MessageAttributes={'price': {'DataType': 'Number', + 'StringValue': '100'}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal(['match']) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal( + [{'price': {'Type': 'Number', 'Value': 100}}]) + + +@mock_sqs +@mock_sns +def test_filtering_exact_number_float(): + topic, subscription, queue = _setup_filter_policy_test( + {'price': [100.1]}) + + topic.publish( + Message='match', + MessageAttributes={'price': {'DataType': 'Number', + 'StringValue': '100.1'}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal(['match']) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal( + [{'price': {'Type': 'Number', 'Value': 100.1}}]) + + +@mock_sqs +@mock_sns +def test_filtering_exact_number_float_accuracy(): + topic, subscription, queue = _setup_filter_policy_test( + {'price': [100.123456789]}) + + topic.publish( + Message='match', + MessageAttributes={'price': {'DataType': 'Number', + 'StringValue': '100.1234561'}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal(['match']) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal( + [{'price': {'Type': 'Number', 'Value': 100.1234561}}]) + + +@mock_sqs +@mock_sns +def test_filtering_exact_number_no_match(): + topic, subscription, queue = _setup_filter_policy_test( + {'price': [100]}) + + topic.publish( + Message='no match', + MessageAttributes={'price': {'DataType': 'Number', + 'StringValue': '101'}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal([]) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal([]) + + +@mock_sqs +@mock_sns +def test_filtering_exact_number_with_string_no_match(): + topic, subscription, queue = _setup_filter_policy_test( + {'price': [100]}) + + topic.publish( + Message='no match', + MessageAttributes={'price': {'DataType': 'String', + 'StringValue': '100'}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal([]) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal([]) + + +@mock_sqs +@mock_sns +def test_filtering_string_array_match(): + topic, subscription, queue = _setup_filter_policy_test( + {'customer_interests': ['basketball', 'baseball']}) + + topic.publish( + Message='match', + MessageAttributes={'customer_interests': {'DataType': 'String.Array', + 'StringValue': json.dumps(['basketball', 'rugby'])}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal(['match']) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal( + [{'customer_interests': {'Type': 'String.Array', 'Value': json.dumps(['basketball', 'rugby'])}}]) + + +@mock_sqs +@mock_sns +def test_filtering_string_array_no_match(): + topic, subscription, queue = _setup_filter_policy_test( + {'customer_interests': ['baseball']}) + + topic.publish( + Message='no_match', + MessageAttributes={'customer_interests': {'DataType': 'String.Array', + 'StringValue': json.dumps(['basketball', 'rugby'])}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal([]) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal([]) + + +@mock_sqs +@mock_sns +def test_filtering_string_array_with_number_match(): + topic, subscription, queue = _setup_filter_policy_test( + {'price': [100, 500]}) + + topic.publish( + Message='match', + MessageAttributes={'price': {'DataType': 'String.Array', + 'StringValue': json.dumps([100, 50])}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal(['match']) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal( + [{'price': {'Type': 'String.Array', 'Value': json.dumps([100, 50])}}]) + + +@mock_sqs +@mock_sns +def test_filtering_string_array_with_number_float_accuracy_match(): + topic, subscription, queue = _setup_filter_policy_test( + {'price': [100.123456789, 500]}) + + topic.publish( + Message='match', + MessageAttributes={'price': {'DataType': 'String.Array', + 'StringValue': json.dumps([100.1234561, 50])}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal(['match']) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal( + [{'price': {'Type': 'String.Array', 'Value': json.dumps([100.1234561, 50])}}]) + + +@mock_sqs +@mock_sns +# this is the correct behavior from SNS +def test_filtering_string_array_with_number_no_array_match(): + topic, subscription, queue = _setup_filter_policy_test( + {'price': [100, 500]}) + + topic.publish( + Message='match', + MessageAttributes={'price': {'DataType': 'String.Array', + 'StringValue': '100'}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal(['match']) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal( + [{'price': {'Type': 'String.Array', 'Value': '100'}}]) + + +@mock_sqs +@mock_sns +def test_filtering_string_array_with_number_no_match(): + topic, subscription, queue = _setup_filter_policy_test( + {'price': [500]}) + + topic.publish( + Message='no_match', + MessageAttributes={'price': {'DataType': 'String.Array', + 'StringValue': json.dumps([100, 50])}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal([]) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal([]) + + +@mock_sqs +@mock_sns +# this is the correct behavior from SNS +def test_filtering_string_array_with_string_no_array_no_match(): + topic, subscription, queue = _setup_filter_policy_test( + {'price': [100]}) + + topic.publish( + Message='no_match', + MessageAttributes={'price': {'DataType': 'String.Array', + 'StringValue': 'one hundread'}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal([]) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal([]) + + +@mock_sqs +@mock_sns +def test_filtering_attribute_key_exists_match(): + topic, subscription, queue = _setup_filter_policy_test( + {'store': [{'exists': True}]}) + + topic.publish( + Message='match', + MessageAttributes={'store': {'DataType': 'String', + 'StringValue': 'example_corp'}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal(['match']) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal( + [{'store': {'Type': 'String', 'Value': 'example_corp'}}]) + + +@mock_sqs +@mock_sns +def test_filtering_attribute_key_exists_no_match(): + topic, subscription, queue = _setup_filter_policy_test( + {'store': [{'exists': True}]}) + + topic.publish( + Message='no match', + MessageAttributes={'event': {'DataType': 'String', + 'StringValue': 'order_cancelled'}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal([]) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal([]) + + +@mock_sqs +@mock_sns +def test_filtering_attribute_key_not_exists_match(): + topic, subscription, queue = _setup_filter_policy_test( + {'store': [{'exists': False}]}) + + topic.publish( + Message='match', + MessageAttributes={'event': {'DataType': 'String', + 'StringValue': 'order_cancelled'}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal(['match']) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal( + [{'event': {'Type': 'String', 'Value': 'order_cancelled'}}]) + + +@mock_sqs +@mock_sns +def test_filtering_attribute_key_not_exists_no_match(): + topic, subscription, queue = _setup_filter_policy_test( + {'store': [{'exists': False}]}) + + topic.publish( + Message='no match', + MessageAttributes={'store': {'DataType': 'String', + 'StringValue': 'example_corp'}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal([]) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal([]) + + +@mock_sqs +@mock_sns +def test_filtering_all_AND_matching_match(): + topic, subscription, queue = _setup_filter_policy_test( + {'store': [{'exists': True}], + 'event': ['order_cancelled'], + 'customer_interests': ['basketball', 'baseball'], + 'price': [100]}) + + topic.publish( + Message='match', + MessageAttributes={'store': {'DataType': 'String', + 'StringValue': 'example_corp'}, + 'event': {'DataType': 'String', + 'StringValue': 'order_cancelled'}, + 'customer_interests': {'DataType': 'String.Array', + 'StringValue': json.dumps(['basketball', 'rugby'])}, + 'price': {'DataType': 'Number', + 'StringValue': '100'}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal( + ['match']) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal([{ + 'store': {'Type': 'String', 'Value': 'example_corp'}, + 'event': {'Type': 'String', 'Value': 'order_cancelled'}, + 'customer_interests': {'Type': 'String.Array', 'Value': json.dumps(['basketball', 'rugby'])}, + 'price': {'Type': 'Number', 'Value': 100}}]) + + +@mock_sqs +@mock_sns +def test_filtering_all_AND_matching_no_match(): + topic, subscription, queue = _setup_filter_policy_test( + {'store': [{'exists': True}], + 'event': ['order_cancelled'], + 'customer_interests': ['basketball', 'baseball'], + 'price': [100], + "encrypted": [False]}) + + topic.publish( + Message='no match', + MessageAttributes={'store': {'DataType': 'String', + 'StringValue': 'example_corp'}, + 'event': {'DataType': 'String', + 'StringValue': 'order_cancelled'}, + 'customer_interests': {'DataType': 'String.Array', + 'StringValue': json.dumps(['basketball', 'rugby'])}, + 'price': {'DataType': 'Number', + 'StringValue': '100'}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal([]) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal([]) diff --git a/tests/test_sns/test_subscriptions_boto3.py b/tests/test_sns/test_subscriptions_boto3.py index 2a56c8213..012cd6470 100644 --- a/tests/test_sns/test_subscriptions_boto3.py +++ b/tests/test_sns/test_subscriptions_boto3.py @@ -201,7 +201,9 @@ def test_creating_subscription_with_attributes(): "store": ["example_corp"], "event": ["order_cancelled"], "encrypted": [False], - "customer_interests": ["basketball", "baseball"] + "customer_interests": ["basketball", "baseball"], + "price": [100, 100.12], + "error": [None] }) conn.subscribe(TopicArn=topic_arn, @@ -294,7 +296,9 @@ def test_set_subscription_attributes(): "store": ["example_corp"], "event": ["order_cancelled"], "encrypted": [False], - "customer_interests": ["basketball", "baseball"] + "customer_interests": ["basketball", "baseball"], + "price": [100, 100.12], + "error": [None] }) conn.set_subscription_attributes( SubscriptionArn=subscription_arn, @@ -332,6 +336,77 @@ def test_set_subscription_attributes(): ) +@mock_sns +def test_subscribe_invalid_filter_policy(): + conn = boto3.client('sns', region_name = 'us-east-1') + conn.create_topic(Name = 'some-topic') + response = conn.list_topics() + topic_arn = response['Topics'][0]['TopicArn'] + + try: + conn.subscribe(TopicArn = topic_arn, + Protocol = 'http', + Endpoint = 'http://example.com/', + Attributes = { + 'FilterPolicy': json.dumps({ + 'store': [str(i) for i in range(151)] + }) + }) + except ClientError as err: + err.response['Error']['Code'].should.equal('InvalidParameter') + err.response['Error']['Message'].should.equal('Invalid parameter: FilterPolicy: Filter policy is too complex') + + try: + conn.subscribe(TopicArn = topic_arn, + Protocol = 'http', + Endpoint = 'http://example.com/', + Attributes = { + 'FilterPolicy': json.dumps({ + 'store': [['example_corp']] + }) + }) + except ClientError as err: + err.response['Error']['Code'].should.equal('InvalidParameter') + err.response['Error']['Message'].should.equal('Invalid parameter: FilterPolicy: Match value must be String, number, true, false, or null') + + try: + conn.subscribe(TopicArn = topic_arn, + Protocol = 'http', + Endpoint = 'http://example.com/', + Attributes = { + 'FilterPolicy': json.dumps({ + 'store': [{'exists': None}] + }) + }) + except ClientError as err: + err.response['Error']['Code'].should.equal('InvalidParameter') + err.response['Error']['Message'].should.equal('Invalid parameter: FilterPolicy: exists match pattern must be either true or false.') + + try: + conn.subscribe(TopicArn = topic_arn, + Protocol = 'http', + Endpoint = 'http://example.com/', + Attributes = { + 'FilterPolicy': json.dumps({ + 'store': [{'error': True}] + }) + }) + except ClientError as err: + err.response['Error']['Code'].should.equal('InvalidParameter') + err.response['Error']['Message'].should.equal('Invalid parameter: FilterPolicy: Unrecognized match type error') + + try: + conn.subscribe(TopicArn = topic_arn, + Protocol = 'http', + Endpoint = 'http://example.com/', + Attributes = { + 'FilterPolicy': json.dumps({ + 'store': [1000000001] + }) + }) + except ClientError as err: + err.response['Error']['Code'].should.equal('InternalFailure') + @mock_sns def test_check_not_opted_out(): conn = boto3.client('sns', region_name='us-east-1') From 778fc47c216825dfb2efe40ef026d7afab2b3d3e Mon Sep 17 00:00:00 2001 From: Tomoya Iwata Date: Mon, 26 Aug 2019 13:28:56 +0900 Subject: [PATCH 12/67] fix #2392 Add validation for shadow version,when update_thing_shadow() has called --- moto/iotdata/exceptions.py | 8 ++++++++ moto/iotdata/models.py | 3 +++ tests/test_iotdata/test_iotdata.py | 6 ++++++ 3 files changed, 17 insertions(+) diff --git a/moto/iotdata/exceptions.py b/moto/iotdata/exceptions.py index ddc6b37fd..f2c209eed 100644 --- a/moto/iotdata/exceptions.py +++ b/moto/iotdata/exceptions.py @@ -21,3 +21,11 @@ class InvalidRequestException(IoTDataPlaneClientError): super(InvalidRequestException, self).__init__( "InvalidRequestException", message ) + + +class ConflictException(IoTDataPlaneClientError): + def __init__(self, message): + self.code = 409 + super(ConflictException, self).__init__( + "ConflictException", message + ) diff --git a/moto/iotdata/models.py b/moto/iotdata/models.py index ad4caa89e..fec066f07 100644 --- a/moto/iotdata/models.py +++ b/moto/iotdata/models.py @@ -6,6 +6,7 @@ import jsondiff from moto.core import BaseBackend, BaseModel from moto.iot import iot_backends from .exceptions import ( + ConflictException, ResourceNotFoundException, InvalidRequestException ) @@ -161,6 +162,8 @@ class IoTDataPlaneBackend(BaseBackend): if any(_ for _ in payload['state'].keys() if _ not in ['desired', 'reported']): raise InvalidRequestException('State contains an invalid node') + if 'version' in payload and thing.thing_shadow.version != payload['version']: + raise ConflictException('Version conflict') new_shadow = FakeShadow.create_from_previous_version(thing.thing_shadow, payload) thing.thing_shadow = new_shadow return thing.thing_shadow diff --git a/tests/test_iotdata/test_iotdata.py b/tests/test_iotdata/test_iotdata.py index 09c1ada4c..1cedcaa72 100644 --- a/tests/test_iotdata/test_iotdata.py +++ b/tests/test_iotdata/test_iotdata.py @@ -86,6 +86,12 @@ def test_update(): payload.should.have.key('version').which.should.equal(2) payload.should.have.key('timestamp') + raw_payload = b'{"state": {"desired": {"led": "on"}}, "version": 1}' + with assert_raises(ClientError) as ex: + client.update_thing_shadow(thingName=name, payload=raw_payload) + ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(409) + ex.exception.response['Error']['Message'].should.equal('Version conflict') + @mock_iotdata def test_publish(): From b3a5e0fe3b5d2b1d0b6c643afb4e5b1888146bad Mon Sep 17 00:00:00 2001 From: dezkoat Date: Mon, 26 Aug 2019 17:11:08 +0700 Subject: [PATCH 13/67] Use long in creationTime and lastIngestionTime for LogStream and LogGroup model --- moto/logs/models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/moto/logs/models.py b/moto/logs/models.py index 2b8dcfeb4..2fc4b0d8b 100644 --- a/moto/logs/models.py +++ b/moto/logs/models.py @@ -41,7 +41,7 @@ class LogStream: self.region = region self.arn = "arn:aws:logs:{region}:{id}:log-group:{log_group}:log-stream:{log_stream}".format( region=region, id=self.__class__._log_ids, log_group=log_group, log_stream=name) - self.creationTime = unix_time_millis() + self.creationTime = int(unix_time_millis()) self.firstEventTimestamp = None self.lastEventTimestamp = None self.lastIngestionTime = None @@ -80,7 +80,7 @@ class LogStream: def put_log_events(self, log_group_name, log_stream_name, log_events, sequence_token): # TODO: ensure sequence_token # TODO: to be thread safe this would need a lock - self.lastIngestionTime = unix_time_millis() + self.lastIngestionTime = int(unix_time_millis()) # TODO: make this match AWS if possible self.storedBytes += sum([len(log_event["message"]) for log_event in log_events]) self.events += [LogEvent(self.lastIngestionTime, log_event) for log_event in log_events] @@ -146,7 +146,7 @@ class LogGroup: self.region = region self.arn = "arn:aws:logs:{region}:1:log-group:{log_group}".format( region=region, log_group=name) - self.creationTime = unix_time_millis() + self.creationTime = int(unix_time_millis()) self.tags = tags self.streams = dict() # {name: LogStream} self.retentionInDays = None # AWS defaults to Never Expire for log group retention From c1618943243c25d2bdeea22923ae718d98223633 Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Mon, 26 Aug 2019 22:38:49 -0700 Subject: [PATCH 14/67] add parameterize to dev requirements to simplify things --- requirements-dev.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-dev.txt b/requirements-dev.txt index f87ab3db6..1dd8ef1f8 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -10,6 +10,7 @@ boto>=2.45.0 boto3>=1.4.4 botocore>=1.12.13 six>=1.9 +parameterized>=0.7.0 prompt-toolkit==1.0.14 click==6.7 inflection==0.3.1 From 7eeead8a37768956db50a81ef7073f4b9fde1c18 Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Mon, 26 Aug 2019 23:24:31 -0700 Subject: [PATCH 15/67] add encrypt/decrypt utility functions with appropriate exceptions and tests --- moto/kms/exceptions.py | 20 +++++ moto/kms/models.py | 3 +- moto/kms/utils.py | 135 ++++++++++++++++++++++++++++++++ tests/test_kms/test_utils.py | 146 +++++++++++++++++++++++++++++++++++ 4 files changed, 303 insertions(+), 1 deletion(-) create mode 100644 tests/test_kms/test_utils.py diff --git a/moto/kms/exceptions.py b/moto/kms/exceptions.py index 70edd3dcd..c9094e8f8 100644 --- a/moto/kms/exceptions.py +++ b/moto/kms/exceptions.py @@ -34,3 +34,23 @@ class NotAuthorizedException(JsonRESTError): "NotAuthorizedException", None) self.description = '{"__type":"NotAuthorizedException"}' + + +class AccessDeniedException(JsonRESTError): + code = 400 + + def __init__(self, message): + super(AccessDeniedException, self).__init__( + "AccessDeniedException", message) + + self.description = '{"__type":"AccessDeniedException"}' + + +class InvalidCiphertextException(JsonRESTError): + code = 400 + + def __init__(self): + super(InvalidCiphertextException, self).__init__( + "InvalidCiphertextException", None) + + self.description = '{"__type":"InvalidCiphertextException"}' diff --git a/moto/kms/models.py b/moto/kms/models.py index 577840b06..d1b61d86c 100644 --- a/moto/kms/models.py +++ b/moto/kms/models.py @@ -4,7 +4,7 @@ import os import boto.kms from moto.core import BaseBackend, BaseModel from moto.core.utils import iso_8601_datetime_without_milliseconds -from .utils import generate_key_id +from .utils import generate_key_id, generate_master_key from collections import defaultdict from datetime import datetime, timedelta @@ -23,6 +23,7 @@ class Key(BaseModel): self.key_rotation_status = False self.deletion_date = None self.tags = tags or {} + self.key_material = generate_master_key() @property def physical_resource_id(self): diff --git a/moto/kms/utils.py b/moto/kms/utils.py index fad38150f..96d3f25cc 100644 --- a/moto/kms/utils.py +++ b/moto/kms/utils.py @@ -1,7 +1,142 @@ from __future__ import unicode_literals +from collections import namedtuple +import io +import os +import struct import uuid +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.ciphers import algorithms, Cipher, modes + +from .exceptions import InvalidCiphertextException, AccessDeniedException, NotFoundException + + +MASTER_KEY_LEN = 32 +KEY_ID_LEN = 36 +IV_LEN = 12 +TAG_LEN = 16 +HEADER_LEN = KEY_ID_LEN + IV_LEN + TAG_LEN +# NOTE: This is just a simple binary format. It is not what KMS actually does. +CIPHERTEXT_HEADER_FORMAT = ">{key_id_len}s{iv_len}s{tag_len}s".format( + key_id_len=KEY_ID_LEN, iv_len=IV_LEN, tag_len=TAG_LEN +) +Ciphertext = namedtuple("Ciphertext", ("key_id", "iv", "ciphertext", "tag")) + def generate_key_id(): return str(uuid.uuid4()) + + +def generate_data_key(number_of_bytes): + """Generate a data key.""" + return os.urandom(number_of_bytes) + + +def generate_master_key(): + """Generate a master key.""" + return generate_data_key(MASTER_KEY_LEN) + + +def _serialize_ciphertext_blob(ciphertext): + """Serialize Ciphertext object into a ciphertext blob. + + NOTE: This is just a simple binary format. It is not what KMS actually does. + """ + header = struct.pack(CIPHERTEXT_HEADER_FORMAT, ciphertext.key_id.encode("utf-8"), ciphertext.iv, ciphertext.tag) + return header + ciphertext.ciphertext + + +def _deserialize_ciphertext_blob(ciphertext_blob): + """Deserialize ciphertext blob into a Ciphertext object. + + NOTE: This is just a simple binary format. It is not what KMS actually does. + """ + header = ciphertext_blob[:HEADER_LEN] + ciphertext = ciphertext_blob[HEADER_LEN:] + key_id, iv, tag = struct.unpack(CIPHERTEXT_HEADER_FORMAT, header) + return Ciphertext(key_id=key_id.decode("utf-8"), iv=iv, ciphertext=ciphertext, tag=tag) + + +def _serialize_encryption_context(encryption_context): + """Serialize encryption context for use a AAD. + + NOTE: This is not necessarily what KMS does, but it retains the same properties. + """ + aad = io.BytesIO() + for key, value in sorted(encryption_context.items(), key=lambda x: x[0]): + aad.write(key.encode("utf-8")) + aad.write(value.encode("utf-8")) + return aad.getvalue() + + +def encrypt(master_keys, key_id, plaintext, encryption_context): + """Encrypt data using a master key material. + + NOTE: This is not necessarily what KMS does, but it retains the same properties. + + NOTE: This function is NOT compatible with KMS APIs. + :param dict master_keys: Mapping of a KmsBackend's known master keys + :param str key_id: Key ID of moto master key + :param bytes plaintext: Plaintext data to encrypt + :param dict[str, str] encryption_context: KMS-style encryption context + :returns: Moto-structured ciphertext blob encrypted under a moto master key in master_keys + :rtype: bytes + """ + try: + key = master_keys[key_id] + except KeyError: + is_alias = key_id.startswith("alias/") or ":alias/" in key_id + raise NotFoundException( + "{id_type} {key_id} is not found.".format(id_type="Alias" if is_alias else "keyId", key_id=key_id) + ) + + iv = os.urandom(IV_LEN) + aad = _serialize_encryption_context(encryption_context=encryption_context) + + encryptor = Cipher(algorithms.AES(key.key_material), modes.GCM(iv), backend=default_backend()).encryptor() + encryptor.authenticate_additional_data(aad) + ciphertext = encryptor.update(plaintext) + encryptor.finalize() + return _serialize_ciphertext_blob( + ciphertext=Ciphertext(key_id=key_id, iv=iv, ciphertext=ciphertext, tag=encryptor.tag) + ) + + +def decrypt(master_keys, ciphertext_blob, encryption_context): + """Decrypt a ciphertext blob using a master key material. + + NOTE: This is not necessarily what KMS does, but it retains the same properties. + + NOTE: This function is NOT compatible with KMS APIs. + + :param dict master_keys: Mapping of a KmsBackend's known master keys + :param bytes ciphertext_blob: moto-structured ciphertext blob encrypted under a moto master key in master_keys + :param dict[str, str] encryption_context: KMS-style encryption context + :returns: plaintext bytes and moto key ID + :rtype: bytes and str + """ + try: + ciphertext = _deserialize_ciphertext_blob(ciphertext_blob=ciphertext_blob) + except Exception: + raise InvalidCiphertextException() + + aad = _serialize_encryption_context(encryption_context=encryption_context) + + try: + key = master_keys[ciphertext.key_id] + except KeyError: + raise AccessDeniedException( + "The ciphertext refers to a customer master key that does not exist, " + "does not exist in this region, or you are not allowed to access." + ) + + try: + decryptor = Cipher( + algorithms.AES(key.key_material), modes.GCM(ciphertext.iv, ciphertext.tag), backend=default_backend() + ).decryptor() + decryptor.authenticate_additional_data(aad) + plaintext = decryptor.update(ciphertext.ciphertext) + decryptor.finalize() + except Exception: + raise InvalidCiphertextException() + + return plaintext, ciphertext.key_id diff --git a/tests/test_kms/test_utils.py b/tests/test_kms/test_utils.py new file mode 100644 index 000000000..9d540d50a --- /dev/null +++ b/tests/test_kms/test_utils.py @@ -0,0 +1,146 @@ +from __future__ import unicode_literals + +from nose.tools import assert_raises +from parameterized import parameterized + +from moto.kms.exceptions import AccessDeniedException, InvalidCiphertextException, NotFoundException +from moto.kms.models import Key +from moto.kms.utils import ( + generate_data_key, + generate_master_key, + _serialize_ciphertext_blob, + _deserialize_ciphertext_blob, + _serialize_encryption_context, + encrypt, + decrypt, + Ciphertext, +) + +ENCRYPTION_CONTEXT_VECTORS = ( + ({"this": "is", "an": "encryption", "context": "example"}, b"an" b"encryption" b"context" b"example" b"this" b"is"), + ({"a_this": "one", "b_is": "actually", "c_in": "order"}, b"a_this" b"one" b"b_is" b"actually" b"c_in" b"order"), +) +CIPHERTEXT_BLOB_VECTORS = ( + ( + Ciphertext( + key_id="d25652e4-d2d2-49f7-929a-671ccda580c6", + iv=b"123456789012", + ciphertext=b"some ciphertext", + tag=b"1234567890123456", + ), + b"d25652e4-d2d2-49f7-929a-671ccda580c6" b"123456789012" b"1234567890123456" b"some ciphertext", + ), + ( + Ciphertext( + key_id="d25652e4-d2d2-49f7-929a-671ccda580c6", + iv=b"123456789012", + ciphertext=b"some ciphertext that is much longer now", + tag=b"1234567890123456", + ), + b"d25652e4-d2d2-49f7-929a-671ccda580c6" + b"123456789012" + b"1234567890123456" + b"some ciphertext that is much longer now", + ), +) + + +@parameterized(ENCRYPTION_CONTEXT_VECTORS) +def test_serialize_encryption_context(raw, serialized): + test = _serialize_encryption_context(raw) + test.should.equal(serialized) + + +@parameterized(CIPHERTEXT_BLOB_VECTORS) +def test_cycle_ciphertext_blob(raw, _serialized): + test_serialized = _serialize_ciphertext_blob(raw) + test_deserialized = _deserialize_ciphertext_blob(test_serialized) + test_deserialized.should.equal(raw) + + +@parameterized(CIPHERTEXT_BLOB_VECTORS) +def test_serialize_ciphertext_blob(raw, serialized): + test = _serialize_ciphertext_blob(raw) + test.should.equal(serialized) + + +@parameterized(CIPHERTEXT_BLOB_VECTORS) +def test_deserialize_ciphertext_blob(raw, serialized): + test = _deserialize_ciphertext_blob(serialized) + test.should.equal(raw) + + +@parameterized(((ec[0],) for ec in ENCRYPTION_CONTEXT_VECTORS)) +def test_encrypt_decrypt_cycle(encryption_context): + plaintext = b"some secret plaintext" + master_key = Key("nop", "nop", "nop", [], "nop") + master_key_map = {master_key.id: master_key} + + ciphertext_blob = encrypt( + master_keys=master_key_map, key_id=master_key.id, plaintext=plaintext, encryption_context=encryption_context + ) + ciphertext_blob.should_not.equal(plaintext) + + decrypted, decrypting_key_id = decrypt( + master_keys=master_key_map, ciphertext_blob=ciphertext_blob, encryption_context=encryption_context + ) + decrypted.should.equal(plaintext) + decrypting_key_id.should.equal(master_key.id) + + +def test_encrypt_unknown_key_id(): + assert_raises( + NotFoundException, encrypt, master_keys={}, key_id="anything", plaintext=b"secrets", encryption_context={} + ) + + +def test_decrypt_invalid_ciphertext_format(): + master_key = Key("nop", "nop", "nop", [], "nop") + master_key_map = {master_key.id: master_key} + + assert_raises( + InvalidCiphertextException, decrypt, master_keys=master_key_map, ciphertext_blob=b"", encryption_context={} + ) + + +def test_decrypt_unknwown_key_id(): + ciphertext_blob = b"d25652e4-d2d2-49f7-929a-671ccda580c6" b"123456789012" b"1234567890123456" b"some ciphertext" + + assert_raises( + AccessDeniedException, decrypt, master_keys={}, ciphertext_blob=ciphertext_blob, encryption_context={} + ) + + +def test_decrypt_invalid_ciphertext(): + master_key = Key("nop", "nop", "nop", [], "nop") + master_key_map = {master_key.id: master_key} + ciphertext_blob = master_key.id.encode("utf-8") + b"123456789012" b"1234567890123456" b"some ciphertext" + + assert_raises( + InvalidCiphertextException, + decrypt, + master_keys=master_key_map, + ciphertext_blob=ciphertext_blob, + encryption_context={}, + ) + + +def test_decrypt_invalid_encryption_context(): + plaintext = b"some secret plaintext" + master_key = Key("nop", "nop", "nop", [], "nop") + master_key_map = {master_key.id: master_key} + + ciphertext_blob = encrypt( + master_keys=master_key_map, + key_id=master_key.id, + plaintext=plaintext, + encryption_context={"some": "encryption", "context": "here"}, + ) + + assert_raises( + InvalidCiphertextException, + decrypt, + master_keys=master_key_map, + ciphertext_blob=ciphertext_blob, + encryption_context={}, + ) From 3fe8afaa605f79c769e9b1204b3faee11d94d761 Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Mon, 26 Aug 2019 23:29:30 -0700 Subject: [PATCH 16/67] add tests for generate_data_key and generate_master_key --- tests/test_kms/test_utils.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/tests/test_kms/test_utils.py b/tests/test_kms/test_utils.py index 9d540d50a..466c72ea9 100644 --- a/tests/test_kms/test_utils.py +++ b/tests/test_kms/test_utils.py @@ -6,11 +6,12 @@ from parameterized import parameterized from moto.kms.exceptions import AccessDeniedException, InvalidCiphertextException, NotFoundException from moto.kms.models import Key from moto.kms.utils import ( + _deserialize_ciphertext_blob, + _serialize_ciphertext_blob, + _serialize_encryption_context, generate_data_key, generate_master_key, - _serialize_ciphertext_blob, - _deserialize_ciphertext_blob, - _serialize_encryption_context, + MASTER_KEY_LEN, encrypt, decrypt, Ciphertext, @@ -45,6 +46,20 @@ CIPHERTEXT_BLOB_VECTORS = ( ) +def test_generate_data_key(): + test = generate_data_key(123) + + test.should.be.a(bytes) + len(test).should.equal(123) + + +def test_generate_master_key(): + test = generate_master_key() + + test.should.be.a(bytes) + len(test).should.equal(MASTER_KEY_LEN) + + @parameterized(ENCRYPTION_CONTEXT_VECTORS) def test_serialize_encryption_context(raw, serialized): test = _serialize_encryption_context(raw) From 98581b9196768ad8d5eaa1e02ca744c0c3b2098e Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Tue, 27 Aug 2019 13:42:36 -0700 Subject: [PATCH 17/67] add proper KMS encrypt, decrypt, and generate_data_key functionality and tests --- moto/kms/models.py | 91 ++-- moto/kms/responses.py | 85 ++-- tests/test_kms/test_kms.py | 894 ++++++++++++++++------------------- tests/test_kms/test_utils.py | 42 +- 4 files changed, 552 insertions(+), 560 deletions(-) diff --git a/moto/kms/models.py b/moto/kms/models.py index d1b61d86c..9fb28bb5f 100644 --- a/moto/kms/models.py +++ b/moto/kms/models.py @@ -4,13 +4,12 @@ import os import boto.kms from moto.core import BaseBackend, BaseModel from moto.core.utils import iso_8601_datetime_without_milliseconds -from .utils import generate_key_id, generate_master_key +from .utils import decrypt, encrypt, generate_key_id, generate_master_key from collections import defaultdict from datetime import datetime, timedelta class Key(BaseModel): - def __init__(self, policy, key_usage, description, tags, region): self.id = generate_key_id() self.policy = policy @@ -46,8 +45,8 @@ class Key(BaseModel): "KeyState": self.key_state, } } - if self.key_state == 'PendingDeletion': - key_dict['KeyMetadata']['DeletionDate'] = iso_8601_datetime_without_milliseconds(self.deletion_date) + if self.key_state == "PendingDeletion": + key_dict["KeyMetadata"]["DeletionDate"] = iso_8601_datetime_without_milliseconds(self.deletion_date) return key_dict def delete(self, region_name): @@ -56,28 +55,28 @@ class Key(BaseModel): @classmethod def create_from_cloudformation_json(self, resource_name, cloudformation_json, region_name): kms_backend = kms_backends[region_name] - properties = cloudformation_json['Properties'] + properties = cloudformation_json["Properties"] key = kms_backend.create_key( - policy=properties['KeyPolicy'], - key_usage='ENCRYPT_DECRYPT', - description=properties['Description'], - tags=properties.get('Tags'), + policy=properties["KeyPolicy"], + key_usage="ENCRYPT_DECRYPT", + description=properties["Description"], + tags=properties.get("Tags"), region=region_name, ) - key.key_rotation_status = properties['EnableKeyRotation'] - key.enabled = properties['Enabled'] + key.key_rotation_status = properties["EnableKeyRotation"] + key.enabled = properties["Enabled"] return key def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'Arn': + + if attribute_name == "Arn": return self.arn raise UnformattedGetAttTemplateException() class KmsBackend(BaseBackend): - def __init__(self): self.keys = {} self.key_to_aliases = defaultdict(set) @@ -110,8 +109,8 @@ class KmsBackend(BaseBackend): # allow the different methods (alias, ARN :key/, keyId, ARN alias) to # describe key not just KeyId key_id = self.get_key_id(key_id) - if r'alias/' in str(key_id).lower(): - key_id = self.get_key_id_from_alias(key_id.split('alias/')[1]) + if r"alias/" in str(key_id).lower(): + key_id = self.get_key_id_from_alias(key_id.split("alias/")[1]) return self.keys[self.get_key_id(key_id)] def list_keys(self): @@ -119,7 +118,26 @@ class KmsBackend(BaseBackend): def get_key_id(self, key_id): # Allow use of ARN as well as pure KeyId - return str(key_id).split(r':key/')[1] if r':key/' in str(key_id).lower() else key_id + return str(key_id).split(r":key/")[1] if r":key/" in str(key_id).lower() else key_id + + def get_alias_name(self, alias_name): + # Allow use of ARN as well as alias name + return str(alias_name).split(r":alias/")[1] if r":alias/" in str(alias_name).lower() else alias_name + + def any_id_to_key_id(self, key_id): + """Go from any valid key ID to the raw key ID. + + Acceptable inputs: + - raw key ID + - key ARN + - alias name + - alias ARN + """ + key_id = self.get_alias_name(key_id) + key_id = self.get_key_id(key_id) + if key_id.startswith("alias/"): + key_id = self.get_key_id_from_alias(key_id) + return key_id def alias_exists(self, alias_name): for aliases in self.key_to_aliases.values(): @@ -163,37 +181,56 @@ class KmsBackend(BaseBackend): def disable_key(self, key_id): self.keys[key_id].enabled = False - self.keys[key_id].key_state = 'Disabled' + self.keys[key_id].key_state = "Disabled" def enable_key(self, key_id): self.keys[key_id].enabled = True - self.keys[key_id].key_state = 'Enabled' + self.keys[key_id].key_state = "Enabled" def cancel_key_deletion(self, key_id): - self.keys[key_id].key_state = 'Disabled' + self.keys[key_id].key_state = "Disabled" self.keys[key_id].deletion_date = None def schedule_key_deletion(self, key_id, pending_window_in_days): if 7 <= pending_window_in_days <= 30: self.keys[key_id].enabled = False - self.keys[key_id].key_state = 'PendingDeletion' + self.keys[key_id].key_state = "PendingDeletion" self.keys[key_id].deletion_date = datetime.now() + timedelta(days=pending_window_in_days) return iso_8601_datetime_without_milliseconds(self.keys[key_id].deletion_date) + def encrypt(self, key_id, plaintext, encryption_context): + key_id = self.any_id_to_key_id(key_id) + + ciphertext_blob = encrypt( + master_keys=self.keys, key_id=key_id, plaintext=plaintext, encryption_context=encryption_context + ) + arn = self.keys[key_id].arn + return ciphertext_blob, arn + + def decrypt(self, ciphertext_blob, encryption_context): + plaintext, key_id = decrypt( + master_keys=self.keys, ciphertext_blob=ciphertext_blob, encryption_context=encryption_context + ) + arn = self.keys[key_id].arn + return plaintext, arn + def generate_data_key(self, key_id, encryption_context, number_of_bytes, key_spec, grant_tokens): - key = self.keys[self.get_key_id(key_id)] + key_id = self.any_id_to_key_id(key_id) if key_spec: - if key_spec == 'AES_128': - bytes = 16 + # Note: Actual validation of key_spec is done in kms.responses + if key_spec == "AES_128": + plaintext_len = 16 else: - bytes = 32 + plaintext_len = 32 else: - bytes = number_of_bytes + plaintext_len = number_of_bytes - plaintext = os.urandom(bytes) + plaintext = os.urandom(plaintext_len) - return plaintext, key.arn + ciphertext_blob, arn = self.encrypt(key_id=key_id, plaintext=plaintext, encryption_context=encryption_context) + + return plaintext, ciphertext_blob, arn kms_backends = {} diff --git a/moto/kms/responses.py b/moto/kms/responses.py index 53012b7f8..0b8684019 100644 --- a/moto/kms/responses.py +++ b/moto/kms/responses.py @@ -8,6 +8,7 @@ import six from moto.core.responses import BaseResponse from .models import kms_backends from .exceptions import NotFoundException, ValidationException, AlreadyExistsException, NotAuthorizedException +from .utils import decrypt, encrypt reserved_aliases = [ 'alias/aws/ebs', @@ -21,7 +22,13 @@ class KmsResponse(BaseResponse): @property def parameters(self): - return json.loads(self.body) + params = json.loads(self.body) + + for key in ("Plaintext", "CiphertextBlob"): + if key in params: + params[key] = base64.b64decode(params[key].encode("utf-8")) + + return params @property def kms_backend(self): @@ -224,24 +231,34 @@ class KmsResponse(BaseResponse): return json.dumps({'Truncated': False, 'PolicyNames': ['default']}) def encrypt(self): - """ - We perform no encryption, we just encode the value as base64 and then - decode it in decrypt(). - """ - value = self.parameters.get("Plaintext") - if isinstance(value, six.text_type): - value = value.encode('utf-8') - return json.dumps({"CiphertextBlob": base64.b64encode(value).decode("utf-8"), 'KeyId': 'key_id'}) + key_id = self.parameters.get("KeyId") + encryption_context = self.parameters.get('EncryptionContext', {}) + plaintext = self.parameters.get("Plaintext") + + if isinstance(plaintext, six.text_type): + plaintext = plaintext.encode('utf-8') + + ciphertext_blob, arn = self.kms_backend.encrypt( + key_id=key_id, + plaintext=plaintext, + encryption_context=encryption_context, + ) + ciphertext_blob_response = base64.b64encode(ciphertext_blob).decode("utf-8") + + return json.dumps({"CiphertextBlob": ciphertext_blob_response, "KeyId": arn}) def decrypt(self): - # TODO refuse decode if EncryptionContext is not the same as when it was encrypted / generated + ciphertext_blob = self.parameters.get("CiphertextBlob") + encryption_context = self.parameters.get('EncryptionContext', {}) - value = self.parameters.get("CiphertextBlob") - try: - return json.dumps({"Plaintext": base64.b64decode(value).decode("utf-8"), 'KeyId': 'key_id'}) - except UnicodeDecodeError: - # Generate data key will produce random bytes which when decrypted is still returned as base64 - return json.dumps({"Plaintext": value}) + plaintext, arn = self.kms_backend.decrypt( + ciphertext_blob=ciphertext_blob, + encryption_context=encryption_context, + ) + + plaintext_response = base64.b64encode(plaintext).decode("utf-8") + + return json.dumps({"Plaintext": plaintext_response, 'KeyId': arn}) def disable_key(self): key_id = self.parameters.get('KeyId') @@ -291,7 +308,7 @@ class KmsResponse(BaseResponse): def generate_data_key(self): key_id = self.parameters.get('KeyId') - encryption_context = self.parameters.get('EncryptionContext') + encryption_context = self.parameters.get('EncryptionContext', {}) number_of_bytes = self.parameters.get('NumberOfBytes') key_spec = self.parameters.get('KeySpec') grant_tokens = self.parameters.get('GrantTokens') @@ -306,27 +323,39 @@ class KmsResponse(BaseResponse): raise NotFoundException('Invalid keyId') if number_of_bytes and (number_of_bytes > 1024 or number_of_bytes < 0): - raise ValidationException("1 validation error detected: Value '2048' at 'numberOfBytes' failed " - "to satisfy constraint: Member must have value less than or " - "equal to 1024") + raise ValidationException(( + "1 validation error detected: Value '{number_of_bytes:d}' at 'numberOfBytes' failed " + "to satisfy constraint: Member must have value less than or " + "equal to 1024" + ).format(number_of_bytes=number_of_bytes) + ) if key_spec and key_spec not in ('AES_256', 'AES_128'): - raise ValidationException("1 validation error detected: Value 'AES_257' at 'keySpec' failed " - "to satisfy constraint: Member must satisfy enum value set: " - "[AES_256, AES_128]") + raise ValidationException(( + "1 validation error detected: Value '{key_spec}' at 'keySpec' failed " + "to satisfy constraint: Member must satisfy enum value set: " + "[AES_256, AES_128]" + ).format(key_spec=key_spec) + ) if not key_spec and not number_of_bytes: raise ValidationException("Please specify either number of bytes or key spec.") if key_spec and number_of_bytes: raise ValidationException("Please specify either number of bytes or key spec.") - plaintext, key_arn = self.kms_backend.generate_data_key(key_id, encryption_context, - number_of_bytes, key_spec, grant_tokens) + plaintext, ciphertext_blob, key_arn = self.kms_backend.generate_data_key( + key_id=key_id, + encryption_context=encryption_context, + number_of_bytes=number_of_bytes, + key_spec=key_spec, + grant_tokens=grant_tokens + ) - plaintext = base64.b64encode(plaintext).decode() + plaintext_response = base64.b64encode(plaintext).decode("utf-8") + ciphertext_blob_response = base64.b64encode(ciphertext_blob).decode("utf-8") return json.dumps({ - 'CiphertextBlob': plaintext, - 'Plaintext': plaintext, + 'CiphertextBlob': ciphertext_blob_response, + 'Plaintext': plaintext_response, 'KeyId': key_arn # not alias }) diff --git a/tests/test_kms/test_kms.py b/tests/test_kms/test_kms.py index f189fbe41..4e1f39540 100644 --- a/tests/test_kms/test_kms.py +++ b/tests/test_kms/test_kms.py @@ -1,333 +1,368 @@ from __future__ import unicode_literals -import os, re -import boto3 -import boto.kms -import botocore.exceptions -from boto.exception import JSONResponseError -from boto.kms.exceptions import AlreadyExistsException, NotFoundException - -from moto.kms.exceptions import NotFoundException as MotoNotFoundException -import sure # noqa -from moto import mock_kms, mock_kms_deprecated -from nose.tools import assert_raises -from freezegun import freeze_time from datetime import date from datetime import datetime from dateutil.tz import tzutc +import base64 +import binascii +import os +import re + +import boto3 +import boto.kms +import botocore.exceptions +import sure # noqa +from boto.exception import JSONResponseError +from boto.kms.exceptions import AlreadyExistsException, NotFoundException +from freezegun import freeze_time +from nose.tools import assert_raises +from parameterized import parameterized + +from moto.kms.exceptions import NotFoundException as MotoNotFoundException +from moto import mock_kms, mock_kms_deprecated + +PLAINTEXT_VECTORS = ( + (b"some encodeable plaintext",), + (b"some unencodeable plaintext \xec\x8a\xcf\xb6r\xe9\xb5\xeb\xff\xa23\x16",), +) @mock_kms def test_create_key(): - conn = boto3.client('kms', region_name='us-east-1') + conn = boto3.client("kms", region_name="us-east-1") with freeze_time("2015-01-01 00:00:00"): - key = conn.create_key(Policy="my policy", - Description="my key", - KeyUsage='ENCRYPT_DECRYPT', - Tags=[ - { - 'TagKey': 'project', - 'TagValue': 'moto', - }, - ]) + key = conn.create_key( + Policy="my policy", + Description="my key", + KeyUsage="ENCRYPT_DECRYPT", + Tags=[{"TagKey": "project", "TagValue": "moto"}], + ) - key['KeyMetadata']['Description'].should.equal("my key") - key['KeyMetadata']['KeyUsage'].should.equal("ENCRYPT_DECRYPT") - key['KeyMetadata']['Enabled'].should.equal(True) - key['KeyMetadata']['CreationDate'].should.be.a(date) + key["KeyMetadata"]["Description"].should.equal("my key") + key["KeyMetadata"]["KeyUsage"].should.equal("ENCRYPT_DECRYPT") + key["KeyMetadata"]["Enabled"].should.equal(True) + key["KeyMetadata"]["CreationDate"].should.be.a(date) @mock_kms_deprecated def test_describe_key(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", - description="my key", key_usage='ENCRYPT_DECRYPT') - key_id = key['KeyMetadata']['KeyId'] + key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + key_id = key["KeyMetadata"]["KeyId"] key = conn.describe_key(key_id) - key['KeyMetadata']['Description'].should.equal("my key") - key['KeyMetadata']['KeyUsage'].should.equal("ENCRYPT_DECRYPT") + key["KeyMetadata"]["Description"].should.equal("my key") + key["KeyMetadata"]["KeyUsage"].should.equal("ENCRYPT_DECRYPT") @mock_kms_deprecated def test_describe_key_via_alias(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", - description="my key", key_usage='ENCRYPT_DECRYPT') - conn.create_alias(alias_name='alias/my-key-alias', - target_key_id=key['KeyMetadata']['KeyId']) + key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + conn.create_alias(alias_name="alias/my-key-alias", target_key_id=key["KeyMetadata"]["KeyId"]) - alias_key = conn.describe_key('alias/my-key-alias') - alias_key['KeyMetadata']['Description'].should.equal("my key") - alias_key['KeyMetadata']['KeyUsage'].should.equal("ENCRYPT_DECRYPT") - alias_key['KeyMetadata']['Arn'].should.equal(key['KeyMetadata']['Arn']) + alias_key = conn.describe_key("alias/my-key-alias") + alias_key["KeyMetadata"]["Description"].should.equal("my key") + alias_key["KeyMetadata"]["KeyUsage"].should.equal("ENCRYPT_DECRYPT") + alias_key["KeyMetadata"]["Arn"].should.equal(key["KeyMetadata"]["Arn"]) @mock_kms_deprecated def test_describe_key_via_alias_not_found(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", - description="my key", key_usage='ENCRYPT_DECRYPT') - conn.create_alias(alias_name='alias/my-key-alias', - target_key_id=key['KeyMetadata']['KeyId']) + key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + conn.create_alias(alias_name="alias/my-key-alias", target_key_id=key["KeyMetadata"]["KeyId"]) - conn.describe_key.when.called_with( - 'alias/not-found-alias').should.throw(JSONResponseError) + conn.describe_key.when.called_with("alias/not-found-alias").should.throw(JSONResponseError) @mock_kms_deprecated def test_describe_key_via_arn(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", - description="my key", key_usage='ENCRYPT_DECRYPT') - arn = key['KeyMetadata']['Arn'] + key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + arn = key["KeyMetadata"]["Arn"] the_key = conn.describe_key(arn) - the_key['KeyMetadata']['Description'].should.equal("my key") - the_key['KeyMetadata']['KeyUsage'].should.equal("ENCRYPT_DECRYPT") - the_key['KeyMetadata']['KeyId'].should.equal(key['KeyMetadata']['KeyId']) + the_key["KeyMetadata"]["Description"].should.equal("my key") + the_key["KeyMetadata"]["KeyUsage"].should.equal("ENCRYPT_DECRYPT") + the_key["KeyMetadata"]["KeyId"].should.equal(key["KeyMetadata"]["KeyId"]) @mock_kms_deprecated def test_describe_missing_key(): conn = boto.kms.connect_to_region("us-west-2") - conn.describe_key.when.called_with( - "not-a-key").should.throw(JSONResponseError) + conn.describe_key.when.called_with("not-a-key").should.throw(JSONResponseError) @mock_kms_deprecated def test_list_keys(): conn = boto.kms.connect_to_region("us-west-2") - conn.create_key(policy="my policy", description="my key1", - key_usage='ENCRYPT_DECRYPT') - conn.create_key(policy="my policy", description="my key2", - key_usage='ENCRYPT_DECRYPT') + conn.create_key(policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT") + conn.create_key(policy="my policy", description="my key2", key_usage="ENCRYPT_DECRYPT") keys = conn.list_keys() - keys['Keys'].should.have.length_of(2) + keys["Keys"].should.have.length_of(2) @mock_kms_deprecated def test_enable_key_rotation(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", - description="my key", key_usage='ENCRYPT_DECRYPT') - key_id = key['KeyMetadata']['KeyId'] + key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + key_id = key["KeyMetadata"]["KeyId"] conn.enable_key_rotation(key_id) - conn.get_key_rotation_status( - key_id)['KeyRotationEnabled'].should.equal(True) + conn.get_key_rotation_status(key_id)["KeyRotationEnabled"].should.equal(True) @mock_kms_deprecated def test_enable_key_rotation_via_arn(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", - description="my key", key_usage='ENCRYPT_DECRYPT') - key_id = key['KeyMetadata']['Arn'] + key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + key_id = key["KeyMetadata"]["Arn"] conn.enable_key_rotation(key_id) - conn.get_key_rotation_status( - key_id)['KeyRotationEnabled'].should.equal(True) + conn.get_key_rotation_status(key_id)["KeyRotationEnabled"].should.equal(True) @mock_kms_deprecated def test_enable_key_rotation_with_missing_key(): conn = boto.kms.connect_to_region("us-west-2") - conn.enable_key_rotation.when.called_with( - "not-a-key").should.throw(NotFoundException) + conn.enable_key_rotation.when.called_with("not-a-key").should.throw(NotFoundException) @mock_kms_deprecated def test_enable_key_rotation_with_alias_name_should_fail(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", - description="my key", key_usage='ENCRYPT_DECRYPT') - conn.create_alias(alias_name='alias/my-key-alias', - target_key_id=key['KeyMetadata']['KeyId']) + key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + conn.create_alias(alias_name="alias/my-key-alias", target_key_id=key["KeyMetadata"]["KeyId"]) - alias_key = conn.describe_key('alias/my-key-alias') - alias_key['KeyMetadata']['Arn'].should.equal(key['KeyMetadata']['Arn']) + alias_key = conn.describe_key("alias/my-key-alias") + alias_key["KeyMetadata"]["Arn"].should.equal(key["KeyMetadata"]["Arn"]) - conn.enable_key_rotation.when.called_with( - 'alias/my-alias').should.throw(NotFoundException) + conn.enable_key_rotation.when.called_with("alias/my-alias").should.throw(NotFoundException) @mock_kms_deprecated def test_disable_key_rotation(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", - description="my key", key_usage='ENCRYPT_DECRYPT') - key_id = key['KeyMetadata']['KeyId'] + key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + key_id = key["KeyMetadata"]["KeyId"] conn.enable_key_rotation(key_id) - conn.get_key_rotation_status( - key_id)['KeyRotationEnabled'].should.equal(True) + conn.get_key_rotation_status(key_id)["KeyRotationEnabled"].should.equal(True) conn.disable_key_rotation(key_id) - conn.get_key_rotation_status( - key_id)['KeyRotationEnabled'].should.equal(False) + conn.get_key_rotation_status(key_id)["KeyRotationEnabled"].should.equal(False) @mock_kms_deprecated -def test_encrypt(): - """ - test_encrypt - Using base64 encoding to merely test that the endpoint was called - """ +def test_generate_data_key(): conn = boto.kms.connect_to_region("us-west-2") - response = conn.encrypt('key_id', 'encryptme'.encode('utf-8')) - response['CiphertextBlob'].should.equal(b'ZW5jcnlwdG1l') - response['KeyId'].should.equal('key_id') + + key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + key_id = key["KeyMetadata"]["KeyId"] + key_arn = key["KeyMetadata"]["Arn"] + + response = conn.generate_data_key(key_id=key_id, number_of_bytes=32) + + # CiphertextBlob must NOT be base64-encoded + with assert_raises(Exception): + base64.b64decode(response["CiphertextBlob"], validate=True) + # Plaintext must NOT be base64-encoded + with assert_raises(Exception): + base64.b64decode(response["Plaintext"], validate=True) + + response["KeyId"].should.equal(key_arn) +@mock_kms +def test_boto3_generate_data_key(): + kms = boto3.client("kms", region_name="us-west-2") + + key = kms.create_key() + key_id = key["KeyMetadata"]["KeyId"] + key_arn = key["KeyMetadata"]["Arn"] + + response = kms.generate_data_key(KeyId=key_id, NumberOfBytes=32) + + # CiphertextBlob must NOT be base64-encoded + with assert_raises(Exception): + base64.b64decode(response["CiphertextBlob"], validate=True) + # Plaintext must NOT be base64-encoded + with assert_raises(Exception): + base64.b64decode(response["Plaintext"], validate=True) + + response["KeyId"].should.equal(key_arn) + + +@parameterized(PLAINTEXT_VECTORS) @mock_kms_deprecated -def test_decrypt(): - conn = boto.kms.connect_to_region('us-west-2') - response = conn.decrypt('ZW5jcnlwdG1l'.encode('utf-8')) - response['Plaintext'].should.equal(b'encryptme') - response['KeyId'].should.equal('key_id') +def test_encrypt(plaintext): + conn = boto.kms.connect_to_region("us-west-2") + + key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + key_id = key["KeyMetadata"]["KeyId"] + key_arn = key["KeyMetadata"]["Arn"] + + response = conn.encrypt(key_id, plaintext) + response["CiphertextBlob"].should_not.equal(plaintext) + + # CiphertextBlob must NOT be base64-encoded + with assert_raises(Exception): + base64.b64decode(response["CiphertextBlob"], validate=True) + + response["KeyId"].should.equal(key_arn) + + +@parameterized(PLAINTEXT_VECTORS) +@mock_kms_deprecated +def test_decrypt(plaintext): + conn = boto.kms.connect_to_region("us-west-2") + + key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + key_id = key["KeyMetadata"]["KeyId"] + key_arn = key["KeyMetadata"]["Arn"] + + encrypt_response = conn.encrypt(key_id, plaintext) + + # CiphertextBlob must NOT be base64-encoded + with assert_raises(Exception): + base64.b64decode(encrypt_response["CiphertextBlob"], validate=True) + + decrypt_response = conn.decrypt(encrypt_response["CiphertextBlob"]) + + # Plaintext must NOT be base64-encoded + with assert_raises(Exception): + base64.b64decode(decrypt_response["Plaintext"], validate=True) + + decrypt_response["Plaintext"].should.equal(plaintext) + decrypt_response["KeyId"].should.equal(key_arn) @mock_kms_deprecated def test_disable_key_rotation_with_missing_key(): conn = boto.kms.connect_to_region("us-west-2") - conn.disable_key_rotation.when.called_with( - "not-a-key").should.throw(NotFoundException) + conn.disable_key_rotation.when.called_with("not-a-key").should.throw(NotFoundException) @mock_kms_deprecated def test_get_key_rotation_status_with_missing_key(): conn = boto.kms.connect_to_region("us-west-2") - conn.get_key_rotation_status.when.called_with( - "not-a-key").should.throw(NotFoundException) + conn.get_key_rotation_status.when.called_with("not-a-key").should.throw(NotFoundException) @mock_kms_deprecated def test_get_key_rotation_status(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", - description="my key", key_usage='ENCRYPT_DECRYPT') - key_id = key['KeyMetadata']['KeyId'] + key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + key_id = key["KeyMetadata"]["KeyId"] - conn.get_key_rotation_status( - key_id)['KeyRotationEnabled'].should.equal(False) + conn.get_key_rotation_status(key_id)["KeyRotationEnabled"].should.equal(False) @mock_kms_deprecated def test_create_key_defaults_key_rotation(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", - description="my key", key_usage='ENCRYPT_DECRYPT') - key_id = key['KeyMetadata']['KeyId'] + key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + key_id = key["KeyMetadata"]["KeyId"] - conn.get_key_rotation_status( - key_id)['KeyRotationEnabled'].should.equal(False) + conn.get_key_rotation_status(key_id)["KeyRotationEnabled"].should.equal(False) @mock_kms_deprecated def test_get_key_policy(): - conn = boto.kms.connect_to_region('us-west-2') + conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy='my policy', - description='my key1', key_usage='ENCRYPT_DECRYPT') - key_id = key['KeyMetadata']['KeyId'] + key = conn.create_key(policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT") + key_id = key["KeyMetadata"]["KeyId"] - policy = conn.get_key_policy(key_id, 'default') - policy['Policy'].should.equal('my policy') + policy = conn.get_key_policy(key_id, "default") + policy["Policy"].should.equal("my policy") @mock_kms_deprecated def test_get_key_policy_via_arn(): - conn = boto.kms.connect_to_region('us-west-2') + conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy='my policy', - description='my key1', key_usage='ENCRYPT_DECRYPT') - policy = conn.get_key_policy(key['KeyMetadata']['Arn'], 'default') + key = conn.create_key(policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT") + policy = conn.get_key_policy(key["KeyMetadata"]["Arn"], "default") - policy['Policy'].should.equal('my policy') + policy["Policy"].should.equal("my policy") @mock_kms_deprecated def test_put_key_policy(): - conn = boto.kms.connect_to_region('us-west-2') + conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy='my policy', - description='my key1', key_usage='ENCRYPT_DECRYPT') - key_id = key['KeyMetadata']['KeyId'] + key = conn.create_key(policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT") + key_id = key["KeyMetadata"]["KeyId"] - conn.put_key_policy(key_id, 'default', 'new policy') - policy = conn.get_key_policy(key_id, 'default') - policy['Policy'].should.equal('new policy') + conn.put_key_policy(key_id, "default", "new policy") + policy = conn.get_key_policy(key_id, "default") + policy["Policy"].should.equal("new policy") @mock_kms_deprecated def test_put_key_policy_via_arn(): - conn = boto.kms.connect_to_region('us-west-2') + conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy='my policy', - description='my key1', key_usage='ENCRYPT_DECRYPT') - key_id = key['KeyMetadata']['Arn'] + key = conn.create_key(policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT") + key_id = key["KeyMetadata"]["Arn"] - conn.put_key_policy(key_id, 'default', 'new policy') - policy = conn.get_key_policy(key_id, 'default') - policy['Policy'].should.equal('new policy') + conn.put_key_policy(key_id, "default", "new policy") + policy = conn.get_key_policy(key_id, "default") + policy["Policy"].should.equal("new policy") @mock_kms_deprecated def test_put_key_policy_via_alias_should_not_update(): - conn = boto.kms.connect_to_region('us-west-2') + conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy='my policy', - description='my key1', key_usage='ENCRYPT_DECRYPT') - conn.create_alias(alias_name='alias/my-key-alias', - target_key_id=key['KeyMetadata']['KeyId']) + key = conn.create_key(policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT") + conn.create_alias(alias_name="alias/my-key-alias", target_key_id=key["KeyMetadata"]["KeyId"]) - conn.put_key_policy.when.called_with( - 'alias/my-key-alias', 'default', 'new policy').should.throw(NotFoundException) + conn.put_key_policy.when.called_with("alias/my-key-alias", "default", "new policy").should.throw(NotFoundException) - policy = conn.get_key_policy(key['KeyMetadata']['KeyId'], 'default') - policy['Policy'].should.equal('my policy') + policy = conn.get_key_policy(key["KeyMetadata"]["KeyId"], "default") + policy["Policy"].should.equal("my policy") @mock_kms_deprecated def test_put_key_policy(): - conn = boto.kms.connect_to_region('us-west-2') + conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy='my policy', - description='my key1', key_usage='ENCRYPT_DECRYPT') - conn.put_key_policy(key['KeyMetadata']['Arn'], 'default', 'new policy') + key = conn.create_key(policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT") + conn.put_key_policy(key["KeyMetadata"]["Arn"], "default", "new policy") - policy = conn.get_key_policy(key['KeyMetadata']['KeyId'], 'default') - policy['Policy'].should.equal('new policy') + policy = conn.get_key_policy(key["KeyMetadata"]["KeyId"], "default") + policy["Policy"].should.equal("new policy") @mock_kms_deprecated def test_list_key_policies(): - conn = boto.kms.connect_to_region('us-west-2') + conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy='my policy', - description='my key1', key_usage='ENCRYPT_DECRYPT') - key_id = key['KeyMetadata']['KeyId'] + key = conn.create_key(policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT") + key_id = key["KeyMetadata"]["KeyId"] policies = conn.list_key_policies(key_id) - policies['PolicyNames'].should.equal(['default']) + policies["PolicyNames"].should.equal(["default"]) @mock_kms_deprecated def test__create_alias__returns_none_if_correct(): kms = boto.connect_kms() create_resp = kms.create_key() - key_id = create_resp['KeyMetadata']['KeyId'] + key_id = create_resp["KeyMetadata"]["KeyId"] - resp = kms.create_alias('alias/my-alias', key_id) + resp = kms.create_alias("alias/my-alias", key_id) resp.should.be.none @@ -336,14 +371,9 @@ def test__create_alias__returns_none_if_correct(): def test__create_alias__raises_if_reserved_alias(): kms = boto.connect_kms() create_resp = kms.create_key() - key_id = create_resp['KeyMetadata']['KeyId'] + key_id = create_resp["KeyMetadata"]["KeyId"] - reserved_aliases = [ - 'alias/aws/ebs', - 'alias/aws/s3', - 'alias/aws/redshift', - 'alias/aws/rds', - ] + reserved_aliases = ["alias/aws/ebs", "alias/aws/s3", "alias/aws/redshift", "alias/aws/rds"] for alias_name in reserved_aliases: with assert_raises(JSONResponseError) as err: @@ -351,9 +381,9 @@ def test__create_alias__raises_if_reserved_alias(): ex = err.exception ex.error_message.should.be.none - ex.error_code.should.equal('NotAuthorizedException') - ex.body.should.equal({'__type': 'NotAuthorizedException'}) - ex.reason.should.equal('Bad Request') + ex.error_code.should.equal("NotAuthorizedException") + ex.body.should.equal({"__type": "NotAuthorizedException"}) + ex.reason.should.equal("Bad Request") ex.status.should.equal(400) @@ -361,38 +391,37 @@ def test__create_alias__raises_if_reserved_alias(): def test__create_alias__can_create_multiple_aliases_for_same_key_id(): kms = boto.connect_kms() create_resp = kms.create_key() - key_id = create_resp['KeyMetadata']['KeyId'] + key_id = create_resp["KeyMetadata"]["KeyId"] - kms.create_alias('alias/my-alias3', key_id).should.be.none - kms.create_alias('alias/my-alias4', key_id).should.be.none - kms.create_alias('alias/my-alias5', key_id).should.be.none + kms.create_alias("alias/my-alias3", key_id).should.be.none + kms.create_alias("alias/my-alias4", key_id).should.be.none + kms.create_alias("alias/my-alias5", key_id).should.be.none @mock_kms_deprecated def test__create_alias__raises_if_wrong_prefix(): kms = boto.connect_kms() create_resp = kms.create_key() - key_id = create_resp['KeyMetadata']['KeyId'] + key_id = create_resp["KeyMetadata"]["KeyId"] with assert_raises(JSONResponseError) as err: - kms.create_alias('wrongprefix/my-alias', key_id) + kms.create_alias("wrongprefix/my-alias", key_id) ex = err.exception - ex.error_message.should.equal('Invalid identifier') - ex.error_code.should.equal('ValidationException') - ex.body.should.equal({'message': 'Invalid identifier', - '__type': 'ValidationException'}) - ex.reason.should.equal('Bad Request') + ex.error_message.should.equal("Invalid identifier") + ex.error_code.should.equal("ValidationException") + ex.body.should.equal({"message": "Invalid identifier", "__type": "ValidationException"}) + ex.reason.should.equal("Bad Request") ex.status.should.equal(400) @mock_kms_deprecated def test__create_alias__raises_if_duplicate(): - region = 'us-west-2' + region = "us-west-2" kms = boto.kms.connect_to_region(region) create_resp = kms.create_key() - key_id = create_resp['KeyMetadata']['KeyId'] - alias = 'alias/my-alias' + key_id = create_resp["KeyMetadata"]["KeyId"] + alias = "alias/my-alias" kms.create_alias(alias, key_id) @@ -400,15 +429,17 @@ def test__create_alias__raises_if_duplicate(): kms.create_alias(alias, key_id) ex = err.exception - ex.error_message.should.match(r'An alias with the name arn:aws:kms:{region}:\d{{12}}:{alias} already exists' - .format(**locals())) + ex.error_message.should.match( + r"An alias with the name arn:aws:kms:{region}:\d{{12}}:{alias} already exists".format(**locals()) + ) ex.error_code.should.be.none ex.box_usage.should.be.none ex.request_id.should.be.none - ex.body['message'].should.match(r'An alias with the name arn:aws:kms:{region}:\d{{12}}:{alias} already exists' - .format(**locals())) - ex.body['__type'].should.equal('AlreadyExistsException') - ex.reason.should.equal('Bad Request') + ex.body["message"].should.match( + r"An alias with the name arn:aws:kms:{region}:\d{{12}}:{alias} already exists".format(**locals()) + ) + ex.body["__type"].should.equal("AlreadyExistsException") + ex.reason.should.equal("Bad Request") ex.status.should.equal(400) @@ -416,25 +447,27 @@ def test__create_alias__raises_if_duplicate(): def test__create_alias__raises_if_alias_has_restricted_characters(): kms = boto.connect_kms() create_resp = kms.create_key() - key_id = create_resp['KeyMetadata']['KeyId'] + key_id = create_resp["KeyMetadata"]["KeyId"] - alias_names_with_restricted_characters = [ - 'alias/my-alias!', - 'alias/my-alias$', - 'alias/my-alias@', - ] + alias_names_with_restricted_characters = ["alias/my-alias!", "alias/my-alias$", "alias/my-alias@"] for alias_name in alias_names_with_restricted_characters: with assert_raises(JSONResponseError) as err: kms.create_alias(alias_name, key_id) ex = err.exception - ex.body['__type'].should.equal('ValidationException') - ex.body['message'].should.equal( - "1 validation error detected: Value '{alias_name}' at 'aliasName' failed to satisfy constraint: Member must satisfy regular expression pattern: ^[a-zA-Z0-9:/_-]+$".format(**locals())) - ex.error_code.should.equal('ValidationException') + ex.body["__type"].should.equal("ValidationException") + ex.body["message"].should.equal( + "1 validation error detected: Value '{alias_name}' at 'aliasName' failed to satisfy constraint: Member must satisfy regular expression pattern: ^[a-zA-Z0-9:/_-]+$".format( + **locals() + ) + ) + ex.error_code.should.equal("ValidationException") ex.message.should.equal( - "1 validation error detected: Value '{alias_name}' at 'aliasName' failed to satisfy constraint: Member must satisfy regular expression pattern: ^[a-zA-Z0-9:/_-]+$".format(**locals())) - ex.reason.should.equal('Bad Request') + "1 validation error detected: Value '{alias_name}' at 'aliasName' failed to satisfy constraint: Member must satisfy regular expression pattern: ^[a-zA-Z0-9:/_-]+$".format( + **locals() + ) + ) + ex.reason.should.equal("Bad Request") ex.status.should.equal(400) @@ -444,23 +477,19 @@ def test__create_alias__raises_if_alias_has_colon_character(): # are accepted by regex ^[a-zA-Z0-9:/_-]+$ kms = boto.connect_kms() create_resp = kms.create_key() - key_id = create_resp['KeyMetadata']['KeyId'] + key_id = create_resp["KeyMetadata"]["KeyId"] - alias_names_with_restricted_characters = [ - 'alias/my:alias', - ] + alias_names_with_restricted_characters = ["alias/my:alias"] for alias_name in alias_names_with_restricted_characters: with assert_raises(JSONResponseError) as err: kms.create_alias(alias_name, key_id) ex = err.exception - ex.body['__type'].should.equal('ValidationException') - ex.body['message'].should.equal( - "{alias_name} contains invalid characters for an alias".format(**locals())) - ex.error_code.should.equal('ValidationException') - ex.message.should.equal( - "{alias_name} contains invalid characters for an alias".format(**locals())) - ex.reason.should.equal('Bad Request') + ex.body["__type"].should.equal("ValidationException") + ex.body["message"].should.equal("{alias_name} contains invalid characters for an alias".format(**locals())) + ex.error_code.should.equal("ValidationException") + ex.message.should.equal("{alias_name} contains invalid characters for an alias".format(**locals())) + ex.reason.should.equal("Bad Request") ex.status.should.equal(400) @@ -468,12 +497,9 @@ def test__create_alias__raises_if_alias_has_colon_character(): def test__create_alias__accepted_characters(): kms = boto.connect_kms() create_resp = kms.create_key() - key_id = create_resp['KeyMetadata']['KeyId'] + key_id = create_resp["KeyMetadata"]["KeyId"] - alias_names_with_accepted_characters = [ - 'alias/my-alias_/', - 'alias/my_alias-/', - ] + alias_names_with_accepted_characters = ["alias/my-alias_/", "alias/my_alias-/"] for alias_name in alias_names_with_accepted_characters: kms.create_alias(alias_name, key_id) @@ -483,8 +509,8 @@ def test__create_alias__accepted_characters(): def test__create_alias__raises_if_target_key_id_is_existing_alias(): kms = boto.connect_kms() create_resp = kms.create_key() - key_id = create_resp['KeyMetadata']['KeyId'] - alias = 'alias/my-alias' + key_id = create_resp["KeyMetadata"]["KeyId"] + alias = "alias/my-alias" kms.create_alias(alias, key_id) @@ -492,11 +518,11 @@ def test__create_alias__raises_if_target_key_id_is_existing_alias(): kms.create_alias(alias, alias) ex = err.exception - ex.body['__type'].should.equal('ValidationException') - ex.body['message'].should.equal('Aliases must refer to keys. Not aliases') - ex.error_code.should.equal('ValidationException') - ex.message.should.equal('Aliases must refer to keys. Not aliases') - ex.reason.should.equal('Bad Request') + ex.body["__type"].should.equal("ValidationException") + ex.body["message"].should.equal("Aliases must refer to keys. Not aliases") + ex.error_code.should.equal("ValidationException") + ex.message.should.equal("Aliases must refer to keys. Not aliases") + ex.reason.should.equal("Bad Request") ex.status.should.equal(400) @@ -504,14 +530,14 @@ def test__create_alias__raises_if_target_key_id_is_existing_alias(): def test__delete_alias(): kms = boto.connect_kms() create_resp = kms.create_key() - key_id = create_resp['KeyMetadata']['KeyId'] - alias = 'alias/my-alias' + key_id = create_resp["KeyMetadata"]["KeyId"] + alias = "alias/my-alias" # added another alias here to make sure that the deletion of the alias can # be done when there are multiple existing aliases. another_create_resp = kms.create_key() - another_key_id = create_resp['KeyMetadata']['KeyId'] - another_alias = 'alias/another-alias' + another_key_id = create_resp["KeyMetadata"]["KeyId"] + another_alias = "alias/another-alias" kms.create_alias(alias, key_id) kms.create_alias(another_alias, another_key_id) @@ -529,35 +555,35 @@ def test__delete_alias__raises_if_wrong_prefix(): kms = boto.connect_kms() with assert_raises(JSONResponseError) as err: - kms.delete_alias('wrongprefix/my-alias') + kms.delete_alias("wrongprefix/my-alias") ex = err.exception - ex.body['__type'].should.equal('ValidationException') - ex.body['message'].should.equal('Invalid identifier') - ex.error_code.should.equal('ValidationException') - ex.message.should.equal('Invalid identifier') - ex.reason.should.equal('Bad Request') + ex.body["__type"].should.equal("ValidationException") + ex.body["message"].should.equal("Invalid identifier") + ex.error_code.should.equal("ValidationException") + ex.message.should.equal("Invalid identifier") + ex.reason.should.equal("Bad Request") ex.status.should.equal(400) @mock_kms_deprecated def test__delete_alias__raises_if_alias_is_not_found(): - region = 'us-west-2' + region = "us-west-2" kms = boto.kms.connect_to_region(region) - alias_name = 'alias/unexisting-alias' + alias_name = "alias/unexisting-alias" with assert_raises(NotFoundException) as err: kms.delete_alias(alias_name) ex = err.exception - ex.body['__type'].should.equal('NotFoundException') - ex.body['message'].should.match( - r'Alias arn:aws:kms:{region}:\d{{12}}:{alias_name} is not found.'.format(**locals())) + ex.body["__type"].should.equal("NotFoundException") + ex.body["message"].should.match( + r"Alias arn:aws:kms:{region}:\d{{12}}:{alias_name} is not found.".format(**locals()) + ) ex.box_usage.should.be.none ex.error_code.should.be.none - ex.message.should.match( - r'Alias arn:aws:kms:{region}:\d{{12}}:{alias_name} is not found.'.format(**locals())) - ex.reason.should.equal('Bad Request') + ex.message.should.match(r"Alias arn:aws:kms:{region}:\d{{12}}:{alias_name} is not found.".format(**locals())) + ex.reason.should.equal("Bad Request") ex.request_id.should.be.none ex.status.should.equal(400) @@ -568,39 +594,43 @@ def test__list_aliases(): kms = boto.kms.connect_to_region(region) create_resp = kms.create_key() - key_id = create_resp['KeyMetadata']['KeyId'] - kms.create_alias('alias/my-alias1', key_id) - kms.create_alias('alias/my-alias2', key_id) - kms.create_alias('alias/my-alias3', key_id) + key_id = create_resp["KeyMetadata"]["KeyId"] + kms.create_alias("alias/my-alias1", key_id) + kms.create_alias("alias/my-alias2", key_id) + kms.create_alias("alias/my-alias3", key_id) resp = kms.list_aliases() - resp['Truncated'].should.be.false + resp["Truncated"].should.be.false - aliases = resp['Aliases'] + aliases = resp["Aliases"] def has_correct_arn(alias_obj): - alias_name = alias_obj['AliasName'] - alias_arn = alias_obj['AliasArn'] - return re.match(r'arn:aws:kms:{region}:\d{{12}}:{alias_name}'.format(region=region, alias_name=alias_name), - alias_arn) + alias_name = alias_obj["AliasName"] + alias_arn = alias_obj["AliasArn"] + return re.match( + r"arn:aws:kms:{region}:\d{{12}}:{alias_name}".format(region=region, alias_name=alias_name), alias_arn + ) - len([alias for alias in aliases if - has_correct_arn(alias) and 'alias/aws/ebs' == alias['AliasName']]).should.equal(1) - len([alias for alias in aliases if - has_correct_arn(alias) and 'alias/aws/rds' == alias['AliasName']]).should.equal(1) - len([alias for alias in aliases if - has_correct_arn(alias) and 'alias/aws/redshift' == alias['AliasName']]).should.equal(1) - len([alias for alias in aliases if - has_correct_arn(alias) and 'alias/aws/s3' == alias['AliasName']]).should.equal(1) + len([alias for alias in aliases if has_correct_arn(alias) and "alias/aws/ebs" == alias["AliasName"]]).should.equal( + 1 + ) + len([alias for alias in aliases if has_correct_arn(alias) and "alias/aws/rds" == alias["AliasName"]]).should.equal( + 1 + ) + len( + [alias for alias in aliases if has_correct_arn(alias) and "alias/aws/redshift" == alias["AliasName"]] + ).should.equal(1) + len([alias for alias in aliases if has_correct_arn(alias) and "alias/aws/s3" == alias["AliasName"]]).should.equal(1) - len([alias for alias in aliases if - has_correct_arn(alias) and 'alias/my-alias1' == alias['AliasName']]).should.equal(1) - len([alias for alias in aliases if - has_correct_arn(alias) and 'alias/my-alias2' == alias['AliasName']]).should.equal(1) + len( + [alias for alias in aliases if has_correct_arn(alias) and "alias/my-alias1" == alias["AliasName"]] + ).should.equal(1) + len( + [alias for alias in aliases if has_correct_arn(alias) and "alias/my-alias2" == alias["AliasName"]] + ).should.equal(1) - len([alias for alias in aliases if 'TargetKeyId' in alias and key_id == - alias['TargetKeyId']]).should.equal(3) + len([alias for alias in aliases if "TargetKeyId" in alias and key_id == alias["TargetKeyId"]]).should.equal(3) len(aliases).should.equal(7) @@ -610,156 +640,124 @@ def test__assert_valid_key_id(): from moto.kms.responses import _assert_valid_key_id import uuid - _assert_valid_key_id.when.called_with( - "not-a-key").should.throw(MotoNotFoundException) - _assert_valid_key_id.when.called_with( - str(uuid.uuid4())).should_not.throw(MotoNotFoundException) + _assert_valid_key_id.when.called_with("not-a-key").should.throw(MotoNotFoundException) + _assert_valid_key_id.when.called_with(str(uuid.uuid4())).should_not.throw(MotoNotFoundException) @mock_kms_deprecated def test__assert_default_policy(): from moto.kms.responses import _assert_default_policy - _assert_default_policy.when.called_with( - "not-default").should.throw(MotoNotFoundException) - _assert_default_policy.when.called_with( - "default").should_not.throw(MotoNotFoundException) + _assert_default_policy.when.called_with("not-default").should.throw(MotoNotFoundException) + _assert_default_policy.when.called_with("default").should_not.throw(MotoNotFoundException) +@parameterized(PLAINTEXT_VECTORS) @mock_kms -def test_kms_encrypt_boto3(): - client = boto3.client('kms', region_name='us-east-1') - response = client.encrypt(KeyId='foo', Plaintext=b'bar') +def test_kms_encrypt_boto3(plaintext): + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="key") + response = client.encrypt(KeyId=key["KeyMetadata"]["KeyId"], Plaintext=plaintext) - response = client.decrypt(CiphertextBlob=response['CiphertextBlob']) - response['Plaintext'].should.equal(b'bar') + response = client.decrypt(CiphertextBlob=response["CiphertextBlob"]) + response["Plaintext"].should.equal(plaintext) @mock_kms def test_disable_key(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='disable-key') - client.disable_key( - KeyId=key['KeyMetadata']['KeyId'] - ) + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="disable-key") + client.disable_key(KeyId=key["KeyMetadata"]["KeyId"]) - result = client.describe_key(KeyId=key['KeyMetadata']['KeyId']) + result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"]) assert result["KeyMetadata"]["Enabled"] == False - assert result["KeyMetadata"]["KeyState"] == 'Disabled' + assert result["KeyMetadata"]["KeyState"] == "Disabled" @mock_kms def test_enable_key(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='enable-key') - client.disable_key( - KeyId=key['KeyMetadata']['KeyId'] - ) - client.enable_key( - KeyId=key['KeyMetadata']['KeyId'] - ) + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="enable-key") + client.disable_key(KeyId=key["KeyMetadata"]["KeyId"]) + client.enable_key(KeyId=key["KeyMetadata"]["KeyId"]) - result = client.describe_key(KeyId=key['KeyMetadata']['KeyId']) + result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"]) assert result["KeyMetadata"]["Enabled"] == True - assert result["KeyMetadata"]["KeyState"] == 'Enabled' + assert result["KeyMetadata"]["KeyState"] == "Enabled" @mock_kms def test_schedule_key_deletion(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='schedule-key-deletion') - if os.environ.get('TEST_SERVER_MODE', 'false').lower() == 'false': + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="schedule-key-deletion") + if os.environ.get("TEST_SERVER_MODE", "false").lower() == "false": with freeze_time("2015-01-01 12:00:00"): - response = client.schedule_key_deletion( - KeyId=key['KeyMetadata']['KeyId'] - ) - assert response['KeyId'] == key['KeyMetadata']['KeyId'] - assert response['DeletionDate'] == datetime(2015, 1, 31, 12, 0, tzinfo=tzutc()) + response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"]) + assert response["KeyId"] == key["KeyMetadata"]["KeyId"] + assert response["DeletionDate"] == datetime(2015, 1, 31, 12, 0, tzinfo=tzutc()) else: # Can't manipulate time in server mode - response = client.schedule_key_deletion( - KeyId=key['KeyMetadata']['KeyId'] - ) - assert response['KeyId'] == key['KeyMetadata']['KeyId'] + response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"]) + assert response["KeyId"] == key["KeyMetadata"]["KeyId"] - result = client.describe_key(KeyId=key['KeyMetadata']['KeyId']) + result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"]) assert result["KeyMetadata"]["Enabled"] == False - assert result["KeyMetadata"]["KeyState"] == 'PendingDeletion' - assert 'DeletionDate' in result["KeyMetadata"] + assert result["KeyMetadata"]["KeyState"] == "PendingDeletion" + assert "DeletionDate" in result["KeyMetadata"] @mock_kms def test_schedule_key_deletion_custom(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='schedule-key-deletion') - if os.environ.get('TEST_SERVER_MODE', 'false').lower() == 'false': + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="schedule-key-deletion") + if os.environ.get("TEST_SERVER_MODE", "false").lower() == "false": with freeze_time("2015-01-01 12:00:00"): - response = client.schedule_key_deletion( - KeyId=key['KeyMetadata']['KeyId'], - PendingWindowInDays=7 - ) - assert response['KeyId'] == key['KeyMetadata']['KeyId'] - assert response['DeletionDate'] == datetime(2015, 1, 8, 12, 0, tzinfo=tzutc()) + response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"], PendingWindowInDays=7) + assert response["KeyId"] == key["KeyMetadata"]["KeyId"] + assert response["DeletionDate"] == datetime(2015, 1, 8, 12, 0, tzinfo=tzutc()) else: # Can't manipulate time in server mode - response = client.schedule_key_deletion( - KeyId=key['KeyMetadata']['KeyId'], - PendingWindowInDays=7 - ) - assert response['KeyId'] == key['KeyMetadata']['KeyId'] + response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"], PendingWindowInDays=7) + assert response["KeyId"] == key["KeyMetadata"]["KeyId"] - result = client.describe_key(KeyId=key['KeyMetadata']['KeyId']) + result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"]) assert result["KeyMetadata"]["Enabled"] == False - assert result["KeyMetadata"]["KeyState"] == 'PendingDeletion' - assert 'DeletionDate' in result["KeyMetadata"] + assert result["KeyMetadata"]["KeyState"] == "PendingDeletion" + assert "DeletionDate" in result["KeyMetadata"] @mock_kms def test_cancel_key_deletion(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='cancel-key-deletion') - client.schedule_key_deletion( - KeyId=key['KeyMetadata']['KeyId'] - ) - response = client.cancel_key_deletion( - KeyId=key['KeyMetadata']['KeyId'] - ) - assert response['KeyId'] == key['KeyMetadata']['KeyId'] + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="cancel-key-deletion") + client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"]) + response = client.cancel_key_deletion(KeyId=key["KeyMetadata"]["KeyId"]) + assert response["KeyId"] == key["KeyMetadata"]["KeyId"] - result = client.describe_key(KeyId=key['KeyMetadata']['KeyId']) + result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"]) assert result["KeyMetadata"]["Enabled"] == False - assert result["KeyMetadata"]["KeyState"] == 'Disabled' - assert 'DeletionDate' not in result["KeyMetadata"] + assert result["KeyMetadata"]["KeyState"] == "Disabled" + assert "DeletionDate" not in result["KeyMetadata"] @mock_kms def test_update_key_description(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='old_description') - key_id = key['KeyMetadata']['KeyId'] + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="old_description") + key_id = key["KeyMetadata"]["KeyId"] - result = client.update_key_description(KeyId=key_id, Description='new_description') - assert 'ResponseMetadata' in result + result = client.update_key_description(KeyId=key_id, Description="new_description") + assert "ResponseMetadata" in result @mock_kms def test_tag_resource(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='cancel-key-deletion') - response = client.schedule_key_deletion( - KeyId=key['KeyMetadata']['KeyId'] - ) + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="cancel-key-deletion") + response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"]) - keyid = response['KeyId'] - response = client.tag_resource( - KeyId=keyid, - Tags=[ - { - 'TagKey': 'string', - 'TagValue': 'string' - }, - ] - ) + keyid = response["KeyId"] + response = client.tag_resource(KeyId=keyid, Tags=[{"TagKey": "string", "TagValue": "string"}]) # Shouldn't have any data, just header assert len(response.keys()) == 1 @@ -767,226 +765,158 @@ def test_tag_resource(): @mock_kms def test_list_resource_tags(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='cancel-key-deletion') - response = client.schedule_key_deletion( - KeyId=key['KeyMetadata']['KeyId'] - ) + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="cancel-key-deletion") + response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"]) - keyid = response['KeyId'] - response = client.tag_resource( - KeyId=keyid, - Tags=[ - { - 'TagKey': 'string', - 'TagValue': 'string' - }, - ] - ) + keyid = response["KeyId"] + response = client.tag_resource(KeyId=keyid, Tags=[{"TagKey": "string", "TagValue": "string"}]) response = client.list_resource_tags(KeyId=keyid) - assert response['Tags'][0]['TagKey'] == 'string' - assert response['Tags'][0]['TagValue'] == 'string' + assert response["Tags"][0]["TagKey"] == "string" + assert response["Tags"][0]["TagValue"] == "string" @mock_kms def test_generate_data_key_sizes(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='generate-data-key-size') + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="generate-data-key-size") - resp1 = client.generate_data_key( - KeyId=key['KeyMetadata']['KeyId'], - KeySpec='AES_256' - ) - resp2 = client.generate_data_key( - KeyId=key['KeyMetadata']['KeyId'], - KeySpec='AES_128' - ) - resp3 = client.generate_data_key( - KeyId=key['KeyMetadata']['KeyId'], - NumberOfBytes=64 - ) + resp1 = client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"], KeySpec="AES_256") + resp2 = client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"], KeySpec="AES_128") + resp3 = client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"], NumberOfBytes=64) - assert len(resp1['Plaintext']) == 32 - assert len(resp2['Plaintext']) == 16 - assert len(resp3['Plaintext']) == 64 + assert len(resp1["Plaintext"]) == 32 + assert len(resp2["Plaintext"]) == 16 + assert len(resp3["Plaintext"]) == 64 @mock_kms def test_generate_data_key_decrypt(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='generate-data-key-decrypt') + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="generate-data-key-decrypt") - resp1 = client.generate_data_key( - KeyId=key['KeyMetadata']['KeyId'], - KeySpec='AES_256' - ) - resp2 = client.decrypt( - CiphertextBlob=resp1['CiphertextBlob'] - ) + resp1 = client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"], KeySpec="AES_256") + resp2 = client.decrypt(CiphertextBlob=resp1["CiphertextBlob"]) - assert resp1['Plaintext'] == resp2['Plaintext'] + assert resp1["Plaintext"] == resp2["Plaintext"] @mock_kms def test_generate_data_key_invalid_size_params(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='generate-data-key-size') + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="generate-data-key-size") with assert_raises(botocore.exceptions.ClientError) as err: - client.generate_data_key( - KeyId=key['KeyMetadata']['KeyId'], - KeySpec='AES_257' - ) + client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"], KeySpec="AES_257") with assert_raises(botocore.exceptions.ClientError) as err: - client.generate_data_key( - KeyId=key['KeyMetadata']['KeyId'], - KeySpec='AES_128', - NumberOfBytes=16 - ) + client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"], KeySpec="AES_128", NumberOfBytes=16) with assert_raises(botocore.exceptions.ClientError) as err: - client.generate_data_key( - KeyId=key['KeyMetadata']['KeyId'], - NumberOfBytes=2048 - ) + client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"], NumberOfBytes=2048) with assert_raises(botocore.exceptions.ClientError) as err: - client.generate_data_key( - KeyId=key['KeyMetadata']['KeyId'] - ) + client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"]) @mock_kms def test_generate_data_key_invalid_key(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='generate-data-key-size') + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="generate-data-key-size") with assert_raises(client.exceptions.NotFoundException): - client.generate_data_key( - KeyId='alias/randomnonexistantkey', - KeySpec='AES_256' - ) + client.generate_data_key(KeyId="alias/randomnonexistantkey", KeySpec="AES_256") with assert_raises(client.exceptions.NotFoundException): - client.generate_data_key( - KeyId=key['KeyMetadata']['KeyId'] + '4', - KeySpec='AES_256' - ) + client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"] + "4", KeySpec="AES_256") @mock_kms def test_generate_data_key_without_plaintext_decrypt(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='generate-data-key-decrypt') + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="generate-data-key-decrypt") - resp1 = client.generate_data_key_without_plaintext( - KeyId=key['KeyMetadata']['KeyId'], - KeySpec='AES_256' - ) + resp1 = client.generate_data_key_without_plaintext(KeyId=key["KeyMetadata"]["KeyId"], KeySpec="AES_256") - assert 'Plaintext' not in resp1 + assert "Plaintext" not in resp1 @mock_kms def test_enable_key_rotation_key_not_found(): - client = boto3.client('kms', region_name='us-east-1') + client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): - client.enable_key_rotation( - KeyId='12366f9b-1230-123d-123e-123e6ae60c02' - ) + client.enable_key_rotation(KeyId="12366f9b-1230-123d-123e-123e6ae60c02") @mock_kms def test_disable_key_rotation_key_not_found(): - client = boto3.client('kms', region_name='us-east-1') + client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): - client.disable_key_rotation( - KeyId='12366f9b-1230-123d-123e-123e6ae60c02' - ) + client.disable_key_rotation(KeyId="12366f9b-1230-123d-123e-123e6ae60c02") @mock_kms def test_enable_key_key_not_found(): - client = boto3.client('kms', region_name='us-east-1') + client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): - client.enable_key( - KeyId='12366f9b-1230-123d-123e-123e6ae60c02' - ) + client.enable_key(KeyId="12366f9b-1230-123d-123e-123e6ae60c02") @mock_kms def test_disable_key_key_not_found(): - client = boto3.client('kms', region_name='us-east-1') + client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): - client.disable_key( - KeyId='12366f9b-1230-123d-123e-123e6ae60c02' - ) + client.disable_key(KeyId="12366f9b-1230-123d-123e-123e6ae60c02") @mock_kms def test_cancel_key_deletion_key_not_found(): - client = boto3.client('kms', region_name='us-east-1') + client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): - client.cancel_key_deletion( - KeyId='12366f9b-1230-123d-123e-123e6ae60c02' - ) + client.cancel_key_deletion(KeyId="12366f9b-1230-123d-123e-123e6ae60c02") @mock_kms def test_schedule_key_deletion_key_not_found(): - client = boto3.client('kms', region_name='us-east-1') + client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): - client.schedule_key_deletion( - KeyId='12366f9b-1230-123d-123e-123e6ae60c02' - ) + client.schedule_key_deletion(KeyId="12366f9b-1230-123d-123e-123e6ae60c02") @mock_kms def test_get_key_rotation_status_key_not_found(): - client = boto3.client('kms', region_name='us-east-1') + client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): - client.get_key_rotation_status( - KeyId='12366f9b-1230-123d-123e-123e6ae60c02' - ) + client.get_key_rotation_status(KeyId="12366f9b-1230-123d-123e-123e6ae60c02") @mock_kms def test_get_key_policy_key_not_found(): - client = boto3.client('kms', region_name='us-east-1') + client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): - client.get_key_policy( - KeyId='12366f9b-1230-123d-123e-123e6ae60c02', - PolicyName='default' - ) + client.get_key_policy(KeyId="12366f9b-1230-123d-123e-123e6ae60c02", PolicyName="default") @mock_kms def test_list_key_policies_key_not_found(): - client = boto3.client('kms', region_name='us-east-1') + client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): - client.list_key_policies( - KeyId='12366f9b-1230-123d-123e-123e6ae60c02' - ) + client.list_key_policies(KeyId="12366f9b-1230-123d-123e-123e6ae60c02") @mock_kms def test_put_key_policy_key_not_found(): - client = boto3.client('kms', region_name='us-east-1') + client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): - client.put_key_policy( - KeyId='00000000-0000-0000-0000-000000000000', - PolicyName='default', - Policy='new policy' - ) + client.put_key_policy(KeyId="00000000-0000-0000-0000-000000000000", PolicyName="default", Policy="new policy") diff --git a/tests/test_kms/test_utils.py b/tests/test_kms/test_utils.py index 466c72ea9..73d7d3580 100644 --- a/tests/test_kms/test_utils.py +++ b/tests/test_kms/test_utils.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals +import sure # noqa from nose.tools import assert_raises from parameterized import parameterized @@ -104,26 +105,23 @@ def test_encrypt_decrypt_cycle(encryption_context): def test_encrypt_unknown_key_id(): - assert_raises( - NotFoundException, encrypt, master_keys={}, key_id="anything", plaintext=b"secrets", encryption_context={} - ) + with assert_raises(NotFoundException): + encrypt(master_keys={}, key_id="anything", plaintext=b"secrets", encryption_context={}) def test_decrypt_invalid_ciphertext_format(): master_key = Key("nop", "nop", "nop", [], "nop") master_key_map = {master_key.id: master_key} - assert_raises( - InvalidCiphertextException, decrypt, master_keys=master_key_map, ciphertext_blob=b"", encryption_context={} - ) + with assert_raises(InvalidCiphertextException): + decrypt(master_keys=master_key_map, ciphertext_blob=b"", encryption_context={}) def test_decrypt_unknwown_key_id(): ciphertext_blob = b"d25652e4-d2d2-49f7-929a-671ccda580c6" b"123456789012" b"1234567890123456" b"some ciphertext" - assert_raises( - AccessDeniedException, decrypt, master_keys={}, ciphertext_blob=ciphertext_blob, encryption_context={} - ) + with assert_raises(AccessDeniedException): + decrypt(master_keys={}, ciphertext_blob=ciphertext_blob, encryption_context={}) def test_decrypt_invalid_ciphertext(): @@ -131,13 +129,12 @@ def test_decrypt_invalid_ciphertext(): master_key_map = {master_key.id: master_key} ciphertext_blob = master_key.id.encode("utf-8") + b"123456789012" b"1234567890123456" b"some ciphertext" - assert_raises( - InvalidCiphertextException, - decrypt, - master_keys=master_key_map, - ciphertext_blob=ciphertext_blob, - encryption_context={}, - ) + with assert_raises(InvalidCiphertextException): + decrypt( + master_keys=master_key_map, + ciphertext_blob=ciphertext_blob, + encryption_context={}, + ) def test_decrypt_invalid_encryption_context(): @@ -152,10 +149,9 @@ def test_decrypt_invalid_encryption_context(): encryption_context={"some": "encryption", "context": "here"}, ) - assert_raises( - InvalidCiphertextException, - decrypt, - master_keys=master_key_map, - ciphertext_blob=ciphertext_blob, - encryption_context={}, - ) + with assert_raises(InvalidCiphertextException): + decrypt( + master_keys=master_key_map, + ciphertext_blob=ciphertext_blob, + encryption_context={}, + ) From d5ac5453b30691b1cbc19d7f3e35d439d5eabc2e Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Tue, 27 Aug 2019 13:57:50 -0700 Subject: [PATCH 18/67] streamline KMS tests --- tests/test_kms/test_kms.py | 34 ++++++++++++++++------------------ 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/tests/test_kms/test_kms.py b/tests/test_kms/test_kms.py index 4e1f39540..dfb558dde 100644 --- a/tests/test_kms/test_kms.py +++ b/tests/test_kms/test_kms.py @@ -777,18 +777,19 @@ def test_list_resource_tags(): assert response["Tags"][0]["TagValue"] == "string" +@parameterized(( + (dict(KeySpec="AES_256"), 32), + (dict(KeySpec="AES_128"), 16), + (dict(NumberOfBytes=64), 64), +)) @mock_kms -def test_generate_data_key_sizes(): +def test_generate_data_key_sizes(kwargs, expected_key_length): client = boto3.client("kms", region_name="us-east-1") key = client.create_key(Description="generate-data-key-size") - resp1 = client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"], KeySpec="AES_256") - resp2 = client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"], KeySpec="AES_128") - resp3 = client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"], NumberOfBytes=64) + response = client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"], **kwargs) - assert len(resp1["Plaintext"]) == 32 - assert len(resp2["Plaintext"]) == 16 - assert len(resp3["Plaintext"]) == 64 + assert len(response["Plaintext"]) == expected_key_length @mock_kms @@ -802,22 +803,19 @@ def test_generate_data_key_decrypt(): assert resp1["Plaintext"] == resp2["Plaintext"] +@parameterized(( + (dict(KeySpec="AES_257"),), + (dict(KeySpec="AES_128", NumberOfBytes=16),), + (dict(NumberOfBytes=2048),), + (dict(),), +)) @mock_kms -def test_generate_data_key_invalid_size_params(): +def test_generate_data_key_invalid_size_params(kwargs): client = boto3.client("kms", region_name="us-east-1") key = client.create_key(Description="generate-data-key-size") with assert_raises(botocore.exceptions.ClientError) as err: - client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"], KeySpec="AES_257") - - with assert_raises(botocore.exceptions.ClientError) as err: - client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"], KeySpec="AES_128", NumberOfBytes=16) - - with assert_raises(botocore.exceptions.ClientError) as err: - client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"], NumberOfBytes=2048) - - with assert_raises(botocore.exceptions.ClientError) as err: - client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"]) + client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"], **kwargs) @mock_kms From 4d2b12f40da355e55fba91df14a9d11d6570c27f Mon Sep 17 00:00:00 2001 From: Daniel Guerrero Date: Tue, 27 Aug 2019 19:59:43 -0500 Subject: [PATCH 19/67] Adding six.string_types checking --- moto/dynamodbstreams/responses.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/moto/dynamodbstreams/responses.py b/moto/dynamodbstreams/responses.py index c570483c5..7774f3239 100644 --- a/moto/dynamodbstreams/responses.py +++ b/moto/dynamodbstreams/responses.py @@ -3,6 +3,7 @@ from __future__ import unicode_literals from moto.core.responses import BaseResponse from .models import dynamodbstreams_backends +from six import string_types class DynamoDBStreamsHandler(BaseResponse): @@ -25,7 +26,7 @@ class DynamoDBStreamsHandler(BaseResponse): shard_iterator_type = self._get_param('ShardIteratorType') sequence_number = self._get_param('SequenceNumber') # according to documentation sequence_number param should be string - if isinstance(sequence_number, "".__class__): + if isinstance(sequence_number, string_types): sequence_number = int(sequence_number) return self.backend.get_shard_iterator(arn, shard_id, From 9ffb9d3d0a9032ad4a9cd507637e18885dfe1b2e Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Tue, 27 Aug 2019 20:24:57 -0700 Subject: [PATCH 20/67] add kms:ReEncrypt and tests --- moto/kms/models.py | 11 +++++++++ moto/kms/responses.py | 19 +++++++++++++++ tests/test_kms/test_kms.py | 49 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 79 insertions(+) diff --git a/moto/kms/models.py b/moto/kms/models.py index 9fb28bb5f..e5dc1cd76 100644 --- a/moto/kms/models.py +++ b/moto/kms/models.py @@ -214,6 +214,17 @@ class KmsBackend(BaseBackend): arn = self.keys[key_id].arn return plaintext, arn + def re_encrypt( + self, ciphertext_blob, source_encryption_context, destination_key_id, destination_encryption_context + ): + plaintext, decrypting_arn = self.decrypt( + ciphertext_blob=ciphertext_blob, encryption_context=source_encryption_context + ) + new_ciphertext_blob, encrypting_arn = self.encrypt( + key_id=destination_key_id, plaintext=plaintext, encryption_context=destination_encryption_context + ) + return new_ciphertext_blob, decrypting_arn, encrypting_arn + def generate_data_key(self, key_id, encryption_context, number_of_bytes, key_spec, grant_tokens): key_id = self.any_id_to_key_id(key_id) diff --git a/moto/kms/responses.py b/moto/kms/responses.py index 0b8684019..aa500ca5c 100644 --- a/moto/kms/responses.py +++ b/moto/kms/responses.py @@ -260,6 +260,25 @@ class KmsResponse(BaseResponse): return json.dumps({"Plaintext": plaintext_response, 'KeyId': arn}) + def re_encrypt(self): + ciphertext_blob = self.parameters.get("CiphertextBlob") + source_encryption_context = self.parameters.get("SourceEncryptionContext", {}) + destination_key_id = self.parameters.get("DestinationKeyId") + destination_encryption_context = self.parameters.get("DestinationEncryptionContext", {}) + + new_ciphertext_blob, decrypting_arn, encrypting_arn = self.kms_backend.re_encrypt( + ciphertext_blob=ciphertext_blob, + source_encryption_context=source_encryption_context, + destination_key_id=destination_key_id, + destination_encryption_context=destination_encryption_context, + ) + + response_ciphertext_blob = base64.b64encode(new_ciphertext_blob).decode("utf-8") + + return json.dumps( + {"CiphertextBlob": response_ciphertext_blob, "KeyId": encrypting_arn, "SourceKeyId": decrypting_arn} + ) + def disable_key(self): key_id = self.parameters.get('KeyId') _assert_valid_key_id(self.kms_backend.get_key_id(key_id)) diff --git a/tests/test_kms/test_kms.py b/tests/test_kms/test_kms.py index dfb558dde..c132608c9 100644 --- a/tests/test_kms/test_kms.py +++ b/tests/test_kms/test_kms.py @@ -840,6 +840,55 @@ def test_generate_data_key_without_plaintext_decrypt(): assert "Plaintext" not in resp1 +@parameterized(PLAINTEXT_VECTORS) +@mock_kms +def test_re_encrypt_decrypt(plaintext): + client = boto3.client("kms", region_name="us-west-2") + + key_1 = client.create_key(Description="key 1") + key_1_id = key_1["KeyMetadata"]["KeyId"] + key_1_arn = key_1["KeyMetadata"]["Arn"] + key_2 = client.create_key(Description="key 2") + key_2_id = key_2["KeyMetadata"]["KeyId"] + key_2_arn = key_2["KeyMetadata"]["Arn"] + + encrypt_response = client.encrypt( + KeyId=key_1_id, + Plaintext=plaintext, + EncryptionContext={"encryption": "context"}, + ) + + re_encrypt_response = client.re_encrypt( + CiphertextBlob=encrypt_response["CiphertextBlob"], + SourceEncryptionContext={"encryption": "context"}, + DestinationKeyId=key_2_id, + DestinationEncryptionContext={"another": "context"}, + ) + + # CiphertextBlob must NOT be base64-encoded + with assert_raises(Exception): + base64.b64decode(re_encrypt_response["CiphertextBlob"], validate=True) + + re_encrypt_response["SourceKeyId"].should.equal(key_1_arn) + re_encrypt_response["KeyId"].should.equal(key_2_arn) + + decrypt_response_1 = client.decrypt( + CiphertextBlob=encrypt_response["CiphertextBlob"], + EncryptionContext={"encryption": "context"}, + ) + decrypt_response_1["Plaintext"].should.equal(plaintext) + decrypt_response_1["KeyId"].should.equal(key_1_arn) + + decrypt_response_2 = client.decrypt( + CiphertextBlob=re_encrypt_response["CiphertextBlob"], + EncryptionContext={"another": "context"}, + ) + decrypt_response_2["Plaintext"].should.equal(plaintext) + decrypt_response_2["KeyId"].should.equal(key_2_arn) + + decrypt_response_1["Plaintext"].should.equal(decrypt_response_2["Plaintext"]) + + @mock_kms def test_enable_key_rotation_key_not_found(): client = boto3.client("kms", region_name="us-east-1") From dd63cebf8174a48da57543fc43af7eb708846eeb Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Tue, 27 Aug 2019 20:49:47 -0700 Subject: [PATCH 21/67] add kms:ReEncrypt invalid destination key test --- moto/kms/models.py | 2 ++ tests/test_kms/test_kms.py | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/moto/kms/models.py b/moto/kms/models.py index e5dc1cd76..5f89407f5 100644 --- a/moto/kms/models.py +++ b/moto/kms/models.py @@ -217,6 +217,8 @@ class KmsBackend(BaseBackend): def re_encrypt( self, ciphertext_blob, source_encryption_context, destination_key_id, destination_encryption_context ): + destination_key_id = self.any_id_to_key_id(destination_key_id) + plaintext, decrypting_arn = self.decrypt( ciphertext_blob=ciphertext_blob, encryption_context=source_encryption_context ) diff --git a/tests/test_kms/test_kms.py b/tests/test_kms/test_kms.py index c132608c9..1c5aa39ea 100644 --- a/tests/test_kms/test_kms.py +++ b/tests/test_kms/test_kms.py @@ -889,6 +889,25 @@ def test_re_encrypt_decrypt(plaintext): decrypt_response_1["Plaintext"].should.equal(decrypt_response_2["Plaintext"]) +@mock_kms +def test_re_encrypt_to_invalid_destination(): + client = boto3.client("kms", region_name="us-west-2") + + key = client.create_key(Description="key 1") + key_id = key["KeyMetadata"]["KeyId"] + + encrypt_response = client.encrypt( + KeyId=key_id, + Plaintext=b"some plaintext", + ) + + with assert_raises(client.exceptions.NotFoundException): + client.re_encrypt( + CiphertextBlob=encrypt_response["CiphertextBlob"], + DestinationKeyId="8327948729348", + ) + + @mock_kms def test_enable_key_rotation_key_not_found(): client = boto3.client("kms", region_name="us-east-1") From f7043e1eaf2eb7bd23b440cf74ce862dcf4a75f7 Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Tue, 27 Aug 2019 20:55:58 -0700 Subject: [PATCH 22/67] add kms:GenerateRandom and tests --- moto/kms/responses.py | 11 +++++++++++ tests/test_kms/test_kms.py | 15 +++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/moto/kms/responses.py b/moto/kms/responses.py index aa500ca5c..6a67614a0 100644 --- a/moto/kms/responses.py +++ b/moto/kms/responses.py @@ -2,7 +2,9 @@ from __future__ import unicode_literals import base64 import json +import os import re + import six from moto.core.responses import BaseResponse @@ -384,6 +386,15 @@ class KmsResponse(BaseResponse): return json.dumps(result) + def generate_random(self): + number_of_bytes = self.parameters.get("NumberOfBytes") + + entropy = os.urandom(number_of_bytes) + + response_entropy = base64.b64encode(entropy).decode("utf-8") + + return json.dumps({"Plaintext": response_entropy}) + def _assert_valid_key_id(key_id): if not re.match(r'^[A-F0-9]{8}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{12}$', key_id, re.IGNORECASE): diff --git a/tests/test_kms/test_kms.py b/tests/test_kms/test_kms.py index 1c5aa39ea..4daeaa7cf 100644 --- a/tests/test_kms/test_kms.py +++ b/tests/test_kms/test_kms.py @@ -908,6 +908,21 @@ def test_re_encrypt_to_invalid_destination(): ) +@parameterized(((12,), (44,), (91,))) +@mock_kms +def test_generate_random(number_of_bytes): + client = boto3.client("kms", region_name="us-west-2") + + response = client.generate_random(NumberOfBytes=number_of_bytes) + + # Plaintext must NOT be base64-encoded + with assert_raises(Exception): + base64.b64decode(response["Plaintext"], validate=True) + + response["Plaintext"].should.be.a(bytes) + len(response["Plaintext"]).should.equal(number_of_bytes) + + @mock_kms def test_enable_key_rotation_key_not_found(): client = boto3.client("kms", region_name="us-east-1") From 776a54b89f5fdfcd79ed718971cfd52a4ecec203 Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Tue, 27 Aug 2019 23:59:46 -0700 Subject: [PATCH 23/67] update KMS implementation coverage --- IMPLEMENTATION_COVERAGE.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/IMPLEMENTATION_COVERAGE.md b/IMPLEMENTATION_COVERAGE.md index d149b0dd8..9fbfe455e 100644 --- a/IMPLEMENTATION_COVERAGE.md +++ b/IMPLEMENTATION_COVERAGE.md @@ -3801,14 +3801,14 @@ - [ ] update_stream ## kms -41% implemented +54% implemented - [X] cancel_key_deletion - [ ] connect_custom_key_store - [ ] create_alias - [ ] create_custom_key_store - [ ] create_grant - [X] create_key -- [ ] decrypt +- [X] decrypt - [X] delete_alias - [ ] delete_custom_key_store - [ ] delete_imported_key_material @@ -3819,10 +3819,10 @@ - [ ] disconnect_custom_key_store - [X] enable_key - [X] enable_key_rotation -- [ ] encrypt +- [X] encrypt - [X] generate_data_key -- [ ] generate_data_key_without_plaintext -- [ ] generate_random +- [X] generate_data_key_without_plaintext +- [X] generate_random - [X] get_key_policy - [X] get_key_rotation_status - [ ] get_parameters_for_import @@ -3834,7 +3834,7 @@ - [X] list_resource_tags - [ ] list_retirable_grants - [X] put_key_policy -- [ ] re_encrypt +- [X] re_encrypt - [ ] retire_grant - [ ] revoke_grant - [X] schedule_key_deletion From 819d354af3bc51c6d522e085c79f27129d0712a1 Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Wed, 28 Aug 2019 00:48:53 -0700 Subject: [PATCH 24/67] fix linting issues --- moto/kms/responses.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/moto/kms/responses.py b/moto/kms/responses.py index 6a67614a0..fecb391d3 100644 --- a/moto/kms/responses.py +++ b/moto/kms/responses.py @@ -10,7 +10,6 @@ import six from moto.core.responses import BaseResponse from .models import kms_backends from .exceptions import NotFoundException, ValidationException, AlreadyExistsException, NotAuthorizedException -from .utils import decrypt, encrypt reserved_aliases = [ 'alias/aws/ebs', @@ -345,19 +344,17 @@ class KmsResponse(BaseResponse): if number_of_bytes and (number_of_bytes > 1024 or number_of_bytes < 0): raise ValidationException(( - "1 validation error detected: Value '{number_of_bytes:d}' at 'numberOfBytes' failed " - "to satisfy constraint: Member must have value less than or " - "equal to 1024" - ).format(number_of_bytes=number_of_bytes) - ) + "1 validation error detected: Value '{number_of_bytes:d}' at 'numberOfBytes' failed " + "to satisfy constraint: Member must have value less than or " + "equal to 1024" + ).format(number_of_bytes=number_of_bytes)) if key_spec and key_spec not in ('AES_256', 'AES_128'): raise ValidationException(( - "1 validation error detected: Value '{key_spec}' at 'keySpec' failed " - "to satisfy constraint: Member must satisfy enum value set: " - "[AES_256, AES_128]" - ).format(key_spec=key_spec) - ) + "1 validation error detected: Value '{key_spec}' at 'keySpec' failed " + "to satisfy constraint: Member must satisfy enum value set: " + "[AES_256, AES_128]" + ).format(key_spec=key_spec)) if not key_spec and not number_of_bytes: raise ValidationException("Please specify either number of bytes or key spec.") if key_spec and number_of_bytes: From e0304bc5000620262746dcdc15422b51740cd7f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Sznuk?= Date: Wed, 28 Aug 2019 16:17:45 +0200 Subject: [PATCH 25/67] Allows leading // for mocked s3 paths (#1637). --- moto/s3/urls.py | 2 ++ moto/server.py | 5 +++-- tests/test_s3/test_s3.py | 27 +++++++++++++++++++++++++++ 3 files changed, 32 insertions(+), 2 deletions(-) diff --git a/moto/s3/urls.py b/moto/s3/urls.py index fa81568a4..1388c81e5 100644 --- a/moto/s3/urls.py +++ b/moto/s3/urls.py @@ -15,4 +15,6 @@ url_paths = { '{0}/(?P[^/]+)/?$': S3ResponseInstance.ambiguous_response, # path-based bucket + key '{0}/(?P[^/]+)/(?P.+)': S3ResponseInstance.key_response, + # subdomain bucket + key with empty first part of path + '{0}//(?P.*)$': S3ResponseInstance.key_response, } diff --git a/moto/server.py b/moto/server.py index 89be47093..b245f6e6f 100644 --- a/moto/server.py +++ b/moto/server.py @@ -174,10 +174,11 @@ def create_backend_app(service): backend_app.url_map.converters['regex'] = RegexConverter backend = list(BACKENDS[service].values())[0] for url_path, handler in backend.flask_paths.items(): + view_func = convert_flask_to_httpretty_response(handler) if handler.__name__ == 'dispatch': endpoint = '{0}.dispatch'.format(handler.__self__.__name__) else: - endpoint = None + endpoint = view_func.__name__ original_endpoint = endpoint index = 2 @@ -191,7 +192,7 @@ def create_backend_app(service): url_path, endpoint=endpoint, methods=HTTP_METHODS, - view_func=convert_flask_to_httpretty_response(handler), + view_func=view_func, strict_slashes=False, ) diff --git a/tests/test_s3/test_s3.py b/tests/test_s3/test_s3.py index cd57fc92b..0c0721f01 100644 --- a/tests/test_s3/test_s3.py +++ b/tests/test_s3/test_s3.py @@ -2,6 +2,7 @@ from __future__ import unicode_literals import datetime +import os from six.moves.urllib.request import urlopen from six.moves.urllib.error import HTTPError from functools import wraps @@ -23,6 +24,7 @@ from freezegun import freeze_time import six import requests import tests.backport_assert_raises # noqa +from nose import SkipTest from nose.tools import assert_raises import sure # noqa @@ -2991,3 +2993,28 @@ def test_accelerate_configuration_is_not_supported_when_bucket_name_has_dots(): AccelerateConfiguration={'Status': 'Enabled'}, ) exc.exception.response['Error']['Code'].should.equal('InvalidRequest') + +def store_and_read_back_a_key(key): + s3 = boto3.client('s3', region_name='us-east-1') + bucket_name = 'mybucket' + body = b'Some body' + + s3.create_bucket(Bucket=bucket_name) + s3.put_object( + Bucket=bucket_name, + Key=key, + Body=body + ) + + response = s3.get_object(Bucket=bucket_name, Key=key) + response['Body'].read().should.equal(body) + +@mock_s3 +def test_paths_with_leading_slashes_work(): + store_and_read_back_a_key('/a-key') + +@mock_s3 +def test_root_dir_with_empty_name_works(): + if os.environ.get('TEST_SERVER_MODE', 'false').lower() == 'true': + raise SkipTest('Does not work in server mode due to error in Workzeug') + store_and_read_back_a_key('/') From 35507f33dfe161456f241f3096d9bf8278b5eb38 Mon Sep 17 00:00:00 2001 From: Don Kuntz Date: Wed, 28 Aug 2019 13:55:19 -0500 Subject: [PATCH 26/67] Don't error out on route53.list_tags_for_resource when resource has no tags Without the added `return {}`, calling route53.list_tags_for_resource when called with a ResourceId of a resource without any tags would result in the error: jinja2.exceptions.UndefinedError: 'None' has no attribute 'items' Because the LIST_TAGS_FOR_RESOURCE_RESPONSE was given None instead of empty dict. This now allows list_tags_for_resource to be called without issue on tag-less resources. --- moto/route53/models.py | 1 + tests/test_route53/test_route53.py | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/moto/route53/models.py b/moto/route53/models.py index 61a6609aa..77a0e59e6 100644 --- a/moto/route53/models.py +++ b/moto/route53/models.py @@ -305,6 +305,7 @@ class Route53Backend(BaseBackend): def list_tags_for_resource(self, resource_id): if resource_id in self.resource_tags: return self.resource_tags[resource_id] + return {} def get_all_hosted_zones(self): return self.zones.values() diff --git a/tests/test_route53/test_route53.py b/tests/test_route53/test_route53.py index de9465d6d..e70137156 100644 --- a/tests/test_route53/test_route53.py +++ b/tests/test_route53/test_route53.py @@ -404,6 +404,13 @@ def test_list_or_change_tags_for_resource_request(): ) healthcheck_id = health_check['HealthCheck']['Id'] + # confirm this works for resources with zero tags + response = conn.list_tags_for_resource( + ResourceType="healthcheck", ResourceId=healthcheck_id) + response["ResourceTagSet"]["Tags"].should.be.empty + + print(response) + tag1 = {"Key": "Deploy", "Value": "True"} tag2 = {"Key": "Name", "Value": "UnitTest"} From cae0b5bc45de692580a07fdee16c9a83628f52d1 Mon Sep 17 00:00:00 2001 From: Don Kuntz Date: Wed, 28 Aug 2019 13:59:49 -0500 Subject: [PATCH 27/67] Remove extraneous print statement from test --- tests/test_route53/test_route53.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_route53/test_route53.py b/tests/test_route53/test_route53.py index e70137156..babd54d26 100644 --- a/tests/test_route53/test_route53.py +++ b/tests/test_route53/test_route53.py @@ -409,8 +409,6 @@ def test_list_or_change_tags_for_resource_request(): ResourceType="healthcheck", ResourceId=healthcheck_id) response["ResourceTagSet"]["Tags"].should.be.empty - print(response) - tag1 = {"Key": "Deploy", "Value": "True"} tag2 = {"Key": "Name", "Value": "UnitTest"} From 675db17ace22d6554292d51afbed1c0dac445de7 Mon Sep 17 00:00:00 2001 From: acsbendi Date: Fri, 30 Aug 2019 18:21:11 +0200 Subject: [PATCH 28/67] Implemented deregistering terminated instances from ELB target groups. --- moto/ec2/responses/instances.py | 2 + moto/elbv2/models.py | 9 ++++ tests/test_elbv2/test_elbv2.py | 77 +++++++++++++++++++++++++++++++++ 3 files changed, 88 insertions(+) diff --git a/moto/ec2/responses/instances.py b/moto/ec2/responses/instances.py index 82c2b1997..28123b995 100644 --- a/moto/ec2/responses/instances.py +++ b/moto/ec2/responses/instances.py @@ -6,6 +6,7 @@ from moto.core.responses import BaseResponse from moto.core.utils import camelcase_to_underscores from moto.ec2.utils import filters_from_querystring, \ dict_from_querystring +from moto.elbv2 import elbv2_backends class InstanceResponse(BaseResponse): @@ -68,6 +69,7 @@ class InstanceResponse(BaseResponse): if self.is_not_dryrun('TerminateInstance'): instances = self.ec2_backend.terminate_instances(instance_ids) autoscaling_backends[self.region].notify_terminate_instances(instance_ids) + elbv2_backends[self.region].notify_terminate_instances(instance_ids) template = self.response_template(EC2_TERMINATE_INSTANCES) return template.render(instances=instances) diff --git a/moto/elbv2/models.py b/moto/elbv2/models.py index 7e73c7042..726799fe5 100644 --- a/moto/elbv2/models.py +++ b/moto/elbv2/models.py @@ -110,6 +110,11 @@ class FakeTargetGroup(BaseModel): if not t: raise InvalidTargetError() + def deregister_terminated_instances(self, instance_ids): + for target_id in list(self.targets.keys()): + if target_id in instance_ids: + del self.targets[target_id] + def add_tag(self, key, value): if len(self.tags) >= 10 and key not in self.tags: raise TooManyTagsError() @@ -936,6 +941,10 @@ class ELBv2Backend(BaseBackend): return True return False + def notify_terminate_instances(self, instance_ids): + for target_group in self.target_groups.values(): + target_group.deregister_terminated_instances(instance_ids) + elbv2_backends = {} for region in ec2_backends.keys(): diff --git a/tests/test_elbv2/test_elbv2.py b/tests/test_elbv2/test_elbv2.py index 36772c02e..b2512a3f1 100644 --- a/tests/test_elbv2/test_elbv2.py +++ b/tests/test_elbv2/test_elbv2.py @@ -752,6 +752,83 @@ def test_stopped_instance_target(): }) +@mock_ec2 +@mock_elbv2 +def test_terminated_instance_target(): + target_group_port = 8080 + + conn = boto3.client('elbv2', region_name='us-east-1') + ec2 = boto3.resource('ec2', region_name='us-east-1') + + security_group = ec2.create_security_group( + GroupName='a-security-group', Description='First One') + vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + subnet1 = ec2.create_subnet( + VpcId=vpc.id, + CidrBlock='172.28.7.192/26', + AvailabilityZone='us-east-1a') + subnet2 = ec2.create_subnet( + VpcId=vpc.id, + CidrBlock='172.28.7.0/26', + AvailabilityZone='us-east-1b') + + conn.create_load_balancer( + Name='my-lb', + Subnets=[subnet1.id, subnet2.id], + SecurityGroups=[security_group.id], + Scheme='internal', + Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + + response = conn.create_target_group( + Name='a-target', + Protocol='HTTP', + Port=target_group_port, + VpcId=vpc.id, + HealthCheckProtocol='HTTP', + HealthCheckPath='/', + HealthCheckIntervalSeconds=5, + HealthCheckTimeoutSeconds=5, + HealthyThresholdCount=5, + UnhealthyThresholdCount=2, + Matcher={'HttpCode': '200'}) + target_group = response.get('TargetGroups')[0] + + # No targets registered yet + response = conn.describe_target_health( + TargetGroupArn=target_group.get('TargetGroupArn')) + response.get('TargetHealthDescriptions').should.have.length_of(0) + + response = ec2.create_instances( + ImageId='ami-1234abcd', MinCount=1, MaxCount=1) + instance = response[0] + + target_dict = { + 'Id': instance.id, + 'Port': 500 + } + + response = conn.register_targets( + TargetGroupArn=target_group.get('TargetGroupArn'), + Targets=[target_dict]) + + response = conn.describe_target_health( + TargetGroupArn=target_group.get('TargetGroupArn')) + response.get('TargetHealthDescriptions').should.have.length_of(1) + target_health_description = response.get('TargetHealthDescriptions')[0] + + target_health_description['Target'].should.equal(target_dict) + target_health_description['HealthCheckPort'].should.equal(str(target_group_port)) + target_health_description['TargetHealth'].should.equal({ + 'State': 'healthy' + }) + + instance.terminate() + + response = conn.describe_target_health( + TargetGroupArn=target_group.get('TargetGroupArn')) + response.get('TargetHealthDescriptions').should.have.length_of(0) + + @mock_ec2 @mock_elbv2 def test_target_group_attributes(): From 1ae641fab893aa05375b29dd7829477c981a0500 Mon Sep 17 00:00:00 2001 From: Wessel van der Veen Date: Sat, 31 Aug 2019 09:08:12 +0200 Subject: [PATCH 29/67] adds basic implementation for describe-identity-pool --- moto/cognitoidentity/exceptions.py | 15 +++++++++ moto/cognitoidentity/models.py | 51 ++++++++++++++++++++---------- moto/cognitoidentity/responses.py | 5 ++- 3 files changed, 54 insertions(+), 17 deletions(-) create mode 100644 moto/cognitoidentity/exceptions.py diff --git a/moto/cognitoidentity/exceptions.py b/moto/cognitoidentity/exceptions.py new file mode 100644 index 000000000..ec22f3b42 --- /dev/null +++ b/moto/cognitoidentity/exceptions.py @@ -0,0 +1,15 @@ +from __future__ import unicode_literals + +import json + +from werkzeug.exceptions import BadRequest + + +class ResourceNotFoundError(BadRequest): + + def __init__(self, message): + super(ResourceNotFoundError, self).__init__() + self.description = json.dumps({ + "message": message, + '__type': 'ResourceNotFoundException', + }) diff --git a/moto/cognitoidentity/models.py b/moto/cognitoidentity/models.py index c916b7f62..7193b551f 100644 --- a/moto/cognitoidentity/models.py +++ b/moto/cognitoidentity/models.py @@ -8,7 +8,7 @@ import boto.cognito.identity from moto.compat import OrderedDict from moto.core import BaseBackend, BaseModel from moto.core.utils import iso_8601_datetime_with_milliseconds - +from .exceptions import ResourceNotFoundError from .utils import get_random_identity_id @@ -39,17 +39,36 @@ class CognitoIdentityBackend(BaseBackend): self.__dict__ = {} self.__init__(region) - def create_identity_pool(self, identity_pool_name, allow_unauthenticated_identities, - supported_login_providers, developer_provider_name, open_id_connect_provider_arns, - cognito_identity_providers, saml_provider_arns): + def describe_identity_pool(self, identity_pool_id): + identity_pool = self.identity_pools.get(identity_pool_id, None) + if not identity_pool: + raise ResourceNotFoundError(identity_pool) + + response = json.dumps({ + 'AllowUnauthenticatedIdentities': identity_pool.allow_unauthenticated_identities, + 'CognitoIdentityProviders': identity_pool.cognito_identity_providers, + 'DeveloperProviderName': identity_pool.developer_provider_name, + 'IdentityPoolId': identity_pool.identity_pool_id, + 'IdentityPoolName': identity_pool.identity_pool_name, + 'IdentityPoolTags': {}, + 'OpenIdConnectProviderARNs': identity_pool.open_id_connect_provider_arns, + 'SamlProviderARNs': identity_pool.saml_provider_arns, + 'SupportedLoginProviders': identity_pool.supported_login_providers + }) + + return response + + def create_identity_pool(self, identity_pool_name, allow_unauthenticated_identities, + supported_login_providers, developer_provider_name, open_id_connect_provider_arns, + cognito_identity_providers, saml_provider_arns): new_identity = CognitoIdentity(self.region, identity_pool_name, - allow_unauthenticated_identities=allow_unauthenticated_identities, - supported_login_providers=supported_login_providers, - developer_provider_name=developer_provider_name, - open_id_connect_provider_arns=open_id_connect_provider_arns, - cognito_identity_providers=cognito_identity_providers, - saml_provider_arns=saml_provider_arns) + allow_unauthenticated_identities=allow_unauthenticated_identities, + supported_login_providers=supported_login_providers, + developer_provider_name=developer_provider_name, + open_id_connect_provider_arns=open_id_connect_provider_arns, + cognito_identity_providers=cognito_identity_providers, + saml_provider_arns=saml_provider_arns) self.identity_pools[new_identity.identity_pool_id] = new_identity response = json.dumps({ @@ -77,12 +96,12 @@ class CognitoIdentityBackend(BaseBackend): response = json.dumps( { "Credentials": - { - "AccessKeyId": "TESTACCESSKEY12345", - "Expiration": expiration_str, - "SecretKey": "ABCSECRETKEY", - "SessionToken": "ABC12345" - }, + { + "AccessKeyId": "TESTACCESSKEY12345", + "Expiration": expiration_str, + "SecretKey": "ABCSECRETKEY", + "SessionToken": "ABC12345" + }, "IdentityId": identity_id }) return response diff --git a/moto/cognitoidentity/responses.py b/moto/cognitoidentity/responses.py index 33faaa300..709fdb40a 100644 --- a/moto/cognitoidentity/responses.py +++ b/moto/cognitoidentity/responses.py @@ -1,7 +1,6 @@ from __future__ import unicode_literals from moto.core.responses import BaseResponse - from .models import cognitoidentity_backends from .utils import get_random_identity_id @@ -16,6 +15,7 @@ class CognitoIdentityResponse(BaseResponse): open_id_connect_provider_arns = self._get_param('OpenIdConnectProviderARNs') cognito_identity_providers = self._get_param('CognitoIdentityProviders') saml_provider_arns = self._get_param('SamlProviderARNs') + return cognitoidentity_backends[self.region].create_identity_pool( identity_pool_name=identity_pool_name, allow_unauthenticated_identities=allow_unauthenticated_identities, @@ -28,6 +28,9 @@ class CognitoIdentityResponse(BaseResponse): def get_id(self): return cognitoidentity_backends[self.region].get_id() + def describe_identity_pool(self): + return cognitoidentity_backends[self.region].describe_identity_pool(self._get_param('IdentityPoolId')) + def get_credentials_for_identity(self): return cognitoidentity_backends[self.region].get_credentials_for_identity(self._get_param('IdentityId')) From ff27e021bca644e76e86701118814dd14fcd0a7d Mon Sep 17 00:00:00 2001 From: rocky4570 Date: Wed, 21 Aug 2019 23:55:28 +1000 Subject: [PATCH 30/67] add enhanced vpc routing option to redshift moto EnhancedVpcRouting is only available when mock_redshift not mock_redshift_deprecated --- moto/redshift/models.py | 9 ++- moto/redshift/responses.py | 5 ++ tests/test_redshift/test_redshift.py | 90 ++++++++++++++++++++++++++-- 3 files changed, 97 insertions(+), 7 deletions(-) diff --git a/moto/redshift/models.py b/moto/redshift/models.py index c0b783bde..8a2b7e6b6 100644 --- a/moto/redshift/models.py +++ b/moto/redshift/models.py @@ -74,7 +74,7 @@ class Cluster(TaggableResourceMixin, BaseModel): automated_snapshot_retention_period, port, cluster_version, allow_version_upgrade, number_of_nodes, publicly_accessible, encrypted, region_name, tags=None, iam_roles_arn=None, - restored_from_snapshot=False): + enhanced_vpc_routing=None, restored_from_snapshot=False): super(Cluster, self).__init__(region_name, tags) self.redshift_backend = redshift_backend self.cluster_identifier = cluster_identifier @@ -85,6 +85,7 @@ class Cluster(TaggableResourceMixin, BaseModel): self.master_user_password = master_user_password self.db_name = db_name if db_name else "dev" self.vpc_security_group_ids = vpc_security_group_ids + self.enhanced_vpc_routing = enhanced_vpc_routing if enhanced_vpc_routing is not None else False self.cluster_subnet_group_name = cluster_subnet_group_name self.publicly_accessible = publicly_accessible self.encrypted = encrypted @@ -154,6 +155,7 @@ class Cluster(TaggableResourceMixin, BaseModel): port=properties.get('Port'), cluster_version=properties.get('ClusterVersion'), allow_version_upgrade=properties.get('AllowVersionUpgrade'), + enhanced_vpc_routing=properties.get('EnhancedVpcRouting'), number_of_nodes=properties.get('NumberOfNodes'), publicly_accessible=properties.get("PubliclyAccessible"), encrypted=properties.get("Encrypted"), @@ -241,6 +243,7 @@ class Cluster(TaggableResourceMixin, BaseModel): 'ClusterCreateTime': self.create_time, "PendingModifiedValues": [], "Tags": self.tags, + "EnhancedVpcRouting": self.enhanced_vpc_routing, "IamRoles": [{ "ApplyStatus": "in-sync", "IamRoleArn": iam_role_arn @@ -427,6 +430,7 @@ class Snapshot(TaggableResourceMixin, BaseModel): 'NumberOfNodes': self.cluster.number_of_nodes, 'DBName': self.cluster.db_name, 'Tags': self.tags, + 'EnhancedVpcRouting': self.cluster.enhanced_vpc_routing, "IamRoles": [{ "ApplyStatus": "in-sync", "IamRoleArn": iam_role_arn @@ -678,7 +682,8 @@ class RedshiftBackend(BaseBackend): "number_of_nodes": snapshot.cluster.number_of_nodes, "encrypted": snapshot.cluster.encrypted, "tags": snapshot.cluster.tags, - "restored_from_snapshot": True + "restored_from_snapshot": True, + "enhanced_vpc_routing": snapshot.cluster.enhanced_vpc_routing } create_kwargs.update(kwargs) return self.create_cluster(**create_kwargs) diff --git a/moto/redshift/responses.py b/moto/redshift/responses.py index a7758febb..7ac73d470 100644 --- a/moto/redshift/responses.py +++ b/moto/redshift/responses.py @@ -135,6 +135,7 @@ class RedshiftResponse(BaseResponse): "region_name": self.region, "tags": self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value')), "iam_roles_arn": self._get_iam_roles(), + "enhanced_vpc_routing": self._get_param('EnhancedVpcRouting'), } cluster = self.redshift_backend.create_cluster(**cluster_kwargs).to_json() cluster['ClusterStatus'] = 'creating' @@ -150,6 +151,7 @@ class RedshiftResponse(BaseResponse): }) def restore_from_cluster_snapshot(self): + enhanced_vpc_routing = self._get_bool_param('EnhancedVpcRouting') restore_kwargs = { "snapshot_identifier": self._get_param('SnapshotIdentifier'), "cluster_identifier": self._get_param('ClusterIdentifier'), @@ -171,6 +173,8 @@ class RedshiftResponse(BaseResponse): "region_name": self.region, "iam_roles_arn": self._get_iam_roles(), } + if enhanced_vpc_routing is not None: + restore_kwargs['enhanced_vpc_routing'] = enhanced_vpc_routing cluster = self.redshift_backend.restore_from_cluster_snapshot(**restore_kwargs).to_json() cluster['ClusterStatus'] = 'creating' return self.get_response({ @@ -218,6 +222,7 @@ class RedshiftResponse(BaseResponse): "publicly_accessible": self._get_param("PubliclyAccessible"), "encrypted": self._get_param("Encrypted"), "iam_roles_arn": self._get_iam_roles(), + "enhanced_vpc_routing": self._get_param("EnhancedVpcRouting") } cluster_kwargs = {} # We only want parameters that were actually passed in, otherwise diff --git a/tests/test_redshift/test_redshift.py b/tests/test_redshift/test_redshift.py index 2c9b42a1d..79e283e5b 100644 --- a/tests/test_redshift/test_redshift.py +++ b/tests/test_redshift/test_redshift.py @@ -37,6 +37,25 @@ def test_create_cluster_boto3(): create_time = response['Cluster']['ClusterCreateTime'] create_time.should.be.lower_than(datetime.datetime.now(create_time.tzinfo)) create_time.should.be.greater_than(datetime.datetime.now(create_time.tzinfo) - datetime.timedelta(minutes=1)) + response['Cluster']['EnhancedVpcRouting'].should.equal(False) + +@mock_redshift +def test_create_cluster_boto3(): + client = boto3.client('redshift', region_name='us-east-1') + response = client.create_cluster( + DBName='test', + ClusterIdentifier='test', + ClusterType='single-node', + NodeType='ds2.xlarge', + MasterUsername='user', + MasterUserPassword='password', + EnhancedVpcRouting=True + ) + response['Cluster']['NodeType'].should.equal('ds2.xlarge') + create_time = response['Cluster']['ClusterCreateTime'] + create_time.should.be.lower_than(datetime.datetime.now(create_time.tzinfo)) + create_time.should.be.greater_than(datetime.datetime.now(create_time.tzinfo) - datetime.timedelta(minutes=1)) + response['Cluster']['EnhancedVpcRouting'].should.equal(True) @mock_redshift @@ -425,6 +444,58 @@ def test_delete_cluster(): "not-a-cluster").should.throw(ClusterNotFound) +@mock_redshift +def test_modify_cluster_vpc_routing(): + iam_roles_arn = ['arn:aws:iam:::role/my-iam-role', ] + client = boto3.client('redshift', region_name='us-east-1') + cluster_identifier = 'my_cluster' + + client.create_cluster( + ClusterIdentifier=cluster_identifier, + NodeType="single-node", + MasterUsername="username", + MasterUserPassword="password", + IamRoles=iam_roles_arn + ) + + cluster_response = client.describe_clusters(ClusterIdentifier=cluster_identifier) + cluster = cluster_response['Clusters'][0] + cluster['EnhancedVpcRouting'].should.equal(False) + + client.create_cluster_security_group(ClusterSecurityGroupName='security_group', + Description='security_group') + + client.create_cluster_parameter_group(ParameterGroupName='my_parameter_group', + ParameterGroupFamily='redshift-1.0', + Description='my_parameter_group') + + client.modify_cluster( + ClusterIdentifier=cluster_identifier, + ClusterType='multi-node', + NodeType="ds2.8xlarge", + NumberOfNodes=3, + ClusterSecurityGroups=["security_group"], + MasterUserPassword="new_password", + ClusterParameterGroupName="my_parameter_group", + AutomatedSnapshotRetentionPeriod=7, + PreferredMaintenanceWindow="Tue:03:00-Tue:11:00", + AllowVersionUpgrade=False, + NewClusterIdentifier=cluster_identifier, + EnhancedVpcRouting=True + ) + + cluster_response = client.describe_clusters(ClusterIdentifier=cluster_identifier) + cluster = cluster_response['Clusters'][0] + cluster['ClusterIdentifier'].should.equal(cluster_identifier) + cluster['NodeType'].should.equal("ds2.8xlarge") + cluster['PreferredMaintenanceWindow'].should.equal("Tue:03:00-Tue:11:00") + cluster['AutomatedSnapshotRetentionPeriod'].should.equal(7) + cluster['AllowVersionUpgrade'].should.equal(False) + # This one should remain unmodified. + cluster['NumberOfNodes'].should.equal(3) + cluster['EnhancedVpcRouting'].should.equal(True) + + @mock_redshift_deprecated def test_modify_cluster(): conn = boto.connect_redshift() @@ -446,6 +517,10 @@ def test_modify_cluster(): master_user_password="password", ) + cluster_response = conn.describe_clusters(cluster_identifier) + cluster = cluster_response['DescribeClustersResponse']['DescribeClustersResult']['Clusters'][0] + cluster['EnhancedVpcRouting'].should.equal(False) + conn.modify_cluster( cluster_identifier, cluster_type="multi-node", @@ -456,14 +531,13 @@ def test_modify_cluster(): automated_snapshot_retention_period=7, preferred_maintenance_window="Tue:03:00-Tue:11:00", allow_version_upgrade=False, - new_cluster_identifier="new_identifier", + new_cluster_identifier=cluster_identifier, ) - cluster_response = conn.describe_clusters("new_identifier") + cluster_response = conn.describe_clusters(cluster_identifier) cluster = cluster_response['DescribeClustersResponse'][ 'DescribeClustersResult']['Clusters'][0] - - cluster['ClusterIdentifier'].should.equal("new_identifier") + cluster['ClusterIdentifier'].should.equal(cluster_identifier) cluster['NodeType'].should.equal("dw.hs1.xlarge") cluster['ClusterSecurityGroups'][0][ 'ClusterSecurityGroupName'].should.equal("security_group") @@ -674,6 +748,7 @@ def test_create_cluster_snapshot(): NodeType='ds2.xlarge', MasterUsername='username', MasterUserPassword='password', + EnhancedVpcRouting=True ) cluster_response['Cluster']['NodeType'].should.equal('ds2.xlarge') @@ -823,11 +898,14 @@ def test_create_cluster_from_snapshot(): NodeType='ds2.xlarge', MasterUsername='username', MasterUserPassword='password', + EnhancedVpcRouting=True, ) + client.create_cluster_snapshot( SnapshotIdentifier=original_snapshot_identifier, ClusterIdentifier=original_cluster_identifier ) + response = client.restore_from_cluster_snapshot( ClusterIdentifier=new_cluster_identifier, SnapshotIdentifier=original_snapshot_identifier, @@ -842,7 +920,7 @@ def test_create_cluster_from_snapshot(): new_cluster['NodeType'].should.equal('ds2.xlarge') new_cluster['MasterUsername'].should.equal('username') new_cluster['Endpoint']['Port'].should.equal(1234) - + new_cluster['EnhancedVpcRouting'].should.equal(True) @mock_redshift def test_create_cluster_from_snapshot_with_waiter(): @@ -857,6 +935,7 @@ def test_create_cluster_from_snapshot_with_waiter(): NodeType='ds2.xlarge', MasterUsername='username', MasterUserPassword='password', + EnhancedVpcRouting=True ) client.create_cluster_snapshot( SnapshotIdentifier=original_snapshot_identifier, @@ -883,6 +962,7 @@ def test_create_cluster_from_snapshot_with_waiter(): new_cluster = response['Clusters'][0] new_cluster['NodeType'].should.equal('ds2.xlarge') new_cluster['MasterUsername'].should.equal('username') + new_cluster['EnhancedVpcRouting'].should.equal(True) new_cluster['Endpoint']['Port'].should.equal(1234) From 9ac20ad5f1ea4d8d4fb02dadf441ada451653b6e Mon Sep 17 00:00:00 2001 From: gruebel Date: Sat, 31 Aug 2019 19:21:06 +0200 Subject: [PATCH 31/67] store SQS RedrivePolicy maxReceiveCount value as int --- moto/sqs/models.py | 3 +++ tests/test_sqs/test_sqs.py | 22 ++++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/moto/sqs/models.py b/moto/sqs/models.py index f2e3ed400..6779bc2b5 100644 --- a/moto/sqs/models.py +++ b/moto/sqs/models.py @@ -263,6 +263,9 @@ class Queue(BaseModel): if 'maxReceiveCount' not in self.redrive_policy: raise RESTError('InvalidParameterValue', 'Redrive policy does not contain maxReceiveCount') + # 'maxReceiveCount' is stored as int + self.redrive_policy['maxReceiveCount'] = int(self.redrive_policy['maxReceiveCount']) + for queue in sqs_backends[self.region].queues.values(): if queue.queue_arn == self.redrive_policy['deadLetterTargetArn']: self.dead_letter_queue = queue diff --git a/tests/test_sqs/test_sqs.py b/tests/test_sqs/test_sqs.py index d53ae50f7..56d82ea61 100644 --- a/tests/test_sqs/test_sqs.py +++ b/tests/test_sqs/test_sqs.py @@ -1117,6 +1117,28 @@ def test_redrive_policy_set_attributes(): assert copy_policy == redrive_policy +@mock_sqs +def test_redrive_policy_set_attributes_with_string_value(): + sqs = boto3.resource('sqs', region_name='us-east-1') + + queue = sqs.create_queue(QueueName='test-queue') + deadletter_queue = sqs.create_queue(QueueName='test-deadletter') + + queue.set_attributes(Attributes={ + 'RedrivePolicy': json.dumps({ + 'deadLetterTargetArn': deadletter_queue.attributes['QueueArn'], + 'maxReceiveCount': '1', + })}) + + copy = sqs.get_queue_by_name(QueueName='test-queue') + assert 'RedrivePolicy' in copy.attributes + copy_policy = json.loads(copy.attributes['RedrivePolicy']) + assert copy_policy == { + 'deadLetterTargetArn': deadletter_queue.attributes['QueueArn'], + 'maxReceiveCount': 1, + } + + @mock_sqs def test_receive_messages_with_message_group_id(): sqs = boto3.resource('sqs', region_name='us-east-1') From 0c11daf62345dec8c7d23b16effc5c2f54864a49 Mon Sep 17 00:00:00 2001 From: Wessel van der Veen Date: Sun, 1 Sep 2019 17:38:33 +0200 Subject: [PATCH 32/67] adds test cases, and fixes formatting. --- moto/cognitoidentity/models.py | 12 ++--- .../test_cognitoidentity.py | 50 +++++++++++++++++-- 2 files changed, 53 insertions(+), 9 deletions(-) diff --git a/moto/cognitoidentity/models.py b/moto/cognitoidentity/models.py index 7193b551f..6f752ab69 100644 --- a/moto/cognitoidentity/models.py +++ b/moto/cognitoidentity/models.py @@ -63,12 +63,12 @@ class CognitoIdentityBackend(BaseBackend): supported_login_providers, developer_provider_name, open_id_connect_provider_arns, cognito_identity_providers, saml_provider_arns): new_identity = CognitoIdentity(self.region, identity_pool_name, - allow_unauthenticated_identities=allow_unauthenticated_identities, - supported_login_providers=supported_login_providers, - developer_provider_name=developer_provider_name, - open_id_connect_provider_arns=open_id_connect_provider_arns, - cognito_identity_providers=cognito_identity_providers, - saml_provider_arns=saml_provider_arns) + allow_unauthenticated_identities=allow_unauthenticated_identities, + supported_login_providers=supported_login_providers, + developer_provider_name=developer_provider_name, + open_id_connect_provider_arns=open_id_connect_provider_arns, + cognito_identity_providers=cognito_identity_providers, + saml_provider_arns=saml_provider_arns) self.identity_pools[new_identity.identity_pool_id] = new_identity response = json.dumps({ diff --git a/tests/test_cognitoidentity/test_cognitoidentity.py b/tests/test_cognitoidentity/test_cognitoidentity.py index ea9ccbc78..67679e896 100644 --- a/tests/test_cognitoidentity/test_cognitoidentity.py +++ b/tests/test_cognitoidentity/test_cognitoidentity.py @@ -1,10 +1,10 @@ from __future__ import unicode_literals import boto3 +from botocore.exceptions import ClientError +from nose.tools import assert_raises from moto import mock_cognitoidentity -import sure # noqa - from moto.cognitoidentity.utils import get_random_identity_id @@ -28,6 +28,47 @@ def test_create_identity_pool(): assert result['IdentityPoolId'] != '' +@mock_cognitoidentity +def test_describe_identity_pool(): + conn = boto3.client('cognito-identity', 'us-west-2') + + res = conn.create_identity_pool(IdentityPoolName='TestPool', + AllowUnauthenticatedIdentities=False, + SupportedLoginProviders={'graph.facebook.com': '123456789012345'}, + DeveloperProviderName='devname', + OpenIdConnectProviderARNs=['arn:aws:rds:eu-west-2:123456789012:db:mysql-db'], + CognitoIdentityProviders=[ + { + 'ProviderName': 'testprovider', + 'ClientId': 'CLIENT12345', + 'ServerSideTokenCheck': True + }, + ], + SamlProviderARNs=['arn:aws:rds:eu-west-2:123456789012:db:mysql-db']) + + result = conn.describe_identity_pool(IdentityPoolId=res['IdentityPoolId']) + + assert result['IdentityPoolId'] == res['IdentityPoolId'] + assert result['AllowUnauthenticatedIdentities'] == res['AllowUnauthenticatedIdentities'] + assert result['SupportedLoginProviders'] == res['SupportedLoginProviders'] + assert result['DeveloperProviderName'] == res['DeveloperProviderName'] + assert result['OpenIdConnectProviderARNs'] == res['OpenIdConnectProviderARNs'] + assert result['CognitoIdentityProviders'] == res['CognitoIdentityProviders'] + assert result['SamlProviderARNs'] == res['SamlProviderARNs'] + + +@mock_cognitoidentity +def test_describe_identity_pool_with_invalid_id_raises_error(): + conn = boto3.client('cognito-identity', 'us-west-2') + + with assert_raises(ClientError) as cm: + conn.describe_identity_pool(IdentityPoolId='us-west-2_non-existent') + + cm.exception.operation_name.should.equal('DescribeIdentityPool') + cm.exception.response['Error']['Code'].should.equal('ResourceNotFoundException') + cm.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) + + # testing a helper function def test_get_random_identity_id(): assert len(get_random_identity_id('us-west-2')) > 0 @@ -44,7 +85,8 @@ def test_get_id(): 'someurl': '12345' }) print(result) - assert result.get('IdentityId', "").startswith('us-west-2') or result.get('ResponseMetadata').get('HTTPStatusCode') == 200 + assert result.get('IdentityId', "").startswith('us-west-2') or result.get('ResponseMetadata').get( + 'HTTPStatusCode') == 200 @mock_cognitoidentity @@ -71,6 +113,7 @@ def test_get_open_id_token_for_developer_identity(): assert len(result['Token']) > 0 assert result['IdentityId'] == '12345' + @mock_cognitoidentity def test_get_open_id_token_for_developer_identity_when_no_explicit_identity_id(): conn = boto3.client('cognito-identity', 'us-west-2') @@ -84,6 +127,7 @@ def test_get_open_id_token_for_developer_identity_when_no_explicit_identity_id() assert len(result['Token']) > 0 assert len(result['IdentityId']) > 0 + @mock_cognitoidentity def test_get_open_id_token(): conn = boto3.client('cognito-identity', 'us-west-2') From a4c79c19abeaad655255d53a73f57fe526395510 Mon Sep 17 00:00:00 2001 From: Wessel van der Veen Date: Mon, 2 Sep 2019 12:37:23 +0200 Subject: [PATCH 33/67] forgot to update implementation coverage. --- IMPLEMENTATION_COVERAGE.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/IMPLEMENTATION_COVERAGE.md b/IMPLEMENTATION_COVERAGE.md index d149b0dd8..bca6fef63 100644 --- a/IMPLEMENTATION_COVERAGE.md +++ b/IMPLEMENTATION_COVERAGE.md @@ -1237,7 +1237,7 @@ - [ ] delete_identities - [ ] delete_identity_pool - [ ] describe_identity -- [ ] describe_identity_pool +- [X] describe_identity_pool - [X] get_credentials_for_identity - [X] get_id - [ ] get_identity_pool_roles From af4082f38edbd78d0f56a7e8bd38340f46a96e74 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Mon, 2 Sep 2019 16:26:40 +0100 Subject: [PATCH 34/67] Step Functions - State Machines methods --- IMPLEMENTATION_COVERAGE.md | 10 +- docs/index.rst | 2 + moto/__init__.py | 1 + moto/backends.py | 2 + moto/stepfunctions/__init__.py | 6 + moto/stepfunctions/exceptions.py | 35 +++ moto/stepfunctions/models.py | 121 ++++++++ moto/stepfunctions/responses.py | 80 +++++ moto/stepfunctions/urls.py | 10 + .../test_stepfunctions/test_stepfunctions.py | 276 ++++++++++++++++++ 10 files changed, 538 insertions(+), 5 deletions(-) create mode 100644 moto/stepfunctions/__init__.py create mode 100644 moto/stepfunctions/exceptions.py create mode 100644 moto/stepfunctions/models.py create mode 100644 moto/stepfunctions/responses.py create mode 100644 moto/stepfunctions/urls.py create mode 100644 tests/test_stepfunctions/test_stepfunctions.py diff --git a/IMPLEMENTATION_COVERAGE.md b/IMPLEMENTATION_COVERAGE.md index d149b0dd8..7a839fb96 100644 --- a/IMPLEMENTATION_COVERAGE.md +++ b/IMPLEMENTATION_COVERAGE.md @@ -6050,19 +6050,19 @@ ## stepfunctions 0% implemented - [ ] create_activity -- [ ] create_state_machine +- [X] create_state_machine - [ ] delete_activity -- [ ] delete_state_machine +- [X] delete_state_machine - [ ] describe_activity - [ ] describe_execution -- [ ] describe_state_machine +- [X] describe_state_machine - [ ] describe_state_machine_for_execution - [ ] get_activity_task - [ ] get_execution_history - [ ] list_activities - [ ] list_executions -- [ ] list_state_machines -- [ ] list_tags_for_resource +- [X] list_state_machines +- [X] list_tags_for_resource - [ ] send_task_failure - [ ] send_task_heartbeat - [ ] send_task_success diff --git a/docs/index.rst b/docs/index.rst index 4811fb797..6311597fe 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -94,6 +94,8 @@ Currently implemented Services: +---------------------------+-----------------------+------------------------------------+ | SES | @mock_ses | all endpoints done | +---------------------------+-----------------------+------------------------------------+ +| SFN | @mock_stepfunctions | basic endpoints done | ++---------------------------+-----------------------+------------------------------------+ | SNS | @mock_sns | all endpoints done | +---------------------------+-----------------------+------------------------------------+ | SQS | @mock_sqs | core endpoints done | diff --git a/moto/__init__.py b/moto/__init__.py index 8594cedd2..f82a411cf 100644 --- a/moto/__init__.py +++ b/moto/__init__.py @@ -42,6 +42,7 @@ from .ses import mock_ses, mock_ses_deprecated # flake8: noqa from .secretsmanager import mock_secretsmanager # flake8: noqa from .sns import mock_sns, mock_sns_deprecated # flake8: noqa from .sqs import mock_sqs, mock_sqs_deprecated # flake8: noqa +from .stepfunctions import mock_stepfunctions # flake8: noqa from .sts import mock_sts, mock_sts_deprecated # flake8: noqa from .ssm import mock_ssm # flake8: noqa from .route53 import mock_route53, mock_route53_deprecated # flake8: noqa diff --git a/moto/backends.py b/moto/backends.py index 6ea85093d..8a20697c2 100644 --- a/moto/backends.py +++ b/moto/backends.py @@ -40,6 +40,7 @@ from moto.secretsmanager import secretsmanager_backends from moto.sns import sns_backends from moto.sqs import sqs_backends from moto.ssm import ssm_backends +from moto.stepfunctions import stepfunction_backends from moto.sts import sts_backends from moto.swf import swf_backends from moto.xray import xray_backends @@ -91,6 +92,7 @@ BACKENDS = { 'sns': sns_backends, 'sqs': sqs_backends, 'ssm': ssm_backends, + 'stepfunctions': stepfunction_backends, 'sts': sts_backends, 'swf': swf_backends, 'route53': route53_backends, diff --git a/moto/stepfunctions/__init__.py b/moto/stepfunctions/__init__.py new file mode 100644 index 000000000..dc2b0ba13 --- /dev/null +++ b/moto/stepfunctions/__init__.py @@ -0,0 +1,6 @@ +from __future__ import unicode_literals +from .models import stepfunction_backends +from ..core.models import base_decorator + +stepfunction_backend = stepfunction_backends['us-east-1'] +mock_stepfunctions = base_decorator(stepfunction_backends) diff --git a/moto/stepfunctions/exceptions.py b/moto/stepfunctions/exceptions.py new file mode 100644 index 000000000..a7c0897a5 --- /dev/null +++ b/moto/stepfunctions/exceptions.py @@ -0,0 +1,35 @@ +from __future__ import unicode_literals +import json + + +class AWSError(Exception): + CODE = None + STATUS = 400 + + def __init__(self, message, code=None, status=None): + self.message = message + self.code = code if code is not None else self.CODE + self.status = status if status is not None else self.STATUS + + def response(self): + return json.dumps({'__type': self.code, 'message': self.message}), dict(status=self.status) + + +class AccessDeniedException(AWSError): + CODE = 'AccessDeniedException' + STATUS = 400 + + +class InvalidArn(AWSError): + CODE = 'InvalidArn' + STATUS = 400 + + +class InvalidName(AWSError): + CODE = 'InvalidName' + STATUS = 400 + + +class StateMachineDoesNotExist(AWSError): + CODE = 'StateMachineDoesNotExist' + STATUS = 400 diff --git a/moto/stepfunctions/models.py b/moto/stepfunctions/models.py new file mode 100644 index 000000000..8571fbe9b --- /dev/null +++ b/moto/stepfunctions/models.py @@ -0,0 +1,121 @@ +import boto +import boto3 +import re +from datetime import datetime +from moto.core import BaseBackend +from moto.core.utils import iso_8601_datetime_without_milliseconds +from .exceptions import AccessDeniedException, InvalidArn, InvalidName, StateMachineDoesNotExist + + +class StateMachine(): + def __init__(self, arn, name, definition, roleArn, tags=None): + self.creation_date = iso_8601_datetime_without_milliseconds(datetime.now()) + self.arn = arn + self.name = name + self.definition = definition + self.roleArn = roleArn + self.tags = tags + + +class StepFunctionBackend(BaseBackend): + + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/stepfunctions.html#SFN.Client.create_state_machine + # A name must not contain: + # whitespace + # brackets < > { } [ ] + # wildcard characters ? * + # special characters " # % \ ^ | ~ ` $ & , ; : / + invalid_chars_for_name = [' ', '{', '}', '[', ']', '<', '>', + '?', '*', + '"', '#', '%', '\\', '^', '|', '~', '`', '$', '&', ',', ';', ':', '/'] + # control characters (U+0000-001F , U+007F-009F ) + invalid_unicodes_for_name = [u'\u0000', u'\u0001', u'\u0002', u'\u0003', u'\u0004', + u'\u0005', u'\u0006', u'\u0007', u'\u0008', u'\u0009', + u'\u000A', u'\u000B', u'\u000C', u'\u000D', u'\u000E', u'\u000F', + u'\u0010', u'\u0011', u'\u0012', u'\u0013', u'\u0014', + u'\u0015', u'\u0016', u'\u0017', u'\u0018', u'\u0019', + u'\u001A', u'\u001B', u'\u001C', u'\u001D', u'\u001E', u'\u001F', + u'\u007F', + u'\u0080', u'\u0081', u'\u0082', u'\u0083', u'\u0084', u'\u0085', + u'\u0086', u'\u0087', u'\u0088', u'\u0089', + u'\u008A', u'\u008B', u'\u008C', u'\u008D', u'\u008E', u'\u008F', + u'\u0090', u'\u0091', u'\u0092', u'\u0093', u'\u0094', u'\u0095', + u'\u0096', u'\u0097', u'\u0098', u'\u0099', + u'\u009A', u'\u009B', u'\u009C', u'\u009D', u'\u009E', u'\u009F'] + accepted_role_arn_format = re.compile('arn:aws:iam:(?P[0-9]{12}):role/.+') + accepted_mchn_arn_format = re.compile('arn:aws:states:[-0-9a-zA-Z]+:(?P[0-9]{12}):stateMachine:.+') + + def __init__(self, region_name): + self.state_machines = [] + self.region_name = region_name + self._account_id = None + + def create_state_machine(self, name, definition, roleArn, tags=None): + self._validate_name(name) + self._validate_role_arn(roleArn) + arn = 'arn:aws:states:' + self.region_name + ':' + str(self._get_account_id()) + ':stateMachine:' + name + try: + return self.describe_state_machine(arn) + except StateMachineDoesNotExist: + state_machine = StateMachine(arn, name, definition, roleArn, tags) + self.state_machines.append(state_machine) + return state_machine + + def list_state_machines(self): + return self.state_machines + + def describe_state_machine(self, arn): + self._validate_machine_arn(arn) + sm = next((x for x in self.state_machines if x.arn == arn), None) + if not sm: + raise StateMachineDoesNotExist("State Machine Does Not Exist: '" + arn + "'") + return sm + + def delete_state_machine(self, arn): + self._validate_machine_arn(arn) + sm = next((x for x in self.state_machines if x.arn == arn), None) + if sm: + self.state_machines.remove(sm) + + def reset(self): + region_name = self.region_name + self.__dict__ = {} + self.__init__(region_name) + + def _validate_name(self, name): + if any(invalid_char in name for invalid_char in self.invalid_chars_for_name): + raise InvalidName("Invalid Name: '" + name + "'") + + if any(name.find(char) >= 0 for char in self.invalid_unicodes_for_name): + raise InvalidName("Invalid Name: '" + name + "'") + + def _validate_role_arn(self, role_arn): + self._validate_arn(arn=role_arn, + regex=self.accepted_role_arn_format, + invalid_msg="Invalid Role Arn: '" + role_arn + "'", + access_denied_msg='Cross-account pass role is not allowed.') + + def _validate_machine_arn(self, machine_arn): + self._validate_arn(arn=machine_arn, + regex=self.accepted_mchn_arn_format, + invalid_msg="Invalid Role Arn: '" + machine_arn + "'", + access_denied_msg='User moto is not authorized to access this resource') + + def _validate_arn(self, arn, regex, invalid_msg, access_denied_msg): + match = regex.match(arn) + if not arn or not match: + raise InvalidArn(invalid_msg) + + if self._get_account_id() != match.group('account_id'): + raise AccessDeniedException(access_denied_msg) + + def _get_account_id(self): + if self._account_id: + return self._account_id + sts = boto3.client("sts") + identity = sts.get_caller_identity() + self._account_id = identity['Account'] + return self._account_id + + +stepfunction_backends = {_region.name: StepFunctionBackend(_region.name) for _region in boto.awslambda.regions()} diff --git a/moto/stepfunctions/responses.py b/moto/stepfunctions/responses.py new file mode 100644 index 000000000..d729a5a38 --- /dev/null +++ b/moto/stepfunctions/responses.py @@ -0,0 +1,80 @@ +from __future__ import unicode_literals + +import json + +from moto.core.responses import BaseResponse +from moto.core.utils import amzn_request_id +from .exceptions import AWSError +from .models import stepfunction_backends + + +class StepFunctionResponse(BaseResponse): + + @property + def stepfunction_backend(self): + return stepfunction_backends[self.region] + + @amzn_request_id + def create_state_machine(self): + name = self._get_param('name') + definition = self._get_param('definition') + roleArn = self._get_param('roleArn') + tags = self._get_param('tags') + try: + state_machine = self.stepfunction_backend.create_state_machine(name=name, definition=definition, + roleArn=roleArn, + tags=tags) + response = { + 'creationDate': state_machine.creation_date, + 'stateMachineArn': state_machine.arn + } + return 200, {}, json.dumps(response) + except AWSError as err: + return err.response() + + @amzn_request_id + def list_state_machines(self): + list_all = self.stepfunction_backend.list_state_machines() + list_all = sorted([{'creationDate': sm.creation_date, + 'name': sm.name, + 'stateMachineArn': sm.arn} for sm in list_all], + key=lambda x: x['name']) + response = {'stateMachines': list_all} + return 200, {}, json.dumps(response) + + @amzn_request_id + def describe_state_machine(self): + arn = self._get_param('stateMachineArn') + try: + state_machine = self.stepfunction_backend.describe_state_machine(arn) + response = { + 'creationDate': state_machine.creation_date, + 'stateMachineArn': state_machine.arn, + 'definition': state_machine.definition, + 'name': state_machine.name, + 'roleArn': state_machine.roleArn, + 'status': 'ACTIVE' + } + return 200, {}, json.dumps(response) + except AWSError as err: + return err.response() + + @amzn_request_id + def delete_state_machine(self): + arn = self._get_param('stateMachineArn') + try: + self.stepfunction_backend.delete_state_machine(arn) + return 200, {}, json.dumps('{}') + except AWSError as err: + return err.response() + + @amzn_request_id + def list_tags_for_resource(self): + arn = self._get_param('resourceArn') + try: + state_machine = self.stepfunction_backend.describe_state_machine(arn) + tags = state_machine.tags or [] + except AWSError: + tags = [] + response = {'tags': tags} + return 200, {}, json.dumps(response) diff --git a/moto/stepfunctions/urls.py b/moto/stepfunctions/urls.py new file mode 100644 index 000000000..f8d5fb1e8 --- /dev/null +++ b/moto/stepfunctions/urls.py @@ -0,0 +1,10 @@ +from __future__ import unicode_literals +from .responses import StepFunctionResponse + +url_bases = [ + "https?://states.(.+).amazonaws.com", +] + +url_paths = { + '{0}/$': StepFunctionResponse.dispatch, +} diff --git a/tests/test_stepfunctions/test_stepfunctions.py b/tests/test_stepfunctions/test_stepfunctions.py new file mode 100644 index 000000000..0b9df50a9 --- /dev/null +++ b/tests/test_stepfunctions/test_stepfunctions.py @@ -0,0 +1,276 @@ +from __future__ import unicode_literals + +import boto3 +import sure # noqa +import datetime + +from datetime import datetime +from botocore.exceptions import ClientError +from moto.config.models import DEFAULT_ACCOUNT_ID +from nose.tools import assert_raises + +from moto import mock_sts, mock_stepfunctions + + +region = 'us-east-1' +simple_definition = '{"Comment": "An example of the Amazon States Language using a choice state.",' \ + '"StartAt": "DefaultState",' \ + '"States": ' \ + '{"DefaultState": {"Type": "Fail","Error": "DefaultStateError","Cause": "No Matches!"}}}' +default_stepfunction_role = 'arn:aws:iam:' + str(DEFAULT_ACCOUNT_ID) + ':role/unknown_sf_role' + + +@mock_stepfunctions +@mock_sts +def test_state_machine_creation_succeeds(): + client = boto3.client('stepfunctions', region_name=region) + name = 'example_step_function' + # + response = client.create_state_machine(name=name, + definition=str(simple_definition), + roleArn=default_stepfunction_role) + # + response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + response['creationDate'].should.be.a(datetime) + response['stateMachineArn'].should.equal('arn:aws:states:' + region + ':123456789012:stateMachine:' + name) + + +@mock_stepfunctions +def test_state_machine_creation_fails_with_invalid_names(): + client = boto3.client('stepfunctions', region_name=region) + invalid_names = [ + 'with space', + 'withbracket', 'with{bracket', 'with}bracket', 'with[bracket', 'with]bracket', + 'with?wildcard', 'with*wildcard', + 'special"char', 'special#char', 'special%char', 'special\\char', 'special^char', 'special|char', + 'special~char', 'special`char', 'special$char', 'special&char', 'special,char', 'special;char', + 'special:char', 'special/char', + u'uni\u0000code', u'uni\u0001code', u'uni\u0002code', u'uni\u0003code', u'uni\u0004code', + u'uni\u0005code', u'uni\u0006code', u'uni\u0007code', u'uni\u0008code', u'uni\u0009code', + u'uni\u000Acode', u'uni\u000Bcode', u'uni\u000Ccode', + u'uni\u000Dcode', u'uni\u000Ecode', u'uni\u000Fcode', + u'uni\u0010code', u'uni\u0011code', u'uni\u0012code', u'uni\u0013code', u'uni\u0014code', + u'uni\u0015code', u'uni\u0016code', u'uni\u0017code', u'uni\u0018code', u'uni\u0019code', + u'uni\u001Acode', u'uni\u001Bcode', u'uni\u001Ccode', + u'uni\u001Dcode', u'uni\u001Ecode', u'uni\u001Fcode', + u'uni\u007Fcode', + u'uni\u0080code', u'uni\u0081code', u'uni\u0082code', u'uni\u0083code', u'uni\u0084code', + u'uni\u0085code', u'uni\u0086code', u'uni\u0087code', u'uni\u0088code', u'uni\u0089code', + u'uni\u008Acode', u'uni\u008Bcode', u'uni\u008Ccode', + u'uni\u008Dcode', u'uni\u008Ecode', u'uni\u008Fcode', + u'uni\u0090code', u'uni\u0091code', u'uni\u0092code', u'uni\u0093code', u'uni\u0094code', + u'uni\u0095code', u'uni\u0096code', u'uni\u0097code', u'uni\u0098code', u'uni\u0099code', + u'uni\u009Acode', u'uni\u009Bcode', u'uni\u009Ccode', + u'uni\u009Dcode', u'uni\u009Ecode', u'uni\u009Fcode'] + # + + for invalid_name in invalid_names: + with assert_raises(ClientError) as exc: + client.create_state_machine(name=invalid_name, + definition=str(simple_definition), + roleArn=default_stepfunction_role) + exc.exception.response['Error']['Code'].should.equal('InvalidName') + exc.exception.response['Error']['Message'].should.equal("Invalid Name: '" + invalid_name + "'") + exc.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) + + +@mock_stepfunctions +def test_state_machine_creation_requires_valid_role_arn(): + client = boto3.client('stepfunctions', region_name=region) + name = 'example_step_function' + # + with assert_raises(ClientError) as exc: + client.create_state_machine(name=name, + definition=str(simple_definition), + roleArn='arn:aws:iam:1234:role/unknown_role') + exc.exception.response['Error']['Code'].should.equal('InvalidArn') + exc.exception.response['Error']['Message'].should.equal("Invalid Role Arn: 'arn:aws:iam:1234:role/unknown_role'") + exc.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_creation_requires_role_in_same_account(): + client = boto3.client('stepfunctions', region_name=region) + name = 'example_step_function' + # + with assert_raises(ClientError) as exc: + client.create_state_machine(name=name, + definition=str(simple_definition), + roleArn='arn:aws:iam:000000000000:role/unknown_role') + exc.exception.response['Error']['Code'].should.equal('AccessDeniedException') + exc.exception.response['Error']['Message'].should.equal('Cross-account pass role is not allowed.') + exc.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) + + +@mock_stepfunctions +def test_state_machine_list_returns_empty_list_by_default(): + client = boto3.client('stepfunctions', region_name=region) + # + list = client.list_state_machines() + list['stateMachines'].should.be.empty + + +@mock_stepfunctions +@mock_sts +def test_state_machine_list_returns_created_state_machines(): + client = boto3.client('stepfunctions', region_name=region) + # + machine2 = client.create_state_machine(name='name2', + definition=str(simple_definition), + roleArn=default_stepfunction_role) + machine1 = client.create_state_machine(name='name1', + definition=str(simple_definition), + roleArn=default_stepfunction_role, + tags=[{'key': 'tag_key', 'value': 'tag_value'}]) + list = client.list_state_machines() + # + list['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + list['stateMachines'].should.have.length_of(2) + list['stateMachines'][0]['creationDate'].should.be.a(datetime) + list['stateMachines'][0]['creationDate'].should.equal(machine1['creationDate']) + list['stateMachines'][0]['name'].should.equal('name1') + list['stateMachines'][0]['stateMachineArn'].should.equal(machine1['stateMachineArn']) + list['stateMachines'][1]['creationDate'].should.be.a(datetime) + list['stateMachines'][1]['creationDate'].should.equal(machine2['creationDate']) + list['stateMachines'][1]['name'].should.equal('name2') + list['stateMachines'][1]['stateMachineArn'].should.equal(machine2['stateMachineArn']) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_creation_is_idempotent_by_name(): + client = boto3.client('stepfunctions', region_name=region) + # + client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) + sm_list = client.list_state_machines() + sm_list['stateMachines'].should.have.length_of(1) + # + client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) + sm_list = client.list_state_machines() + sm_list['stateMachines'].should.have.length_of(1) + # + client.create_state_machine(name='diff_name', definition=str(simple_definition), roleArn=default_stepfunction_role) + sm_list = client.list_state_machines() + sm_list['stateMachines'].should.have.length_of(2) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_creation_can_be_described_by_name(): + client = boto3.client('stepfunctions', region_name=region) + # + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) + desc = client.describe_state_machine(stateMachineArn=sm['stateMachineArn']) + desc['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + desc['creationDate'].should.equal(sm['creationDate']) + desc['definition'].should.equal(str(simple_definition)) + desc['name'].should.equal('name') + desc['roleArn'].should.equal(default_stepfunction_role) + desc['stateMachineArn'].should.equal(sm['stateMachineArn']) + desc['status'].should.equal('ACTIVE') + + +@mock_stepfunctions +@mock_sts +def test_state_machine_throws_error_when_describing_unknown_machine(): + client = boto3.client('stepfunctions', region_name=region) + # + with assert_raises(ClientError) as exc: + unknown_state_machine = 'arn:aws:states:' + region + ':' + str(DEFAULT_ACCOUNT_ID) + ':stateMachine:unknown' + client.describe_state_machine(stateMachineArn=unknown_state_machine) + exc.exception.response['Error']['Code'].should.equal('StateMachineDoesNotExist') + exc.exception.response['Error']['Message'].\ + should.equal("State Machine Does Not Exist: '" + unknown_state_machine + "'") + exc.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_throws_error_when_describing_machine_in_different_account(): + client = boto3.client('stepfunctions', region_name=region) + # + with assert_raises(ClientError) as exc: + unknown_state_machine = 'arn:aws:states:' + region + ':000000000000:stateMachine:unknown' + client.describe_state_machine(stateMachineArn=unknown_state_machine) + exc.exception.response['Error']['Code'].should.equal('AccessDeniedException') + exc.exception.response['Error']['Message'].should.contain('is not authorized to access this resource') + exc.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_can_be_deleted(): + client = boto3.client('stepfunctions', region_name=region) + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) + # + response = client.delete_state_machine(stateMachineArn=sm['stateMachineArn']) + response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + # + sm_list = client.list_state_machines() + sm_list['stateMachines'].should.have.length_of(0) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_can_deleted_nonexisting_machine(): + client = boto3.client('stepfunctions', region_name=region) + # + unknown_state_machine = 'arn:aws:states:' + region + ':123456789012:stateMachine:unknown' + response = client.delete_state_machine(stateMachineArn=unknown_state_machine) + response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + # + sm_list = client.list_state_machines() + sm_list['stateMachines'].should.have.length_of(0) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_deletion_validates_arn(): + client = boto3.client('stepfunctions', region_name=region) + # + with assert_raises(ClientError) as exc: + unknown_account_id = 'arn:aws:states:' + region + ':000000000000:stateMachine:unknown' + client.delete_state_machine(stateMachineArn=unknown_account_id) + exc.exception.response['Error']['Code'].should.equal('AccessDeniedException') + exc.exception.response['Error']['Message'].should.contain('is not authorized to access this resource') + exc.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_list_tags_for_created_machine(): + client = boto3.client('stepfunctions', region_name=region) + # + machine = client.create_state_machine(name='name1', + definition=str(simple_definition), + roleArn=default_stepfunction_role, + tags=[{'key': 'tag_key', 'value': 'tag_value'}]) + response = client.list_tags_for_resource(resourceArn=machine['stateMachineArn']) + tags = response['tags'] + tags.should.have.length_of(1) + tags[0].should.equal({'key': 'tag_key', 'value': 'tag_value'}) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_list_tags_for_machine_without_tags(): + client = boto3.client('stepfunctions', region_name=region) + # + machine = client.create_state_machine(name='name1', + definition=str(simple_definition), + roleArn=default_stepfunction_role) + response = client.list_tags_for_resource(resourceArn=machine['stateMachineArn']) + tags = response['tags'] + tags.should.have.length_of(0) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_list_tags_for_nonexisting_machine(): + client = boto3.client('stepfunctions', region_name=region) + # + non_existing_state_machine = 'arn:aws:states:' + region + ':' + str(DEFAULT_ACCOUNT_ID) + ':stateMachine:unknown' + response = client.list_tags_for_resource(resourceArn=non_existing_state_machine) + tags = response['tags'] + tags.should.have.length_of(0) From 78254cc4f2bf5774d37d79edf9255a35d9e23c49 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Wed, 4 Sep 2019 15:42:42 +0100 Subject: [PATCH 35/67] Step Functions - Execution methods --- IMPLEMENTATION_COVERAGE.md | 10 +- moto/stepfunctions/exceptions.py | 5 + moto/stepfunctions/models.py | 50 ++++++- moto/stepfunctions/responses.py | 60 +++++++- .../test_stepfunctions/test_stepfunctions.py | 138 +++++++++++++++++- 5 files changed, 255 insertions(+), 8 deletions(-) diff --git a/IMPLEMENTATION_COVERAGE.md b/IMPLEMENTATION_COVERAGE.md index 7a839fb96..fd5ad3f1e 100644 --- a/IMPLEMENTATION_COVERAGE.md +++ b/IMPLEMENTATION_COVERAGE.md @@ -6054,20 +6054,20 @@ - [ ] delete_activity - [X] delete_state_machine - [ ] describe_activity -- [ ] describe_execution +- [X] describe_execution - [X] describe_state_machine -- [ ] describe_state_machine_for_execution +- [x] describe_state_machine_for_execution - [ ] get_activity_task - [ ] get_execution_history - [ ] list_activities -- [ ] list_executions +- [X] list_executions - [X] list_state_machines - [X] list_tags_for_resource - [ ] send_task_failure - [ ] send_task_heartbeat - [ ] send_task_success -- [ ] start_execution -- [ ] stop_execution +- [X] start_execution +- [X] stop_execution - [ ] tag_resource - [ ] untag_resource - [ ] update_state_machine diff --git a/moto/stepfunctions/exceptions.py b/moto/stepfunctions/exceptions.py index a7c0897a5..133d0cc83 100644 --- a/moto/stepfunctions/exceptions.py +++ b/moto/stepfunctions/exceptions.py @@ -20,6 +20,11 @@ class AccessDeniedException(AWSError): STATUS = 400 +class ExecutionDoesNotExist(AWSError): + CODE = 'ExecutionDoesNotExist' + STATUS = 400 + + class InvalidArn(AWSError): CODE = 'InvalidArn' STATUS = 400 diff --git a/moto/stepfunctions/models.py b/moto/stepfunctions/models.py index 8571fbe9b..fd272624f 100644 --- a/moto/stepfunctions/models.py +++ b/moto/stepfunctions/models.py @@ -4,7 +4,8 @@ import re from datetime import datetime from moto.core import BaseBackend from moto.core.utils import iso_8601_datetime_without_milliseconds -from .exceptions import AccessDeniedException, InvalidArn, InvalidName, StateMachineDoesNotExist +from uuid import uuid4 +from .exceptions import AccessDeniedException, ExecutionDoesNotExist, InvalidArn, InvalidName, StateMachineDoesNotExist class StateMachine(): @@ -17,6 +18,22 @@ class StateMachine(): self.tags = tags +class Execution(): + def __init__(self, region_name, account_id, state_machine_name, execution_name, state_machine_arn): + execution_arn = 'arn:aws:states:{}:{}:execution:{}:{}' + execution_arn = execution_arn.format(region_name, account_id, state_machine_name, execution_name) + self.execution_arn = execution_arn + self.name = execution_name + self.start_date = iso_8601_datetime_without_milliseconds(datetime.now()) + self.state_machine_arn = state_machine_arn + self.status = 'RUNNING' + self.stop_date = None + + def stop(self): + self.status = 'SUCCEEDED' + self.stop_date = iso_8601_datetime_without_milliseconds(datetime.now()) + + class StepFunctionBackend(BaseBackend): # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/stepfunctions.html#SFN.Client.create_state_machine @@ -44,9 +61,11 @@ class StepFunctionBackend(BaseBackend): u'\u009A', u'\u009B', u'\u009C', u'\u009D', u'\u009E', u'\u009F'] accepted_role_arn_format = re.compile('arn:aws:iam:(?P[0-9]{12}):role/.+') accepted_mchn_arn_format = re.compile('arn:aws:states:[-0-9a-zA-Z]+:(?P[0-9]{12}):stateMachine:.+') + accepted_exec_arn_format = re.compile('arn:aws:states:[-0-9a-zA-Z]+:(?P[0-9]{12}):execution:.+') def __init__(self, region_name): self.state_machines = [] + self.executions = [] self.region_name = region_name self._account_id = None @@ -77,6 +96,29 @@ class StepFunctionBackend(BaseBackend): if sm: self.state_machines.remove(sm) + def start_execution(self, state_machine_arn): + state_machine_name = self.describe_state_machine(state_machine_arn).name + execution = Execution(region_name=self.region_name, account_id=self._get_account_id(), state_machine_name=state_machine_name, execution_name=str(uuid4()), state_machine_arn=state_machine_arn) + self.executions.append(execution) + return execution + + def stop_execution(self, execution_arn): + execution = next((x for x in self.executions if x.execution_arn == execution_arn), None) + if not execution: + raise ExecutionDoesNotExist("Execution Does Not Exist: '" + execution_arn + "'") + execution.stop() + return execution + + def list_executions(self, state_machine_arn): + return [execution for execution in self.executions if execution.state_machine_arn == state_machine_arn] + + def describe_execution(self, arn): + self._validate_execution_arn(arn) + exctn = next((x for x in self.executions if x.execution_arn == arn), None) + if not exctn: + raise ExecutionDoesNotExist("Execution Does Not Exist: '" + arn + "'") + return exctn + def reset(self): region_name = self.region_name self.__dict__ = {} @@ -101,6 +143,12 @@ class StepFunctionBackend(BaseBackend): invalid_msg="Invalid Role Arn: '" + machine_arn + "'", access_denied_msg='User moto is not authorized to access this resource') + def _validate_execution_arn(self, execution_arn): + self._validate_arn(arn=execution_arn, + regex=self.accepted_exec_arn_format, + invalid_msg="Execution Does Not Exist: '" + execution_arn + "'", + access_denied_msg='User moto is not authorized to access this resource') + def _validate_arn(self, arn, regex, invalid_msg, access_denied_msg): match = regex.match(arn) if not arn or not match: diff --git a/moto/stepfunctions/responses.py b/moto/stepfunctions/responses.py index d729a5a38..0a170aa57 100644 --- a/moto/stepfunctions/responses.py +++ b/moto/stepfunctions/responses.py @@ -45,8 +45,12 @@ class StepFunctionResponse(BaseResponse): @amzn_request_id def describe_state_machine(self): arn = self._get_param('stateMachineArn') + return self._describe_state_machine(arn) + + @amzn_request_id + def _describe_state_machine(self, state_machine_arn): try: - state_machine = self.stepfunction_backend.describe_state_machine(arn) + state_machine = self.stepfunction_backend.describe_state_machine(state_machine_arn) response = { 'creationDate': state_machine.creation_date, 'stateMachineArn': state_machine.arn, @@ -78,3 +82,57 @@ class StepFunctionResponse(BaseResponse): tags = [] response = {'tags': tags} return 200, {}, json.dumps(response) + + @amzn_request_id + def start_execution(self): + arn = self._get_param('stateMachineArn') + execution = self.stepfunction_backend.start_execution(arn) + response = {'executionArn': execution.execution_arn, + 'startDate': execution.start_date} + return 200, {}, json.dumps(response) + + @amzn_request_id + def list_executions(self): + arn = self._get_param('stateMachineArn') + state_machine = self.stepfunction_backend.describe_state_machine(arn) + executions = self.stepfunction_backend.list_executions(arn) + executions = [{'executionArn': execution.execution_arn, + 'name': execution.name, + 'startDate': execution.start_date, + 'stateMachineArn': state_machine.arn, + 'status': execution.status} for execution in executions] + return 200, {}, json.dumps({'executions': executions}) + + @amzn_request_id + def describe_execution(self): + arn = self._get_param('executionArn') + try: + execution = self.stepfunction_backend.describe_execution(arn) + response = { + 'executionArn': arn, + 'input': '{}', + 'name': execution.name, + 'startDate': execution.start_date, + 'stateMachineArn': execution.state_machine_arn, + 'status': execution.status, + 'stopDate': execution.stop_date + } + return 200, {}, json.dumps(response) + except AWSError as err: + return err.response() + + @amzn_request_id + def describe_state_machine_for_execution(self): + arn = self._get_param('executionArn') + try: + execution = self.stepfunction_backend.describe_execution(arn) + return self._describe_state_machine(execution.state_machine_arn) + except AWSError as err: + return err.response() + + @amzn_request_id + def stop_execution(self): + arn = self._get_param('executionArn') + execution = self.stepfunction_backend.stop_execution(arn) + response = {'stopDate': execution.stop_date} + return 200, {}, json.dumps(response) diff --git a/tests/test_stepfunctions/test_stepfunctions.py b/tests/test_stepfunctions/test_stepfunctions.py index 0b9df50a9..bf5c92570 100644 --- a/tests/test_stepfunctions/test_stepfunctions.py +++ b/tests/test_stepfunctions/test_stepfunctions.py @@ -157,7 +157,7 @@ def test_state_machine_creation_is_idempotent_by_name(): @mock_stepfunctions @mock_sts -def test_state_machine_creation_can_be_described_by_name(): +def test_state_machine_creation_can_be_described(): client = boto3.client('stepfunctions', region_name=region) # sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) @@ -274,3 +274,139 @@ def test_state_machine_list_tags_for_nonexisting_machine(): response = client.list_tags_for_resource(resourceArn=non_existing_state_machine) tags = response['tags'] tags.should.have.length_of(0) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_start_execution(): + client = boto3.client('stepfunctions', region_name=region) + # + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) + execution = client.start_execution(stateMachineArn=sm['stateMachineArn']) + # + execution['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + expected_exec_name = 'arn:aws:states:' + region + ':' + str(DEFAULT_ACCOUNT_ID) + ':execution:name:[a-zA-Z0-9-]+' + execution['executionArn'].should.match(expected_exec_name) + execution['startDate'].should.be.a(datetime) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_list_executions(): + client = boto3.client('stepfunctions', region_name=region) + # + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) + execution = client.start_execution(stateMachineArn=sm['stateMachineArn']) + execution_arn = execution['executionArn'] + execution_name = execution_arn[execution_arn.rindex(':')+1:] + executions = client.list_executions(stateMachineArn=sm['stateMachineArn']) + # + executions['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + executions['executions'].should.have.length_of(1) + executions['executions'][0]['executionArn'].should.equal(execution_arn) + executions['executions'][0]['name'].should.equal(execution_name) + executions['executions'][0]['startDate'].should.equal(execution['startDate']) + executions['executions'][0]['stateMachineArn'].should.equal(sm['stateMachineArn']) + executions['executions'][0]['status'].should.equal('RUNNING') + executions['executions'][0].shouldnt.have('stopDate') + + +@mock_stepfunctions +@mock_sts +def test_state_machine_list_executions_when_none_exist(): + client = boto3.client('stepfunctions', region_name=region) + # + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) + executions = client.list_executions(stateMachineArn=sm['stateMachineArn']) + # + executions['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + executions['executions'].should.have.length_of(0) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_describe_execution(): + client = boto3.client('stepfunctions', region_name=region) + # + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) + execution = client.start_execution(stateMachineArn=sm['stateMachineArn']) + description = client.describe_execution(executionArn=execution['executionArn']) + # + description['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + description['executionArn'].should.equal(execution['executionArn']) + description['input'].should.equal("{}") + description['name'].shouldnt.be.empty + description['startDate'].should.equal(execution['startDate']) + description['stateMachineArn'].should.equal(sm['stateMachineArn']) + description['status'].should.equal('RUNNING') + description.shouldnt.have('stopDate') + + +@mock_stepfunctions +@mock_sts +def test_state_machine_throws_error_when_describing_unknown_machine(): + client = boto3.client('stepfunctions', region_name=region) + # + with assert_raises(ClientError) as exc: + unknown_execution = 'arn:aws:states:' + region + ':' + str(DEFAULT_ACCOUNT_ID) + ':execution:unknown' + client.describe_execution(executionArn=unknown_execution) + exc.exception.response['Error']['Code'].should.equal('ExecutionDoesNotExist') + exc.exception.response['Error']['Message'].should.equal("Execution Does Not Exist: '" + unknown_execution + "'") + exc.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_can_be_described_by_execution(): + client = boto3.client('stepfunctions', region_name=region) + # + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) + execution = client.start_execution(stateMachineArn=sm['stateMachineArn']) + desc = client.describe_state_machine_for_execution(executionArn=execution['executionArn']) + desc['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + desc['definition'].should.equal(str(simple_definition)) + desc['name'].should.equal('name') + desc['roleArn'].should.equal(default_stepfunction_role) + desc['stateMachineArn'].should.equal(sm['stateMachineArn']) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_throws_error_when_describing_unknown_execution(): + client = boto3.client('stepfunctions', region_name=region) + # + with assert_raises(ClientError) as exc: + unknown_execution = 'arn:aws:states:' + region + ':' + str(DEFAULT_ACCOUNT_ID) + ':execution:unknown' + client.describe_state_machine_for_execution(executionArn=unknown_execution) + exc.exception.response['Error']['Code'].should.equal('ExecutionDoesNotExist') + exc.exception.response['Error']['Message'].should.equal("Execution Does Not Exist: '" + unknown_execution + "'") + exc.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_stop_execution(): + client = boto3.client('stepfunctions', region_name=region) + # + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) + start = client.start_execution(stateMachineArn=sm['stateMachineArn']) + stop = client.stop_execution(executionArn=start['executionArn']) + print(stop) + # + stop['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + stop['stopDate'].should.be.a(datetime) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_describe_execution_after_stoppage(): + client = boto3.client('stepfunctions', region_name=region) + # + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) + execution = client.start_execution(stateMachineArn=sm['stateMachineArn']) + client.stop_execution(executionArn=execution['executionArn']) + description = client.describe_execution(executionArn=execution['executionArn']) + # + description['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + description['status'].should.equal('SUCCEEDED') + description['stopDate'].should.be.a(datetime) From 6a1a8df7ccd172f67308b99f6ccf7b1d2d4d1f6d Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Sat, 7 Sep 2019 16:37:55 +0100 Subject: [PATCH 36/67] Step Functions - Simplify tests --- moto/stepfunctions/exceptions.py | 21 ++-- moto/stepfunctions/models.py | 22 ++-- .../test_stepfunctions/test_stepfunctions.py | 116 +++++++----------- 3 files changed, 59 insertions(+), 100 deletions(-) diff --git a/moto/stepfunctions/exceptions.py b/moto/stepfunctions/exceptions.py index 133d0cc83..8af4686c7 100644 --- a/moto/stepfunctions/exceptions.py +++ b/moto/stepfunctions/exceptions.py @@ -3,38 +3,33 @@ import json class AWSError(Exception): - CODE = None + TYPE = None STATUS = 400 - def __init__(self, message, code=None, status=None): + def __init__(self, message, type=None, status=None): self.message = message - self.code = code if code is not None else self.CODE + self.type = type if type is not None else self.TYPE self.status = status if status is not None else self.STATUS def response(self): - return json.dumps({'__type': self.code, 'message': self.message}), dict(status=self.status) - - -class AccessDeniedException(AWSError): - CODE = 'AccessDeniedException' - STATUS = 400 + return json.dumps({'__type': self.type, 'message': self.message}), dict(status=self.status) class ExecutionDoesNotExist(AWSError): - CODE = 'ExecutionDoesNotExist' + TYPE = 'ExecutionDoesNotExist' STATUS = 400 class InvalidArn(AWSError): - CODE = 'InvalidArn' + TYPE = 'InvalidArn' STATUS = 400 class InvalidName(AWSError): - CODE = 'InvalidName' + TYPE = 'InvalidName' STATUS = 400 class StateMachineDoesNotExist(AWSError): - CODE = 'StateMachineDoesNotExist' + TYPE = 'StateMachineDoesNotExist' STATUS = 400 diff --git a/moto/stepfunctions/models.py b/moto/stepfunctions/models.py index fd272624f..8db9db1a1 100644 --- a/moto/stepfunctions/models.py +++ b/moto/stepfunctions/models.py @@ -5,7 +5,7 @@ from datetime import datetime from moto.core import BaseBackend from moto.core.utils import iso_8601_datetime_without_milliseconds from uuid import uuid4 -from .exceptions import AccessDeniedException, ExecutionDoesNotExist, InvalidArn, InvalidName, StateMachineDoesNotExist +from .exceptions import ExecutionDoesNotExist, InvalidArn, InvalidName, StateMachineDoesNotExist class StateMachine(): @@ -98,7 +98,11 @@ class StepFunctionBackend(BaseBackend): def start_execution(self, state_machine_arn): state_machine_name = self.describe_state_machine(state_machine_arn).name - execution = Execution(region_name=self.region_name, account_id=self._get_account_id(), state_machine_name=state_machine_name, execution_name=str(uuid4()), state_machine_arn=state_machine_arn) + execution = Execution(region_name=self.region_name, + account_id=self._get_account_id(), + state_machine_name=state_machine_name, + execution_name=str(uuid4()), + state_machine_arn=state_machine_arn) self.executions.append(execution) return execution @@ -134,29 +138,23 @@ class StepFunctionBackend(BaseBackend): def _validate_role_arn(self, role_arn): self._validate_arn(arn=role_arn, regex=self.accepted_role_arn_format, - invalid_msg="Invalid Role Arn: '" + role_arn + "'", - access_denied_msg='Cross-account pass role is not allowed.') + invalid_msg="Invalid Role Arn: '" + role_arn + "'") def _validate_machine_arn(self, machine_arn): self._validate_arn(arn=machine_arn, regex=self.accepted_mchn_arn_format, - invalid_msg="Invalid Role Arn: '" + machine_arn + "'", - access_denied_msg='User moto is not authorized to access this resource') + invalid_msg="Invalid Role Arn: '" + machine_arn + "'") def _validate_execution_arn(self, execution_arn): self._validate_arn(arn=execution_arn, regex=self.accepted_exec_arn_format, - invalid_msg="Execution Does Not Exist: '" + execution_arn + "'", - access_denied_msg='User moto is not authorized to access this resource') + invalid_msg="Execution Does Not Exist: '" + execution_arn + "'") - def _validate_arn(self, arn, regex, invalid_msg, access_denied_msg): + def _validate_arn(self, arn, regex, invalid_msg): match = regex.match(arn) if not arn or not match: raise InvalidArn(invalid_msg) - if self._get_account_id() != match.group('account_id'): - raise AccessDeniedException(access_denied_msg) - def _get_account_id(self): if self._account_id: return self._account_id diff --git a/tests/test_stepfunctions/test_stepfunctions.py b/tests/test_stepfunctions/test_stepfunctions.py index bf5c92570..10953ce2d 100644 --- a/tests/test_stepfunctions/test_stepfunctions.py +++ b/tests/test_stepfunctions/test_stepfunctions.py @@ -6,7 +6,6 @@ import datetime from datetime import datetime from botocore.exceptions import ClientError -from moto.config.models import DEFAULT_ACCOUNT_ID from nose.tools import assert_raises from moto import mock_sts, mock_stepfunctions @@ -17,7 +16,7 @@ simple_definition = '{"Comment": "An example of the Amazon States Language using '"StartAt": "DefaultState",' \ '"States": ' \ '{"DefaultState": {"Type": "Fail","Error": "DefaultStateError","Cause": "No Matches!"}}}' -default_stepfunction_role = 'arn:aws:iam:' + str(DEFAULT_ACCOUNT_ID) + ':role/unknown_sf_role' +account_id = None @mock_stepfunctions @@ -28,7 +27,7 @@ def test_state_machine_creation_succeeds(): # response = client.create_state_machine(name=name, definition=str(simple_definition), - roleArn=default_stepfunction_role) + roleArn=_get_default_role()) # response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) response['creationDate'].should.be.a(datetime) @@ -68,10 +67,7 @@ def test_state_machine_creation_fails_with_invalid_names(): with assert_raises(ClientError) as exc: client.create_state_machine(name=invalid_name, definition=str(simple_definition), - roleArn=default_stepfunction_role) - exc.exception.response['Error']['Code'].should.equal('InvalidName') - exc.exception.response['Error']['Message'].should.equal("Invalid Name: '" + invalid_name + "'") - exc.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) + roleArn=_get_default_role()) @mock_stepfunctions @@ -83,24 +79,6 @@ def test_state_machine_creation_requires_valid_role_arn(): client.create_state_machine(name=name, definition=str(simple_definition), roleArn='arn:aws:iam:1234:role/unknown_role') - exc.exception.response['Error']['Code'].should.equal('InvalidArn') - exc.exception.response['Error']['Message'].should.equal("Invalid Role Arn: 'arn:aws:iam:1234:role/unknown_role'") - exc.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) - - -@mock_stepfunctions -@mock_sts -def test_state_machine_creation_requires_role_in_same_account(): - client = boto3.client('stepfunctions', region_name=region) - name = 'example_step_function' - # - with assert_raises(ClientError) as exc: - client.create_state_machine(name=name, - definition=str(simple_definition), - roleArn='arn:aws:iam:000000000000:role/unknown_role') - exc.exception.response['Error']['Code'].should.equal('AccessDeniedException') - exc.exception.response['Error']['Message'].should.equal('Cross-account pass role is not allowed.') - exc.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) @mock_stepfunctions @@ -118,10 +96,10 @@ def test_state_machine_list_returns_created_state_machines(): # machine2 = client.create_state_machine(name='name2', definition=str(simple_definition), - roleArn=default_stepfunction_role) + roleArn=_get_default_role()) machine1 = client.create_state_machine(name='name1', definition=str(simple_definition), - roleArn=default_stepfunction_role, + roleArn=_get_default_role(), tags=[{'key': 'tag_key', 'value': 'tag_value'}]) list = client.list_state_machines() # @@ -142,15 +120,15 @@ def test_state_machine_list_returns_created_state_machines(): def test_state_machine_creation_is_idempotent_by_name(): client = boto3.client('stepfunctions', region_name=region) # - client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) + client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) sm_list = client.list_state_machines() sm_list['stateMachines'].should.have.length_of(1) # - client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) + client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) sm_list = client.list_state_machines() sm_list['stateMachines'].should.have.length_of(1) # - client.create_state_machine(name='diff_name', definition=str(simple_definition), roleArn=default_stepfunction_role) + client.create_state_machine(name='diff_name', definition=str(simple_definition), roleArn=_get_default_role()) sm_list = client.list_state_machines() sm_list['stateMachines'].should.have.length_of(2) @@ -160,13 +138,13 @@ def test_state_machine_creation_is_idempotent_by_name(): def test_state_machine_creation_can_be_described(): client = boto3.client('stepfunctions', region_name=region) # - sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) desc = client.describe_state_machine(stateMachineArn=sm['stateMachineArn']) desc['ResponseMetadata']['HTTPStatusCode'].should.equal(200) desc['creationDate'].should.equal(sm['creationDate']) desc['definition'].should.equal(str(simple_definition)) desc['name'].should.equal('name') - desc['roleArn'].should.equal(default_stepfunction_role) + desc['roleArn'].should.equal(_get_default_role()) desc['stateMachineArn'].should.equal(sm['stateMachineArn']) desc['status'].should.equal('ACTIVE') @@ -177,12 +155,8 @@ def test_state_machine_throws_error_when_describing_unknown_machine(): client = boto3.client('stepfunctions', region_name=region) # with assert_raises(ClientError) as exc: - unknown_state_machine = 'arn:aws:states:' + region + ':' + str(DEFAULT_ACCOUNT_ID) + ':stateMachine:unknown' + unknown_state_machine = 'arn:aws:states:' + region + ':' + _get_account_id() + ':stateMachine:unknown' client.describe_state_machine(stateMachineArn=unknown_state_machine) - exc.exception.response['Error']['Code'].should.equal('StateMachineDoesNotExist') - exc.exception.response['Error']['Message'].\ - should.equal("State Machine Does Not Exist: '" + unknown_state_machine + "'") - exc.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) @mock_stepfunctions @@ -193,16 +167,13 @@ def test_state_machine_throws_error_when_describing_machine_in_different_account with assert_raises(ClientError) as exc: unknown_state_machine = 'arn:aws:states:' + region + ':000000000000:stateMachine:unknown' client.describe_state_machine(stateMachineArn=unknown_state_machine) - exc.exception.response['Error']['Code'].should.equal('AccessDeniedException') - exc.exception.response['Error']['Message'].should.contain('is not authorized to access this resource') - exc.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) @mock_stepfunctions @mock_sts def test_state_machine_can_be_deleted(): client = boto3.client('stepfunctions', region_name=region) - sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) # response = client.delete_state_machine(stateMachineArn=sm['stateMachineArn']) response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) @@ -224,19 +195,6 @@ def test_state_machine_can_deleted_nonexisting_machine(): sm_list['stateMachines'].should.have.length_of(0) -@mock_stepfunctions -@mock_sts -def test_state_machine_deletion_validates_arn(): - client = boto3.client('stepfunctions', region_name=region) - # - with assert_raises(ClientError) as exc: - unknown_account_id = 'arn:aws:states:' + region + ':000000000000:stateMachine:unknown' - client.delete_state_machine(stateMachineArn=unknown_account_id) - exc.exception.response['Error']['Code'].should.equal('AccessDeniedException') - exc.exception.response['Error']['Message'].should.contain('is not authorized to access this resource') - exc.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) - - @mock_stepfunctions @mock_sts def test_state_machine_list_tags_for_created_machine(): @@ -244,7 +202,7 @@ def test_state_machine_list_tags_for_created_machine(): # machine = client.create_state_machine(name='name1', definition=str(simple_definition), - roleArn=default_stepfunction_role, + roleArn=_get_default_role(), tags=[{'key': 'tag_key', 'value': 'tag_value'}]) response = client.list_tags_for_resource(resourceArn=machine['stateMachineArn']) tags = response['tags'] @@ -259,7 +217,7 @@ def test_state_machine_list_tags_for_machine_without_tags(): # machine = client.create_state_machine(name='name1', definition=str(simple_definition), - roleArn=default_stepfunction_role) + roleArn=_get_default_role()) response = client.list_tags_for_resource(resourceArn=machine['stateMachineArn']) tags = response['tags'] tags.should.have.length_of(0) @@ -270,7 +228,7 @@ def test_state_machine_list_tags_for_machine_without_tags(): def test_state_machine_list_tags_for_nonexisting_machine(): client = boto3.client('stepfunctions', region_name=region) # - non_existing_state_machine = 'arn:aws:states:' + region + ':' + str(DEFAULT_ACCOUNT_ID) + ':stateMachine:unknown' + non_existing_state_machine = 'arn:aws:states:' + region + ':' + _get_account_id() + ':stateMachine:unknown' response = client.list_tags_for_resource(resourceArn=non_existing_state_machine) tags = response['tags'] tags.should.have.length_of(0) @@ -281,11 +239,11 @@ def test_state_machine_list_tags_for_nonexisting_machine(): def test_state_machine_start_execution(): client = boto3.client('stepfunctions', region_name=region) # - sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) execution = client.start_execution(stateMachineArn=sm['stateMachineArn']) # execution['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - expected_exec_name = 'arn:aws:states:' + region + ':' + str(DEFAULT_ACCOUNT_ID) + ':execution:name:[a-zA-Z0-9-]+' + expected_exec_name = 'arn:aws:states:' + region + ':' + _get_account_id() + ':execution:name:[a-zA-Z0-9-]+' execution['executionArn'].should.match(expected_exec_name) execution['startDate'].should.be.a(datetime) @@ -295,7 +253,7 @@ def test_state_machine_start_execution(): def test_state_machine_list_executions(): client = boto3.client('stepfunctions', region_name=region) # - sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) execution = client.start_execution(stateMachineArn=sm['stateMachineArn']) execution_arn = execution['executionArn'] execution_name = execution_arn[execution_arn.rindex(':')+1:] @@ -316,7 +274,7 @@ def test_state_machine_list_executions(): def test_state_machine_list_executions_when_none_exist(): client = boto3.client('stepfunctions', region_name=region) # - sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) executions = client.list_executions(stateMachineArn=sm['stateMachineArn']) # executions['ResponseMetadata']['HTTPStatusCode'].should.equal(200) @@ -328,7 +286,7 @@ def test_state_machine_list_executions_when_none_exist(): def test_state_machine_describe_execution(): client = boto3.client('stepfunctions', region_name=region) # - sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) execution = client.start_execution(stateMachineArn=sm['stateMachineArn']) description = client.describe_execution(executionArn=execution['executionArn']) # @@ -348,11 +306,8 @@ def test_state_machine_throws_error_when_describing_unknown_machine(): client = boto3.client('stepfunctions', region_name=region) # with assert_raises(ClientError) as exc: - unknown_execution = 'arn:aws:states:' + region + ':' + str(DEFAULT_ACCOUNT_ID) + ':execution:unknown' + unknown_execution = 'arn:aws:states:' + region + ':' + _get_account_id() + ':execution:unknown' client.describe_execution(executionArn=unknown_execution) - exc.exception.response['Error']['Code'].should.equal('ExecutionDoesNotExist') - exc.exception.response['Error']['Message'].should.equal("Execution Does Not Exist: '" + unknown_execution + "'") - exc.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) @mock_stepfunctions @@ -360,13 +315,13 @@ def test_state_machine_throws_error_when_describing_unknown_machine(): def test_state_machine_can_be_described_by_execution(): client = boto3.client('stepfunctions', region_name=region) # - sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) execution = client.start_execution(stateMachineArn=sm['stateMachineArn']) desc = client.describe_state_machine_for_execution(executionArn=execution['executionArn']) desc['ResponseMetadata']['HTTPStatusCode'].should.equal(200) desc['definition'].should.equal(str(simple_definition)) desc['name'].should.equal('name') - desc['roleArn'].should.equal(default_stepfunction_role) + desc['roleArn'].should.equal(_get_default_role()) desc['stateMachineArn'].should.equal(sm['stateMachineArn']) @@ -376,11 +331,8 @@ def test_state_machine_throws_error_when_describing_unknown_execution(): client = boto3.client('stepfunctions', region_name=region) # with assert_raises(ClientError) as exc: - unknown_execution = 'arn:aws:states:' + region + ':' + str(DEFAULT_ACCOUNT_ID) + ':execution:unknown' + unknown_execution = 'arn:aws:states:' + region + ':' + _get_account_id() + ':execution:unknown' client.describe_state_machine_for_execution(executionArn=unknown_execution) - exc.exception.response['Error']['Code'].should.equal('ExecutionDoesNotExist') - exc.exception.response['Error']['Message'].should.equal("Execution Does Not Exist: '" + unknown_execution + "'") - exc.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) @mock_stepfunctions @@ -388,10 +340,9 @@ def test_state_machine_throws_error_when_describing_unknown_execution(): def test_state_machine_stop_execution(): client = boto3.client('stepfunctions', region_name=region) # - sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) start = client.start_execution(stateMachineArn=sm['stateMachineArn']) stop = client.stop_execution(executionArn=start['executionArn']) - print(stop) # stop['ResponseMetadata']['HTTPStatusCode'].should.equal(200) stop['stopDate'].should.be.a(datetime) @@ -400,9 +351,10 @@ def test_state_machine_stop_execution(): @mock_stepfunctions @mock_sts def test_state_machine_describe_execution_after_stoppage(): + account_id client = boto3.client('stepfunctions', region_name=region) # - sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=default_stepfunction_role) + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) execution = client.start_execution(stateMachineArn=sm['stateMachineArn']) client.stop_execution(executionArn=execution['executionArn']) description = client.describe_execution(executionArn=execution['executionArn']) @@ -410,3 +362,17 @@ def test_state_machine_describe_execution_after_stoppage(): description['ResponseMetadata']['HTTPStatusCode'].should.equal(200) description['status'].should.equal('SUCCEEDED') description['stopDate'].should.be.a(datetime) + + +def _get_account_id(): + global account_id + if account_id: + return account_id + sts = boto3.client("sts") + identity = sts.get_caller_identity() + account_id = identity['Account'] + return account_id + + +def _get_default_role(): + return 'arn:aws:iam:' + _get_account_id() + ':role/unknown_sf_role' From f4df7a48eea6c8706dcaecd2553d907b51b4b5cc Mon Sep 17 00:00:00 2001 From: Julian Graham Date: Mon, 9 Sep 2019 19:08:16 -0400 Subject: [PATCH 37/67] Prevent overlapping expr name prefixes from corrupting projection expr h/t @beheh. This patch handles the case when ProjectionExpression looks like "#1, ..., #10" - the previous code used `replace`, which would make the resulting projection into "foo, ..., foo0". --- moto/dynamodb2/responses.py | 11 ++++++- tests/test_dynamodb2/test_dynamodb.py | 47 +++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 1 deletion(-) diff --git a/moto/dynamodb2/responses.py b/moto/dynamodb2/responses.py index 3e9fbb553..15c1130f8 100644 --- a/moto/dynamodb2/responses.py +++ b/moto/dynamodb2/responses.py @@ -356,9 +356,18 @@ class DynamoHandler(BaseResponse): if projection_expression and expression_attribute_names: expressions = [x.strip() for x in projection_expression.split(',')] + projection_expression = None for expression in expressions: + if projection_expression is not None: + projection_expression = projection_expression + ", " + else: + projection_expression = "" + if expression in expression_attribute_names: - projection_expression = projection_expression.replace(expression, expression_attribute_names[expression]) + projection_expression = projection_expression + \ + expression_attribute_names[expression] + else: + projection_expression = projection_expression + expression filter_kwargs = {} diff --git a/tests/test_dynamodb2/test_dynamodb.py b/tests/test_dynamodb2/test_dynamodb.py index fb6c0e17d..1044f0d50 100644 --- a/tests/test_dynamodb2/test_dynamodb.py +++ b/tests/test_dynamodb2/test_dynamodb.py @@ -973,6 +973,53 @@ def test_query_filter(): assert response['Count'] == 2 +@mock_dynamodb2 +def test_query_filter_overlapping_expression_prefixes(): + client = boto3.client('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + + # Create the DynamoDB table. + client.create_table( + TableName='test1', + AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'S'}], + KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], + ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123} + ) + + client.put_item( + TableName='test1', + Item={ + 'client': {'S': 'client1'}, + 'app': {'S': 'app1'}, + 'nested': {'M': { + 'version': {'S': 'version1'}, + 'contents': {'L': [ + {'S': 'value1'}, {'S': 'value2'}, + ]}, + }}, + }) + + table = dynamodb.Table('test1') + response = table.query( + KeyConditionExpression=Key('client').eq('client1') & Key('app').eq('app1'), + ProjectionExpression='#1, #10, nested', + ExpressionAttributeNames={ + '#1': 'client', + '#10': 'app', + } + ) + + assert response['Count'] == 1 + assert response['Items'][0] == { + 'client': 'client1', + 'app': 'app1', + 'nested': { + 'version': 'version1', + 'contents': ['value1', 'value2'] + } + } + + @mock_dynamodb2 def test_scan_filter(): client = boto3.client('dynamodb', region_name='us-east-1') From 0d4d2b70415c4c2ab03ef6a35bfc8946db6f8ff2 Mon Sep 17 00:00:00 2001 From: William Harvey Date: Tue, 10 Sep 2019 14:24:00 -0400 Subject: [PATCH 38/67] Fix/tighten AWS Batch test_reregister_task_definition unit test --- tests/test_batch/test_batch.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/test_batch/test_batch.py b/tests/test_batch/test_batch.py index 89a8d4d0e..5487cfb91 100644 --- a/tests/test_batch/test_batch.py +++ b/tests/test_batch/test_batch.py @@ -563,6 +563,38 @@ def test_reregister_task_definition(): resp2['jobDefinitionArn'].should_not.equal(resp1['jobDefinitionArn']) + resp3 = batch_client.register_job_definition( + jobDefinitionName='sleep10', + type='container', + containerProperties={ + 'image': 'busybox', + 'vcpus': 1, + 'memory': 42, + 'command': ['sleep', '10'] + } + ) + resp3['revision'].should.equal(3) + + resp3['jobDefinitionArn'].should_not.equal(resp1['jobDefinitionArn']) + resp3['jobDefinitionArn'].should_not.equal(resp2['jobDefinitionArn']) + + resp4 = batch_client.register_job_definition( + jobDefinitionName='sleep10', + type='container', + containerProperties={ + 'image': 'busybox', + 'vcpus': 1, + 'memory': 41, + 'command': ['sleep', '10'] + } + ) + resp4['revision'].should.equal(4) + + resp4['jobDefinitionArn'].should_not.equal(resp1['jobDefinitionArn']) + resp4['jobDefinitionArn'].should_not.equal(resp2['jobDefinitionArn']) + resp4['jobDefinitionArn'].should_not.equal(resp3['jobDefinitionArn']) + + @mock_ec2 @mock_ecs From 21933052d3f797b7b82eae3d5a77d7804c17241d Mon Sep 17 00:00:00 2001 From: Bruno Oliveira Date: Tue, 10 Sep 2019 23:43:50 -0300 Subject: [PATCH 39/67] Fix multiple IAM Policy Statement creation with empty sid --- moto/iam/policy_validation.py | 6 ++++-- tests/test_iam/test_iam_policies.py | 17 +++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/moto/iam/policy_validation.py b/moto/iam/policy_validation.py index 6ee286072..d9a4b0282 100644 --- a/moto/iam/policy_validation.py +++ b/moto/iam/policy_validation.py @@ -152,8 +152,10 @@ class IAMPolicyDocumentValidator: sids = [] for statement in self._statements: if "Sid" in statement: - assert statement["Sid"] not in sids - sids.append(statement["Sid"]) + statementId = statement["Sid"] + if statementId: + assert statementId not in sids + sids.append(statementId) def _validate_statements_syntax(self): assert "Statement" in self._policy_json diff --git a/tests/test_iam/test_iam_policies.py b/tests/test_iam/test_iam_policies.py index e1924a559..adb8bd990 100644 --- a/tests/test_iam/test_iam_policies.py +++ b/tests/test_iam/test_iam_policies.py @@ -1827,6 +1827,23 @@ valid_policy_documents = [ "Resource": ["*"] } ] + }, + { + "Version": "2012-10-17", + "Statement": [ + { + "Sid": "", + "Effect": "Allow", + "Action": "rds:*", + "Resource": ["arn:aws:rds:region:*:*"] + }, + { + "Sid": "", + "Effect": "Allow", + "Action": ["rds:Describe*"], + "Resource": ["*"] + } + ] } ] From efe676dbd51d1b91d1208d25216b7279a3f7bfd1 Mon Sep 17 00:00:00 2001 From: Steve Pulec Date: Wed, 11 Sep 2019 22:07:24 -0500 Subject: [PATCH 40/67] Add comment. --- moto/core/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/moto/core/models.py b/moto/core/models.py index 7ac8adba1..63287608d 100644 --- a/moto/core/models.py +++ b/moto/core/models.py @@ -198,6 +198,7 @@ class CallbackResponse(responses.CallbackResponse): botocore_mock = responses.RequestsMock(assert_all_requests_are_fired=False, target='botocore.vendored.requests.adapters.HTTPAdapter.send') responses_mock = responses._default_mock # Add passthrough to allow any other requests to work +# Since this uses .startswith, it applies to http and https requests. responses_mock.add_passthru("http") From eea67543d154d050ea791f2c0889971eea7260ab Mon Sep 17 00:00:00 2001 From: Gapex <1377762942@qq.com> Date: Thu, 12 Sep 2019 17:54:02 +0800 Subject: [PATCH 41/67] MaxKeys limits the sum of folders and keys --- moto/s3/responses.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/moto/s3/responses.py b/moto/s3/responses.py index ee047a14f..a192cf511 100644 --- a/moto/s3/responses.py +++ b/moto/s3/responses.py @@ -458,10 +458,11 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): else: result_folders = self._get_results_from_token(result_folders, limit) - if not delimiter: - result_keys, is_truncated, next_continuation_token = self._truncate_result(result_keys, max_keys) - else: - result_folders, is_truncated, next_continuation_token = self._truncate_result(result_folders, max_keys) + tagged_keys = [(key, True) for key in result_keys] + tagged_folders = [(folder, False) for folder in result_folders] + all_keys = tagged_keys + tagged_folders + all_keys.sort() + result_keys, result_folders, is_truncated, next_continuation_token = self._truncate_result(all_keys, max_keys) key_count = len(result_keys) + len(result_folders) @@ -487,16 +488,19 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): continuation_index += 1 return result_keys[continuation_index:] - def _truncate_result(self, result_keys, max_keys): - if len(result_keys) > max_keys: + def _truncate_result(self, all_keys, max_keys): + if len(all_keys) > max_keys: is_truncated = 'true' - result_keys = result_keys[:max_keys] - item = result_keys[-1] + all_keys = all_keys[:max_keys] + item = all_keys[-1][0] next_continuation_token = (item.name if isinstance(item, FakeKey) else item) else: is_truncated = 'false' next_continuation_token = None - return result_keys, is_truncated, next_continuation_token + result_keys, result_folders = [], [] + for (key, is_key) in all_keys: + (result_keys if is_key else result_folders).append(key) + return result_keys, result_folders, is_truncated, next_continuation_token def _bucket_response_put(self, request, body, region_name, bucket_name, querystring): if not request.headers.get('Content-Length'): From d6ef01b9fdfb1521c2adb1cf92ede0933e1648d2 Mon Sep 17 00:00:00 2001 From: Gapex <1377762942@qq.com> Date: Thu, 12 Sep 2019 18:40:07 +0800 Subject: [PATCH 42/67] lint --- moto/s3/responses.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/moto/s3/responses.py b/moto/s3/responses.py index a192cf511..f4640023e 100644 --- a/moto/s3/responses.py +++ b/moto/s3/responses.py @@ -2,6 +2,8 @@ from __future__ import unicode_literals import re +from itertools import chain + import six from moto.core.utils import str_to_rfc_1123_datetime @@ -458,11 +460,10 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): else: result_folders = self._get_results_from_token(result_folders, limit) - tagged_keys = [(key, True) for key in result_keys] - tagged_folders = [(folder, False) for folder in result_folders] - all_keys = tagged_keys + tagged_folders - all_keys.sort() - result_keys, result_folders, is_truncated, next_continuation_token = self._truncate_result(all_keys, max_keys) + tagged_keys = ((key, True) for key in result_keys) + tagged_folders = ((folder, False) for folder in result_folders) + sorted_keys = sorted(chain(tagged_keys, tagged_folders)) + result_keys, result_folders, is_truncated, next_continuation_token = self._truncate_result(sorted_keys, max_keys) key_count = len(result_keys) + len(result_folders) @@ -488,17 +489,17 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): continuation_index += 1 return result_keys[continuation_index:] - def _truncate_result(self, all_keys, max_keys): - if len(all_keys) > max_keys: + def _truncate_result(self, sorted_keys, max_keys): + if len(sorted_keys) > max_keys: is_truncated = 'true' - all_keys = all_keys[:max_keys] - item = all_keys[-1][0] + sorted_keys = sorted_keys[:max_keys] + item = sorted_keys[-1][0] next_continuation_token = (item.name if isinstance(item, FakeKey) else item) else: is_truncated = 'false' next_continuation_token = None result_keys, result_folders = [], [] - for (key, is_key) in all_keys: + for (key, is_key) in sorted_keys: (result_keys if is_key else result_folders).append(key) return result_keys, result_folders, is_truncated, next_continuation_token From 2983a63c0d2cefa01de2c7e9038eddab69078179 Mon Sep 17 00:00:00 2001 From: Jessie Nadler Date: Thu, 12 Sep 2019 11:24:47 -0400 Subject: [PATCH 43/67] Allow fixed-response action type for elbv2 --- moto/elbv2/exceptions.py | 2 +- moto/elbv2/models.py | 14 +++- tests/test_elbv2/test_elbv2.py | 129 +++++++++++++++++++++++++++++++++ 3 files changed, 140 insertions(+), 5 deletions(-) diff --git a/moto/elbv2/exceptions.py b/moto/elbv2/exceptions.py index 11dcbcb21..f67db5143 100644 --- a/moto/elbv2/exceptions.py +++ b/moto/elbv2/exceptions.py @@ -131,7 +131,7 @@ class InvalidActionTypeError(ELBClientError): def __init__(self, invalid_name, index): super(InvalidActionTypeError, self).__init__( "ValidationError", - "1 validation error detected: Value '%s' at 'actions.%s.member.type' failed to satisfy constraint: Member must satisfy enum value set: [forward, redirect]" % (invalid_name, index) + "1 validation error detected: Value '%s' at 'actions.%s.member.type' failed to satisfy constraint: Member must satisfy enum value set: [forward, redirect, fixed-response]" % (invalid_name, index) ) diff --git a/moto/elbv2/models.py b/moto/elbv2/models.py index 726799fe5..28c91c3e5 100644 --- a/moto/elbv2/models.py +++ b/moto/elbv2/models.py @@ -6,7 +6,7 @@ from jinja2 import Template from moto.compat import OrderedDict from moto.core.exceptions import RESTError from moto.core import BaseBackend, BaseModel -from moto.core.utils import camelcase_to_underscores +from moto.core.utils import camelcase_to_underscores, underscores_to_camelcase from moto.ec2.models import ec2_backends from moto.acm.models import acm_backends from .utils import make_arn_for_target_group @@ -220,9 +220,9 @@ class FakeListener(BaseModel): action_type = action['Type'] if action_type == 'forward': default_actions.append({'type': action_type, 'target_group_arn': action['TargetGroupArn']}) - elif action_type in ['redirect', 'authenticate-cognito']: + elif action_type in ['redirect', 'authenticate-cognito', 'fixed-response']: redirect_action = {'type': action_type} - key = 'RedirectConfig' if action_type == 'redirect' else 'AuthenticateCognitoConfig' + key = underscores_to_camelcase(action_type.capitalize().replace('-', '_')) + 'Config' for redirect_config_key, redirect_config_value in action[key].items(): # need to match the output of _get_list_prefix redirect_action[camelcase_to_underscores(key) + '._' + camelcase_to_underscores(redirect_config_key)] = redirect_config_value @@ -258,6 +258,12 @@ class FakeAction(BaseModel): {{ action.data["authenticate_cognito_config._user_pool_client_id"] }} {{ action.data["authenticate_cognito_config._user_pool_domain"] }} + {% elif action.type == "fixed-response" %} + + {{ action.data["fixed_response_config._content_type"] }} + {{ action.data["fixed_response_config._message_body"] }} + {{ action.data["fixed_response_config._status_code"] }} + {% endif %} """) return template.render(action=self) @@ -482,7 +488,7 @@ class ELBv2Backend(BaseBackend): action_target_group_arn = action.data['target_group_arn'] if action_target_group_arn not in target_group_arns: raise ActionTargetGroupNotFoundError(action_target_group_arn) - elif action_type in ['redirect', 'authenticate-cognito']: + elif action_type in ['redirect', 'authenticate-cognito', 'fixed-response']: pass else: raise InvalidActionTypeError(action_type, index) diff --git a/tests/test_elbv2/test_elbv2.py b/tests/test_elbv2/test_elbv2.py index b2512a3f1..538a2d911 100644 --- a/tests/test_elbv2/test_elbv2.py +++ b/tests/test_elbv2/test_elbv2.py @@ -2017,3 +2017,132 @@ def test_cognito_action_listener_rule_cloudformation(): 'UserPoolDomain': 'testpool', } },]) + + +@mock_elbv2 +@mock_ec2 +def test_fixed_response_action_listener_rule(): + conn = boto3.client('elbv2', region_name='us-east-1') + ec2 = boto3.resource('ec2', region_name='us-east-1') + + security_group = ec2.create_security_group( + GroupName='a-security-group', Description='First One') + vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + subnet1 = ec2.create_subnet( + VpcId=vpc.id, + CidrBlock='172.28.7.192/26', + AvailabilityZone='us-east-1a') + subnet2 = ec2.create_subnet( + VpcId=vpc.id, + CidrBlock='172.28.7.128/26', + AvailabilityZone='us-east-1b') + + response = conn.create_load_balancer( + Name='my-lb', + Subnets=[subnet1.id, subnet2.id], + SecurityGroups=[security_group.id], + Scheme='internal', + Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + load_balancer_arn = response.get('LoadBalancers')[0].get('LoadBalancerArn') + + action = { + 'Type': 'fixed-response', + 'FixedResponseConfig': { + 'ContentType': 'text/plain', + 'MessageBody': 'This page does not exist', + 'StatusCode': '404', + } + } + response = conn.create_listener(LoadBalancerArn=load_balancer_arn, + Protocol='HTTP', + Port=80, + DefaultActions=[action]) + + listener = response.get('Listeners')[0] + listener.get('DefaultActions')[0].should.equal(action) + listener_arn = listener.get('ListenerArn') + + describe_rules_response = conn.describe_rules(ListenerArn=listener_arn) + describe_rules_response['Rules'][0]['Actions'][0].should.equal(action) + + describe_listener_response = conn.describe_listeners(ListenerArns=[listener_arn, ]) + describe_listener_actions = describe_listener_response['Listeners'][0]['DefaultActions'][0] + describe_listener_actions.should.equal(action) + + +@mock_elbv2 +@mock_cloudformation +def test_fixed_response_action_listener_rule_cloudformation(): + cnf_conn = boto3.client('cloudformation', region_name='us-east-1') + elbv2_client = boto3.client('elbv2', region_name='us-east-1') + + template = { + "AWSTemplateFormatVersion": "2010-09-09", + "Description": "ECS Cluster Test CloudFormation", + "Resources": { + "testVPC": { + "Type": "AWS::EC2::VPC", + "Properties": { + "CidrBlock": "10.0.0.0/16", + }, + }, + "subnet1": { + "Type": "AWS::EC2::Subnet", + "Properties": { + "CidrBlock": "10.0.0.0/24", + "VpcId": {"Ref": "testVPC"}, + "AvalabilityZone": "us-east-1b", + }, + }, + "subnet2": { + "Type": "AWS::EC2::Subnet", + "Properties": { + "CidrBlock": "10.0.1.0/24", + "VpcId": {"Ref": "testVPC"}, + "AvalabilityZone": "us-east-1b", + }, + }, + "testLb": { + "Type": "AWS::ElasticLoadBalancingV2::LoadBalancer", + "Properties": { + "Name": "my-lb", + "Subnets": [{"Ref": "subnet1"}, {"Ref": "subnet2"}], + "Type": "application", + "SecurityGroups": [], + } + }, + "testListener": { + "Type": "AWS::ElasticLoadBalancingV2::Listener", + "Properties": { + "LoadBalancerArn": {"Ref": "testLb"}, + "Port": 80, + "Protocol": "HTTP", + "DefaultActions": [{ + "Type": "fixed-response", + "FixedResponseConfig": { + 'ContentType': 'text/plain', + 'MessageBody': 'This page does not exist', + 'StatusCode': '404', + } + }] + } + + } + } + } + template_json = json.dumps(template) + cnf_conn.create_stack(StackName="test-stack", TemplateBody=template_json) + + describe_load_balancers_response = elbv2_client.describe_load_balancers(Names=['my-lb',]) + load_balancer_arn = describe_load_balancers_response['LoadBalancers'][0]['LoadBalancerArn'] + describe_listeners_response = elbv2_client.describe_listeners(LoadBalancerArn=load_balancer_arn) + + describe_listeners_response['Listeners'].should.have.length_of(1) + describe_listeners_response['Listeners'][0]['DefaultActions'].should.equal([{ + 'Type': 'fixed-response', + "FixedResponseConfig": { + 'ContentType': 'text/plain', + 'MessageBody': 'This page does not exist', + 'StatusCode': '404', + } + },]) From 2b8bdc9bca85d97852e0bd9b4fddccd009554d35 Mon Sep 17 00:00:00 2001 From: Jessie Nadler Date: Thu, 12 Sep 2019 12:29:03 -0400 Subject: [PATCH 44/67] Validate elbv2 FixedResponseConfig attributes --- moto/elbv2/exceptions.py | 15 ++++ moto/elbv2/models.py | 26 +++++- tests/test_elbv2/test_elbv2.py | 149 ++++++++++++++++++++++++++++++++- 3 files changed, 186 insertions(+), 4 deletions(-) diff --git a/moto/elbv2/exceptions.py b/moto/elbv2/exceptions.py index f67db5143..ccbfd06dd 100644 --- a/moto/elbv2/exceptions.py +++ b/moto/elbv2/exceptions.py @@ -190,3 +190,18 @@ class InvalidModifyRuleArgumentsError(ELBClientError): "ValidationError", "Either conditions or actions must be specified" ) + + +class InvalidStatusCodeActionTypeError(ELBClientError): + def __init__(self, msg): + super(InvalidStatusCodeActionTypeError, self).__init__( + "ValidationError", msg + ) + + +class InvalidLoadBalancerActionException(ELBClientError): + + def __init__(self, msg): + super(InvalidLoadBalancerActionException, self).__init__( + "InvalidLoadBalancerAction", msg + ) diff --git a/moto/elbv2/models.py b/moto/elbv2/models.py index 28c91c3e5..636cc56a1 100644 --- a/moto/elbv2/models.py +++ b/moto/elbv2/models.py @@ -3,6 +3,7 @@ from __future__ import unicode_literals import datetime import re from jinja2 import Template +from botocore.exceptions import ParamValidationError from moto.compat import OrderedDict from moto.core.exceptions import RESTError from moto.core import BaseBackend, BaseModel @@ -31,8 +32,8 @@ from .exceptions import ( RuleNotFoundError, DuplicatePriorityError, InvalidTargetGroupNameError, - InvalidModifyRuleArgumentsError -) + InvalidModifyRuleArgumentsError, + InvalidStatusCodeActionTypeError, InvalidLoadBalancerActionException) class FakeHealthStatus(BaseModel): @@ -488,11 +489,30 @@ class ELBv2Backend(BaseBackend): action_target_group_arn = action.data['target_group_arn'] if action_target_group_arn not in target_group_arns: raise ActionTargetGroupNotFoundError(action_target_group_arn) - elif action_type in ['redirect', 'authenticate-cognito', 'fixed-response']: + elif action_type == 'fixed-response': + self._validate_fixed_response_action(action, i, index) + elif action_type in ['redirect', 'authenticate-cognito']: pass else: raise InvalidActionTypeError(action_type, index) + def _validate_fixed_response_action(self, action, i, index): + status_code = action.data.get('fixed_response_config._status_code') + if status_code is None: + raise ParamValidationError( + report='Missing required parameter in Actions[%s].FixedResponseConfig: "StatusCode"' % i) + if not re.match(r'^(2|4|5)\d\d$', status_code): + raise InvalidStatusCodeActionTypeError( + "1 validation error detected: Value '%s' at 'actions.%s.member.fixedResponseConfig.statusCode' failed to satisfy constraint: \ +Member must satisfy regular expression pattern: ^(2|4|5)\d\d$" % (status_code, index) + ) + content_type = action.data['fixed_response_config._content_type'] + if content_type and content_type not in ['text/plain', 'text/css', 'text/html', 'application/javascript', + 'application/json']: + raise InvalidLoadBalancerActionException( + "The ContentType must be one of:'text/html', 'application/json', 'application/javascript', 'text/css', 'text/plain'" + ) + def create_target_group(self, name, **kwargs): if len(name) > 32: raise InvalidTargetGroupNameError( diff --git a/tests/test_elbv2/test_elbv2.py b/tests/test_elbv2/test_elbv2.py index 538a2d911..97b876fec 100644 --- a/tests/test_elbv2/test_elbv2.py +++ b/tests/test_elbv2/test_elbv2.py @@ -4,7 +4,7 @@ import json import os import boto3 import botocore -from botocore.exceptions import ClientError +from botocore.exceptions import ClientError, ParamValidationError from nose.tools import assert_raises import sure # noqa @@ -2146,3 +2146,150 @@ def test_fixed_response_action_listener_rule_cloudformation(): 'StatusCode': '404', } },]) + + +@mock_elbv2 +@mock_ec2 +def test_fixed_response_action_listener_rule_validates_status_code(): + conn = boto3.client('elbv2', region_name='us-east-1') + ec2 = boto3.resource('ec2', region_name='us-east-1') + + security_group = ec2.create_security_group( + GroupName='a-security-group', Description='First One') + vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + subnet1 = ec2.create_subnet( + VpcId=vpc.id, + CidrBlock='172.28.7.192/26', + AvailabilityZone='us-east-1a') + subnet2 = ec2.create_subnet( + VpcId=vpc.id, + CidrBlock='172.28.7.128/26', + AvailabilityZone='us-east-1b') + + response = conn.create_load_balancer( + Name='my-lb', + Subnets=[subnet1.id, subnet2.id], + SecurityGroups=[security_group.id], + Scheme='internal', + Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + load_balancer_arn = response.get('LoadBalancers')[0].get('LoadBalancerArn') + + missing_status_code_action = { + 'Type': 'fixed-response', + 'FixedResponseConfig': { + 'ContentType': 'text/plain', + 'MessageBody': 'This page does not exist', + } + } + with assert_raises(ParamValidationError): + conn.create_listener(LoadBalancerArn=load_balancer_arn, + Protocol='HTTP', + Port=80, + DefaultActions=[missing_status_code_action]) + + invalid_status_code_action = { + 'Type': 'fixed-response', + 'FixedResponseConfig': { + 'ContentType': 'text/plain', + 'MessageBody': 'This page does not exist', + 'StatusCode': '100' + } + } + + @mock_elbv2 + @mock_ec2 + def test_fixed_response_action_listener_rule_validates_status_code(): + conn = boto3.client('elbv2', region_name='us-east-1') + ec2 = boto3.resource('ec2', region_name='us-east-1') + + security_group = ec2.create_security_group( + GroupName='a-security-group', Description='First One') + vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + subnet1 = ec2.create_subnet( + VpcId=vpc.id, + CidrBlock='172.28.7.192/26', + AvailabilityZone='us-east-1a') + subnet2 = ec2.create_subnet( + VpcId=vpc.id, + CidrBlock='172.28.7.128/26', + AvailabilityZone='us-east-1b') + + response = conn.create_load_balancer( + Name='my-lb', + Subnets=[subnet1.id, subnet2.id], + SecurityGroups=[security_group.id], + Scheme='internal', + Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + load_balancer_arn = response.get('LoadBalancers')[0].get('LoadBalancerArn') + + missing_status_code_action = { + 'Type': 'fixed-response', + 'FixedResponseConfig': { + 'ContentType': 'text/plain', + 'MessageBody': 'This page does not exist', + } + } + with assert_raises(ParamValidationError): + conn.create_listener(LoadBalancerArn=load_balancer_arn, + Protocol='HTTP', + Port=80, + DefaultActions=[missing_status_code_action]) + + invalid_status_code_action = { + 'Type': 'fixed-response', + 'FixedResponseConfig': { + 'ContentType': 'text/plain', + 'MessageBody': 'This page does not exist', + 'StatusCode': '100' + } + } + + with assert_raises(ClientError) as invalid_status_code_exception: + conn.create_listener(LoadBalancerArn=load_balancer_arn, + Protocol='HTTP', + Port=80, + DefaultActions=[invalid_status_code_action]) + + invalid_status_code_exception.exception.response['Error']['Code'].should.equal('ValidationError') + + +@mock_elbv2 +@mock_ec2 +def test_fixed_response_action_listener_rule_validates_content_type(): + conn = boto3.client('elbv2', region_name='us-east-1') + ec2 = boto3.resource('ec2', region_name='us-east-1') + + security_group = ec2.create_security_group( + GroupName='a-security-group', Description='First One') + vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + subnet1 = ec2.create_subnet( + VpcId=vpc.id, + CidrBlock='172.28.7.192/26', + AvailabilityZone='us-east-1a') + subnet2 = ec2.create_subnet( + VpcId=vpc.id, + CidrBlock='172.28.7.128/26', + AvailabilityZone='us-east-1b') + + response = conn.create_load_balancer( + Name='my-lb', + Subnets=[subnet1.id, subnet2.id], + SecurityGroups=[security_group.id], + Scheme='internal', + Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + load_balancer_arn = response.get('LoadBalancers')[0].get('LoadBalancerArn') + + invalid_content_type_action = { + 'Type': 'fixed-response', + 'FixedResponseConfig': { + 'ContentType': 'Fake content type', + 'MessageBody': 'This page does not exist', + 'StatusCode': '200' + } + } + with assert_raises(ClientError) as invalid_content_type_exception: + conn.create_listener(LoadBalancerArn=load_balancer_arn, + Protocol='HTTP', + Port=80, + DefaultActions=[invalid_content_type_action]) + invalid_content_type_exception.exception.response['Error']['Code'].should.equal('InvalidLoadBalancerAction') From f4c5dfbdfb09122b9d39b227cf774e2f74e15ece Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Tue, 27 Aug 2019 23:52:44 -0700 Subject: [PATCH 45/67] reorganize and centralize key ID validation --- moto/kms/models.py | 30 +++-- moto/kms/responses.py | 248 +++++++++++++++++++++++++------------ tests/test_kms/test_kms.py | 82 ++++++++---- 3 files changed, 249 insertions(+), 111 deletions(-) diff --git a/moto/kms/models.py b/moto/kms/models.py index 5f89407f5..1f5c77560 100644 --- a/moto/kms/models.py +++ b/moto/kms/models.py @@ -1,13 +1,17 @@ from __future__ import unicode_literals import os -import boto.kms -from moto.core import BaseBackend, BaseModel -from moto.core.utils import iso_8601_datetime_without_milliseconds -from .utils import decrypt, encrypt, generate_key_id, generate_master_key +import re from collections import defaultdict from datetime import datetime, timedelta +import boto.kms + +from moto.core import BaseBackend, BaseModel +from moto.core.utils import iso_8601_datetime_without_milliseconds + +from .utils import decrypt, encrypt, generate_key_id, generate_master_key + class Key(BaseModel): def __init__(self, policy, key_usage, description, tags, region): @@ -18,7 +22,7 @@ class Key(BaseModel): self.description = description self.enabled = True self.region = region - self.account_id = "0123456789012" + self.account_id = "012345678912" self.key_rotation_status = False self.deletion_date = None self.tags = tags or {} @@ -116,13 +120,21 @@ class KmsBackend(BaseBackend): def list_keys(self): return self.keys.values() - def get_key_id(self, key_id): + @staticmethod + def get_key_id(key_id): # Allow use of ARN as well as pure KeyId - return str(key_id).split(r":key/")[1] if r":key/" in str(key_id).lower() else key_id + if key_id.startswith("arn:") and ":key/" in key_id: + return key_id.split(":key/")[1] - def get_alias_name(self, alias_name): + return key_id + + @staticmethod + def get_alias_name(alias_name): # Allow use of ARN as well as alias name - return str(alias_name).split(r":alias/")[1] if r":alias/" in str(alias_name).lower() else alias_name + if alias_name.startswith("arn:") and ":alias/" in alias_name: + return alias_name.split(":alias/")[1] + + return alias_name def any_id_to_key_id(self, key_id): """Go from any valid key ID to the raw key ID. diff --git a/moto/kms/responses.py b/moto/kms/responses.py index fecb391d3..201084fd7 100644 --- a/moto/kms/responses.py +++ b/moto/kms/responses.py @@ -11,6 +11,7 @@ from moto.core.responses import BaseResponse from .models import kms_backends from .exceptions import NotFoundException, ValidationException, AlreadyExistsException, NotAuthorizedException +ACCOUNT_ID = "012345678912" reserved_aliases = [ 'alias/aws/ebs', 'alias/aws/s3', @@ -35,7 +36,74 @@ class KmsResponse(BaseResponse): def kms_backend(self): return kms_backends[self.region] + def _display_arn(self, key_id): + if key_id.startswith("arn:"): + return key_id + + if key_id.startswith("alias/"): + id_type = "" + else: + id_type = "key/" + + return "arn:aws:kms:{region}:{account}:{id_type}{key_id}".format( + region=self.region, account=ACCOUNT_ID, id_type=id_type, key_id=key_id + ) + + def _validate_cmk_id(self, key_id): + """Determine whether a CMK ID exists. + + - raw key ID + - key ARN + """ + is_arn = key_id.startswith("arn:") and ":key/" in key_id + is_raw_key_id = re.match(r"^[A-F0-9]{8}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{12}$", key_id, re.IGNORECASE) + + if not is_arn and not is_raw_key_id: + raise NotFoundException("Invalid keyId {key_id}".format(key_id=key_id)) + + cmk_id = self.kms_backend.get_key_id(key_id) + + if cmk_id not in self.kms_backend.keys: + raise NotFoundException("Key '{key_id}' does not exist".format(key_id=self._display_arn(key_id))) + + def _validate_alias(self, key_id): + """Determine whether an alias exists. + + - alias name + - alias ARN + """ + error = NotFoundException("Alias {key_id} is not found.".format(key_id=self._display_arn(key_id))) + + is_arn = key_id.startswith("arn:") and ":alias/" in key_id + is_name = key_id.startswith("alias/") + + if not is_arn and not is_name: + raise error + + alias_name = self.kms_backend.get_alias_name(key_id) + cmk_id = self.kms_backend.get_key_id_from_alias(alias_name) + if cmk_id is None: + raise error + + def _validate_key_id(self, key_id): + """Determine whether or not a key ID exists. + + - raw key ID + - key ARN + - alias name + - alias ARN + """ + is_alias_arn = key_id.startswith("arn:") and ":alias/" in key_id + is_alias_name = key_id.startswith("alias/") + + if is_alias_arn or is_alias_name: + self._validate_alias(key_id) + return + + self._validate_cmk_id(key_id) + def create_key(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_CreateKey.html""" policy = self.parameters.get('Policy') key_usage = self.parameters.get('KeyUsage') description = self.parameters.get('Description') @@ -46,20 +114,31 @@ class KmsResponse(BaseResponse): return json.dumps(key.to_dict()) def update_key_description(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_UpdateKeyDescription.html""" key_id = self.parameters.get('KeyId') description = self.parameters.get('Description') + self._validate_cmk_id(key_id) + self.kms_backend.update_key_description(key_id, description) return json.dumps(None) def tag_resource(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_TagResource.html""" key_id = self.parameters.get('KeyId') tags = self.parameters.get('Tags') + + self._validate_cmk_id(key_id) + self.kms_backend.tag_resource(key_id, tags) return json.dumps({}) def list_resource_tags(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_ListResourceTags.html""" key_id = self.parameters.get('KeyId') + + self._validate_cmk_id(key_id) + tags = self.kms_backend.list_resource_tags(key_id) return json.dumps({ "Tags": tags, @@ -68,7 +147,11 @@ class KmsResponse(BaseResponse): }) def describe_key(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_DescribeKey.html""" key_id = self.parameters.get('KeyId') + + self._validate_key_id(key_id) + try: key = self.kms_backend.describe_key( self.kms_backend.get_key_id(key_id)) @@ -79,6 +162,7 @@ class KmsResponse(BaseResponse): return json.dumps(key.to_dict()) def list_keys(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_ListKeys.html""" keys = self.kms_backend.list_keys() return json.dumps({ @@ -93,6 +177,7 @@ class KmsResponse(BaseResponse): }) def create_alias(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_CreateAlias.html""" alias_name = self.parameters['AliasName'] target_key_id = self.parameters['TargetKeyId'] @@ -118,27 +203,31 @@ class KmsResponse(BaseResponse): raise AlreadyExistsException('An alias with the name arn:aws:kms:{region}:012345678912:{alias_name} ' 'already exists'.format(region=self.region, alias_name=alias_name)) + self._validate_cmk_id(target_key_id) + self.kms_backend.add_alias(target_key_id, alias_name) return json.dumps(None) def delete_alias(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_DeleteAlias.html""" alias_name = self.parameters['AliasName'] if not alias_name.startswith('alias/'): raise ValidationException('Invalid identifier') - if not self.kms_backend.alias_exists(alias_name): - raise NotFoundException('Alias arn:aws:kms:{region}:012345678912:' - '{alias_name} is not found.'.format(region=self.region, alias_name=alias_name)) + self._validate_alias(alias_name) self.kms_backend.delete_alias(alias_name) return json.dumps(None) def list_aliases(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_ListAliases.html""" region = self.region + # TODO: The actual API can filter on KeyId. + response_aliases = [ { 'AliasArn': u'arn:aws:kms:{region}:012345678912:{reserved_alias}'.format(region=region, @@ -163,79 +252,76 @@ class KmsResponse(BaseResponse): }) def enable_key_rotation(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKeyRotation.html""" key_id = self.parameters.get('KeyId') - _assert_valid_key_id(self.kms_backend.get_key_id(key_id)) - try: - self.kms_backend.enable_key_rotation(key_id) - except KeyError: - raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/" - "{key_id}' does not exist".format(region=self.region, key_id=key_id)) + + self._validate_cmk_id(key_id) + + self.kms_backend.enable_key_rotation(key_id) return json.dumps(None) def disable_key_rotation(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKeyRotation.html""" key_id = self.parameters.get('KeyId') - _assert_valid_key_id(self.kms_backend.get_key_id(key_id)) - try: - self.kms_backend.disable_key_rotation(key_id) - except KeyError: - raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/" - "{key_id}' does not exist".format(region=self.region, key_id=key_id)) + + self._validate_cmk_id(key_id) + + self.kms_backend.disable_key_rotation(key_id) + return json.dumps(None) def get_key_rotation_status(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_GetKeyRotationStatus.html""" key_id = self.parameters.get('KeyId') - _assert_valid_key_id(self.kms_backend.get_key_id(key_id)) - try: - rotation_enabled = self.kms_backend.get_key_rotation_status(key_id) - except KeyError: - raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/" - "{key_id}' does not exist".format(region=self.region, key_id=key_id)) + + self._validate_cmk_id(key_id) + + rotation_enabled = self.kms_backend.get_key_rotation_status(key_id) + return json.dumps({'KeyRotationEnabled': rotation_enabled}) def put_key_policy(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_PutKeyPolicy.html""" key_id = self.parameters.get('KeyId') policy_name = self.parameters.get('PolicyName') policy = self.parameters.get('Policy') - _assert_valid_key_id(self.kms_backend.get_key_id(key_id)) _assert_default_policy(policy_name) - try: - self.kms_backend.put_key_policy(key_id, policy) - except KeyError: - raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/" - "{key_id}' does not exist".format(region=self.region, key_id=key_id)) + self._validate_cmk_id(key_id) + + self.kms_backend.put_key_policy(key_id, policy) return json.dumps(None) def get_key_policy(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_GetKeyPolicy.html""" key_id = self.parameters.get('KeyId') policy_name = self.parameters.get('PolicyName') - _assert_valid_key_id(self.kms_backend.get_key_id(key_id)) _assert_default_policy(policy_name) - try: - return json.dumps({'Policy': self.kms_backend.get_key_policy(key_id)}) - except KeyError: - raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/" - "{key_id}' does not exist".format(region=self.region, key_id=key_id)) + self._validate_cmk_id(key_id) + + return json.dumps({'Policy': self.kms_backend.get_key_policy(key_id)}) def list_key_policies(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_ListKeyPolicies.html""" key_id = self.parameters.get('KeyId') - _assert_valid_key_id(self.kms_backend.get_key_id(key_id)) - try: - self.kms_backend.describe_key(key_id) - except KeyError: - raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/" - "{key_id}' does not exist".format(region=self.region, key_id=key_id)) + + self._validate_cmk_id(key_id) + + self.kms_backend.describe_key(key_id) return json.dumps({'Truncated': False, 'PolicyNames': ['default']}) def encrypt(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_Encrypt.html""" key_id = self.parameters.get("KeyId") encryption_context = self.parameters.get('EncryptionContext', {}) plaintext = self.parameters.get("Plaintext") + self._validate_key_id(key_id) + if isinstance(plaintext, six.text_type): plaintext = plaintext.encode('utf-8') @@ -249,6 +335,7 @@ class KmsResponse(BaseResponse): return json.dumps({"CiphertextBlob": ciphertext_blob_response, "KeyId": arn}) def decrypt(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_Decrypt.html""" ciphertext_blob = self.parameters.get("CiphertextBlob") encryption_context = self.parameters.get('EncryptionContext', {}) @@ -262,11 +349,14 @@ class KmsResponse(BaseResponse): return json.dumps({"Plaintext": plaintext_response, 'KeyId': arn}) def re_encrypt(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_ReEncrypt.html""" ciphertext_blob = self.parameters.get("CiphertextBlob") source_encryption_context = self.parameters.get("SourceEncryptionContext", {}) destination_key_id = self.parameters.get("DestinationKeyId") destination_encryption_context = self.parameters.get("DestinationEncryptionContext", {}) + self._validate_cmk_id(destination_key_id) + new_ciphertext_blob, decrypting_arn, encrypting_arn = self.kms_backend.re_encrypt( ciphertext_blob=ciphertext_blob, source_encryption_context=source_encryption_context, @@ -281,52 +371,52 @@ class KmsResponse(BaseResponse): ) def disable_key(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_DisableKey.html""" key_id = self.parameters.get('KeyId') - _assert_valid_key_id(self.kms_backend.get_key_id(key_id)) - try: - self.kms_backend.disable_key(key_id) - except KeyError: - raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/" - "{key_id}' does not exist".format(region=self.region, key_id=key_id)) + + self._validate_cmk_id(key_id) + + self.kms_backend.disable_key(key_id) + return json.dumps(None) def enable_key(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKey.html""" key_id = self.parameters.get('KeyId') - _assert_valid_key_id(self.kms_backend.get_key_id(key_id)) - try: - self.kms_backend.enable_key(key_id) - except KeyError: - raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/" - "{key_id}' does not exist".format(region=self.region, key_id=key_id)) + + self._validate_cmk_id(key_id) + + self.kms_backend.enable_key(key_id) + return json.dumps(None) def cancel_key_deletion(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_CancelKeyDeletion.html""" key_id = self.parameters.get('KeyId') - _assert_valid_key_id(self.kms_backend.get_key_id(key_id)) - try: - self.kms_backend.cancel_key_deletion(key_id) - except KeyError: - raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/" - "{key_id}' does not exist".format(region=self.region, key_id=key_id)) + + self._validate_cmk_id(key_id) + + self.kms_backend.cancel_key_deletion(key_id) + return json.dumps({'KeyId': key_id}) def schedule_key_deletion(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_ScheduleKeyDeletion.html""" key_id = self.parameters.get('KeyId') if self.parameters.get('PendingWindowInDays') is None: pending_window_in_days = 30 else: pending_window_in_days = self.parameters.get('PendingWindowInDays') - _assert_valid_key_id(self.kms_backend.get_key_id(key_id)) - try: - return json.dumps({ - 'KeyId': key_id, - 'DeletionDate': self.kms_backend.schedule_key_deletion(key_id, pending_window_in_days) - }) - except KeyError: - raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/" - "{key_id}' does not exist".format(region=self.region, key_id=key_id)) + + self._validate_cmk_id(key_id) + + return json.dumps({ + 'KeyId': key_id, + 'DeletionDate': self.kms_backend.schedule_key_deletion(key_id, pending_window_in_days) + }) def generate_data_key(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateDataKey.html""" key_id = self.parameters.get('KeyId') encryption_context = self.parameters.get('EncryptionContext', {}) number_of_bytes = self.parameters.get('NumberOfBytes') @@ -334,15 +424,9 @@ class KmsResponse(BaseResponse): grant_tokens = self.parameters.get('GrantTokens') # Param validation - if key_id.startswith('alias'): - if self.kms_backend.get_key_id_from_alias(key_id) is None: - raise NotFoundException('Alias arn:aws:kms:{region}:012345678912:{alias_name} is not found.'.format( - region=self.region, alias_name=key_id)) - else: - if self.kms_backend.get_key_id(key_id) not in self.kms_backend.keys: - raise NotFoundException('Invalid keyId') + self._validate_key_id(key_id) - if number_of_bytes and (number_of_bytes > 1024 or number_of_bytes < 0): + if number_of_bytes and (number_of_bytes > 1024 or number_of_bytes < 1): raise ValidationException(( "1 validation error detected: Value '{number_of_bytes:d}' at 'numberOfBytes' failed " "to satisfy constraint: Member must have value less than or " @@ -357,6 +441,7 @@ class KmsResponse(BaseResponse): ).format(key_spec=key_spec)) if not key_spec and not number_of_bytes: raise ValidationException("Please specify either number of bytes or key spec.") + if key_spec and number_of_bytes: raise ValidationException("Please specify either number of bytes or key spec.") @@ -378,14 +463,24 @@ class KmsResponse(BaseResponse): }) def generate_data_key_without_plaintext(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateDataKeyWithoutPlaintext.html""" result = json.loads(self.generate_data_key()) del result['Plaintext'] return json.dumps(result) def generate_random(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateRandom.html""" number_of_bytes = self.parameters.get("NumberOfBytes") + if number_of_bytes and (number_of_bytes > 1024 or number_of_bytes < 1): + raise ValidationException(( + "1 validation error detected: Value '{number_of_bytes:d}' at 'numberOfBytes' failed " + "to satisfy constraint: Member must have value less than or " + "equal to 1024" + ).format(number_of_bytes=number_of_bytes) + ) + entropy = os.urandom(number_of_bytes) response_entropy = base64.b64encode(entropy).decode("utf-8") @@ -393,11 +488,6 @@ class KmsResponse(BaseResponse): return json.dumps({"Plaintext": response_entropy}) -def _assert_valid_key_id(key_id): - if not re.match(r'^[A-F0-9]{8}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{12}$', key_id, re.IGNORECASE): - raise NotFoundException('Invalid keyId') - - def _assert_default_policy(policy_name): if policy_name != 'default': raise NotFoundException("No such policy exists") diff --git a/tests/test_kms/test_kms.py b/tests/test_kms/test_kms.py index 4daeaa7cf..49e6e1fdf 100644 --- a/tests/test_kms/test_kms.py +++ b/tests/test_kms/test_kms.py @@ -68,6 +68,7 @@ def test_describe_key_via_alias(): @mock_kms_deprecated def test_describe_key_via_alias_not_found(): + # TODO: Fix in next commit: bug. Should (and now does) throw NotFoundError conn = boto.kms.connect_to_region("us-west-2") key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") conn.create_alias(alias_name="alias/my-key-alias", target_key_id=key["KeyMetadata"]["KeyId"]) @@ -89,6 +90,7 @@ def test_describe_key_via_arn(): @mock_kms_deprecated def test_describe_missing_key(): + # TODO: Fix in next commit: bug. Should (and now does) throw NotFoundError conn = boto.kms.connect_to_region("us-west-2") conn.describe_key.when.called_with("not-a-key").should.throw(JSONResponseError) @@ -493,16 +495,17 @@ def test__create_alias__raises_if_alias_has_colon_character(): ex.status.should.equal(400) +@parameterized(( + ("alias/my-alias_/",), + ("alias/my_alias-/",), +)) @mock_kms_deprecated -def test__create_alias__accepted_characters(): +def test__create_alias__accepted_characters(alias_name): kms = boto.connect_kms() create_resp = kms.create_key() key_id = create_resp["KeyMetadata"]["KeyId"] - alias_names_with_accepted_characters = ["alias/my-alias_/", "alias/my_alias-/"] - - for alias_name in alias_names_with_accepted_characters: - kms.create_alias(alias_name, key_id) + kms.create_alias(alias_name, key_id) @mock_kms_deprecated @@ -575,14 +578,16 @@ def test__delete_alias__raises_if_alias_is_not_found(): with assert_raises(NotFoundException) as err: kms.delete_alias(alias_name) + expected_message_match = r"Alias arn:aws:kms:{region}:[0-9]{{12}}:{alias_name} is not found.".format( + region=region, + alias_name=alias_name + ) ex = err.exception ex.body["__type"].should.equal("NotFoundException") - ex.body["message"].should.match( - r"Alias arn:aws:kms:{region}:\d{{12}}:{alias_name} is not found.".format(**locals()) - ) + ex.body["message"].should.match(expected_message_match) ex.box_usage.should.be.none ex.error_code.should.be.none - ex.message.should.match(r"Alias arn:aws:kms:{region}:\d{{12}}:{alias_name} is not found.".format(**locals())) + ex.message.should.match(expected_message_match) ex.reason.should.equal("Bad Request") ex.request_id.should.be.none ex.status.should.equal(400) @@ -635,13 +640,19 @@ def test__list_aliases(): len(aliases).should.equal(7) -@mock_kms_deprecated -def test__assert_valid_key_id(): - from moto.kms.responses import _assert_valid_key_id - import uuid +@parameterized(( + ("not-a-uuid",), + ("alias/DoesNotExist",), + ("arn:aws:kms:us-east-1:012345678912:alias/DoesNotExist",), + ("d25652e4-d2d2-49f7-929a-671ccda580c6",), + ("arn:aws:kms:us-east-1:012345678912:key/d25652e4-d2d2-49f7-929a-671ccda580c6",), +)) +@mock_kms +def test_invalid_key_ids(key_id): + client = boto3.client("kms", region_name="us-east-1") - _assert_valid_key_id.when.called_with("not-a-key").should.throw(MotoNotFoundException) - _assert_valid_key_id.when.called_with(str(uuid.uuid4())).should_not.throw(MotoNotFoundException) + with assert_raises(client.exceptions.NotFoundException): + client.generate_data_key(KeyId=key_id, NumberOfBytes=5) @mock_kms_deprecated @@ -781,6 +792,8 @@ def test_list_resource_tags(): (dict(KeySpec="AES_256"), 32), (dict(KeySpec="AES_128"), 16), (dict(NumberOfBytes=64), 64), + (dict(NumberOfBytes=1), 1), + (dict(NumberOfBytes=1024), 1024), )) @mock_kms def test_generate_data_key_sizes(kwargs, expected_key_length): @@ -807,6 +820,7 @@ def test_generate_data_key_decrypt(): (dict(KeySpec="AES_257"),), (dict(KeySpec="AES_128", NumberOfBytes=16),), (dict(NumberOfBytes=2048),), + (dict(NumberOfBytes=0),), (dict(),), )) @mock_kms @@ -814,20 +828,42 @@ def test_generate_data_key_invalid_size_params(kwargs): client = boto3.client("kms", region_name="us-east-1") key = client.create_key(Description="generate-data-key-size") - with assert_raises(botocore.exceptions.ClientError) as err: + with assert_raises((botocore.exceptions.ClientError, botocore.exceptions.ParamValidationError)) as err: client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"], **kwargs) +@parameterized(( + ("alias/DoesNotExist",), + ("arn:aws:kms:us-east-1:012345678912:alias/DoesNotExist",), + ("d25652e4-d2d2-49f7-929a-671ccda580c6",), + ("arn:aws:kms:us-east-1:012345678912:key/d25652e4-d2d2-49f7-929a-671ccda580c6",), +)) @mock_kms -def test_generate_data_key_invalid_key(): +def test_generate_data_key_invalid_key(key_id): client = boto3.client("kms", region_name="us-east-1") - key = client.create_key(Description="generate-data-key-size") with assert_raises(client.exceptions.NotFoundException): - client.generate_data_key(KeyId="alias/randomnonexistantkey", KeySpec="AES_256") + client.generate_data_key(KeyId=key_id, KeySpec="AES_256") - with assert_raises(client.exceptions.NotFoundException): - client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"] + "4", KeySpec="AES_256") + +@parameterized(( + ("alias/DoesExist", False), + ("arn:aws:kms:us-east-1:012345678912:alias/DoesExist", False), + ("", True), + ("arn:aws:kms:us-east-1:012345678912:key/", True), +)) +@mock_kms +def test_generate_data_key_all_valid_key_ids(prefix, append_key_id): + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key() + key_id = key["KeyMetadata"]["KeyId"] + client.create_alias(AliasName="alias/DoesExist", TargetKeyId=key_id) + + target_id = prefix + if append_key_id: + target_id += key_id + + client.generate_data_key(KeyId=key_id, NumberOfBytes=32) @mock_kms @@ -904,11 +940,11 @@ def test_re_encrypt_to_invalid_destination(): with assert_raises(client.exceptions.NotFoundException): client.re_encrypt( CiphertextBlob=encrypt_response["CiphertextBlob"], - DestinationKeyId="8327948729348", + DestinationKeyId="alias/DoesNotExist", ) -@parameterized(((12,), (44,), (91,))) +@parameterized(((12,), (44,), (91,), (1,), (1024,))) @mock_kms def test_generate_random(number_of_bytes): client = boto3.client("kms", region_name="us-west-2") From aa6b505415631fdc9d96b46dbb3f1dab6d351734 Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Tue, 27 Aug 2019 23:55:08 -0700 Subject: [PATCH 46/67] fix tests to expect the correct error --- tests/test_kms/test_kms.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/test_kms/test_kms.py b/tests/test_kms/test_kms.py index 49e6e1fdf..61c19afbd 100644 --- a/tests/test_kms/test_kms.py +++ b/tests/test_kms/test_kms.py @@ -68,12 +68,11 @@ def test_describe_key_via_alias(): @mock_kms_deprecated def test_describe_key_via_alias_not_found(): - # TODO: Fix in next commit: bug. Should (and now does) throw NotFoundError conn = boto.kms.connect_to_region("us-west-2") key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") conn.create_alias(alias_name="alias/my-key-alias", target_key_id=key["KeyMetadata"]["KeyId"]) - conn.describe_key.when.called_with("alias/not-found-alias").should.throw(JSONResponseError) + conn.describe_key.when.called_with("alias/not-found-alias").should.throw(NotFoundException) @mock_kms_deprecated @@ -90,9 +89,8 @@ def test_describe_key_via_arn(): @mock_kms_deprecated def test_describe_missing_key(): - # TODO: Fix in next commit: bug. Should (and now does) throw NotFoundError conn = boto.kms.connect_to_region("us-west-2") - conn.describe_key.when.called_with("not-a-key").should.throw(JSONResponseError) + conn.describe_key.when.called_with("not-a-key").should.throw(NotFoundException) @mock_kms_deprecated From a2c2a831988a48f0a3e03cf3f21552c6d3d85b74 Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Thu, 12 Sep 2019 18:04:18 -0700 Subject: [PATCH 47/67] fix linting issues --- moto/kms/models.py | 1 - moto/kms/responses.py | 9 ++++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/moto/kms/models.py b/moto/kms/models.py index 1f5c77560..9e1b08bf9 100644 --- a/moto/kms/models.py +++ b/moto/kms/models.py @@ -1,7 +1,6 @@ from __future__ import unicode_literals import os -import re from collections import defaultdict from datetime import datetime, timedelta diff --git a/moto/kms/responses.py b/moto/kms/responses.py index 201084fd7..baa552953 100644 --- a/moto/kms/responses.py +++ b/moto/kms/responses.py @@ -475,11 +475,10 @@ class KmsResponse(BaseResponse): if number_of_bytes and (number_of_bytes > 1024 or number_of_bytes < 1): raise ValidationException(( - "1 validation error detected: Value '{number_of_bytes:d}' at 'numberOfBytes' failed " - "to satisfy constraint: Member must have value less than or " - "equal to 1024" - ).format(number_of_bytes=number_of_bytes) - ) + "1 validation error detected: Value '{number_of_bytes:d}' at 'numberOfBytes' failed " + "to satisfy constraint: Member must have value less than or " + "equal to 1024" + ).format(number_of_bytes=number_of_bytes)) entropy = os.urandom(number_of_bytes) From 9a095d731a313a187362ab8305dbd4612814ce9c Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Fri, 13 Sep 2019 13:01:55 -0700 Subject: [PATCH 48/67] add tests for invalid aliases in describe_key --- tests/test_kms/test_kms.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/test_kms/test_kms.py b/tests/test_kms/test_kms.py index 61c19afbd..31ba717e3 100644 --- a/tests/test_kms/test_kms.py +++ b/tests/test_kms/test_kms.py @@ -75,6 +75,20 @@ def test_describe_key_via_alias_not_found(): conn.describe_key.when.called_with("alias/not-found-alias").should.throw(NotFoundException) +@parameterized(( + ("alias/does-not-exist",), + ("arn:aws:kms:us-east-1:012345678912:alias/does-not-exist",), + ("invalid",), +)) +@mock_kms +def test_describe_key_via_alias_invalid_alias(key_id): + client = boto3.client("kms", region_name="us-east-1") + client.create_key(Description="key") + + with assert_raises(client.exceptions.NotFoundException): + client.describe_key(KeyId=key_id) + + @mock_kms_deprecated def test_describe_key_via_arn(): conn = boto.kms.connect_to_region("us-west-2") From c44178f2f789e22b43a35c4091105999a079c571 Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Fri, 13 Sep 2019 13:19:11 -0700 Subject: [PATCH 49/67] add tests for invalid values passed to generate_random --- tests/test_kms/test_kms.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/test_kms/test_kms.py b/tests/test_kms/test_kms.py index 31ba717e3..d9389094b 100644 --- a/tests/test_kms/test_kms.py +++ b/tests/test_kms/test_kms.py @@ -971,6 +971,21 @@ def test_generate_random(number_of_bytes): len(response["Plaintext"]).should.equal(number_of_bytes) +@parameterized(( + (2048, botocore.exceptions.ClientError), + (1025, botocore.exceptions.ClientError), + (0, botocore.exceptions.ParamValidationError), + (-1, botocore.exceptions.ParamValidationError), + (-1024, botocore.exceptions.ParamValidationError) +)) +@mock_kms +def test_generate_random_invalid_number_of_bytes(number_of_bytes, error_type): + client = boto3.client("kms", region_name="us-west-2") + + with assert_raises(error_type): + client.generate_random(NumberOfBytes=number_of_bytes) + + @mock_kms def test_enable_key_rotation_key_not_found(): client = boto3.client("kms", region_name="us-east-1") From 24832982d484986ff8891b8792053cee1551cf04 Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Fri, 13 Sep 2019 13:32:19 -0700 Subject: [PATCH 50/67] convert tests from boto to boto3 and add unicode plaintext vector to test auto-conversion --- tests/test_kms/test_kms.py | 43 ++++++++++++++++++++++++++------------ 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/tests/test_kms/test_kms.py b/tests/test_kms/test_kms.py index d9389094b..150dfae8f 100644 --- a/tests/test_kms/test_kms.py +++ b/tests/test_kms/test_kms.py @@ -23,6 +23,7 @@ from moto import mock_kms, mock_kms_deprecated PLAINTEXT_VECTORS = ( (b"some encodeable plaintext",), (b"some unencodeable plaintext \xec\x8a\xcf\xb6r\xe9\xb5\xeb\xff\xa23\x16",), + (u"some unicode characters ø˚∆øˆˆ∆ßçøˆˆçßøˆ¨¥",), ) @@ -215,15 +216,15 @@ def test_boto3_generate_data_key(): @parameterized(PLAINTEXT_VECTORS) -@mock_kms_deprecated +@mock_kms def test_encrypt(plaintext): - conn = boto.kms.connect_to_region("us-west-2") + client = boto3.client("kms", region_name="us-west-2") - key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + key = client.create_key(Description="key") key_id = key["KeyMetadata"]["KeyId"] key_arn = key["KeyMetadata"]["Arn"] - response = conn.encrypt(key_id, plaintext) + response = client.encrypt(KeyId=key_id, Plaintext=plaintext) response["CiphertextBlob"].should_not.equal(plaintext) # CiphertextBlob must NOT be base64-encoded @@ -234,27 +235,33 @@ def test_encrypt(plaintext): @parameterized(PLAINTEXT_VECTORS) -@mock_kms_deprecated +@mock_kms def test_decrypt(plaintext): - conn = boto.kms.connect_to_region("us-west-2") + client = boto3.client("kms", region_name="us-west-2") - key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + key = client.create_key(Description="key") key_id = key["KeyMetadata"]["KeyId"] key_arn = key["KeyMetadata"]["Arn"] - encrypt_response = conn.encrypt(key_id, plaintext) + encrypt_response = client.encrypt(KeyId=key_id, Plaintext=plaintext) + try: + encoded_plaintext = plaintext.encode("utf-8") + except AttributeError: + encoded_plaintext = plaintext + + client.create_key(Description="key") # CiphertextBlob must NOT be base64-encoded with assert_raises(Exception): base64.b64decode(encrypt_response["CiphertextBlob"], validate=True) - decrypt_response = conn.decrypt(encrypt_response["CiphertextBlob"]) + decrypt_response = client.decrypt(CiphertextBlob=encrypt_response["CiphertextBlob"]) # Plaintext must NOT be base64-encoded with assert_raises(Exception): base64.b64decode(decrypt_response["Plaintext"], validate=True) - decrypt_response["Plaintext"].should.equal(plaintext) + decrypt_response["Plaintext"].should.equal(encoded_plaintext) decrypt_response["KeyId"].should.equal(key_arn) @@ -682,8 +689,13 @@ def test_kms_encrypt_boto3(plaintext): key = client.create_key(Description="key") response = client.encrypt(KeyId=key["KeyMetadata"]["KeyId"], Plaintext=plaintext) + try: + encoded_plaintext = plaintext.encode("utf-8") + except AttributeError: + encoded_plaintext = plaintext + response = client.decrypt(CiphertextBlob=response["CiphertextBlob"]) - response["Plaintext"].should.equal(plaintext) + response["Plaintext"].should.equal(encoded_plaintext) @mock_kms @@ -906,6 +918,11 @@ def test_re_encrypt_decrypt(plaintext): EncryptionContext={"encryption": "context"}, ) + try: + encoded_plaintext = plaintext.encode("utf-8") + except AttributeError: + encoded_plaintext = plaintext + re_encrypt_response = client.re_encrypt( CiphertextBlob=encrypt_response["CiphertextBlob"], SourceEncryptionContext={"encryption": "context"}, @@ -924,14 +941,14 @@ def test_re_encrypt_decrypt(plaintext): CiphertextBlob=encrypt_response["CiphertextBlob"], EncryptionContext={"encryption": "context"}, ) - decrypt_response_1["Plaintext"].should.equal(plaintext) + decrypt_response_1["Plaintext"].should.equal(encoded_plaintext) decrypt_response_1["KeyId"].should.equal(key_1_arn) decrypt_response_2 = client.decrypt( CiphertextBlob=re_encrypt_response["CiphertextBlob"], EncryptionContext={"another": "context"}, ) - decrypt_response_2["Plaintext"].should.equal(plaintext) + decrypt_response_2["Plaintext"].should.equal(encoded_plaintext) decrypt_response_2["KeyId"].should.equal(key_2_arn) decrypt_response_1["Plaintext"].should.equal(decrypt_response_2["Plaintext"]) From 4f34af95bc0a3ac1a3b980ba772118f5e8e81d41 Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Fri, 13 Sep 2019 13:35:55 -0700 Subject: [PATCH 51/67] remove dead code because the key ID validation is now centralized, by the time this code would have been reached, we know that the key ID exists, so a KeyError will never be thrown --- moto/kms/responses.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/moto/kms/responses.py b/moto/kms/responses.py index baa552953..998d5cc4b 100644 --- a/moto/kms/responses.py +++ b/moto/kms/responses.py @@ -152,13 +152,10 @@ class KmsResponse(BaseResponse): self._validate_key_id(key_id) - try: - key = self.kms_backend.describe_key( - self.kms_backend.get_key_id(key_id)) - except KeyError: - headers = dict(self.headers) - headers['status'] = 404 - return "{}", headers + key = self.kms_backend.describe_key( + self.kms_backend.get_key_id(key_id) + ) + return json.dumps(key.to_dict()) def list_keys(self): From 4c7cdec96542d0d2bddc0e5528612c126a9a2e46 Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Fri, 13 Sep 2019 14:08:26 -0700 Subject: [PATCH 52/67] fix encoding for Python 2 in KMS tests --- tests/test_kms/test_kms.py | 33 +++++++++++++-------------------- 1 file changed, 13 insertions(+), 20 deletions(-) diff --git a/tests/test_kms/test_kms.py b/tests/test_kms/test_kms.py index 150dfae8f..49c0f886e 100644 --- a/tests/test_kms/test_kms.py +++ b/tests/test_kms/test_kms.py @@ -1,15 +1,16 @@ +# -*- coding: utf-8 -*- from __future__ import unicode_literals from datetime import date from datetime import datetime from dateutil.tz import tzutc import base64 -import binascii import os import re import boto3 import boto.kms import botocore.exceptions +import six import sure # noqa from boto.exception import JSONResponseError from boto.kms.exceptions import AlreadyExistsException, NotFoundException @@ -27,6 +28,13 @@ PLAINTEXT_VECTORS = ( ) +def _get_encoded_value(plaintext): + if isinstance(plaintext, six.binary_type): + return plaintext + + return plaintext.encode("utf-8") + + @mock_kms def test_create_key(): conn = boto3.client("kms", region_name="us-east-1") @@ -245,11 +253,6 @@ def test_decrypt(plaintext): encrypt_response = client.encrypt(KeyId=key_id, Plaintext=plaintext) - try: - encoded_plaintext = plaintext.encode("utf-8") - except AttributeError: - encoded_plaintext = plaintext - client.create_key(Description="key") # CiphertextBlob must NOT be base64-encoded with assert_raises(Exception): @@ -261,7 +264,7 @@ def test_decrypt(plaintext): with assert_raises(Exception): base64.b64decode(decrypt_response["Plaintext"], validate=True) - decrypt_response["Plaintext"].should.equal(encoded_plaintext) + decrypt_response["Plaintext"].should.equal(_get_encoded_value(plaintext)) decrypt_response["KeyId"].should.equal(key_arn) @@ -689,13 +692,8 @@ def test_kms_encrypt_boto3(plaintext): key = client.create_key(Description="key") response = client.encrypt(KeyId=key["KeyMetadata"]["KeyId"], Plaintext=plaintext) - try: - encoded_plaintext = plaintext.encode("utf-8") - except AttributeError: - encoded_plaintext = plaintext - response = client.decrypt(CiphertextBlob=response["CiphertextBlob"]) - response["Plaintext"].should.equal(encoded_plaintext) + response["Plaintext"].should.equal(_get_encoded_value(plaintext)) @mock_kms @@ -918,11 +916,6 @@ def test_re_encrypt_decrypt(plaintext): EncryptionContext={"encryption": "context"}, ) - try: - encoded_plaintext = plaintext.encode("utf-8") - except AttributeError: - encoded_plaintext = plaintext - re_encrypt_response = client.re_encrypt( CiphertextBlob=encrypt_response["CiphertextBlob"], SourceEncryptionContext={"encryption": "context"}, @@ -941,14 +934,14 @@ def test_re_encrypt_decrypt(plaintext): CiphertextBlob=encrypt_response["CiphertextBlob"], EncryptionContext={"encryption": "context"}, ) - decrypt_response_1["Plaintext"].should.equal(encoded_plaintext) + decrypt_response_1["Plaintext"].should.equal(_get_encoded_value(plaintext)) decrypt_response_1["KeyId"].should.equal(key_1_arn) decrypt_response_2 = client.decrypt( CiphertextBlob=re_encrypt_response["CiphertextBlob"], EncryptionContext={"another": "context"}, ) - decrypt_response_2["Plaintext"].should.equal(encoded_plaintext) + decrypt_response_2["Plaintext"].should.equal(_get_encoded_value(plaintext)) decrypt_response_2["KeyId"].should.equal(key_2_arn) decrypt_response_1["Plaintext"].should.equal(decrypt_response_2["Plaintext"]) From a36b84b3aa99686a19b19d352983c903039297ba Mon Sep 17 00:00:00 2001 From: Gapex <1377762942@qq.com> Date: Mon, 16 Sep 2019 11:35:36 +0800 Subject: [PATCH 53/67] fix MaxKeys in list_objects_v2 --- moto/s3/responses.py | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/moto/s3/responses.py b/moto/s3/responses.py index f4640023e..e5b5cac0d 100644 --- a/moto/s3/responses.py +++ b/moto/s3/responses.py @@ -2,8 +2,6 @@ from __future__ import unicode_literals import re -from itertools import chain - import six from moto.core.utils import str_to_rfc_1123_datetime @@ -460,10 +458,10 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): else: result_folders = self._get_results_from_token(result_folders, limit) - tagged_keys = ((key, True) for key in result_keys) - tagged_folders = ((folder, False) for folder in result_folders) - sorted_keys = sorted(chain(tagged_keys, tagged_folders)) - result_keys, result_folders, is_truncated, next_continuation_token = self._truncate_result(sorted_keys, max_keys) + all_keys = [(key, True) for key in result_keys] + [(folder, False) for folder in result_folders] + all_keys.sort(key=lambda tagged_key: tagged_key if isinstance(tagged_key[0], str) else tagged_key[0].name) + truncated_keys, is_truncated, next_continuation_token = self._truncate_result(all_keys, max_keys) + result_keys, result_folders = self._split_truncated_keys(truncated_keys) key_count = len(result_keys) + len(result_folders) @@ -481,6 +479,16 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): start_after=None if continuation_token else start_after ) + def _split_truncated_keys(self, truncated_keys): + result_keys = [] + result_folders = [] + for key in truncated_keys: + if key[1]: + result_keys.append(key[0]) + else: + result_folders.append(key[0]) + return result_keys, result_folders + def _get_results_from_token(self, result_keys, token): continuation_index = 0 for key in result_keys: @@ -489,19 +497,16 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): continuation_index += 1 return result_keys[continuation_index:] - def _truncate_result(self, sorted_keys, max_keys): - if len(sorted_keys) > max_keys: + def _truncate_result(self, result_keys, max_keys): + if len(result_keys) > max_keys: is_truncated = 'true' - sorted_keys = sorted_keys[:max_keys] - item = sorted_keys[-1][0] + result_keys = result_keys[:max_keys] + item = (result_keys[-1][0] if isinstance(result_keys[-1], tuple) else result_keys[-1]) next_continuation_token = (item.name if isinstance(item, FakeKey) else item) else: is_truncated = 'false' next_continuation_token = None - result_keys, result_folders = [], [] - for (key, is_key) in sorted_keys: - (result_keys if is_key else result_folders).append(key) - return result_keys, result_folders, is_truncated, next_continuation_token + return result_keys, is_truncated, next_continuation_token def _bucket_response_put(self, request, body, region_name, bucket_name, querystring): if not request.headers.get('Content-Length'): From 47635dc82e5658e9faf2119ef63cb75d69867000 Mon Sep 17 00:00:00 2001 From: Gapex <1377762942@qq.com> Date: Mon, 16 Sep 2019 13:33:53 +0800 Subject: [PATCH 54/67] update key of sort --- moto/s3/responses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/moto/s3/responses.py b/moto/s3/responses.py index e5b5cac0d..7005e15df 100644 --- a/moto/s3/responses.py +++ b/moto/s3/responses.py @@ -459,7 +459,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): result_folders = self._get_results_from_token(result_folders, limit) all_keys = [(key, True) for key in result_keys] + [(folder, False) for folder in result_folders] - all_keys.sort(key=lambda tagged_key: tagged_key if isinstance(tagged_key[0], str) else tagged_key[0].name) + all_keys.sort(key=lambda tagged_key: tagged_key[0].name if isinstance(tagged_key[0], FakeKey) else tagged_key[0]) truncated_keys, is_truncated, next_continuation_token = self._truncate_result(all_keys, max_keys) result_keys, result_folders = self._split_truncated_keys(truncated_keys) From 59f87e30ba071b383db6bed75995634d9d4bc7ef Mon Sep 17 00:00:00 2001 From: Gapex <1377762942@qq.com> Date: Mon, 16 Sep 2019 15:20:24 +0800 Subject: [PATCH 55/67] split truncated keys by type --- moto/s3/responses.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/moto/s3/responses.py b/moto/s3/responses.py index 7005e15df..11c7750a5 100644 --- a/moto/s3/responses.py +++ b/moto/s3/responses.py @@ -2,6 +2,8 @@ from __future__ import unicode_literals import re +from collections import namedtuple + import six from moto.core.utils import str_to_rfc_1123_datetime @@ -92,6 +94,7 @@ ACTION_MAP = { } +TaggedKey = namedtuple("TaggedKey", ("entity", "is_key")) def parse_key_name(pth): return pth.lstrip("/") @@ -458,8 +461,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): else: result_folders = self._get_results_from_token(result_folders, limit) - all_keys = [(key, True) for key in result_keys] + [(folder, False) for folder in result_folders] - all_keys.sort(key=lambda tagged_key: tagged_key[0].name if isinstance(tagged_key[0], FakeKey) else tagged_key[0]) + all_keys = result_keys + result_folders + all_keys.sort(key=self._get_key_name) # sort by name, lexicographical order truncated_keys, is_truncated, next_continuation_token = self._truncate_result(all_keys, max_keys) result_keys, result_folders = self._split_truncated_keys(truncated_keys) @@ -479,14 +482,22 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): start_after=None if continuation_token else start_after ) - def _split_truncated_keys(self, truncated_keys): + @staticmethod + def _get_key_name(key): + if isinstance(key, FakeKey): + return key.name + else: + return key + + @staticmethod + def _split_truncated_keys(truncated_keys): result_keys = [] result_folders = [] for key in truncated_keys: - if key[1]: - result_keys.append(key[0]) + if isinstance(key, FakeKey): + result_keys.append(key) else: - result_folders.append(key[0]) + result_folders.append(key) return result_keys, result_folders def _get_results_from_token(self, result_keys, token): @@ -501,7 +512,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): if len(result_keys) > max_keys: is_truncated = 'true' result_keys = result_keys[:max_keys] - item = (result_keys[-1][0] if isinstance(result_keys[-1], tuple) else result_keys[-1]) + item = result_keys[-1] next_continuation_token = (item.name if isinstance(item, FakeKey) else item) else: is_truncated = 'false' From 4946f8b853c7620441be55e86147ccec258dd0d4 Mon Sep 17 00:00:00 2001 From: Gapex <1377762942@qq.com> Date: Mon, 16 Sep 2019 15:31:57 +0800 Subject: [PATCH 56/67] 'lint' --- moto/s3/responses.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/moto/s3/responses.py b/moto/s3/responses.py index 11c7750a5..582fe8ec7 100644 --- a/moto/s3/responses.py +++ b/moto/s3/responses.py @@ -2,8 +2,6 @@ from __future__ import unicode_literals import re -from collections import namedtuple - import six from moto.core.utils import str_to_rfc_1123_datetime @@ -94,7 +92,6 @@ ACTION_MAP = { } -TaggedKey = namedtuple("TaggedKey", ("entity", "is_key")) def parse_key_name(pth): return pth.lstrip("/") @@ -462,7 +459,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): result_folders = self._get_results_from_token(result_folders, limit) all_keys = result_keys + result_folders - all_keys.sort(key=self._get_key_name) # sort by name, lexicographical order + all_keys.sort(key=self._get_key_name) truncated_keys, is_truncated, next_continuation_token = self._truncate_result(all_keys, max_keys) result_keys, result_folders = self._split_truncated_keys(truncated_keys) From 84715e9a2aadff785f5ba47a931f558ec4182bf7 Mon Sep 17 00:00:00 2001 From: Gapex <1377762942@qq.com> Date: Mon, 16 Sep 2019 16:46:19 +0800 Subject: [PATCH 57/67] add truncate unite test --- tests/test_s3/test_s3.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/test_s3/test_s3.py b/tests/test_s3/test_s3.py index 0c0721f01..23e305bcc 100644 --- a/tests/test_s3/test_s3.py +++ b/tests/test_s3/test_s3.py @@ -1392,6 +1392,28 @@ def test_boto3_list_objects_v2_fetch_owner(): assert len(owner.keys()) == 2 +@mock_s3 +def test_boto3_list_objects_v2_truncate_combined_keys_and_folders(): + s3 = boto3.client('s3', region_name='us-east-1') + s3.create_bucket(Bucket='mybucket') + s3.put_object(Bucket='mybucket', Key="A/a", Body="folder/a") + s3.put_object(Bucket='mybucket', Key="A/aa", Body="folder/aa") + s3.put_object(Bucket='mybucket', Key="A/B/a", Body="nested/folder/a") + s3.put_object(Bucket='mybucket', Key="c", Body="plain c") + + resp = s3.list_objects_v2(Bucket='mybucket', Prefix="A", MaxKeys=2) + + assert "Prefix" in resp + result_keys = [key["Key"] for key in resp["Contents"]] + + # Test truncate combination of keys and folders + assert len(result_keys) == 2 + + # Test lexicographical order + assert "A/B/a" == result_keys[0] + assert "A/a" == result_keys[1] + + @mock_s3 def test_boto3_bucket_create(): s3 = boto3.resource('s3', region_name='us-east-1') From c04c72d435e72d51a78fb4f0a15b9bb39ebca55f Mon Sep 17 00:00:00 2001 From: Gapex <1377762942@qq.com> Date: Mon, 16 Sep 2019 18:09:42 +0800 Subject: [PATCH 58/67] update MaxKeys unite test --- tests/test_s3/test_s3.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/tests/test_s3/test_s3.py b/tests/test_s3/test_s3.py index 23e305bcc..27005724d 100644 --- a/tests/test_s3/test_s3.py +++ b/tests/test_s3/test_s3.py @@ -1401,17 +1401,11 @@ def test_boto3_list_objects_v2_truncate_combined_keys_and_folders(): s3.put_object(Bucket='mybucket', Key="A/B/a", Body="nested/folder/a") s3.put_object(Bucket='mybucket', Key="c", Body="plain c") - resp = s3.list_objects_v2(Bucket='mybucket', Prefix="A", MaxKeys=2) - + resp = s3.list_objects_v2(Bucket='mybucket', Prefix="A/", MaxKeys=1, Delimiter="/") assert "Prefix" in resp - result_keys = [key["Key"] for key in resp["Contents"]] - - # Test truncate combination of keys and folders - assert len(result_keys) == 2 - - # Test lexicographical order - assert "A/B/a" == result_keys[0] - assert "A/a" == result_keys[1] + assert "Delimiter" in resp + assert resp["IsTruncated"] is True + assert resp["KeyCount"] != 0 @mock_s3 From 7ee35a8510eff9d6a3c3385c8a8cf18c1819c712 Mon Sep 17 00:00:00 2001 From: Kiyonori Matsumoto Date: Mon, 16 Sep 2019 23:33:52 +0900 Subject: [PATCH 59/67] fix: raises ValueError on conditional and operator if lhs evaluates to false, rhs must be ignored, but rhs was evaluated then ValueError is occurred. --- moto/dynamodb2/comparisons.py | 3 +-- tests/test_dynamodb2/test_dynamodb.py | 30 +++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/moto/dynamodb2/comparisons.py b/moto/dynamodb2/comparisons.py index 151a314f1..dbc0bd57d 100644 --- a/moto/dynamodb2/comparisons.py +++ b/moto/dynamodb2/comparisons.py @@ -941,8 +941,7 @@ class OpAnd(Op): def expr(self, item): lhs = self.lhs.expr(item) - rhs = self.rhs.expr(item) - return lhs and rhs + return lhs and self.rhs.expr(item) class OpLessThan(Op): diff --git a/tests/test_dynamodb2/test_dynamodb.py b/tests/test_dynamodb2/test_dynamodb.py index fb6c0e17d..ba5c256c3 100644 --- a/tests/test_dynamodb2/test_dynamodb.py +++ b/tests/test_dynamodb2/test_dynamodb.py @@ -2034,6 +2034,36 @@ def test_condition_expression__or_order(): ) +@mock_dynamodb2 +def test_condition_expression__and_order(): + client = boto3.client('dynamodb', region_name='us-east-1') + + client.create_table( + TableName='test', + KeySchema=[{'AttributeName': 'forum_name', 'KeyType': 'HASH'}], + AttributeDefinitions=[ + {'AttributeName': 'forum_name', 'AttributeType': 'S'}, + ], + ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1}, + ) + + # ensure that the RHS of the AND expression is not evaluated if the LHS + # returns true (as it would result an error) + with assert_raises(client.exceptions.ConditionalCheckFailedException): + client.update_item( + TableName='test', + Key={ + 'forum_name': {'S': 'the-key'}, + }, + UpdateExpression='set #ttl=:ttl', + ConditionExpression='attribute_exists(#ttl) AND #ttl <= :old_ttl', + ExpressionAttributeNames={'#ttl': 'ttl'}, + ExpressionAttributeValues={ + ':ttl': {'N': '6'}, + ':old_ttl': {'N': '5'}, + } + ) + @mock_dynamodb2 def test_query_gsi_with_range_key(): dynamodb = boto3.client('dynamodb', region_name='us-east-1') From 1c36e1e2c5029d16112a121d1dd8bc39cc445fe2 Mon Sep 17 00:00:00 2001 From: Gapex <1377762942@qq.com> Date: Tue, 17 Sep 2019 10:42:10 +0800 Subject: [PATCH 60/67] update unit test and fix StartAfter --- moto/s3/responses.py | 13 ++++++------- tests/test_s3/test_s3.py | 22 +++++++++++++--------- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/moto/s3/responses.py b/moto/s3/responses.py index 582fe8ec7..4c546c595 100644 --- a/moto/s3/responses.py +++ b/moto/s3/responses.py @@ -451,15 +451,14 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): continuation_token = querystring.get('continuation-token', [None])[0] start_after = querystring.get('start-after', [None])[0] + # sort the combination of folders and keys into lexicographical order + all_keys = result_keys + result_folders + all_keys.sort(key=self._get_name) + if continuation_token or start_after: limit = continuation_token or start_after - if not delimiter: - result_keys = self._get_results_from_token(result_keys, limit) - else: - result_folders = self._get_results_from_token(result_folders, limit) + all_keys = self._get_results_from_token(all_keys, limit) - all_keys = result_keys + result_folders - all_keys.sort(key=self._get_key_name) truncated_keys, is_truncated, next_continuation_token = self._truncate_result(all_keys, max_keys) result_keys, result_folders = self._split_truncated_keys(truncated_keys) @@ -480,7 +479,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): ) @staticmethod - def _get_key_name(key): + def _get_name(key): if isinstance(key, FakeKey): return key.name else: diff --git a/tests/test_s3/test_s3.py b/tests/test_s3/test_s3.py index 27005724d..a8cec737c 100644 --- a/tests/test_s3/test_s3.py +++ b/tests/test_s3/test_s3.py @@ -1396,16 +1396,20 @@ def test_boto3_list_objects_v2_fetch_owner(): def test_boto3_list_objects_v2_truncate_combined_keys_and_folders(): s3 = boto3.client('s3', region_name='us-east-1') s3.create_bucket(Bucket='mybucket') - s3.put_object(Bucket='mybucket', Key="A/a", Body="folder/a") - s3.put_object(Bucket='mybucket', Key="A/aa", Body="folder/aa") - s3.put_object(Bucket='mybucket', Key="A/B/a", Body="nested/folder/a") - s3.put_object(Bucket='mybucket', Key="c", Body="plain c") + s3.put_object(Bucket='mybucket', Key='1/2', Body='') + s3.put_object(Bucket='mybucket', Key='2', Body='') + s3.put_object(Bucket='mybucket', Key='3/4', Body='') + s3.put_object(Bucket='mybucket', Key='4', Body='') - resp = s3.list_objects_v2(Bucket='mybucket', Prefix="A/", MaxKeys=1, Delimiter="/") - assert "Prefix" in resp - assert "Delimiter" in resp - assert resp["IsTruncated"] is True - assert resp["KeyCount"] != 0 + resp = s3.list_objects_v2(Bucket='mybucket', MaxKeys=2, Delimiter='/') + assert 'Delimiter' in resp + assert resp['IsTruncated'] is True + assert resp['KeyCount'] == 2 + + last_tail = resp['NextContinuationToken'] + resp = s3.list_objects_v2(Bucket='mybucket', MaxKeys=2, Delimiter='/', StartAfter=last_tail) + assert resp['KeyCount'] == 2 + assert resp['IsTruncated'] is False @mock_s3 From a466ef2d1ba8796e38b9c1c9f834cfde36afcb14 Mon Sep 17 00:00:00 2001 From: Gapex <1377762942@qq.com> Date: Tue, 17 Sep 2019 12:42:33 +0800 Subject: [PATCH 61/67] check key & common prefix in unit test' --- tests/test_s3/test_s3.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_s3/test_s3.py b/tests/test_s3/test_s3.py index a8cec737c..bbe5e19a3 100644 --- a/tests/test_s3/test_s3.py +++ b/tests/test_s3/test_s3.py @@ -1405,11 +1405,19 @@ def test_boto3_list_objects_v2_truncate_combined_keys_and_folders(): assert 'Delimiter' in resp assert resp['IsTruncated'] is True assert resp['KeyCount'] == 2 + assert len(resp['Contents']) == 1 + assert resp['Contents'][0]['Key'] == '2' + assert len(resp['CommonPrefixes']) == 1 + assert resp['CommonPrefixes'][0]['Prefix'] == '1/' last_tail = resp['NextContinuationToken'] resp = s3.list_objects_v2(Bucket='mybucket', MaxKeys=2, Delimiter='/', StartAfter=last_tail) assert resp['KeyCount'] == 2 assert resp['IsTruncated'] is False + assert len(resp['Contents']) == 1 + assert resp['Contents'][0]['Key'] == '4' + assert len(resp['CommonPrefixes']) == 1 + assert resp['CommonPrefixes'][0]['Prefix'] == '3/' @mock_s3 From d8e69a9a36ecede13120079ba3731ef64b88b2b6 Mon Sep 17 00:00:00 2001 From: Gapex <1377762942@qq.com> Date: Tue, 17 Sep 2019 12:44:48 +0800 Subject: [PATCH 62/67] list with prifix --- tests/test_s3/test_s3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_s3/test_s3.py b/tests/test_s3/test_s3.py index bbe5e19a3..2764ee2c5 100644 --- a/tests/test_s3/test_s3.py +++ b/tests/test_s3/test_s3.py @@ -1401,7 +1401,7 @@ def test_boto3_list_objects_v2_truncate_combined_keys_and_folders(): s3.put_object(Bucket='mybucket', Key='3/4', Body='') s3.put_object(Bucket='mybucket', Key='4', Body='') - resp = s3.list_objects_v2(Bucket='mybucket', MaxKeys=2, Delimiter='/') + resp = s3.list_objects_v2(Bucket='mybucket', Prefix='', MaxKeys=2, Delimiter='/') assert 'Delimiter' in resp assert resp['IsTruncated'] is True assert resp['KeyCount'] == 2 @@ -1411,7 +1411,7 @@ def test_boto3_list_objects_v2_truncate_combined_keys_and_folders(): assert resp['CommonPrefixes'][0]['Prefix'] == '1/' last_tail = resp['NextContinuationToken'] - resp = s3.list_objects_v2(Bucket='mybucket', MaxKeys=2, Delimiter='/', StartAfter=last_tail) + resp = s3.list_objects_v2(Bucket='mybucket', MaxKeys=2, Prefix='', Delimiter='/', StartAfter=last_tail) assert resp['KeyCount'] == 2 assert resp['IsTruncated'] is False assert len(resp['Contents']) == 1 From b163f23a2236e6253f6d55e9bffc1b3369526804 Mon Sep 17 00:00:00 2001 From: Aleksandr Mangin Date: Mon, 23 Sep 2019 18:35:44 +0200 Subject: [PATCH 63/67] fix tail message problem in get_log_events --- moto/logs/models.py | 2 ++ tests/test_logs/test_logs.py | 5 ++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/moto/logs/models.py b/moto/logs/models.py index 2fc4b0d8b..3c5360371 100644 --- a/moto/logs/models.py +++ b/moto/logs/models.py @@ -115,6 +115,8 @@ class LogStream: events_page = [event.to_response_dict() for event in events[next_index: next_index + limit]] if next_index + limit < len(self.events): next_index += limit + else: + next_index = len(self.events) back_index -= limit if back_index <= 0: diff --git a/tests/test_logs/test_logs.py b/tests/test_logs/test_logs.py index 49e593fdc..0a63308c2 100644 --- a/tests/test_logs/test_logs.py +++ b/tests/test_logs/test_logs.py @@ -190,6 +190,8 @@ def test_get_log_events(): resp['events'].should.have.length_of(10) resp.should.have.key('nextForwardToken') resp.should.have.key('nextBackwardToken') + resp['nextForwardToken'].should.equal('f/00000000000000000000000000000000000000000000000000000010') + resp['nextBackwardToken'].should.equal('b/00000000000000000000000000000000000000000000000000000000') for i in range(10): resp['events'][i]['timestamp'].should.equal(i) resp['events'][i]['message'].should.equal(str(i)) @@ -205,7 +207,8 @@ def test_get_log_events(): resp['events'].should.have.length_of(10) resp.should.have.key('nextForwardToken') resp.should.have.key('nextBackwardToken') - resp['nextForwardToken'].should.equal(next_token) + resp['nextForwardToken'].should.equal('f/00000000000000000000000000000000000000000000000000000020') + resp['nextBackwardToken'].should.equal('b/00000000000000000000000000000000000000000000000000000000') for i in range(10): resp['events'][i]['timestamp'].should.equal(i+10) resp['events'][i]['message'].should.equal(str(i+10)) From 38455c8e1943e2ffb8bf1f2306095b3c08cf559a Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Tue, 24 Sep 2019 14:36:34 +0100 Subject: [PATCH 64/67] Step Functions - Remove STS-client and refer to hardcoded account-id --- moto/stepfunctions/models.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/moto/stepfunctions/models.py b/moto/stepfunctions/models.py index 8db9db1a1..7784919b0 100644 --- a/moto/stepfunctions/models.py +++ b/moto/stepfunctions/models.py @@ -1,9 +1,9 @@ import boto -import boto3 import re from datetime import datetime from moto.core import BaseBackend from moto.core.utils import iso_8601_datetime_without_milliseconds +from moto.sts.models import ACCOUNT_ID from uuid import uuid4 from .exceptions import ExecutionDoesNotExist, InvalidArn, InvalidName, StateMachineDoesNotExist @@ -156,12 +156,7 @@ class StepFunctionBackend(BaseBackend): raise InvalidArn(invalid_msg) def _get_account_id(self): - if self._account_id: - return self._account_id - sts = boto3.client("sts") - identity = sts.get_caller_identity() - self._account_id = identity['Account'] - return self._account_id + return ACCOUNT_ID stepfunction_backends = {_region.name: StepFunctionBackend(_region.name) for _region in boto.awslambda.regions()} From 2df0309db5bb9b0e2fea020ca8cf8b3a6c649b17 Mon Sep 17 00:00:00 2001 From: Jesse Vogt Date: Tue, 24 Sep 2019 15:22:25 -0500 Subject: [PATCH 65/67] unquote key name multiple times until stable value --- moto/s3/utils.py | 17 +++++++++++++--- tests/test_s3/test_s3.py | 37 ++++++++++++++++++++++++++++++++++ tests/test_s3/test_s3_utils.py | 15 +++++++++++++- 3 files changed, 65 insertions(+), 4 deletions(-) diff --git a/moto/s3/utils.py b/moto/s3/utils.py index 85a812aad..811e44f71 100644 --- a/moto/s3/utils.py +++ b/moto/s3/utils.py @@ -68,11 +68,22 @@ def metadata_from_headers(headers): return metadata -def clean_key_name(key_name): +def clean_key_name(key_name, attempts=4): if six.PY2: - return unquote(key_name.encode('utf-8')).decode('utf-8') + def uq(k): + return unquote(k.encode('utf-8')).decode('utf-8') + else: + uq = unquote - return unquote(key_name) + original = cleaned = key_name + last_attempt = attempts - 1 + for attempt in range(attempts): + cleaned = uq(key_name) + if cleaned == key_name: + return cleaned + if attempt != last_attempt: + key_name = cleaned + raise Exception('unable to fully clean name: original %s, last clean %s prior clean %s' % (original, cleaned, key_name)) class _VersionedKeyStore(dict): diff --git a/tests/test_s3/test_s3.py b/tests/test_s3/test_s3.py index 2764ee2c5..336639a8c 100644 --- a/tests/test_s3/test_s3.py +++ b/tests/test_s3/test_s3.py @@ -21,6 +21,7 @@ from botocore.handlers import disable_signing from boto.s3.connection import S3Connection from boto.s3.key import Key from freezegun import freeze_time +from parameterized import parameterized import six import requests import tests.backport_assert_raises # noqa @@ -3046,3 +3047,39 @@ def test_root_dir_with_empty_name_works(): if os.environ.get('TEST_SERVER_MODE', 'false').lower() == 'true': raise SkipTest('Does not work in server mode due to error in Workzeug') store_and_read_back_a_key('/') + + +@parameterized([ + ('foo/bar/baz',), + ('foo',), + ('foo/run_dt%3D2019-01-01%252012%253A30%253A00',), +]) +@mock_s3 +def test_delete_objects_with_url_encoded_key(key): + s3 = boto3.client('s3', region_name='us-east-1') + bucket_name = 'mybucket' + body = b'Some body' + + s3.create_bucket(Bucket=bucket_name) + + def put_object(): + s3.put_object( + Bucket=bucket_name, + Key=key, + Body=body + ) + + def assert_deleted(): + with assert_raises(ClientError) as e: + s3.get_object(Bucket=bucket_name, Key=key) + + e.exception.response['Error']['Code'].should.equal('NoSuchKey') + + put_object() + s3.delete_object(Bucket=bucket_name, Key=key) + assert_deleted() + + put_object() + s3.delete_objects(Bucket=bucket_name, Delete={'Objects': [{'Key': key}]}) + assert_deleted() + diff --git a/tests/test_s3/test_s3_utils.py b/tests/test_s3/test_s3_utils.py index ce9f54c75..d55b28b6d 100644 --- a/tests/test_s3/test_s3_utils.py +++ b/tests/test_s3/test_s3_utils.py @@ -1,7 +1,8 @@ from __future__ import unicode_literals import os from sure import expect -from moto.s3.utils import bucket_name_from_url, _VersionedKeyStore, parse_region_from_url +from moto.s3.utils import bucket_name_from_url, _VersionedKeyStore, parse_region_from_url, clean_key_name +from parameterized import parameterized def test_base_url(): @@ -78,3 +79,15 @@ def test_parse_region_from_url(): 'https://s3.amazonaws.com/bucket', 'https://bucket.s3.amazonaws.com']: parse_region_from_url(url).should.equal(expected) + + +@parameterized([ + ('foo/bar/baz', + 'foo/bar/baz'), + ('foo', + 'foo'), + ('foo/run_dt%3D2019-01-01%252012%253A30%253A00', + 'foo/run_dt=2019-01-01 12:30:00'), +]) +def test_clean_key_name(key, expected): + clean_key_name(key).should.equal(expected) From 3b4cd1c27bb2fa0e06ba9d84be70b67e1c4f0198 Mon Sep 17 00:00:00 2001 From: Jesse Vogt Date: Tue, 24 Sep 2019 17:07:58 -0500 Subject: [PATCH 66/67] switch from calling clean in loop to undoing clean in delete_keys --- moto/s3/responses.py | 4 ++-- moto/s3/utils.py | 24 +++++++++--------------- tests/test_s3/test_s3_utils.py | 20 +++++++++++++++++--- 3 files changed, 28 insertions(+), 20 deletions(-) diff --git a/moto/s3/responses.py b/moto/s3/responses.py index 4c546c595..61ebff9d0 100644 --- a/moto/s3/responses.py +++ b/moto/s3/responses.py @@ -20,7 +20,7 @@ from .exceptions import BucketAlreadyExists, S3ClientError, MissingBucket, Missi MalformedACLError, InvalidNotificationARN, InvalidNotificationEvent, ObjectNotInActiveTierError from .models import s3_backend, get_canned_acl, FakeGrantee, FakeGrant, FakeAcl, FakeKey, FakeTagging, FakeTagSet, \ FakeTag -from .utils import bucket_name_from_url, clean_key_name, metadata_from_headers, parse_region_from_url +from .utils import bucket_name_from_url, clean_key_name, undo_clean_key_name, metadata_from_headers, parse_region_from_url from xml.dom import minidom @@ -711,7 +711,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): for k in keys: key_name = k.firstChild.nodeValue - success = self.backend.delete_key(bucket_name, key_name) + success = self.backend.delete_key(bucket_name, undo_clean_key_name(key_name)) if success: deleted_names.append(key_name) else: diff --git a/moto/s3/utils.py b/moto/s3/utils.py index 811e44f71..3bdd24cc4 100644 --- a/moto/s3/utils.py +++ b/moto/s3/utils.py @@ -5,7 +5,7 @@ import os from boto.s3.key import Key import re import six -from six.moves.urllib.parse import urlparse, unquote +from six.moves.urllib.parse import urlparse, unquote, quote import sys @@ -68,22 +68,16 @@ def metadata_from_headers(headers): return metadata -def clean_key_name(key_name, attempts=4): +def clean_key_name(key_name): if six.PY2: - def uq(k): - return unquote(k.encode('utf-8')).decode('utf-8') - else: - uq = unquote + return unquote(key_name.encode('utf-8')).decode('utf-8') + return unquote(key_name) - original = cleaned = key_name - last_attempt = attempts - 1 - for attempt in range(attempts): - cleaned = uq(key_name) - if cleaned == key_name: - return cleaned - if attempt != last_attempt: - key_name = cleaned - raise Exception('unable to fully clean name: original %s, last clean %s prior clean %s' % (original, cleaned, key_name)) + +def undo_clean_key_name(key_name): + if six.PY2: + return quote(key_name.encode('utf-8')).decode('utf-8') + return quote(key_name) class _VersionedKeyStore(dict): diff --git a/tests/test_s3/test_s3_utils.py b/tests/test_s3/test_s3_utils.py index d55b28b6d..93a4683e6 100644 --- a/tests/test_s3/test_s3_utils.py +++ b/tests/test_s3/test_s3_utils.py @@ -1,7 +1,7 @@ from __future__ import unicode_literals import os from sure import expect -from moto.s3.utils import bucket_name_from_url, _VersionedKeyStore, parse_region_from_url, clean_key_name +from moto.s3.utils import bucket_name_from_url, _VersionedKeyStore, parse_region_from_url, clean_key_name, undo_clean_key_name from parameterized import parameterized @@ -87,7 +87,21 @@ def test_parse_region_from_url(): ('foo', 'foo'), ('foo/run_dt%3D2019-01-01%252012%253A30%253A00', - 'foo/run_dt=2019-01-01 12:30:00'), + 'foo/run_dt=2019-01-01%2012%3A30%3A00'), ]) def test_clean_key_name(key, expected): - clean_key_name(key).should.equal(expected) + clean_key_name(key).should.equal(expected) + + +@parameterized([ + ('foo/bar/baz', + 'foo/bar/baz'), + ('foo', + 'foo'), + ('foo/run_dt%3D2019-01-01%252012%253A30%253A00', + 'foo/run_dt%253D2019-01-01%25252012%25253A30%25253A00'), +]) +def test_undo_clean_key_name(key, expected): + undo_clean_key_name(key).should.equal(expected) + + From 4497f18c1a32653db601a6da0e0a9cc036383605 Mon Sep 17 00:00:00 2001 From: Jack Danger Date: Fri, 27 Sep 2019 11:14:53 -0700 Subject: [PATCH 67/67] fixing ErrorResponse top-level tag (#2434) In the golang SDK the previous format throws an unmarshaling error: /usr/local/Cellar/go/1.12.6/libexec/src/encoding/xml/read.go:209 &errors.errorString{s:"unknown error response tag, {{ Response} []}"} err: <*>SerializationError: failed to unmarshal error message --- moto/core/exceptions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/moto/core/exceptions.py b/moto/core/exceptions.py index 06cfd8895..a81d89093 100644 --- a/moto/core/exceptions.py +++ b/moto/core/exceptions.py @@ -14,7 +14,7 @@ SINGLE_ERROR_RESPONSE = u""" """ ERROR_RESPONSE = u""" - + {{error_type}} @@ -23,7 +23,7 @@ ERROR_RESPONSE = u""" 7a62c49f-347e-4fc4-9331-6e8eEXAMPLE - + """ ERROR_JSON_RESPONSE = u"""{