Merge branch 'master' into get-caller-identity

This commit is contained in:
Bendegúz Ács 2019-08-21 12:36:40 +02:00 committed by GitHub
commit 24dcdb7453
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
82 changed files with 3064 additions and 310 deletions

1
.gitignore vendored
View File

@ -15,6 +15,7 @@ python_env
.ropeproject/ .ropeproject/
.pytest_cache/ .pytest_cache/
venv/ venv/
env/
.python-version .python-version
.vscode/ .vscode/
tests/file.tmp tests/file.tmp

View File

@ -47,11 +47,11 @@ deploy:
- master - master
skip_cleanup: true skip_cleanup: true
skip_existing: true skip_existing: true
- provider: pypi # - provider: pypi
distributions: sdist bdist_wheel # distributions: sdist bdist_wheel
user: spulec # user: spulec
password: # password:
secure: NxnPylnTfekJmGyoufCw0lMoYRskSMJzvAIyAlJJVYKwEhmiCPOrdy5qV8i8mRZ1AkUsqU3jBZ/PD56n96clHW0E3d080UleRDj6JpyALVdeLfMqZl9kLmZ8bqakWzYq3VSJKw2zGP/L4tPGf8wTK1SUv9yl/YNDsBdCkjDverw= # secure: NxnPylnTfekJmGyoufCw0lMoYRskSMJzvAIyAlJJVYKwEhmiCPOrdy5qV8i8mRZ1AkUsqU3jBZ/PD56n96clHW0E3d080UleRDj6JpyALVdeLfMqZl9kLmZ8bqakWzYq3VSJKw2zGP/L4tPGf8wTK1SUv9yl/YNDsBdCkjDverw=
on: # on:
tags: true # tags: true
skip_existing: true # skip_existing: true

View File

@ -181,7 +181,7 @@
- [ ] test_invoke_method - [ ] test_invoke_method
- [ ] untag_resource - [ ] untag_resource
- [ ] update_account - [ ] update_account
- [ ] update_api_key - [X] update_api_key
- [ ] update_authorizer - [ ] update_authorizer
- [ ] update_base_path_mapping - [ ] update_base_path_mapping
- [ ] update_client_certificate - [ ] update_client_certificate
@ -815,16 +815,16 @@
- [ ] update_user_profile - [ ] update_user_profile
## cognito-identity - 0% implemented ## cognito-identity - 0% implemented
- [ ] create_identity_pool - [X] create_identity_pool
- [ ] delete_identities - [ ] delete_identities
- [ ] delete_identity_pool - [ ] delete_identity_pool
- [ ] describe_identity - [ ] describe_identity
- [ ] describe_identity_pool - [ ] describe_identity_pool
- [ ] get_credentials_for_identity - [X] get_credentials_for_identity
- [ ] get_id - [X] get_id
- [ ] get_identity_pool_roles - [ ] get_identity_pool_roles
- [ ] get_open_id_token - [X] get_open_id_token
- [ ] get_open_id_token_for_developer_identity - [X] get_open_id_token_for_developer_identity
- [ ] list_identities - [ ] list_identities
- [ ] list_identity_pools - [ ] list_identity_pools
- [ ] lookup_developer_identity - [ ] lookup_developer_identity
@ -928,6 +928,7 @@
- [ ] update_user_attributes - [ ] update_user_attributes
- [ ] update_user_pool - [ ] update_user_pool
- [X] update_user_pool_client - [X] update_user_pool_client
- [X] update_user_pool_domain
- [ ] verify_software_token - [ ] verify_software_token
- [ ] verify_user_attribute - [ ] verify_user_attribute
@ -4127,7 +4128,7 @@
## sts - 42% implemented ## sts - 42% implemented
- [X] assume_role - [X] assume_role
- [ ] assume_role_with_saml - [ ] assume_role_with_saml
- [ ] assume_role_with_web_identity - [X] assume_role_with_web_identity
- [ ] decode_authorization_message - [ ] decode_authorization_message
- [ ] get_caller_identity - [ ] get_caller_identity
- [X] get_federation_token - [X] get_federation_token

View File

@ -5,6 +5,9 @@
[![Build Status](https://travis-ci.org/spulec/moto.svg?branch=master)](https://travis-ci.org/spulec/moto) [![Build Status](https://travis-ci.org/spulec/moto.svg?branch=master)](https://travis-ci.org/spulec/moto)
[![Coverage Status](https://coveralls.io/repos/spulec/moto/badge.svg?branch=master)](https://coveralls.io/r/spulec/moto) [![Coverage Status](https://coveralls.io/repos/spulec/moto/badge.svg?branch=master)](https://coveralls.io/r/spulec/moto)
[![Docs](https://readthedocs.org/projects/pip/badge/?version=stable)](http://docs.getmoto.org) [![Docs](https://readthedocs.org/projects/pip/badge/?version=stable)](http://docs.getmoto.org)
![PyPI](https://img.shields.io/pypi/v/moto.svg)
![PyPI - Python Version](https://img.shields.io/pypi/pyversions/moto.svg)
![PyPI - Downloads](https://img.shields.io/pypi/dw/moto.svg)
# In a nutshell # In a nutshell
@ -75,6 +78,7 @@ It gets even better! Moto isn't just for Python code and it isn't just for S3. L
| Cognito Identity Provider | @mock_cognitoidp | basic endpoints done | | Cognito Identity Provider | @mock_cognitoidp | basic endpoints done |
|-------------------------------------------------------------------------------------| |-------------------------------------------------------------------------------------|
| Config | @mock_config | basic endpoints done | | Config | @mock_config | basic endpoints done |
| | | core endpoints done |
|-------------------------------------------------------------------------------------| |-------------------------------------------------------------------------------------|
| Data Pipeline | @mock_datapipeline | basic endpoints done | | Data Pipeline | @mock_datapipeline | basic endpoints done |
|-------------------------------------------------------------------------------------| |-------------------------------------------------------------------------------------|
@ -293,6 +297,96 @@ def test_describe_instances_allowed():
See [the related test suite](https://github.com/spulec/moto/blob/master/tests/test_core/test_auth.py) for more examples. See [the related test suite](https://github.com/spulec/moto/blob/master/tests/test_core/test_auth.py) for more examples.
## Very Important -- Recommended Usage
There are some important caveats to be aware of when using moto:
*Failure to follow these guidelines could result in your tests mutating your __REAL__ infrastructure!*
### How do I avoid tests from mutating my real infrastructure?
You need to ensure that the mocks are actually in place. Changes made to recent versions of `botocore`
have altered some of the mock behavior. In short, you need to ensure that you _always_ do the following:
1. Ensure that your tests have dummy environment variables set up:
export AWS_ACCESS_KEY_ID='testing'
export AWS_SECRET_ACCESS_KEY='testing'
export AWS_SECURITY_TOKEN='testing'
export AWS_SESSION_TOKEN='testing'
1. __VERY IMPORTANT__: ensure that you have your mocks set up __BEFORE__ your `boto3` client is established.
This can typically happen if you import a module that has a `boto3` client instantiated outside of a function.
See the pesky imports section below on how to work around this.
### Example on usage?
If you are a user of [pytest](https://pytest.org/en/latest/), you can leverage [pytest fixtures](https://pytest.org/en/latest/fixture.html#fixture)
to help set up your mocks and other AWS resources that you would need.
Here is an example:
```python
@pytest.fixture(scope='function')
def aws_credentials():
"""Mocked AWS Credentials for moto."""
os.environ['AWS_ACCESS_KEY_ID'] = 'testing'
os.environ['AWS_SECRET_ACCESS_KEY'] = 'testing'
os.environ['AWS_SECURITY_TOKEN'] = 'testing'
os.environ['AWS_SESSION_TOKEN'] = 'testing'
@pytest.fixture(scope='function')
def s3(aws_credentials):
with mock_s3():
yield boto3.client('s3', region_name='us-east-1')
@pytest.fixture(scope='function')
def sts(aws_credentials):
with mock_sts():
yield boto3.client('sts', region_name='us-east-1')
@pytest.fixture(scope='function')
def cloudwatch(aws_credentials):
with mock_cloudwatch():
yield boto3.client('cloudwatch', region_name='us-east-1')
... etc.
```
In the code sample above, all of the AWS/mocked fixtures take in a parameter of `aws_credentials`,
which sets the proper fake environment variables. The fake environment variables are used so that `botocore` doesn't try to locate real
credentials on your system.
Next, once you need to do anything with the mocked AWS environment, do something like:
```python
def test_create_bucket(s3):
# s3 is a fixture defined above that yields a boto3 s3 client.
# Feel free to instantiate another boto3 S3 client -- Keep note of the region though.
s3.create_bucket(Bucket="somebucket")
result = s3.list_buckets()
assert len(result['Buckets']) == 1
assert result['Buckets'][0]['Name'] == 'somebucket'
```
### What about those pesky imports?
Recall earlier, it was mentioned that mocks should be established __BEFORE__ the clients are set up. One way
to avoid import issues is to make use of local Python imports -- i.e. import the module inside of the unit
test you want to run vs. importing at the top of the file.
Example:
```python
def test_something(s3):
from some.package.that.does.something.with.s3 import some_func # <-- Local import for unit test
# ^^ Importing here ensures that the mock has been established.
sume_func() # The mock has been established from the "s3" pytest fixture, so this function that uses
# a package-level S3 client will properly use the mock and not reach out to AWS.
```
### Other caveats
For Tox, Travis CI, and other build systems, you might need to also perform a `touch ~/.aws/credentials`
command before running the tests. As long as that file is present (empty preferably) and the environment
variables above are set, you should be good to go.
## Stand-alone Server Mode ## Stand-alone Server Mode
Moto also has a stand-alone server mode. This allows you to utilize Moto also has a stand-alone server mode. This allows you to utilize

View File

@ -3,7 +3,7 @@ import logging
# logging.getLogger('boto').setLevel(logging.CRITICAL) # logging.getLogger('boto').setLevel(logging.CRITICAL)
__title__ = 'moto' __title__ = 'moto'
__version__ = '1.3.11' __version__ = '1.3.14.dev'
from .acm import mock_acm # flake8: noqa from .acm import mock_acm # flake8: noqa
from .apigateway import mock_apigateway, mock_apigateway_deprecated # flake8: noqa from .apigateway import mock_apigateway, mock_apigateway_deprecated # flake8: noqa

View File

@ -309,6 +309,25 @@ class ApiKey(BaseModel, dict):
self['createdDate'] = self['lastUpdatedDate'] = int(time.time()) self['createdDate'] = self['lastUpdatedDate'] = int(time.time())
self['stageKeys'] = stageKeys self['stageKeys'] = stageKeys
def update_operations(self, patch_operations):
for op in patch_operations:
if op['op'] == 'replace':
if '/name' in op['path']:
self['name'] = op['value']
elif '/customerId' in op['path']:
self['customerId'] = op['value']
elif '/description' in op['path']:
self['description'] = op['value']
elif '/enabled' in op['path']:
self['enabled'] = self._str2bool(op['value'])
else:
raise Exception(
'Patch operation "%s" not implemented' % op['op'])
return self
def _str2bool(self, v):
return v.lower() == "true"
class UsagePlan(BaseModel, dict): class UsagePlan(BaseModel, dict):
@ -599,6 +618,10 @@ class APIGatewayBackend(BaseBackend):
def get_apikey(self, api_key_id): def get_apikey(self, api_key_id):
return self.keys[api_key_id] return self.keys[api_key_id]
def update_apikey(self, api_key_id, patch_operations):
key = self.keys[api_key_id]
return key.update_operations(patch_operations)
def delete_apikey(self, api_key_id): def delete_apikey(self, api_key_id):
self.keys.pop(api_key_id) self.keys.pop(api_key_id)
return {} return {}

View File

@ -245,6 +245,9 @@ class APIGatewayResponse(BaseResponse):
if self.method == 'GET': if self.method == 'GET':
apikey_response = self.backend.get_apikey(apikey) apikey_response = self.backend.get_apikey(apikey)
elif self.method == 'PATCH':
patch_operations = self._get_param('patchOperations')
apikey_response = self.backend.update_apikey(apikey, patch_operations)
elif self.method == 'DELETE': elif self.method == 'DELETE':
apikey_response = self.backend.delete_apikey(apikey) apikey_response = self.backend.delete_apikey(apikey)
return 200, {}, json.dumps(apikey_response) return 200, {}, json.dumps(apikey_response)

View File

@ -1,9 +1,10 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import six import six
import random import random
import string
def create_id(): def create_id():
size = 10 size = 10
chars = list(range(10)) + ['A-Z'] chars = list(range(10)) + list(string.ascii_lowercase)
return ''.join(six.text_type(random.choice(chars)) for x in range(size)) return ''.join(six.text_type(random.choice(chars)) for x in range(size))

View File

@ -13,3 +13,12 @@ class ResourceContentionError(RESTError):
super(ResourceContentionError, self).__init__( super(ResourceContentionError, self).__init__(
"ResourceContentionError", "ResourceContentionError",
"You already have a pending update to an Auto Scaling resource (for example, a group, instance, or load balancer).") "You already have a pending update to an Auto Scaling resource (for example, a group, instance, or load balancer).")
class InvalidInstanceError(AutoscalingClientError):
def __init__(self, instance_id):
super(InvalidInstanceError, self).__init__(
"ValidationError",
"Instance [{0}] is invalid."
.format(instance_id))

View File

@ -3,6 +3,8 @@ from __future__ import unicode_literals
import random import random
from boto.ec2.blockdevicemapping import BlockDeviceType, BlockDeviceMapping from boto.ec2.blockdevicemapping import BlockDeviceType, BlockDeviceMapping
from moto.ec2.exceptions import InvalidInstanceIdError
from moto.compat import OrderedDict from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.ec2 import ec2_backends from moto.ec2 import ec2_backends
@ -10,7 +12,7 @@ from moto.elb import elb_backends
from moto.elbv2 import elbv2_backends from moto.elbv2 import elbv2_backends
from moto.elb.exceptions import LoadBalancerNotFoundError from moto.elb.exceptions import LoadBalancerNotFoundError
from .exceptions import ( from .exceptions import (
AutoscalingClientError, ResourceContentionError, AutoscalingClientError, ResourceContentionError, InvalidInstanceError
) )
# http://docs.aws.amazon.com/AutoScaling/latest/DeveloperGuide/AS_Concepts.html#Cooldown # http://docs.aws.amazon.com/AutoScaling/latest/DeveloperGuide/AS_Concepts.html#Cooldown
@ -73,6 +75,26 @@ class FakeLaunchConfiguration(BaseModel):
self.associate_public_ip_address = associate_public_ip_address self.associate_public_ip_address = associate_public_ip_address
self.block_device_mapping_dict = block_device_mapping_dict self.block_device_mapping_dict = block_device_mapping_dict
@classmethod
def create_from_instance(cls, name, instance, backend):
config = backend.create_launch_configuration(
name=name,
image_id=instance.image_id,
kernel_id='',
ramdisk_id='',
key_name=instance.key_name,
security_groups=instance.security_groups,
user_data=instance.user_data,
instance_type=instance.instance_type,
instance_monitoring=False,
instance_profile_name=None,
spot_price=None,
ebs_optimized=instance.ebs_optimized,
associate_public_ip_address=instance.associate_public_ip,
block_device_mappings=instance.block_device_mapping
)
return config
@classmethod @classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
properties = cloudformation_json['Properties'] properties = cloudformation_json['Properties']
@ -279,6 +301,12 @@ class FakeAutoScalingGroup(BaseModel):
if min_size is not None: if min_size is not None:
self.min_size = min_size self.min_size = min_size
if desired_capacity is None:
if min_size is not None and min_size > len(self.instance_states):
desired_capacity = min_size
if max_size is not None and max_size < len(self.instance_states):
desired_capacity = max_size
if launch_config_name: if launch_config_name:
self.launch_config = self.autoscaling_backend.launch_configurations[ self.launch_config = self.autoscaling_backend.launch_configurations[
launch_config_name] launch_config_name]
@ -414,7 +442,8 @@ class AutoScalingBackend(BaseBackend):
health_check_type, load_balancers, health_check_type, load_balancers,
target_group_arns, placement_group, target_group_arns, placement_group,
termination_policies, tags, termination_policies, tags,
new_instances_protected_from_scale_in=False): new_instances_protected_from_scale_in=False,
instance_id=None):
def make_int(value): def make_int(value):
return int(value) if value is not None else value return int(value) if value is not None else value
@ -427,6 +456,13 @@ class AutoScalingBackend(BaseBackend):
health_check_period = 300 health_check_period = 300
else: else:
health_check_period = make_int(health_check_period) health_check_period = make_int(health_check_period)
if launch_config_name is None and instance_id is not None:
try:
instance = self.ec2_backend.get_instance(instance_id)
launch_config_name = name
FakeLaunchConfiguration.create_from_instance(launch_config_name, instance, self)
except InvalidInstanceIdError:
raise InvalidInstanceError(instance_id)
group = FakeAutoScalingGroup( group = FakeAutoScalingGroup(
name=name, name=name,
@ -684,6 +720,18 @@ class AutoScalingBackend(BaseBackend):
for instance in protected_instances: for instance in protected_instances:
instance.protected_from_scale_in = protected_from_scale_in instance.protected_from_scale_in = protected_from_scale_in
def notify_terminate_instances(self, instance_ids):
for autoscaling_group_name, autoscaling_group in self.autoscaling_groups.items():
original_instance_count = len(autoscaling_group.instance_states)
autoscaling_group.instance_states = list(filter(
lambda i_state: i_state.instance.id not in instance_ids,
autoscaling_group.instance_states
))
difference = original_instance_count - len(autoscaling_group.instance_states)
if difference > 0:
autoscaling_group.replace_autoscaling_group_instances(difference, autoscaling_group.get_propagated_tags())
self.update_attached_elbs(autoscaling_group_name)
autoscaling_backends = {} autoscaling_backends = {}
for region, ec2_backend in ec2_backends.items(): for region, ec2_backend in ec2_backends.items():

View File

@ -48,7 +48,7 @@ class AutoScalingResponse(BaseResponse):
start = all_names.index(marker) + 1 start = all_names.index(marker) + 1
else: else:
start = 0 start = 0
max_records = self._get_param('MaxRecords', 50) # the default is 100, but using 50 to make testing easier max_records = self._get_int_param('MaxRecords', 50) # the default is 100, but using 50 to make testing easier
launch_configurations_resp = all_launch_configurations[start:start + max_records] launch_configurations_resp = all_launch_configurations[start:start + max_records]
next_token = None next_token = None
if len(all_launch_configurations) > start + max_records: if len(all_launch_configurations) > start + max_records:
@ -74,6 +74,7 @@ class AutoScalingResponse(BaseResponse):
desired_capacity=self._get_int_param('DesiredCapacity'), desired_capacity=self._get_int_param('DesiredCapacity'),
max_size=self._get_int_param('MaxSize'), max_size=self._get_int_param('MaxSize'),
min_size=self._get_int_param('MinSize'), min_size=self._get_int_param('MinSize'),
instance_id=self._get_param('InstanceId'),
launch_config_name=self._get_param('LaunchConfigurationName'), launch_config_name=self._get_param('LaunchConfigurationName'),
vpc_zone_identifier=self._get_param('VPCZoneIdentifier'), vpc_zone_identifier=self._get_param('VPCZoneIdentifier'),
default_cooldown=self._get_int_param('DefaultCooldown'), default_cooldown=self._get_int_param('DefaultCooldown'),

View File

@ -514,10 +514,13 @@ class BatchBackend(BaseBackend):
return self._job_definitions.get(arn) return self._job_definitions.get(arn)
def get_job_definition_by_name(self, name): def get_job_definition_by_name(self, name):
for comp_env in self._job_definitions.values(): latest_revision = -1
if comp_env.name == name: latest_job = None
return comp_env for job_def in self._job_definitions.values():
return None if job_def.name == name and job_def.revision > latest_revision:
latest_job = job_def
latest_revision = job_def.revision
return latest_job
def get_job_definition_by_name_revision(self, name, revision): def get_job_definition_by_name_revision(self, name, revision):
for job_def in self._job_definitions.values(): for job_def in self._job_definitions.values():
@ -534,10 +537,13 @@ class BatchBackend(BaseBackend):
:return: Job definition or None :return: Job definition or None
:rtype: JobDefinition or None :rtype: JobDefinition or None
""" """
env = self.get_job_definition_by_arn(identifier) job_def = self.get_job_definition_by_arn(identifier)
if env is None: if job_def is None:
env = self.get_job_definition_by_name(identifier) if ':' in identifier:
return env job_def = self.get_job_definition_by_name_revision(*identifier.split(':', 1))
else:
job_def = self.get_job_definition_by_name(identifier)
return job_def
def get_job_definitions(self, identifier): def get_job_definitions(self, identifier):
""" """
@ -984,9 +990,7 @@ class BatchBackend(BaseBackend):
# TODO parameters, retries (which is a dict raw from request), job dependancies and container overrides are ignored for now # TODO parameters, retries (which is a dict raw from request), job dependancies and container overrides are ignored for now
# Look for job definition # Look for job definition
job_def = self.get_job_definition_by_arn(job_def_id) job_def = self.get_job_definition(job_def_id)
if job_def is None and ':' in job_def_id:
job_def = self.get_job_definition_by_name_revision(*job_def_id.split(':', 1))
if job_def is None: if job_def is None:
raise ClientException('Job definition {0} does not exist'.format(job_def_id)) raise ClientException('Job definition {0} does not exist'.format(job_def_id))

View File

@ -95,6 +95,15 @@ class CognitoIdentityBackend(BaseBackend):
}) })
return response return response
def get_open_id_token(self, identity_id):
response = json.dumps(
{
"IdentityId": identity_id,
"Token": get_random_identity_id(self.region)
}
)
return response
cognitoidentity_backends = {} cognitoidentity_backends = {}
for region in boto.cognito.identity.regions(): for region in boto.cognito.identity.regions():

View File

@ -35,3 +35,8 @@ class CognitoIdentityResponse(BaseResponse):
return cognitoidentity_backends[self.region].get_open_id_token_for_developer_identity( return cognitoidentity_backends[self.region].get_open_id_token_for_developer_identity(
self._get_param('IdentityId') or get_random_identity_id(self.region) self._get_param('IdentityId') or get_random_identity_id(self.region)
) )
def get_open_id_token(self):
return cognitoidentity_backends[self.region].get_open_id_token(
self._get_param("IdentityId") or get_random_identity_id(self.region)
)

View File

@ -2,6 +2,7 @@ from __future__ import unicode_literals
import datetime import datetime
import functools import functools
import hashlib
import itertools import itertools
import json import json
import os import os
@ -154,20 +155,37 @@ class CognitoIdpUserPool(BaseModel):
class CognitoIdpUserPoolDomain(BaseModel): class CognitoIdpUserPoolDomain(BaseModel):
def __init__(self, user_pool_id, domain): def __init__(self, user_pool_id, domain, custom_domain_config=None):
self.user_pool_id = user_pool_id self.user_pool_id = user_pool_id
self.domain = domain self.domain = domain
self.custom_domain_config = custom_domain_config or {}
def to_json(self): def _distribution_name(self):
return { if self.custom_domain_config and \
"UserPoolId": self.user_pool_id, 'CertificateArn' in self.custom_domain_config:
"AWSAccountId": str(uuid.uuid4()), hash = hashlib.md5(
"CloudFrontDistribution": None, self.custom_domain_config['CertificateArn'].encode('utf-8')
"Domain": self.domain, ).hexdigest()
"S3Bucket": None, return "{hash}.cloudfront.net".format(hash=hash[:16])
"Status": "ACTIVE", return None
"Version": None,
} def to_json(self, extended=True):
distribution = self._distribution_name()
if extended:
return {
"UserPoolId": self.user_pool_id,
"AWSAccountId": str(uuid.uuid4()),
"CloudFrontDistribution": distribution,
"Domain": self.domain,
"S3Bucket": None,
"Status": "ACTIVE",
"Version": None,
}
elif distribution:
return {
"CloudFrontDomain": distribution,
}
return None
class CognitoIdpUserPoolClient(BaseModel): class CognitoIdpUserPoolClient(BaseModel):
@ -338,11 +356,13 @@ class CognitoIdpBackend(BaseBackend):
del self.user_pools[user_pool_id] del self.user_pools[user_pool_id]
# User pool domain # User pool domain
def create_user_pool_domain(self, user_pool_id, domain): def create_user_pool_domain(self, user_pool_id, domain, custom_domain_config=None):
if user_pool_id not in self.user_pools: if user_pool_id not in self.user_pools:
raise ResourceNotFoundError(user_pool_id) raise ResourceNotFoundError(user_pool_id)
user_pool_domain = CognitoIdpUserPoolDomain(user_pool_id, domain) user_pool_domain = CognitoIdpUserPoolDomain(
user_pool_id, domain, custom_domain_config=custom_domain_config
)
self.user_pool_domains[domain] = user_pool_domain self.user_pool_domains[domain] = user_pool_domain
return user_pool_domain return user_pool_domain
@ -358,6 +378,14 @@ class CognitoIdpBackend(BaseBackend):
del self.user_pool_domains[domain] del self.user_pool_domains[domain]
def update_user_pool_domain(self, domain, custom_domain_config):
if domain not in self.user_pool_domains:
raise ResourceNotFoundError(domain)
user_pool_domain = self.user_pool_domains[domain]
user_pool_domain.custom_domain_config = custom_domain_config
return user_pool_domain
# User pool client # User pool client
def create_user_pool_client(self, user_pool_id, extended_config): def create_user_pool_client(self, user_pool_id, extended_config):
user_pool = self.user_pools.get(user_pool_id) user_pool = self.user_pools.get(user_pool_id)

View File

@ -50,7 +50,13 @@ class CognitoIdpResponse(BaseResponse):
def create_user_pool_domain(self): def create_user_pool_domain(self):
domain = self._get_param("Domain") domain = self._get_param("Domain")
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
cognitoidp_backends[self.region].create_user_pool_domain(user_pool_id, domain) custom_domain_config = self._get_param("CustomDomainConfig")
user_pool_domain = cognitoidp_backends[self.region].create_user_pool_domain(
user_pool_id, domain, custom_domain_config
)
domain_description = user_pool_domain.to_json(extended=False)
if domain_description:
return json.dumps(domain_description)
return "" return ""
def describe_user_pool_domain(self): def describe_user_pool_domain(self):
@ -69,6 +75,17 @@ class CognitoIdpResponse(BaseResponse):
cognitoidp_backends[self.region].delete_user_pool_domain(domain) cognitoidp_backends[self.region].delete_user_pool_domain(domain)
return "" return ""
def update_user_pool_domain(self):
domain = self._get_param("Domain")
custom_domain_config = self._get_param("CustomDomainConfig")
user_pool_domain = cognitoidp_backends[self.region].update_user_pool_domain(
domain, custom_domain_config
)
domain_description = user_pool_domain.to_json(extended=False)
if domain_description:
return json.dumps(domain_description)
return ""
# User pool client # User pool client
def create_user_pool_client(self): def create_user_pool_client(self):
user_pool_id = self.parameters.pop("UserPoolId") user_pool_id = self.parameters.pop("UserPoolId")

View File

@ -52,6 +52,18 @@ class InvalidResourceTypeException(JsonRESTError):
super(InvalidResourceTypeException, self).__init__("ValidationException", message) super(InvalidResourceTypeException, self).__init__("ValidationException", message)
class NoSuchConfigurationAggregatorException(JsonRESTError):
code = 400
def __init__(self, number=1):
if number == 1:
message = 'The configuration aggregator does not exist. Check the configuration aggregator name and try again.'
else:
message = 'At least one of the configuration aggregators does not exist. Check the configuration aggregator' \
' names and try again.'
super(NoSuchConfigurationAggregatorException, self).__init__("NoSuchConfigurationAggregatorException", message)
class NoSuchConfigurationRecorderException(JsonRESTError): class NoSuchConfigurationRecorderException(JsonRESTError):
code = 400 code = 400
@ -78,6 +90,14 @@ class NoSuchBucketException(JsonRESTError):
super(NoSuchBucketException, self).__init__("NoSuchBucketException", message) super(NoSuchBucketException, self).__init__("NoSuchBucketException", message)
class InvalidNextTokenException(JsonRESTError):
code = 400
def __init__(self):
message = 'The nextToken provided is invalid'
super(InvalidNextTokenException, self).__init__("InvalidNextTokenException", message)
class InvalidS3KeyPrefixException(JsonRESTError): class InvalidS3KeyPrefixException(JsonRESTError):
code = 400 code = 400
@ -147,3 +167,66 @@ class LastDeliveryChannelDeleteFailedException(JsonRESTError):
message = 'Failed to delete last specified delivery channel with name \'{name}\', because there, ' \ message = 'Failed to delete last specified delivery channel with name \'{name}\', because there, ' \
'because there is a running configuration recorder.'.format(name=name) 'because there is a running configuration recorder.'.format(name=name)
super(LastDeliveryChannelDeleteFailedException, self).__init__("LastDeliveryChannelDeleteFailedException", message) super(LastDeliveryChannelDeleteFailedException, self).__init__("LastDeliveryChannelDeleteFailedException", message)
class TooManyAccountSources(JsonRESTError):
code = 400
def __init__(self, length):
locations = ['com.amazonaws.xyz'] * length
message = 'Value \'[{locations}]\' at \'accountAggregationSources\' failed to satisfy constraint: ' \
'Member must have length less than or equal to 1'.format(locations=', '.join(locations))
super(TooManyAccountSources, self).__init__("ValidationException", message)
class DuplicateTags(JsonRESTError):
code = 400
def __init__(self):
super(DuplicateTags, self).__init__(
'InvalidInput', 'Duplicate tag keys found. Please note that Tag keys are case insensitive.')
class TagKeyTooBig(JsonRESTError):
code = 400
def __init__(self, tag, param='tags.X.member.key'):
super(TagKeyTooBig, self).__init__(
'ValidationException', "1 validation error detected: Value '{}' at '{}' failed to satisfy "
"constraint: Member must have length less than or equal to 128".format(tag, param))
class TagValueTooBig(JsonRESTError):
code = 400
def __init__(self, tag):
super(TagValueTooBig, self).__init__(
'ValidationException', "1 validation error detected: Value '{}' at 'tags.X.member.value' failed to satisfy "
"constraint: Member must have length less than or equal to 256".format(tag))
class InvalidParameterValueException(JsonRESTError):
code = 400
def __init__(self, message):
super(InvalidParameterValueException, self).__init__('InvalidParameterValueException', message)
class InvalidTagCharacters(JsonRESTError):
code = 400
def __init__(self, tag, param='tags.X.member.key'):
message = "1 validation error detected: Value '{}' at '{}' failed to satisfy ".format(tag, param)
message += 'constraint: Member must satisfy regular expression pattern: [\\\\p{L}\\\\p{Z}\\\\p{N}_.:/=+\\\\-@]+'
super(InvalidTagCharacters, self).__init__('ValidationException', message)
class TooManyTags(JsonRESTError):
code = 400
def __init__(self, tags, param='tags'):
super(TooManyTags, self).__init__(
'ValidationException', "1 validation error detected: Value '{}' at '{}' failed to satisfy "
"constraint: Member must have length less than or equal to 50.".format(tags, param))

View File

@ -1,6 +1,9 @@
import json import json
import re
import time import time
import pkg_resources import pkg_resources
import random
import string
from datetime import datetime from datetime import datetime
@ -12,37 +15,125 @@ from moto.config.exceptions import InvalidResourceTypeException, InvalidDelivery
NoSuchConfigurationRecorderException, NoAvailableConfigurationRecorderException, \ NoSuchConfigurationRecorderException, NoAvailableConfigurationRecorderException, \
InvalidDeliveryChannelNameException, NoSuchBucketException, InvalidS3KeyPrefixException, \ InvalidDeliveryChannelNameException, NoSuchBucketException, InvalidS3KeyPrefixException, \
InvalidSNSTopicARNException, MaxNumberOfDeliveryChannelsExceededException, NoAvailableDeliveryChannelException, \ InvalidSNSTopicARNException, MaxNumberOfDeliveryChannelsExceededException, NoAvailableDeliveryChannelException, \
NoSuchDeliveryChannelException, LastDeliveryChannelDeleteFailedException NoSuchDeliveryChannelException, LastDeliveryChannelDeleteFailedException, TagKeyTooBig, \
TooManyTags, TagValueTooBig, TooManyAccountSources, InvalidParameterValueException, InvalidNextTokenException, \
NoSuchConfigurationAggregatorException, InvalidTagCharacters, DuplicateTags
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
DEFAULT_ACCOUNT_ID = 123456789012 DEFAULT_ACCOUNT_ID = 123456789012
POP_STRINGS = [
'capitalizeStart',
'CapitalizeStart',
'capitalizeArn',
'CapitalizeArn',
'capitalizeARN',
'CapitalizeARN'
]
DEFAULT_PAGE_SIZE = 100
def datetime2int(date): def datetime2int(date):
return int(time.mktime(date.timetuple())) return int(time.mktime(date.timetuple()))
def snake_to_camels(original): def snake_to_camels(original, cap_start, cap_arn):
parts = original.split('_') parts = original.split('_')
camel_cased = parts[0].lower() + ''.join(p.title() for p in parts[1:]) camel_cased = parts[0].lower() + ''.join(p.title() for p in parts[1:])
camel_cased = camel_cased.replace('Arn', 'ARN') # Config uses 'ARN' instead of 'Arn'
if cap_arn:
camel_cased = camel_cased.replace('Arn', 'ARN') # Some config services use 'ARN' instead of 'Arn'
if cap_start:
camel_cased = camel_cased[0].upper() + camel_cased[1::]
return camel_cased return camel_cased
def random_string():
"""Returns a random set of 8 lowercase letters for the Config Aggregator ARN"""
chars = []
for x in range(0, 8):
chars.append(random.choice(string.ascii_lowercase))
return "".join(chars)
def validate_tag_key(tag_key, exception_param='tags.X.member.key'):
"""Validates the tag key.
:param tag_key: The tag key to check against.
:param exception_param: The exception parameter to send over to help format the message. This is to reflect
the difference between the tag and untag APIs.
:return:
"""
# Validate that the key length is correct:
if len(tag_key) > 128:
raise TagKeyTooBig(tag_key, param=exception_param)
# Validate that the tag key fits the proper Regex:
# [\w\s_.:/=+\-@]+ SHOULD be the same as the Java regex on the AWS documentation: [\p{L}\p{Z}\p{N}_.:/=+\-@]+
match = re.findall(r'[\w\s_.:/=+\-@]+', tag_key)
# Kudos if you can come up with a better way of doing a global search :)
if not len(match) or len(match[0]) < len(tag_key):
raise InvalidTagCharacters(tag_key, param=exception_param)
def check_tag_duplicate(all_tags, tag_key):
"""Validates that a tag key is not a duplicate
:param all_tags: Dict to check if there is a duplicate tag.
:param tag_key: The tag key to check against.
:return:
"""
if all_tags.get(tag_key):
raise DuplicateTags()
def validate_tags(tags):
proper_tags = {}
if len(tags) > 50:
raise TooManyTags(tags)
for tag in tags:
# Validate the Key:
validate_tag_key(tag['Key'])
check_tag_duplicate(proper_tags, tag['Key'])
# Validate the Value:
if len(tag['Value']) > 256:
raise TagValueTooBig(tag['Value'])
proper_tags[tag['Key']] = tag['Value']
return proper_tags
class ConfigEmptyDictable(BaseModel): class ConfigEmptyDictable(BaseModel):
"""Base class to make serialization easy. This assumes that the sub-class will NOT return 'None's in the JSON.""" """Base class to make serialization easy. This assumes that the sub-class will NOT return 'None's in the JSON."""
def __init__(self, capitalize_start=False, capitalize_arn=True):
"""Assists with the serialization of the config object
:param capitalize_start: For some Config services, the first letter is lowercase -- for others it's capital
:param capitalize_arn: For some Config services, the API expects 'ARN' and for others, it expects 'Arn'
"""
self.capitalize_start = capitalize_start
self.capitalize_arn = capitalize_arn
def to_dict(self): def to_dict(self):
data = {} data = {}
for item, value in self.__dict__.items(): for item, value in self.__dict__.items():
if value is not None: if value is not None:
if isinstance(value, ConfigEmptyDictable): if isinstance(value, ConfigEmptyDictable):
data[snake_to_camels(item)] = value.to_dict() data[snake_to_camels(item, self.capitalize_start, self.capitalize_arn)] = value.to_dict()
else: else:
data[snake_to_camels(item)] = value data[snake_to_camels(item, self.capitalize_start, self.capitalize_arn)] = value
# Cleanse the extra properties:
for prop in POP_STRINGS:
data.pop(prop, None)
return data return data
@ -50,8 +141,9 @@ class ConfigEmptyDictable(BaseModel):
class ConfigRecorderStatus(ConfigEmptyDictable): class ConfigRecorderStatus(ConfigEmptyDictable):
def __init__(self, name): def __init__(self, name):
self.name = name super(ConfigRecorderStatus, self).__init__()
self.name = name
self.recording = False self.recording = False
self.last_start_time = None self.last_start_time = None
self.last_stop_time = None self.last_stop_time = None
@ -75,12 +167,16 @@ class ConfigRecorderStatus(ConfigEmptyDictable):
class ConfigDeliverySnapshotProperties(ConfigEmptyDictable): class ConfigDeliverySnapshotProperties(ConfigEmptyDictable):
def __init__(self, delivery_frequency): def __init__(self, delivery_frequency):
super(ConfigDeliverySnapshotProperties, self).__init__()
self.delivery_frequency = delivery_frequency self.delivery_frequency = delivery_frequency
class ConfigDeliveryChannel(ConfigEmptyDictable): class ConfigDeliveryChannel(ConfigEmptyDictable):
def __init__(self, name, s3_bucket_name, prefix=None, sns_arn=None, snapshot_properties=None): def __init__(self, name, s3_bucket_name, prefix=None, sns_arn=None, snapshot_properties=None):
super(ConfigDeliveryChannel, self).__init__()
self.name = name self.name = name
self.s3_bucket_name = s3_bucket_name self.s3_bucket_name = s3_bucket_name
self.s3_key_prefix = prefix self.s3_key_prefix = prefix
@ -91,6 +187,8 @@ class ConfigDeliveryChannel(ConfigEmptyDictable):
class RecordingGroup(ConfigEmptyDictable): class RecordingGroup(ConfigEmptyDictable):
def __init__(self, all_supported=True, include_global_resource_types=False, resource_types=None): def __init__(self, all_supported=True, include_global_resource_types=False, resource_types=None):
super(RecordingGroup, self).__init__()
self.all_supported = all_supported self.all_supported = all_supported
self.include_global_resource_types = include_global_resource_types self.include_global_resource_types = include_global_resource_types
self.resource_types = resource_types self.resource_types = resource_types
@ -99,6 +197,8 @@ class RecordingGroup(ConfigEmptyDictable):
class ConfigRecorder(ConfigEmptyDictable): class ConfigRecorder(ConfigEmptyDictable):
def __init__(self, role_arn, recording_group, name='default', status=None): def __init__(self, role_arn, recording_group, name='default', status=None):
super(ConfigRecorder, self).__init__()
self.name = name self.name = name
self.role_arn = role_arn self.role_arn = role_arn
self.recording_group = recording_group self.recording_group = recording_group
@ -109,18 +209,118 @@ class ConfigRecorder(ConfigEmptyDictable):
self.status = status self.status = status
class AccountAggregatorSource(ConfigEmptyDictable):
def __init__(self, account_ids, aws_regions=None, all_aws_regions=None):
super(AccountAggregatorSource, self).__init__(capitalize_start=True)
# Can't have both the regions and all_regions flag present -- also can't have them both missing:
if aws_regions and all_aws_regions:
raise InvalidParameterValueException('Your configuration aggregator contains a list of regions and also specifies '
'the use of all regions. You must choose one of these options.')
if not (aws_regions or all_aws_regions):
raise InvalidParameterValueException('Your request does not specify any regions. Select AWS Config-supported '
'regions and try again.')
self.account_ids = account_ids
self.aws_regions = aws_regions
if not all_aws_regions:
all_aws_regions = False
self.all_aws_regions = all_aws_regions
class OrganizationAggregationSource(ConfigEmptyDictable):
def __init__(self, role_arn, aws_regions=None, all_aws_regions=None):
super(OrganizationAggregationSource, self).__init__(capitalize_start=True, capitalize_arn=False)
# Can't have both the regions and all_regions flag present -- also can't have them both missing:
if aws_regions and all_aws_regions:
raise InvalidParameterValueException('Your configuration aggregator contains a list of regions and also specifies '
'the use of all regions. You must choose one of these options.')
if not (aws_regions or all_aws_regions):
raise InvalidParameterValueException('Your request does not specify any regions. Select AWS Config-supported '
'regions and try again.')
self.role_arn = role_arn
self.aws_regions = aws_regions
if not all_aws_regions:
all_aws_regions = False
self.all_aws_regions = all_aws_regions
class ConfigAggregator(ConfigEmptyDictable):
def __init__(self, name, region, account_sources=None, org_source=None, tags=None):
super(ConfigAggregator, self).__init__(capitalize_start=True, capitalize_arn=False)
self.configuration_aggregator_name = name
self.configuration_aggregator_arn = 'arn:aws:config:{region}:{id}:config-aggregator/config-aggregator-{random}'.format(
region=region,
id=DEFAULT_ACCOUNT_ID,
random=random_string()
)
self.account_aggregation_sources = account_sources
self.organization_aggregation_source = org_source
self.creation_time = datetime2int(datetime.utcnow())
self.last_updated_time = datetime2int(datetime.utcnow())
# Tags are listed in the list_tags_for_resource API call ... not implementing yet -- please feel free to!
self.tags = tags or {}
# Override the to_dict so that we can format the tags properly...
def to_dict(self):
result = super(ConfigAggregator, self).to_dict()
# Override the account aggregation sources if present:
if self.account_aggregation_sources:
result['AccountAggregationSources'] = [a.to_dict() for a in self.account_aggregation_sources]
# Tags are listed in the list_tags_for_resource API call ... not implementing yet -- please feel free to!
# if self.tags:
# result['Tags'] = [{'Key': key, 'Value': value} for key, value in self.tags.items()]
return result
class ConfigAggregationAuthorization(ConfigEmptyDictable):
def __init__(self, current_region, authorized_account_id, authorized_aws_region, tags=None):
super(ConfigAggregationAuthorization, self).__init__(capitalize_start=True, capitalize_arn=False)
self.aggregation_authorization_arn = 'arn:aws:config:{region}:{id}:aggregation-authorization/' \
'{auth_account}/{auth_region}'.format(region=current_region,
id=DEFAULT_ACCOUNT_ID,
auth_account=authorized_account_id,
auth_region=authorized_aws_region)
self.authorized_account_id = authorized_account_id
self.authorized_aws_region = authorized_aws_region
self.creation_time = datetime2int(datetime.utcnow())
# Tags are listed in the list_tags_for_resource API call ... not implementing yet -- please feel free to!
self.tags = tags or {}
class ConfigBackend(BaseBackend): class ConfigBackend(BaseBackend):
def __init__(self): def __init__(self):
self.recorders = {} self.recorders = {}
self.delivery_channels = {} self.delivery_channels = {}
self.config_aggregators = {}
self.aggregation_authorizations = {}
@staticmethod @staticmethod
def _validate_resource_types(resource_list): def _validate_resource_types(resource_list):
# Load the service file: # Load the service file:
resource_package = 'botocore' resource_package = 'botocore'
resource_path = '/'.join(('data', 'config', '2014-11-12', 'service-2.json')) resource_path = '/'.join(('data', 'config', '2014-11-12', 'service-2.json'))
conifg_schema = json.loads(pkg_resources.resource_string(resource_package, resource_path)) config_schema = json.loads(pkg_resources.resource_string(resource_package, resource_path))
# Verify that each entry exists in the supported list: # Verify that each entry exists in the supported list:
bad_list = [] bad_list = []
@ -128,11 +328,11 @@ class ConfigBackend(BaseBackend):
# For PY2: # For PY2:
r_str = str(resource) r_str = str(resource)
if r_str not in conifg_schema['shapes']['ResourceType']['enum']: if r_str not in config_schema['shapes']['ResourceType']['enum']:
bad_list.append(r_str) bad_list.append(r_str)
if bad_list: if bad_list:
raise InvalidResourceTypeException(bad_list, conifg_schema['shapes']['ResourceType']['enum']) raise InvalidResourceTypeException(bad_list, config_schema['shapes']['ResourceType']['enum'])
@staticmethod @staticmethod
def _validate_delivery_snapshot_properties(properties): def _validate_delivery_snapshot_properties(properties):
@ -147,6 +347,158 @@ class ConfigBackend(BaseBackend):
raise InvalidDeliveryFrequency(properties.get('deliveryFrequency', None), raise InvalidDeliveryFrequency(properties.get('deliveryFrequency', None),
conifg_schema['shapes']['MaximumExecutionFrequency']['enum']) conifg_schema['shapes']['MaximumExecutionFrequency']['enum'])
def put_configuration_aggregator(self, config_aggregator, region):
# Validate the name:
if len(config_aggregator['ConfigurationAggregatorName']) > 256:
raise NameTooLongException(config_aggregator['ConfigurationAggregatorName'], 'configurationAggregatorName')
account_sources = None
org_source = None
# Tag validation:
tags = validate_tags(config_aggregator.get('Tags', []))
# Exception if both AccountAggregationSources and OrganizationAggregationSource are supplied:
if config_aggregator.get('AccountAggregationSources') and config_aggregator.get('OrganizationAggregationSource'):
raise InvalidParameterValueException('The configuration aggregator cannot be created because your request contains both the'
' AccountAggregationSource and the OrganizationAggregationSource. Include only '
'one aggregation source and try again.')
# If neither are supplied:
if not config_aggregator.get('AccountAggregationSources') and not config_aggregator.get('OrganizationAggregationSource'):
raise InvalidParameterValueException('The configuration aggregator cannot be created because your request is missing either '
'the AccountAggregationSource or the OrganizationAggregationSource. Include the '
'appropriate aggregation source and try again.')
if config_aggregator.get('AccountAggregationSources'):
# Currently, only 1 account aggregation source can be set:
if len(config_aggregator['AccountAggregationSources']) > 1:
raise TooManyAccountSources(len(config_aggregator['AccountAggregationSources']))
account_sources = []
for a in config_aggregator['AccountAggregationSources']:
account_sources.append(AccountAggregatorSource(a['AccountIds'], aws_regions=a.get('AwsRegions'),
all_aws_regions=a.get('AllAwsRegions')))
else:
org_source = OrganizationAggregationSource(config_aggregator['OrganizationAggregationSource']['RoleArn'],
aws_regions=config_aggregator['OrganizationAggregationSource'].get('AwsRegions'),
all_aws_regions=config_aggregator['OrganizationAggregationSource'].get(
'AllAwsRegions'))
# Grab the existing one if it exists and update it:
if not self.config_aggregators.get(config_aggregator['ConfigurationAggregatorName']):
aggregator = ConfigAggregator(config_aggregator['ConfigurationAggregatorName'], region, account_sources=account_sources,
org_source=org_source, tags=tags)
self.config_aggregators[config_aggregator['ConfigurationAggregatorName']] = aggregator
else:
aggregator = self.config_aggregators[config_aggregator['ConfigurationAggregatorName']]
aggregator.tags = tags
aggregator.account_aggregation_sources = account_sources
aggregator.organization_aggregation_source = org_source
aggregator.last_updated_time = datetime2int(datetime.utcnow())
return aggregator.to_dict()
def describe_configuration_aggregators(self, names, token, limit):
limit = DEFAULT_PAGE_SIZE if not limit or limit < 0 else limit
agg_list = []
result = {'ConfigurationAggregators': []}
if names:
for name in names:
if not self.config_aggregators.get(name):
raise NoSuchConfigurationAggregatorException(number=len(names))
agg_list.append(name)
else:
agg_list = list(self.config_aggregators.keys())
# Empty?
if not agg_list:
return result
# Sort by name:
sorted_aggregators = sorted(agg_list)
# Get the start:
if not token:
start = 0
else:
# Tokens for this moto feature are just the next names of the items in the list:
if not self.config_aggregators.get(token):
raise InvalidNextTokenException()
start = sorted_aggregators.index(token)
# Get the list of items to collect:
agg_list = sorted_aggregators[start:(start + limit)]
result['ConfigurationAggregators'] = [self.config_aggregators[agg].to_dict() for agg in agg_list]
if len(sorted_aggregators) > (start + limit):
result['NextToken'] = sorted_aggregators[start + limit]
return result
def delete_configuration_aggregator(self, config_aggregator):
if not self.config_aggregators.get(config_aggregator):
raise NoSuchConfigurationAggregatorException()
del self.config_aggregators[config_aggregator]
def put_aggregation_authorization(self, current_region, authorized_account, authorized_region, tags):
# Tag validation:
tags = validate_tags(tags or [])
# Does this already exist?
key = '{}/{}'.format(authorized_account, authorized_region)
agg_auth = self.aggregation_authorizations.get(key)
if not agg_auth:
agg_auth = ConfigAggregationAuthorization(current_region, authorized_account, authorized_region, tags=tags)
self.aggregation_authorizations['{}/{}'.format(authorized_account, authorized_region)] = agg_auth
else:
# Only update the tags:
agg_auth.tags = tags
return agg_auth.to_dict()
def describe_aggregation_authorizations(self, token, limit):
limit = DEFAULT_PAGE_SIZE if not limit or limit < 0 else limit
result = {'AggregationAuthorizations': []}
if not self.aggregation_authorizations:
return result
# Sort by name:
sorted_authorizations = sorted(self.aggregation_authorizations.keys())
# Get the start:
if not token:
start = 0
else:
# Tokens for this moto feature are just the next names of the items in the list:
if not self.aggregation_authorizations.get(token):
raise InvalidNextTokenException()
start = sorted_authorizations.index(token)
# Get the list of items to collect:
auth_list = sorted_authorizations[start:(start + limit)]
result['AggregationAuthorizations'] = [self.aggregation_authorizations[auth].to_dict() for auth in auth_list]
if len(sorted_authorizations) > (start + limit):
result['NextToken'] = sorted_authorizations[start + limit]
return result
def delete_aggregation_authorization(self, authorized_account, authorized_region):
# This will always return a 200 -- regardless if there is or isn't an existing
# aggregation authorization.
key = '{}/{}'.format(authorized_account, authorized_region)
self.aggregation_authorizations.pop(key, None)
def put_configuration_recorder(self, config_recorder): def put_configuration_recorder(self, config_recorder):
# Validate the name: # Validate the name:
if not config_recorder.get('name'): if not config_recorder.get('name'):

View File

@ -13,6 +13,39 @@ class ConfigResponse(BaseResponse):
self.config_backend.put_configuration_recorder(self._get_param('ConfigurationRecorder')) self.config_backend.put_configuration_recorder(self._get_param('ConfigurationRecorder'))
return "" return ""
def put_configuration_aggregator(self):
aggregator = self.config_backend.put_configuration_aggregator(json.loads(self.body), self.region)
schema = {'ConfigurationAggregator': aggregator}
return json.dumps(schema)
def describe_configuration_aggregators(self):
aggregators = self.config_backend.describe_configuration_aggregators(self._get_param('ConfigurationAggregatorNames'),
self._get_param('NextToken'),
self._get_param('Limit'))
return json.dumps(aggregators)
def delete_configuration_aggregator(self):
self.config_backend.delete_configuration_aggregator(self._get_param('ConfigurationAggregatorName'))
return ""
def put_aggregation_authorization(self):
agg_auth = self.config_backend.put_aggregation_authorization(self.region,
self._get_param('AuthorizedAccountId'),
self._get_param('AuthorizedAwsRegion'),
self._get_param('Tags'))
schema = {'AggregationAuthorization': agg_auth}
return json.dumps(schema)
def describe_aggregation_authorizations(self):
authorizations = self.config_backend.describe_aggregation_authorizations(self._get_param('NextToken'), self._get_param('Limit'))
return json.dumps(authorizations)
def delete_aggregation_authorization(self):
self.config_backend.delete_aggregation_authorization(self._get_param('AuthorizedAccountId'), self._get_param('AuthorizedAwsRegion'))
return ""
def describe_configuration_recorders(self): def describe_configuration_recorders(self):
recorders = self.config_backend.describe_configuration_recorders(self._get_param('ConfigurationRecorderNames')) recorders = self.config_backend.describe_configuration_recorders(self._get_param('ConfigurationRecorderNames'))
schema = {'ConfigurationRecorders': recorders} schema = {'ConfigurationRecorders': recorders}

View File

@ -12,6 +12,7 @@ from collections import defaultdict
from botocore.handlers import BUILTIN_HANDLERS from botocore.handlers import BUILTIN_HANDLERS
from botocore.awsrequest import AWSResponse from botocore.awsrequest import AWSResponse
import mock
from moto import settings from moto import settings
import responses import responses
from moto.packages.httpretty import HTTPretty from moto.packages.httpretty import HTTPretty
@ -22,11 +23,6 @@ from .utils import (
) )
# "Mock" the AWS credentials as they can't be mocked in Botocore currently
os.environ.setdefault("AWS_ACCESS_KEY_ID", "foobar_key")
os.environ.setdefault("AWS_SECRET_ACCESS_KEY", "foobar_secret")
class BaseMockAWS(object): class BaseMockAWS(object):
nested_count = 0 nested_count = 0
@ -42,6 +38,10 @@ class BaseMockAWS(object):
self.backends_for_urls.update(self.backends) self.backends_for_urls.update(self.backends)
self.backends_for_urls.update(default_backends) self.backends_for_urls.update(default_backends)
# "Mock" the AWS credentials as they can't be mocked in Botocore currently
FAKE_KEYS = {"AWS_ACCESS_KEY_ID": "foobar_key", "AWS_SECRET_ACCESS_KEY": "foobar_secret"}
self.env_variables_mocks = mock.patch.dict(os.environ, FAKE_KEYS)
if self.__class__.nested_count == 0: if self.__class__.nested_count == 0:
self.reset() self.reset()
@ -52,11 +52,14 @@ class BaseMockAWS(object):
def __enter__(self): def __enter__(self):
self.start() self.start()
return self
def __exit__(self, *args): def __exit__(self, *args):
self.stop() self.stop()
def start(self, reset=True): def start(self, reset=True):
self.env_variables_mocks.start()
self.__class__.nested_count += 1 self.__class__.nested_count += 1
if reset: if reset:
for backend in self.backends.values(): for backend in self.backends.values():
@ -65,6 +68,7 @@ class BaseMockAWS(object):
self.enable_patching() self.enable_patching()
def stop(self): def stop(self):
self.env_variables_mocks.stop()
self.__class__.nested_count -= 1 self.__class__.nested_count -= 1
if self.__class__.nested_count < 0: if self.__class__.nested_count < 0:
@ -465,10 +469,14 @@ class BaseModel(object):
class BaseBackend(object): class BaseBackend(object):
def reset(self): def _reset_model_refs(self):
# Remove all references to the models stored
for service, models in model_data.items(): for service, models in model_data.items():
for model_name, model in models.items(): for model_name, model in models.items():
model.instances = [] model.instances = []
def reset(self):
self._reset_model_refs()
self.__dict__ = {} self.__dict__ = {}
self.__init__() self.__init__()

View File

@ -1004,8 +1004,7 @@ class OpOr(Op):
def expr(self, item): def expr(self, item):
lhs = self.lhs.expr(item) lhs = self.lhs.expr(item)
rhs = self.rhs.expr(item) return lhs or self.rhs.expr(item)
return lhs or rhs
class Func(object): class Func(object):

View File

@ -298,7 +298,9 @@ class Item(BaseModel):
new_value = list(update_action['Value'].values())[0] new_value = list(update_action['Value'].values())[0]
if action == 'PUT': if action == 'PUT':
# TODO deal with other types # TODO deal with other types
if isinstance(new_value, list) or isinstance(new_value, set): if isinstance(new_value, list):
self.attrs[attribute_name] = DynamoType({"L": new_value})
elif isinstance(new_value, set):
self.attrs[attribute_name] = DynamoType({"SS": new_value}) self.attrs[attribute_name] = DynamoType({"SS": new_value})
elif isinstance(new_value, dict): elif isinstance(new_value, dict):
self.attrs[attribute_name] = DynamoType({"M": new_value}) self.attrs[attribute_name] = DynamoType({"M": new_value})

View File

@ -600,7 +600,7 @@ class DynamoHandler(BaseResponse):
# E.g. `a = b + c` -> `a=b+c` # E.g. `a = b + c` -> `a=b+c`
if update_expression: if update_expression:
update_expression = re.sub( update_expression = re.sub(
'\s*([=\+-])\s*', '\\1', update_expression) r'\s*([=\+-])\s*', '\\1', update_expression)
try: try:
item = self.dynamodb_backend.update_item( item = self.dynamodb_backend.update_item(

View File

@ -142,6 +142,8 @@ AMIS = json.load(
__name__, 'resources/amis.json'), 'r') __name__, 'resources/amis.json'), 'r')
) )
OWNER_ID = "111122223333"
def utc_date_and_time(): def utc_date_and_time():
return datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%S.000Z') return datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%S.000Z')
@ -201,7 +203,7 @@ class TaggedEC2Resource(BaseModel):
class NetworkInterface(TaggedEC2Resource): class NetworkInterface(TaggedEC2Resource):
def __init__(self, ec2_backend, subnet, private_ip_address, device_index=0, def __init__(self, ec2_backend, subnet, private_ip_address, device_index=0,
public_ip_auto_assign=True, group_ids=None): public_ip_auto_assign=True, group_ids=None, description=None):
self.ec2_backend = ec2_backend self.ec2_backend = ec2_backend
self.id = random_eni_id() self.id = random_eni_id()
self.device_index = device_index self.device_index = device_index
@ -209,6 +211,7 @@ class NetworkInterface(TaggedEC2Resource):
self.subnet = subnet self.subnet = subnet
self.instance = None self.instance = None
self.attachment_id = None self.attachment_id = None
self.description = description
self.public_ip = None self.public_ip = None
self.public_ip_auto_assign = public_ip_auto_assign self.public_ip_auto_assign = public_ip_auto_assign
@ -246,11 +249,13 @@ class NetworkInterface(TaggedEC2Resource):
subnet = None subnet = None
private_ip_address = properties.get('PrivateIpAddress', None) private_ip_address = properties.get('PrivateIpAddress', None)
description = properties.get('Description', None)
network_interface = ec2_backend.create_network_interface( network_interface = ec2_backend.create_network_interface(
subnet, subnet,
private_ip_address, private_ip_address,
group_ids=security_group_ids group_ids=security_group_ids,
description=description
) )
return network_interface return network_interface
@ -298,6 +303,8 @@ class NetworkInterface(TaggedEC2Resource):
return [group.id for group in self._group_set] return [group.id for group in self._group_set]
elif filter_name == 'availability-zone': elif filter_name == 'availability-zone':
return self.subnet.availability_zone return self.subnet.availability_zone
elif filter_name == 'description':
return self.description
else: else:
return super(NetworkInterface, self).get_filter_value( return super(NetworkInterface, self).get_filter_value(
filter_name, 'DescribeNetworkInterfaces') filter_name, 'DescribeNetworkInterfaces')
@ -308,9 +315,9 @@ class NetworkInterfaceBackend(object):
self.enis = {} self.enis = {}
super(NetworkInterfaceBackend, self).__init__() super(NetworkInterfaceBackend, self).__init__()
def create_network_interface(self, subnet, private_ip_address, group_ids=None, **kwargs): def create_network_interface(self, subnet, private_ip_address, group_ids=None, description=None, **kwargs):
eni = NetworkInterface( eni = NetworkInterface(
self, subnet, private_ip_address, group_ids=group_ids, **kwargs) self, subnet, private_ip_address, group_ids=group_ids, description=description, **kwargs)
self.enis[eni.id] = eni self.enis[eni.id] = eni
return eni return eni
@ -343,6 +350,12 @@ class NetworkInterfaceBackend(object):
if group.id in _filter_value: if group.id in _filter_value:
enis.append(eni) enis.append(eni)
break break
elif _filter == 'private-ip-address:':
enis = [eni for eni in enis if eni.private_ip_address in _filter_value]
elif _filter == 'subnet-id':
enis = [eni for eni in enis if eni.subnet.id in _filter_value]
elif _filter == 'description':
enis = [eni for eni in enis if eni.description in _filter_value]
else: else:
self.raise_not_implemented_error( self.raise_not_implemented_error(
"The filter '{0}' for DescribeNetworkInterfaces".format(_filter)) "The filter '{0}' for DescribeNetworkInterfaces".format(_filter))
@ -413,10 +426,10 @@ class Instance(TaggedEC2Resource, BotoInstance):
self.instance_initiated_shutdown_behavior = kwargs.get("instance_initiated_shutdown_behavior", "stop") self.instance_initiated_shutdown_behavior = kwargs.get("instance_initiated_shutdown_behavior", "stop")
self.sriov_net_support = "simple" self.sriov_net_support = "simple"
self._spot_fleet_id = kwargs.get("spot_fleet_id", None) self._spot_fleet_id = kwargs.get("spot_fleet_id", None)
associate_public_ip = kwargs.get("associate_public_ip", False) self.associate_public_ip = kwargs.get("associate_public_ip", False)
if in_ec2_classic: if in_ec2_classic:
# If we are in EC2-Classic, autoassign a public IP # If we are in EC2-Classic, autoassign a public IP
associate_public_ip = True self.associate_public_ip = True
amis = self.ec2_backend.describe_images(filters={'image-id': image_id}) amis = self.ec2_backend.describe_images(filters={'image-id': image_id})
ami = amis[0] if amis else None ami = amis[0] if amis else None
@ -447,9 +460,9 @@ class Instance(TaggedEC2Resource, BotoInstance):
self.vpc_id = subnet.vpc_id self.vpc_id = subnet.vpc_id
self._placement.zone = subnet.availability_zone self._placement.zone = subnet.availability_zone
if associate_public_ip is None: if self.associate_public_ip is None:
# Mapping public ip hasnt been explicitly enabled or disabled # Mapping public ip hasnt been explicitly enabled or disabled
associate_public_ip = subnet.map_public_ip_on_launch == 'true' self.associate_public_ip = subnet.map_public_ip_on_launch == 'true'
elif placement: elif placement:
self._placement.zone = placement self._placement.zone = placement
else: else:
@ -461,7 +474,7 @@ class Instance(TaggedEC2Resource, BotoInstance):
self.prep_nics( self.prep_nics(
kwargs.get("nics", {}), kwargs.get("nics", {}),
private_ip=kwargs.get("private_ip"), private_ip=kwargs.get("private_ip"),
associate_public_ip=associate_public_ip associate_public_ip=self.associate_public_ip
) )
def __del__(self): def __del__(self):
@ -1076,7 +1089,7 @@ class TagBackend(object):
class Ami(TaggedEC2Resource): class Ami(TaggedEC2Resource):
def __init__(self, ec2_backend, ami_id, instance=None, source_ami=None, def __init__(self, ec2_backend, ami_id, instance=None, source_ami=None,
name=None, description=None, owner_id=111122223333, name=None, description=None, owner_id=OWNER_ID,
public=False, virtualization_type=None, architecture=None, public=False, virtualization_type=None, architecture=None,
state='available', creation_date=None, platform=None, state='available', creation_date=None, platform=None,
image_type='machine', image_location=None, hypervisor=None, image_type='machine', image_location=None, hypervisor=None,
@ -1189,7 +1202,7 @@ class AmiBackend(object):
ami = Ami(self, ami_id, instance=instance, source_ami=None, ami = Ami(self, ami_id, instance=instance, source_ami=None,
name=name, description=description, name=name, description=description,
owner_id=context.get_current_user() if context else '111122223333') owner_id=context.get_current_user() if context else OWNER_ID)
self.amis[ami_id] = ami self.amis[ami_id] = ami
return ami return ami
@ -1457,7 +1470,7 @@ class SecurityGroup(TaggedEC2Resource):
self.egress_rules = [SecurityRule(-1, None, None, ['0.0.0.0/0'], [])] self.egress_rules = [SecurityRule(-1, None, None, ['0.0.0.0/0'], [])]
self.enis = {} self.enis = {}
self.vpc_id = vpc_id self.vpc_id = vpc_id
self.owner_id = "123456789012" self.owner_id = OWNER_ID
@classmethod @classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
@ -1978,7 +1991,7 @@ class Volume(TaggedEC2Resource):
class Snapshot(TaggedEC2Resource): class Snapshot(TaggedEC2Resource):
def __init__(self, ec2_backend, snapshot_id, volume, description, encrypted=False, owner_id='123456789012'): def __init__(self, ec2_backend, snapshot_id, volume, description, encrypted=False, owner_id=OWNER_ID):
self.id = snapshot_id self.id = snapshot_id
self.volume = volume self.volume = volume
self.description = description self.description = description
@ -2480,7 +2493,7 @@ class VPCPeeringConnectionBackend(object):
class Subnet(TaggedEC2Resource): class Subnet(TaggedEC2Resource):
def __init__(self, ec2_backend, subnet_id, vpc_id, cidr_block, availability_zone, default_for_az, def __init__(self, ec2_backend, subnet_id, vpc_id, cidr_block, availability_zone, default_for_az,
map_public_ip_on_launch, owner_id=111122223333, assign_ipv6_address_on_creation=False): map_public_ip_on_launch, owner_id=OWNER_ID, assign_ipv6_address_on_creation=False):
self.ec2_backend = ec2_backend self.ec2_backend = ec2_backend
self.id = subnet_id self.id = subnet_id
self.vpc_id = vpc_id self.vpc_id = vpc_id
@ -2646,7 +2659,7 @@ class SubnetBackend(object):
raise InvalidAvailabilityZoneError(availability_zone, ", ".join([zone.name for zones in RegionsAndZonesBackend.zones.values() for zone in zones])) raise InvalidAvailabilityZoneError(availability_zone, ", ".join([zone.name for zones in RegionsAndZonesBackend.zones.values() for zone in zones]))
subnet = Subnet(self, subnet_id, vpc_id, cidr_block, availability_zone_data, subnet = Subnet(self, subnet_id, vpc_id, cidr_block, availability_zone_data,
default_for_az, map_public_ip_on_launch, default_for_az, map_public_ip_on_launch,
owner_id=context.get_current_user() if context else '111122223333', assign_ipv6_address_on_creation=False) owner_id=context.get_current_user() if context else OWNER_ID, assign_ipv6_address_on_creation=False)
# AWS associates a new subnet with the default Network ACL # AWS associates a new subnet with the default Network ACL
self.associate_default_network_acl_with_subnet(subnet_id, vpc_id) self.associate_default_network_acl_with_subnet(subnet_id, vpc_id)

View File

@ -10,9 +10,10 @@ class ElasticNetworkInterfaces(BaseResponse):
private_ip_address = self._get_param('PrivateIpAddress') private_ip_address = self._get_param('PrivateIpAddress')
groups = self._get_multi_param('SecurityGroupId') groups = self._get_multi_param('SecurityGroupId')
subnet = self.ec2_backend.get_subnet(subnet_id) subnet = self.ec2_backend.get_subnet(subnet_id)
description = self._get_param('Description')
if self.is_not_dryrun('CreateNetworkInterface'): if self.is_not_dryrun('CreateNetworkInterface'):
eni = self.ec2_backend.create_network_interface( eni = self.ec2_backend.create_network_interface(
subnet, private_ip_address, groups) subnet, private_ip_address, groups, description)
template = self.response_template( template = self.response_template(
CREATE_NETWORK_INTERFACE_RESPONSE) CREATE_NETWORK_INTERFACE_RESPONSE)
return template.render(eni=eni) return template.render(eni=eni)
@ -78,7 +79,11 @@ CREATE_NETWORK_INTERFACE_RESPONSE = """
<subnetId>{{ eni.subnet.id }}</subnetId> <subnetId>{{ eni.subnet.id }}</subnetId>
<vpcId>{{ eni.subnet.vpc_id }}</vpcId> <vpcId>{{ eni.subnet.vpc_id }}</vpcId>
<availabilityZone>us-west-2a</availabilityZone> <availabilityZone>us-west-2a</availabilityZone>
{% if eni.description %}
<description>{{ eni.description }}</description>
{% else %}
<description/> <description/>
{% endif %}
<ownerId>498654062920</ownerId> <ownerId>498654062920</ownerId>
<requesterManaged>false</requesterManaged> <requesterManaged>false</requesterManaged>
<status>pending</status> <status>pending</status>
@ -121,7 +126,7 @@ DESCRIBE_NETWORK_INTERFACES_RESPONSE = """<DescribeNetworkInterfacesResponse xml
<subnetId>{{ eni.subnet.id }}</subnetId> <subnetId>{{ eni.subnet.id }}</subnetId>
<vpcId>{{ eni.subnet.vpc_id }}</vpcId> <vpcId>{{ eni.subnet.vpc_id }}</vpcId>
<availabilityZone>us-west-2a</availabilityZone> <availabilityZone>us-west-2a</availabilityZone>
<description>Primary network interface</description> <description>{{ eni.description }}</description>
<ownerId>190610284047</ownerId> <ownerId>190610284047</ownerId>
<requesterManaged>false</requesterManaged> <requesterManaged>false</requesterManaged>
{% if eni.attachment_id %} {% if eni.attachment_id %}

View File

@ -1,5 +1,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from boto.ec2.instancetype import InstanceType from boto.ec2.instancetype import InstanceType
from moto.autoscaling import autoscaling_backends
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.core.utils import camelcase_to_underscores from moto.core.utils import camelcase_to_underscores
from moto.ec2.utils import filters_from_querystring, \ from moto.ec2.utils import filters_from_querystring, \
@ -65,6 +67,7 @@ class InstanceResponse(BaseResponse):
instance_ids = self._get_multi_param('InstanceId') instance_ids = self._get_multi_param('InstanceId')
if self.is_not_dryrun('TerminateInstance'): if self.is_not_dryrun('TerminateInstance'):
instances = self.ec2_backend.terminate_instances(instance_ids) instances = self.ec2_backend.terminate_instances(instance_ids)
autoscaling_backends[self.region].notify_terminate_instances(instance_ids)
template = self.response_template(EC2_TERMINATE_INSTANCES) template = self.response_template(EC2_TERMINATE_INSTANCES)
return template.render(instances=instances) return template.render(instances=instances)

View File

@ -1,5 +1,5 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from moto.core.exceptions import RESTError from moto.core.exceptions import RESTError, JsonRESTError
class ServiceNotFoundException(RESTError): class ServiceNotFoundException(RESTError):
@ -11,3 +11,13 @@ class ServiceNotFoundException(RESTError):
message="The service {0} does not exist".format(service_name), message="The service {0} does not exist".format(service_name),
template='error_json', template='error_json',
) )
class TaskDefinitionNotFoundException(JsonRESTError):
code = 400
def __init__(self):
super(TaskDefinitionNotFoundException, self).__init__(
error_type="ClientException",
message="The specified task definition does not exist.",
)

View File

@ -1,4 +1,5 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import re
import uuid import uuid
from datetime import datetime from datetime import datetime
from random import random, randint from random import random, randint
@ -7,10 +8,14 @@ import boto3
import pytz import pytz
from moto.core.exceptions import JsonRESTError from moto.core.exceptions import JsonRESTError
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.utils import unix_time
from moto.ec2 import ec2_backends from moto.ec2 import ec2_backends
from copy import copy from copy import copy
from .exceptions import ServiceNotFoundException from .exceptions import (
ServiceNotFoundException,
TaskDefinitionNotFoundException
)
class BaseObject(BaseModel): class BaseObject(BaseModel):
@ -103,12 +108,13 @@ class Cluster(BaseObject):
class TaskDefinition(BaseObject): class TaskDefinition(BaseObject):
def __init__(self, family, revision, container_definitions, volumes=None): def __init__(self, family, revision, container_definitions, volumes=None, tags=None):
self.family = family self.family = family
self.revision = revision self.revision = revision
self.arn = 'arn:aws:ecs:us-east-1:012345678910:task-definition/{0}:{1}'.format( self.arn = 'arn:aws:ecs:us-east-1:012345678910:task-definition/{0}:{1}'.format(
family, revision) family, revision)
self.container_definitions = container_definitions self.container_definitions = container_definitions
self.tags = tags if tags is not None else []
if volumes is None: if volumes is None:
self.volumes = [] self.volumes = []
else: else:
@ -119,6 +125,7 @@ class TaskDefinition(BaseObject):
response_object = self.gen_response_object() response_object = self.gen_response_object()
response_object['taskDefinitionArn'] = response_object['arn'] response_object['taskDefinitionArn'] = response_object['arn']
del response_object['arn'] del response_object['arn']
del response_object['tags']
return response_object return response_object
@property @property
@ -225,9 +232,9 @@ class Service(BaseObject):
for deployment in response_object['deployments']: for deployment in response_object['deployments']:
if isinstance(deployment['createdAt'], datetime): if isinstance(deployment['createdAt'], datetime):
deployment['createdAt'] = deployment['createdAt'].isoformat() deployment['createdAt'] = unix_time(deployment['createdAt'].replace(tzinfo=None))
if isinstance(deployment['updatedAt'], datetime): if isinstance(deployment['updatedAt'], datetime):
deployment['updatedAt'] = deployment['updatedAt'].isoformat() deployment['updatedAt'] = unix_time(deployment['updatedAt'].replace(tzinfo=None))
return response_object return response_object
@ -422,11 +429,9 @@ class EC2ContainerServiceBackend(BaseBackend):
revision = int(revision) revision = int(revision)
else: else:
family = task_definition_name family = task_definition_name
revision = len(self.task_definitions.get(family, [])) revision = self._get_last_task_definition_revision_id(family)
if family in self.task_definitions and 0 < revision <= len(self.task_definitions[family]): if family in self.task_definitions and revision in self.task_definitions[family]:
return self.task_definitions[family][revision - 1]
elif family in self.task_definitions and revision == -1:
return self.task_definitions[family][revision] return self.task_definitions[family][revision]
else: else:
raise Exception( raise Exception(
@ -466,15 +471,16 @@ class EC2ContainerServiceBackend(BaseBackend):
else: else:
raise Exception("{0} is not a cluster".format(cluster_name)) raise Exception("{0} is not a cluster".format(cluster_name))
def register_task_definition(self, family, container_definitions, volumes): def register_task_definition(self, family, container_definitions, volumes, tags=None):
if family in self.task_definitions: if family in self.task_definitions:
revision = len(self.task_definitions[family]) + 1 last_id = self._get_last_task_definition_revision_id(family)
revision = (last_id or 0) + 1
else: else:
self.task_definitions[family] = [] self.task_definitions[family] = {}
revision = 1 revision = 1
task_definition = TaskDefinition( task_definition = TaskDefinition(
family, revision, container_definitions, volumes) family, revision, container_definitions, volumes, tags)
self.task_definitions[family].append(task_definition) self.task_definitions[family][revision] = task_definition
return task_definition return task_definition
@ -484,16 +490,18 @@ class EC2ContainerServiceBackend(BaseBackend):
""" """
task_arns = [] task_arns = []
for task_definition_list in self.task_definitions.values(): for task_definition_list in self.task_definitions.values():
task_arns.extend( task_arns.extend([
[task_definition.arn for task_definition in task_definition_list]) task_definition.arn
for task_definition in task_definition_list.values()
])
return task_arns return task_arns
def deregister_task_definition(self, task_definition_str): def deregister_task_definition(self, task_definition_str):
task_definition_name = task_definition_str.split('/')[-1] task_definition_name = task_definition_str.split('/')[-1]
family, revision = task_definition_name.split(':') family, revision = task_definition_name.split(':')
revision = int(revision) revision = int(revision)
if family in self.task_definitions and 0 < revision <= len(self.task_definitions[family]): if family in self.task_definitions and revision in self.task_definitions[family]:
return self.task_definitions[family].pop(revision - 1) return self.task_definitions[family].pop(revision)
else: else:
raise Exception( raise Exception(
"{0} is not a task_definition".format(task_definition_name)) "{0} is not a task_definition".format(task_definition_name))
@ -950,6 +958,29 @@ class EC2ContainerServiceBackend(BaseBackend):
yield task_fam yield task_fam
def list_tags_for_resource(self, resource_arn):
"""Currently only implemented for task definitions"""
match = re.match(
"^arn:aws:ecs:(?P<region>[^:]+):(?P<account_id>[^:]+):(?P<service>[^:]+)/(?P<id>.*)$",
resource_arn)
if not match:
raise JsonRESTError('InvalidParameterException', 'The ARN provided is invalid.')
service = match.group("service")
if service == "task-definition":
for task_definition in self.task_definitions.values():
for revision in task_definition.values():
if revision.arn == resource_arn:
return revision.tags
else:
raise TaskDefinitionNotFoundException()
raise NotImplementedError()
def _get_last_task_definition_revision_id(self, family):
definitions = self.task_definitions.get(family, {})
if definitions:
return max(definitions.keys())
available_regions = boto3.session.Session().get_available_regions("ecs") available_regions = boto3.session.Session().get_available_regions("ecs")
ecs_backends = {region: EC2ContainerServiceBackend(region) for region in available_regions} ecs_backends = {region: EC2ContainerServiceBackend(region) for region in available_regions}

View File

@ -62,8 +62,9 @@ class EC2ContainerServiceResponse(BaseResponse):
family = self._get_param('family') family = self._get_param('family')
container_definitions = self._get_param('containerDefinitions') container_definitions = self._get_param('containerDefinitions')
volumes = self._get_param('volumes') volumes = self._get_param('volumes')
tags = self._get_param('tags')
task_definition = self.ecs_backend.register_task_definition( task_definition = self.ecs_backend.register_task_definition(
family, container_definitions, volumes) family, container_definitions, volumes, tags)
return json.dumps({ return json.dumps({
'taskDefinition': task_definition.response_object 'taskDefinition': task_definition.response_object
}) })
@ -313,3 +314,8 @@ class EC2ContainerServiceResponse(BaseResponse):
results = self.ecs_backend.list_task_definition_families(family_prefix, status, max_results, next_token) results = self.ecs_backend.list_task_definition_families(family_prefix, status, max_results, next_token)
return json.dumps({'families': list(results)}) return json.dumps({'families': list(results)})
def list_tags_for_resource(self):
resource_arn = self._get_param('resourceArn')
tags = self.ecs_backend.list_tags_for_resource(resource_arn)
return json.dumps({'tags': tags})

View File

@ -2,9 +2,11 @@ from __future__ import unicode_literals
import datetime import datetime
import re import re
from jinja2 import Template
from moto.compat import OrderedDict from moto.compat import OrderedDict
from moto.core.exceptions import RESTError from moto.core.exceptions import RESTError
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.utils import camelcase_to_underscores
from moto.ec2.models import ec2_backends from moto.ec2.models import ec2_backends
from moto.acm.models import acm_backends from moto.acm.models import acm_backends
from .utils import make_arn_for_target_group from .utils import make_arn_for_target_group
@ -35,12 +37,13 @@ from .exceptions import (
class FakeHealthStatus(BaseModel): class FakeHealthStatus(BaseModel):
def __init__(self, instance_id, port, health_port, status, reason=None): def __init__(self, instance_id, port, health_port, status, reason=None, description=None):
self.instance_id = instance_id self.instance_id = instance_id
self.port = port self.port = port
self.health_port = health_port self.health_port = health_port
self.status = status self.status = status
self.reason = reason self.reason = reason
self.description = description
class FakeTargetGroup(BaseModel): class FakeTargetGroup(BaseModel):
@ -69,7 +72,7 @@ class FakeTargetGroup(BaseModel):
self.protocol = protocol self.protocol = protocol
self.port = port self.port = port
self.healthcheck_protocol = healthcheck_protocol or 'HTTP' self.healthcheck_protocol = healthcheck_protocol or 'HTTP'
self.healthcheck_port = healthcheck_port or 'traffic-port' self.healthcheck_port = healthcheck_port or str(self.port)
self.healthcheck_path = healthcheck_path or '/' self.healthcheck_path = healthcheck_path or '/'
self.healthcheck_interval_seconds = healthcheck_interval_seconds or 30 self.healthcheck_interval_seconds = healthcheck_interval_seconds or 30
self.healthcheck_timeout_seconds = healthcheck_timeout_seconds or 5 self.healthcheck_timeout_seconds = healthcheck_timeout_seconds or 5
@ -112,10 +115,14 @@ class FakeTargetGroup(BaseModel):
raise TooManyTagsError() raise TooManyTagsError()
self.tags[key] = value self.tags[key] = value
def health_for(self, target): def health_for(self, target, ec2_backend):
t = self.targets.get(target['id']) t = self.targets.get(target['id'])
if t is None: if t is None:
raise InvalidTargetError() raise InvalidTargetError()
if t['id'].startswith("i-"): # EC2 instance ID
instance = ec2_backend.get_instance_by_id(t['id'])
if instance.state == "stopped":
return FakeHealthStatus(t['id'], t['port'], self.healthcheck_port, 'unused', 'Target.InvalidState', 'Target is in the stopped state')
return FakeHealthStatus(t['id'], t['port'], self.healthcheck_port, 'healthy') return FakeHealthStatus(t['id'], t['port'], self.healthcheck_port, 'healthy')
@classmethod @classmethod
@ -208,13 +215,12 @@ class FakeListener(BaseModel):
action_type = action['Type'] action_type = action['Type']
if action_type == 'forward': if action_type == 'forward':
default_actions.append({'type': action_type, 'target_group_arn': action['TargetGroupArn']}) default_actions.append({'type': action_type, 'target_group_arn': action['TargetGroupArn']})
elif action_type == 'redirect': elif action_type in ['redirect', 'authenticate-cognito']:
redirect_action = {'type': action_type, } redirect_action = {'type': action_type}
for redirect_config_key, redirect_config_value in action['RedirectConfig'].items(): key = 'RedirectConfig' if action_type == 'redirect' else 'AuthenticateCognitoConfig'
for redirect_config_key, redirect_config_value in action[key].items():
# need to match the output of _get_list_prefix # need to match the output of _get_list_prefix
if redirect_config_key == 'StatusCode': redirect_action[camelcase_to_underscores(key) + '._' + camelcase_to_underscores(redirect_config_key)] = redirect_config_value
redirect_config_key = 'status_code'
redirect_action['redirect_config._' + redirect_config_key.lower()] = redirect_config_value
default_actions.append(redirect_action) default_actions.append(redirect_action)
else: else:
raise InvalidActionTypeError(action_type, i + 1) raise InvalidActionTypeError(action_type, i + 1)
@ -226,6 +232,32 @@ class FakeListener(BaseModel):
return listener return listener
class FakeAction(BaseModel):
def __init__(self, data):
self.data = data
self.type = data.get("type")
def to_xml(self):
template = Template("""<Type>{{ action.type }}</Type>
{% if action.type == "forward" %}
<TargetGroupArn>{{ action.data["target_group_arn"] }}</TargetGroupArn>
{% elif action.type == "redirect" %}
<RedirectConfig>
<Protocol>{{ action.data["redirect_config._protocol"] }}</Protocol>
<Port>{{ action.data["redirect_config._port"] }}</Port>
<StatusCode>{{ action.data["redirect_config._status_code"] }}</StatusCode>
</RedirectConfig>
{% elif action.type == "authenticate-cognito" %}
<AuthenticateCognitoConfig>
<UserPoolArn>{{ action.data["authenticate_cognito_config._user_pool_arn"] }}</UserPoolArn>
<UserPoolClientId>{{ action.data["authenticate_cognito_config._user_pool_client_id"] }}</UserPoolClientId>
<UserPoolDomain>{{ action.data["authenticate_cognito_config._user_pool_domain"] }}</UserPoolDomain>
</AuthenticateCognitoConfig>
{% endif %}
""")
return template.render(action=self)
class FakeRule(BaseModel): class FakeRule(BaseModel):
def __init__(self, listener_arn, conditions, priority, actions, is_default): def __init__(self, listener_arn, conditions, priority, actions, is_default):
@ -397,6 +429,7 @@ class ELBv2Backend(BaseBackend):
return new_load_balancer return new_load_balancer
def create_rule(self, listener_arn, conditions, priority, actions): def create_rule(self, listener_arn, conditions, priority, actions):
actions = [FakeAction(action) for action in actions]
listeners = self.describe_listeners(None, [listener_arn]) listeners = self.describe_listeners(None, [listener_arn])
if not listeners: if not listeners:
raise ListenerNotFoundError() raise ListenerNotFoundError()
@ -424,20 +457,7 @@ class ELBv2Backend(BaseBackend):
if rule.priority == priority: if rule.priority == priority:
raise PriorityInUseError() raise PriorityInUseError()
# validate Actions self._validate_actions(actions)
target_group_arns = [target_group.arn for target_group in self.target_groups.values()]
for i, action in enumerate(actions):
index = i + 1
action_type = action['type']
if action_type == 'forward':
action_target_group_arn = action['target_group_arn']
if action_target_group_arn not in target_group_arns:
raise ActionTargetGroupNotFoundError(action_target_group_arn)
elif action_type == 'redirect':
# nothing to do
pass
else:
raise InvalidActionTypeError(action_type, index)
# TODO: check for error 'TooManyRegistrationsForTargetId' # TODO: check for error 'TooManyRegistrationsForTargetId'
# TODO: check for error 'TooManyRules' # TODO: check for error 'TooManyRules'
@ -447,6 +467,21 @@ class ELBv2Backend(BaseBackend):
listener.register(rule) listener.register(rule)
return [rule] return [rule]
def _validate_actions(self, actions):
# validate Actions
target_group_arns = [target_group.arn for target_group in self.target_groups.values()]
for i, action in enumerate(actions):
index = i + 1
action_type = action.type
if action_type == 'forward':
action_target_group_arn = action.data['target_group_arn']
if action_target_group_arn not in target_group_arns:
raise ActionTargetGroupNotFoundError(action_target_group_arn)
elif action_type in ['redirect', 'authenticate-cognito']:
pass
else:
raise InvalidActionTypeError(action_type, index)
def create_target_group(self, name, **kwargs): def create_target_group(self, name, **kwargs):
if len(name) > 32: if len(name) > 32:
raise InvalidTargetGroupNameError( raise InvalidTargetGroupNameError(
@ -490,26 +525,22 @@ class ELBv2Backend(BaseBackend):
return target_group return target_group
def create_listener(self, load_balancer_arn, protocol, port, ssl_policy, certificate, default_actions): def create_listener(self, load_balancer_arn, protocol, port, ssl_policy, certificate, default_actions):
default_actions = [FakeAction(action) for action in default_actions]
balancer = self.load_balancers.get(load_balancer_arn) balancer = self.load_balancers.get(load_balancer_arn)
if balancer is None: if balancer is None:
raise LoadBalancerNotFoundError() raise LoadBalancerNotFoundError()
if port in balancer.listeners: if port in balancer.listeners:
raise DuplicateListenerError() raise DuplicateListenerError()
self._validate_actions(default_actions)
arn = load_balancer_arn.replace(':loadbalancer/', ':listener/') + "/%s%s" % (port, id(self)) arn = load_balancer_arn.replace(':loadbalancer/', ':listener/') + "/%s%s" % (port, id(self))
listener = FakeListener(load_balancer_arn, arn, protocol, port, ssl_policy, certificate, default_actions) listener = FakeListener(load_balancer_arn, arn, protocol, port, ssl_policy, certificate, default_actions)
balancer.listeners[listener.arn] = listener balancer.listeners[listener.arn] = listener
for i, action in enumerate(default_actions): for action in default_actions:
action_type = action['type'] if action.type == 'forward':
if action_type == 'forward': target_group = self.target_groups[action.data['target_group_arn']]
if action['target_group_arn'] in self.target_groups.keys(): target_group.load_balancer_arns.append(load_balancer_arn)
target_group = self.target_groups[action['target_group_arn']]
target_group.load_balancer_arns.append(load_balancer_arn)
elif action_type == 'redirect':
# nothing to do
pass
else:
raise InvalidActionTypeError(action_type, i + 1)
return listener return listener
@ -643,6 +674,7 @@ class ELBv2Backend(BaseBackend):
raise ListenerNotFoundError() raise ListenerNotFoundError()
def modify_rule(self, rule_arn, conditions, actions): def modify_rule(self, rule_arn, conditions, actions):
actions = [FakeAction(action) for action in actions]
# if conditions or actions is empty list, do not update the attributes # if conditions or actions is empty list, do not update the attributes
if not conditions and not actions: if not conditions and not actions:
raise InvalidModifyRuleArgumentsError() raise InvalidModifyRuleArgumentsError()
@ -668,20 +700,7 @@ class ELBv2Backend(BaseBackend):
# TODO: check pattern of value for 'path-pattern' # TODO: check pattern of value for 'path-pattern'
# validate Actions # validate Actions
target_group_arns = [target_group.arn for target_group in self.target_groups.values()] self._validate_actions(actions)
if actions:
for i, action in enumerate(actions):
index = i + 1
action_type = action['type']
if action_type == 'forward':
action_target_group_arn = action['target_group_arn']
if action_target_group_arn not in target_group_arns:
raise ActionTargetGroupNotFoundError(action_target_group_arn)
elif action_type == 'redirect':
# nothing to do
pass
else:
raise InvalidActionTypeError(action_type, index)
# TODO: check for error 'TooManyRegistrationsForTargetId' # TODO: check for error 'TooManyRegistrationsForTargetId'
# TODO: check for error 'TooManyRules' # TODO: check for error 'TooManyRules'
@ -712,7 +731,7 @@ class ELBv2Backend(BaseBackend):
if not targets: if not targets:
targets = target_group.targets.values() targets = target_group.targets.values()
return [target_group.health_for(target) for target in targets] return [target_group.health_for(target, self.ec2_backend) for target in targets]
def set_rule_priorities(self, rule_priorities): def set_rule_priorities(self, rule_priorities):
# validate # validate
@ -846,6 +865,7 @@ class ELBv2Backend(BaseBackend):
return target_group return target_group
def modify_listener(self, arn, port=None, protocol=None, ssl_policy=None, certificates=None, default_actions=None): def modify_listener(self, arn, port=None, protocol=None, ssl_policy=None, certificates=None, default_actions=None):
default_actions = [FakeAction(action) for action in default_actions]
for load_balancer in self.load_balancers.values(): for load_balancer in self.load_balancers.values():
if arn in load_balancer.listeners: if arn in load_balancer.listeners:
break break
@ -912,7 +932,7 @@ class ELBv2Backend(BaseBackend):
for listener in load_balancer.listeners.values(): for listener in load_balancer.listeners.values():
for rule in listener.rules: for rule in listener.rules:
for action in rule.actions: for action in rule.actions:
if action.get('target_group_arn') == target_group_arn: if action.data.get('target_group_arn') == target_group_arn:
return True return True
return False return False

View File

@ -775,16 +775,7 @@ CREATE_LISTENER_TEMPLATE = """<CreateListenerResponse xmlns="http://elasticloadb
<DefaultActions> <DefaultActions>
{% for action in listener.default_actions %} {% for action in listener.default_actions %}
<member> <member>
<Type>{{ action.type }}</Type> {{ action.to_xml() }}
{% if action["type"] == "forward" %}
<TargetGroupArn>{{ action["target_group_arn"] }}</TargetGroupArn>
{% elif action["type"] == "redirect" %}
<RedirectConfig>
<Protocol>{{ action["redirect_config._protocol"] }}</Protocol>
<Port>{{ action["redirect_config._port"] }}</Port>
<StatusCode>{{ action["redirect_config._status_code"] }}</StatusCode>
</RedirectConfig>
{% endif %}
</member> </member>
{% endfor %} {% endfor %}
</DefaultActions> </DefaultActions>
@ -888,16 +879,7 @@ DESCRIBE_RULES_TEMPLATE = """<DescribeRulesResponse xmlns="http://elasticloadbal
<Actions> <Actions>
{% for action in rule.actions %} {% for action in rule.actions %}
<member> <member>
<Type>{{ action["type"] }}</Type> {{ action.to_xml() }}
{% if action["type"] == "forward" %}
<TargetGroupArn>{{ action["target_group_arn"] }}</TargetGroupArn>
{% elif action["type"] == "redirect" %}
<RedirectConfig>
<Protocol>{{ action["redirect_config._protocol"] }}</Protocol>
<Port>{{ action["redirect_config._port"] }}</Port>
<StatusCode>{{ action["redirect_config._status_code"] }}</StatusCode>
</RedirectConfig>
{% endif %}
</member> </member>
{% endfor %} {% endfor %}
</Actions> </Actions>
@ -989,16 +971,7 @@ DESCRIBE_LISTENERS_TEMPLATE = """<DescribeLoadBalancersResponse xmlns="http://el
<DefaultActions> <DefaultActions>
{% for action in listener.default_actions %} {% for action in listener.default_actions %}
<member> <member>
<Type>{{ action.type }}</Type> {{ action.to_xml() }}
{% if action["type"] == "forward" %}
<TargetGroupArn>{{ action["target_group_arn"] }}</TargetGroupArn>m
{% elif action["type"] == "redirect" %}
<RedirectConfig>
<Protocol>{{ action["redirect_config._protocol"] }}</Protocol>
<Port>{{ action["redirect_config._port"] }}</Port>
<StatusCode>{{ action["redirect_config._status_code"] }}</StatusCode>
</RedirectConfig>
{% endif %}
</member> </member>
{% endfor %} {% endfor %}
</DefaultActions> </DefaultActions>
@ -1048,8 +1021,7 @@ MODIFY_RULE_TEMPLATE = """<ModifyRuleResponse xmlns="http://elasticloadbalancing
<Actions> <Actions>
{% for action in rule.actions %} {% for action in rule.actions %}
<member> <member>
<Type>{{ action["type"] }}</Type> {{ action.to_xml() }}
<TargetGroupArn>{{ action["target_group_arn"] }}</TargetGroupArn>
</member> </member>
{% endfor %} {% endfor %}
</Actions> </Actions>
@ -1208,6 +1180,12 @@ DESCRIBE_TARGET_HEALTH_TEMPLATE = """<DescribeTargetHealthResponse xmlns="http:/
<HealthCheckPort>{{ target_health.health_port }}</HealthCheckPort> <HealthCheckPort>{{ target_health.health_port }}</HealthCheckPort>
<TargetHealth> <TargetHealth>
<State>{{ target_health.status }}</State> <State>{{ target_health.status }}</State>
{% if target_health.reason %}
<Reason>{{ target_health.reason }}</Reason>
{% endif %}
{% if target_health.description %}
<Description>{{ target_health.description }}</Description>
{% endif %}
</TargetHealth> </TargetHealth>
<Target> <Target>
<Port>{{ target_health.port }}</Port> <Port>{{ target_health.port }}</Port>
@ -1426,16 +1404,7 @@ MODIFY_LISTENER_TEMPLATE = """<ModifyListenerResponse xmlns="http://elasticloadb
<DefaultActions> <DefaultActions>
{% for action in listener.default_actions %} {% for action in listener.default_actions %}
<member> <member>
<Type>{{ action.type }}</Type> {{ action.to_xml() }}
{% if action["type"] == "forward" %}
<TargetGroupArn>{{ action["target_group_arn"] }}</TargetGroupArn>
{% elif action["type"] == "redirect" %}
<RedirectConfig>
<Protocol>{{ action["redirect_config._protocol"] }}</Protocol>
<Port>{{ action["redirect_config._port"] }}</Port>
<StatusCode>{{ action["redirect_config._status_code"] }}</StatusCode>
</RedirectConfig>
{% endif %}
</member> </member>
{% endfor %} {% endfor %}
</DefaultActions> </DefaultActions>

View File

@ -141,6 +141,23 @@ class GlueResponse(BaseResponse):
return json.dumps({'Partition': p.as_dict()}) return json.dumps({'Partition': p.as_dict()})
def batch_get_partition(self):
database_name = self.parameters.get('DatabaseName')
table_name = self.parameters.get('TableName')
partitions_to_get = self.parameters.get('PartitionsToGet')
table = self.glue_backend.get_table(database_name, table_name)
partitions = []
for values in partitions_to_get:
try:
p = table.get_partition(values=values["Values"])
partitions.append(p.as_dict())
except PartitionNotFoundException:
continue
return json.dumps({'Partitions': partitions})
def create_partition(self): def create_partition(self):
database_name = self.parameters.get('DatabaseName') database_name = self.parameters.get('DatabaseName')
table_name = self.parameters.get('TableName') table_name = self.parameters.get('TableName')

View File

@ -694,7 +694,6 @@ class IAMBackend(BaseBackend):
def _validate_tag_key(self, tag_key, exception_param='tags.X.member.key'): def _validate_tag_key(self, tag_key, exception_param='tags.X.member.key'):
"""Validates the tag key. """Validates the tag key.
:param all_tags: Dict to check if there is a duplicate tag.
:param tag_key: The tag key to check against. :param tag_key: The tag key to check against.
:param exception_param: The exception parameter to send over to help format the message. This is to reflect :param exception_param: The exception parameter to send over to help format the message. This is to reflect
the difference between the tag and untag APIs. the difference between the tag and untag APIs.

View File

@ -3,7 +3,7 @@ from __future__ import unicode_literals
import os import os
import boto.kms import boto.kms
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_without_milliseconds, unix_time from moto.core.utils import iso_8601_datetime_without_milliseconds
from .utils import generate_key_id from .utils import generate_key_id
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -11,7 +11,7 @@ from datetime import datetime, timedelta
class Key(BaseModel): class Key(BaseModel):
def __init__(self, policy, key_usage, description, region): def __init__(self, policy, key_usage, description, tags, region):
self.id = generate_key_id() self.id = generate_key_id()
self.policy = policy self.policy = policy
self.key_usage = key_usage self.key_usage = key_usage
@ -22,7 +22,7 @@ class Key(BaseModel):
self.account_id = "0123456789012" self.account_id = "0123456789012"
self.key_rotation_status = False self.key_rotation_status = False
self.deletion_date = None self.deletion_date = None
self.tags = {} self.tags = tags or {}
@property @property
def physical_resource_id(self): def physical_resource_id(self):
@ -37,7 +37,7 @@ class Key(BaseModel):
"KeyMetadata": { "KeyMetadata": {
"AWSAccountId": self.account_id, "AWSAccountId": self.account_id,
"Arn": self.arn, "Arn": self.arn,
"CreationDate": "%d" % unix_time(), "CreationDate": iso_8601_datetime_without_milliseconds(datetime.now()),
"Description": self.description, "Description": self.description,
"Enabled": self.enabled, "Enabled": self.enabled,
"KeyId": self.id, "KeyId": self.id,
@ -61,6 +61,7 @@ class Key(BaseModel):
policy=properties['KeyPolicy'], policy=properties['KeyPolicy'],
key_usage='ENCRYPT_DECRYPT', key_usage='ENCRYPT_DECRYPT',
description=properties['Description'], description=properties['Description'],
tags=properties.get('Tags'),
region=region_name, region=region_name,
) )
key.key_rotation_status = properties['EnableKeyRotation'] key.key_rotation_status = properties['EnableKeyRotation']
@ -80,8 +81,8 @@ class KmsBackend(BaseBackend):
self.keys = {} self.keys = {}
self.key_to_aliases = defaultdict(set) self.key_to_aliases = defaultdict(set)
def create_key(self, policy, key_usage, description, region): def create_key(self, policy, key_usage, description, tags, region):
key = Key(policy, key_usage, description, region) key = Key(policy, key_usage, description, tags, region)
self.keys[key.id] = key self.keys[key.id] = key
return key return key

View File

@ -31,9 +31,10 @@ class KmsResponse(BaseResponse):
policy = self.parameters.get('Policy') policy = self.parameters.get('Policy')
key_usage = self.parameters.get('KeyUsage') key_usage = self.parameters.get('KeyUsage')
description = self.parameters.get('Description') description = self.parameters.get('Description')
tags = self.parameters.get('Tags')
key = self.kms_backend.create_key( key = self.kms_backend.create_key(
policy, key_usage, description, self.region) policy, key_usage, description, tags, self.region)
return json.dumps(key.to_dict()) return json.dumps(key.to_dict())
def update_key_description(self): def update_key_description(self):
@ -237,7 +238,7 @@ class KmsResponse(BaseResponse):
value = self.parameters.get("CiphertextBlob") value = self.parameters.get("CiphertextBlob")
try: try:
return json.dumps({"Plaintext": base64.b64decode(value).decode("utf-8")}) return json.dumps({"Plaintext": base64.b64decode(value).decode("utf-8"), 'KeyId': 'key_id'})
except UnicodeDecodeError: except UnicodeDecodeError:
# Generate data key will produce random bytes which when decrypted is still returned as base64 # Generate data key will produce random bytes which when decrypted is still returned as base64
return json.dumps({"Plaintext": value}) return json.dumps({"Plaintext": value})

View File

@ -98,17 +98,29 @@ class LogStream:
return True return True
def get_paging_token_from_index(index, back=False):
if index is not None:
return "b/{:056d}".format(index) if back else "f/{:056d}".format(index)
return 0
def get_index_from_paging_token(token):
if token is not None:
return int(token[2:])
return 0
events = sorted(filter(filter_func, self.events), key=lambda event: event.timestamp, reverse=start_from_head) events = sorted(filter(filter_func, self.events), key=lambda event: event.timestamp, reverse=start_from_head)
back_token = next_token next_index = get_index_from_paging_token(next_token)
if next_token is None: back_index = next_index
next_token = 0
events_page = [event.to_response_dict() for event in events[next_token: next_token + limit]] events_page = [event.to_response_dict() for event in events[next_index: next_index + limit]]
next_token += limit if next_index + limit < len(self.events):
if next_token >= len(self.events): next_index += limit
next_token = None
return events_page, back_token, next_token back_index -= limit
if back_index <= 0:
back_index = 0
return events_page, get_paging_token_from_index(back_index, True), get_paging_token_from_index(next_index)
def filter_log_events(self, log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved): def filter_log_events(self, log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved):
def filter_func(event): def filter_func(event):

View File

@ -2,6 +2,7 @@ from __future__ import unicode_literals
import datetime import datetime
import re import re
import json
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.exceptions import RESTError from moto.core.exceptions import RESTError
@ -151,7 +152,6 @@ class FakeRoot(FakeOrganizationalUnit):
class FakeServiceControlPolicy(BaseModel): class FakeServiceControlPolicy(BaseModel):
def __init__(self, organization, **kwargs): def __init__(self, organization, **kwargs):
self.type = 'POLICY'
self.content = kwargs.get('Content') self.content = kwargs.get('Content')
self.description = kwargs.get('Description') self.description = kwargs.get('Description')
self.name = kwargs.get('Name') self.name = kwargs.get('Name')
@ -197,7 +197,38 @@ class OrganizationsBackend(BaseBackend):
def create_organization(self, **kwargs): def create_organization(self, **kwargs):
self.org = FakeOrganization(kwargs['FeatureSet']) self.org = FakeOrganization(kwargs['FeatureSet'])
self.ou.append(FakeRoot(self.org)) root_ou = FakeRoot(self.org)
self.ou.append(root_ou)
master_account = FakeAccount(
self.org,
AccountName='master',
Email=self.org.master_account_email,
)
master_account.id = self.org.master_account_id
self.accounts.append(master_account)
default_policy = FakeServiceControlPolicy(
self.org,
Name='FullAWSAccess',
Description='Allows access to every operation',
Type='SERVICE_CONTROL_POLICY',
Content=json.dumps(
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": "*",
"Resource": "*"
}
]
}
)
)
default_policy.id = utils.DEFAULT_POLICY_ID
default_policy.aws_managed = True
self.policies.append(default_policy)
self.attach_policy(PolicyId=default_policy.id, TargetId=root_ou.id)
self.attach_policy(PolicyId=default_policy.id, TargetId=master_account.id)
return self.org.describe() return self.org.describe()
def describe_organization(self): def describe_organization(self):
@ -216,6 +247,7 @@ class OrganizationsBackend(BaseBackend):
def create_organizational_unit(self, **kwargs): def create_organizational_unit(self, **kwargs):
new_ou = FakeOrganizationalUnit(self.org, **kwargs) new_ou = FakeOrganizationalUnit(self.org, **kwargs)
self.ou.append(new_ou) self.ou.append(new_ou)
self.attach_policy(PolicyId=utils.DEFAULT_POLICY_ID, TargetId=new_ou.id)
return new_ou.describe() return new_ou.describe()
def get_organizational_unit_by_id(self, ou_id): def get_organizational_unit_by_id(self, ou_id):
@ -258,6 +290,7 @@ class OrganizationsBackend(BaseBackend):
def create_account(self, **kwargs): def create_account(self, **kwargs):
new_account = FakeAccount(self.org, **kwargs) new_account = FakeAccount(self.org, **kwargs)
self.accounts.append(new_account) self.accounts.append(new_account)
self.attach_policy(PolicyId=utils.DEFAULT_POLICY_ID, TargetId=new_account.id)
return new_account.create_account_status return new_account.create_account_status
def get_account_by_id(self, account_id): def get_account_by_id(self, account_id):
@ -358,8 +391,7 @@ class OrganizationsBackend(BaseBackend):
def attach_policy(self, **kwargs): def attach_policy(self, **kwargs):
policy = next((p for p in self.policies if p.id == kwargs['PolicyId']), None) policy = next((p for p in self.policies if p.id == kwargs['PolicyId']), None)
if (re.compile(utils.ROOT_ID_REGEX).match(kwargs['TargetId']) or if (re.compile(utils.ROOT_ID_REGEX).match(kwargs['TargetId']) or re.compile(utils.OU_ID_REGEX).match(kwargs['TargetId'])):
re.compile(utils.OU_ID_REGEX).match(kwargs['TargetId'])):
ou = next((ou for ou in self.ou if ou.id == kwargs['TargetId']), None) ou = next((ou for ou in self.ou if ou.id == kwargs['TargetId']), None)
if ou is not None: if ou is not None:
if ou not in ou.attached_policies: if ou not in ou.attached_policies:

View File

@ -4,7 +4,8 @@ import random
import string import string
MASTER_ACCOUNT_ID = '123456789012' MASTER_ACCOUNT_ID = '123456789012'
MASTER_ACCOUNT_EMAIL = 'fakeorg@moto-example.com' MASTER_ACCOUNT_EMAIL = 'master@example.com'
DEFAULT_POLICY_ID = 'p-FullAWSAccess'
ORGANIZATION_ARN_FORMAT = 'arn:aws:organizations::{0}:organization/{1}' ORGANIZATION_ARN_FORMAT = 'arn:aws:organizations::{0}:organization/{1}'
MASTER_ACCOUNT_ARN_FORMAT = 'arn:aws:organizations::{0}:account/{1}/{0}' MASTER_ACCOUNT_ARN_FORMAT = 'arn:aws:organizations::{0}:account/{1}/{0}'
ACCOUNT_ARN_FORMAT = 'arn:aws:organizations::{0}:account/{1}/{2}' ACCOUNT_ARN_FORMAT = 'arn:aws:organizations::{0}:account/{1}/{2}'
@ -26,7 +27,7 @@ ROOT_ID_REGEX = r'r-[a-z0-9]{%s}' % ROOT_ID_SIZE
OU_ID_REGEX = r'ou-[a-z0-9]{%s}-[a-z0-9]{%s}' % (ROOT_ID_SIZE, OU_ID_SUFFIX_SIZE) OU_ID_REGEX = r'ou-[a-z0-9]{%s}-[a-z0-9]{%s}' % (ROOT_ID_SIZE, OU_ID_SUFFIX_SIZE)
ACCOUNT_ID_REGEX = r'[0-9]{%s}' % ACCOUNT_ID_SIZE ACCOUNT_ID_REGEX = r'[0-9]{%s}' % ACCOUNT_ID_SIZE
CREATE_ACCOUNT_STATUS_ID_REGEX = r'car-[a-z0-9]{%s}' % CREATE_ACCOUNT_STATUS_ID_SIZE CREATE_ACCOUNT_STATUS_ID_REGEX = r'car-[a-z0-9]{%s}' % CREATE_ACCOUNT_STATUS_ID_SIZE
SCP_ID_REGEX = r'p-[a-z0-9]{%s}' % SCP_ID_SIZE SCP_ID_REGEX = r'%s|p-[a-z0-9]{%s}' % (DEFAULT_POLICY_ID, SCP_ID_SIZE)
def make_random_org_id(): def make_random_org_id():

View File

@ -95,7 +95,7 @@ class RDSResponse(BaseResponse):
start = all_ids.index(marker) + 1 start = all_ids.index(marker) + 1
else: else:
start = 0 start = 0
page_size = self._get_param('MaxRecords', 50) # the default is 100, but using 50 to make testing easier page_size = self._get_int_param('MaxRecords', 50) # the default is 100, but using 50 to make testing easier
instances_resp = all_instances[start:start + page_size] instances_resp = all_instances[start:start + page_size]
next_marker = None next_marker = None
if len(all_instances) > start + page_size: if len(all_instances) > start + page_size:

View File

@ -149,7 +149,14 @@ class Database(BaseModel):
<DBInstanceStatus>{{ database.status }}</DBInstanceStatus> <DBInstanceStatus>{{ database.status }}</DBInstanceStatus>
{% if database.db_name %}<DBName>{{ database.db_name }}</DBName>{% endif %} {% if database.db_name %}<DBName>{{ database.db_name }}</DBName>{% endif %}
<MultiAZ>{{ database.multi_az }}</MultiAZ> <MultiAZ>{{ database.multi_az }}</MultiAZ>
<VpcSecurityGroups/> <VpcSecurityGroups>
{% for vpc_security_group_id in database.vpc_security_group_ids %}
<VpcSecurityGroupMembership>
<Status>active</Status>
<VpcSecurityGroupId>{{ vpc_security_group_id }}</VpcSecurityGroupId>
</VpcSecurityGroupMembership>
{% endfor %}
</VpcSecurityGroups>
<DBInstanceIdentifier>{{ database.db_instance_identifier }}</DBInstanceIdentifier> <DBInstanceIdentifier>{{ database.db_instance_identifier }}</DBInstanceIdentifier>
<DbiResourceId>{{ database.dbi_resource_id }}</DbiResourceId> <DbiResourceId>{{ database.dbi_resource_id }}</DbiResourceId>
<InstanceCreateTime>{{ database.instance_create_time }}</InstanceCreateTime> <InstanceCreateTime>{{ database.instance_create_time }}</InstanceCreateTime>
@ -323,6 +330,7 @@ class Database(BaseModel):
"storage_encrypted": properties.get("StorageEncrypted"), "storage_encrypted": properties.get("StorageEncrypted"),
"storage_type": properties.get("StorageType"), "storage_type": properties.get("StorageType"),
"tags": properties.get("Tags"), "tags": properties.get("Tags"),
"vpc_security_group_ids": properties.get('VpcSecurityGroupIds', []),
} }
rds2_backend = rds2_backends[region_name] rds2_backend = rds2_backends[region_name]
@ -397,10 +405,12 @@ class Database(BaseModel):
"SecondaryAvailabilityZone": null, "SecondaryAvailabilityZone": null,
"StatusInfos": null, "StatusInfos": null,
"VpcSecurityGroups": [ "VpcSecurityGroups": [
{% for vpc_security_group_id in database.vpc_security_group_ids %}
{ {
"Status": "active", "Status": "active",
"VpcSecurityGroupId": "sg-123456" "VpcSecurityGroupId": "{{ vpc_security_group_id }}"
} }
{% endfor %}
], ],
"DBInstanceArn": "{{ database.db_instance_arn }}" "DBInstanceArn": "{{ database.db_instance_arn }}"
}""") }""")

View File

@ -43,7 +43,7 @@ class RDS2Response(BaseResponse):
"security_groups": self._get_multi_param('DBSecurityGroups.DBSecurityGroupName'), "security_groups": self._get_multi_param('DBSecurityGroups.DBSecurityGroupName'),
"storage_encrypted": self._get_param("StorageEncrypted"), "storage_encrypted": self._get_param("StorageEncrypted"),
"storage_type": self._get_param("StorageType", 'standard'), "storage_type": self._get_param("StorageType", 'standard'),
# VpcSecurityGroupIds.member.N "vpc_security_group_ids": self._get_multi_param("VpcSecurityGroupIds.VpcSecurityGroupId"),
"tags": list(), "tags": list(),
} }
args['tags'] = self.unpack_complex_list_params( args['tags'] = self.unpack_complex_list_params(
@ -280,7 +280,7 @@ class RDS2Response(BaseResponse):
def describe_option_groups(self): def describe_option_groups(self):
kwargs = self._get_option_group_kwargs() kwargs = self._get_option_group_kwargs()
kwargs['max_records'] = self._get_param('MaxRecords') kwargs['max_records'] = self._get_int_param('MaxRecords')
kwargs['marker'] = self._get_param('Marker') kwargs['marker'] = self._get_param('Marker')
option_groups = self.backend.describe_option_groups(kwargs) option_groups = self.backend.describe_option_groups(kwargs)
template = self.response_template(DESCRIBE_OPTION_GROUP_TEMPLATE) template = self.response_template(DESCRIBE_OPTION_GROUP_TEMPLATE)
@ -329,7 +329,7 @@ class RDS2Response(BaseResponse):
def describe_db_parameter_groups(self): def describe_db_parameter_groups(self):
kwargs = self._get_db_parameter_group_kwargs() kwargs = self._get_db_parameter_group_kwargs()
kwargs['max_records'] = self._get_param('MaxRecords') kwargs['max_records'] = self._get_int_param('MaxRecords')
kwargs['marker'] = self._get_param('Marker') kwargs['marker'] = self._get_param('Marker')
db_parameter_groups = self.backend.describe_db_parameter_groups(kwargs) db_parameter_groups = self.backend.describe_db_parameter_groups(kwargs)
template = self.response_template( template = self.response_template(

View File

@ -78,7 +78,7 @@ class Cluster(TaggableResourceMixin, BaseModel):
super(Cluster, self).__init__(region_name, tags) super(Cluster, self).__init__(region_name, tags)
self.redshift_backend = redshift_backend self.redshift_backend = redshift_backend
self.cluster_identifier = cluster_identifier self.cluster_identifier = cluster_identifier
self.create_time = iso_8601_datetime_with_milliseconds(datetime.datetime.now()) self.create_time = iso_8601_datetime_with_milliseconds(datetime.datetime.utcnow())
self.status = 'available' self.status = 'available'
self.node_type = node_type self.node_type = node_type
self.master_username = master_username self.master_username = master_username

View File

@ -10,6 +10,7 @@ from moto.ec2 import ec2_backends
from moto.elb import elb_backends from moto.elb import elb_backends
from moto.elbv2 import elbv2_backends from moto.elbv2 import elbv2_backends
from moto.kinesis import kinesis_backends from moto.kinesis import kinesis_backends
from moto.kms import kms_backends
from moto.rds2 import rds2_backends from moto.rds2 import rds2_backends
from moto.glacier import glacier_backends from moto.glacier import glacier_backends
from moto.redshift import redshift_backends from moto.redshift import redshift_backends
@ -71,6 +72,13 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend):
""" """
return kinesis_backends[self.region_name] return kinesis_backends[self.region_name]
@property
def kms_backend(self):
"""
:rtype: moto.kms.models.KmsBackend
"""
return kms_backends[self.region_name]
@property @property
def rds_backend(self): def rds_backend(self):
""" """
@ -221,9 +229,6 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend):
if not resource_type_filters or 'elasticloadbalancer' in resource_type_filters or 'elasticloadbalancer:loadbalancer' in resource_type_filters: if not resource_type_filters or 'elasticloadbalancer' in resource_type_filters or 'elasticloadbalancer:loadbalancer' in resource_type_filters:
for elb in self.elbv2_backend.load_balancers.values(): for elb in self.elbv2_backend.load_balancers.values():
tags = get_elbv2_tags(elb.arn) tags = get_elbv2_tags(elb.arn)
# if 'elasticloadbalancer:loadbalancer' in resource_type_filters:
# from IPython import embed
# embed()
if not tag_filter(tags): # Skip if no tags, or invalid filter if not tag_filter(tags): # Skip if no tags, or invalid filter
continue continue
@ -235,6 +240,21 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend):
# Kinesis # Kinesis
# KMS
def get_kms_tags(kms_key_id):
result = []
for tag in self.kms_backend.list_resource_tags(kms_key_id):
result.append({'Key': tag['TagKey'], 'Value': tag['TagValue']})
return result
if not resource_type_filters or 'kms' in resource_type_filters:
for kms_key in self.kms_backend.list_keys():
tags = get_kms_tags(kms_key.id)
if not tag_filter(tags): # Skip if no tags, or invalid filter
continue
yield {'ResourceARN': '{0}'.format(kms_key.arn), 'Tags': tags}
# RDS Instance # RDS Instance
# RDS Reserved Database Instance # RDS Reserved Database Instance
# RDS Option Group # RDS Option Group
@ -370,7 +390,7 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend):
def get_resources(self, pagination_token=None, def get_resources(self, pagination_token=None,
resources_per_page=50, tags_per_page=100, resources_per_page=50, tags_per_page=100,
tag_filters=None, resource_type_filters=None): tag_filters=None, resource_type_filters=None):
# Simple range checning # Simple range checking
if 100 >= tags_per_page >= 500: if 100 >= tags_per_page >= 500:
raise RESTError('InvalidParameterException', 'TagsPerPage must be between 100 and 500') raise RESTError('InvalidParameterException', 'TagsPerPage must be between 100 and 500')
if 1 >= resources_per_page >= 50: if 1 >= resources_per_page >= 50:

View File

@ -198,7 +198,7 @@ class FakeZone(BaseModel):
def upsert_rrset(self, record_set): def upsert_rrset(self, record_set):
new_rrset = RecordSet(record_set) new_rrset = RecordSet(record_set)
for i, rrset in enumerate(self.rrsets): for i, rrset in enumerate(self.rrsets):
if rrset.name == new_rrset.name and rrset.type_ == new_rrset.type_: if rrset.name == new_rrset.name and rrset.type_ == new_rrset.type_ and rrset.set_identifier == new_rrset.set_identifier:
self.rrsets[i] = new_rrset self.rrsets[i] = new_rrset
break break
else: else:

View File

@ -60,6 +60,17 @@ class MissingKey(S3ClientError):
) )
class ObjectNotInActiveTierError(S3ClientError):
code = 403
def __init__(self, key_name):
super(ObjectNotInActiveTierError, self).__init__(
"ObjectNotInActiveTierError",
"The source object of the COPY operation is not in the active tier and is only stored in Amazon Glacier.",
Key=key_name,
)
class InvalidPartOrder(S3ClientError): class InvalidPartOrder(S3ClientError):
code = 400 code = 400

View File

@ -28,7 +28,8 @@ MAX_BUCKET_NAME_LENGTH = 63
MIN_BUCKET_NAME_LENGTH = 3 MIN_BUCKET_NAME_LENGTH = 3
UPLOAD_ID_BYTES = 43 UPLOAD_ID_BYTES = 43
UPLOAD_PART_MIN_SIZE = 5242880 UPLOAD_PART_MIN_SIZE = 5242880
STORAGE_CLASS = ["STANDARD", "REDUCED_REDUNDANCY", "STANDARD_IA", "ONEZONE_IA"] STORAGE_CLASS = ["STANDARD", "REDUCED_REDUNDANCY", "STANDARD_IA", "ONEZONE_IA",
"INTELLIGENT_TIERING", "GLACIER", "DEEP_ARCHIVE"]
DEFAULT_KEY_BUFFER_SIZE = 16 * 1024 * 1024 DEFAULT_KEY_BUFFER_SIZE = 16 * 1024 * 1024
DEFAULT_TEXT_ENCODING = sys.getdefaultencoding() DEFAULT_TEXT_ENCODING = sys.getdefaultencoding()
@ -52,8 +53,17 @@ class FakeDeleteMarker(BaseModel):
class FakeKey(BaseModel): class FakeKey(BaseModel):
def __init__(self, name, value, storage="STANDARD", etag=None, is_versioned=False, version_id=0, def __init__(
max_buffer_size=DEFAULT_KEY_BUFFER_SIZE): self,
name,
value,
storage="STANDARD",
etag=None,
is_versioned=False,
version_id=0,
max_buffer_size=DEFAULT_KEY_BUFFER_SIZE,
multipart=None
):
self.name = name self.name = name
self.last_modified = datetime.datetime.utcnow() self.last_modified = datetime.datetime.utcnow()
self.acl = get_canned_acl('private') self.acl = get_canned_acl('private')
@ -65,6 +75,7 @@ class FakeKey(BaseModel):
self._version_id = version_id self._version_id = version_id
self._is_versioned = is_versioned self._is_versioned = is_versioned
self._tagging = FakeTagging() self._tagging = FakeTagging()
self.multipart = multipart
self._value_buffer = tempfile.SpooledTemporaryFile(max_size=max_buffer_size) self._value_buffer = tempfile.SpooledTemporaryFile(max_size=max_buffer_size)
self._max_buffer_size = max_buffer_size self._max_buffer_size = max_buffer_size
@ -754,7 +765,7 @@ class S3Backend(BaseBackend):
prefix=''): prefix=''):
bucket = self.get_bucket(bucket_name) bucket = self.get_bucket(bucket_name)
if any((delimiter, encoding_type, key_marker, version_id_marker)): if any((delimiter, key_marker, version_id_marker)):
raise NotImplementedError( raise NotImplementedError(
"Called get_bucket_versions with some of delimiter, encoding_type, key_marker, version_id_marker") "Called get_bucket_versions with some of delimiter, encoding_type, key_marker, version_id_marker")
@ -782,7 +793,15 @@ class S3Backend(BaseBackend):
bucket = self.get_bucket(bucket_name) bucket = self.get_bucket(bucket_name)
return bucket.website_configuration return bucket.website_configuration
def set_key(self, bucket_name, key_name, value, storage=None, etag=None): def set_key(
self,
bucket_name,
key_name,
value,
storage=None,
etag=None,
multipart=None,
):
key_name = clean_key_name(key_name) key_name = clean_key_name(key_name)
if storage is not None and storage not in STORAGE_CLASS: if storage is not None and storage not in STORAGE_CLASS:
raise InvalidStorageClass(storage=storage) raise InvalidStorageClass(storage=storage)
@ -795,7 +814,9 @@ class S3Backend(BaseBackend):
storage=storage, storage=storage,
etag=etag, etag=etag,
is_versioned=bucket.is_versioned, is_versioned=bucket.is_versioned,
version_id=str(uuid.uuid4()) if bucket.is_versioned else None) version_id=str(uuid.uuid4()) if bucket.is_versioned else None,
multipart=multipart,
)
keys = [ keys = [
key for key in bucket.keys.getlist(key_name, []) key for key in bucket.keys.getlist(key_name, [])
@ -812,7 +833,7 @@ class S3Backend(BaseBackend):
key.append_to_value(value) key.append_to_value(value)
return key return key
def get_key(self, bucket_name, key_name, version_id=None): def get_key(self, bucket_name, key_name, version_id=None, part_number=None):
key_name = clean_key_name(key_name) key_name = clean_key_name(key_name)
bucket = self.get_bucket(bucket_name) bucket = self.get_bucket(bucket_name)
key = None key = None
@ -827,6 +848,9 @@ class S3Backend(BaseBackend):
key = key_version key = key_version
break break
if part_number and key.multipart:
key = key.multipart.parts[part_number]
if isinstance(key, FakeKey): if isinstance(key, FakeKey):
return key return key
else: else:
@ -890,7 +914,12 @@ class S3Backend(BaseBackend):
return return
del bucket.multiparts[multipart_id] del bucket.multiparts[multipart_id]
key = self.set_key(bucket_name, multipart.key_name, value, etag=etag) key = self.set_key(
bucket_name,
multipart.key_name,
value, etag=etag,
multipart=multipart
)
key.set_metadata(multipart.metadata) key.set_metadata(multipart.metadata)
return key return key

View File

@ -17,7 +17,7 @@ from moto.s3bucket_path.utils import bucket_name_from_url as bucketpath_bucket_n
parse_key_name as bucketpath_parse_key_name, is_delete_keys as bucketpath_is_delete_keys parse_key_name as bucketpath_parse_key_name, is_delete_keys as bucketpath_is_delete_keys
from .exceptions import BucketAlreadyExists, S3ClientError, MissingBucket, MissingKey, InvalidPartOrder, MalformedXML, \ from .exceptions import BucketAlreadyExists, S3ClientError, MissingBucket, MissingKey, InvalidPartOrder, MalformedXML, \
MalformedACLError, InvalidNotificationARN, InvalidNotificationEvent MalformedACLError, InvalidNotificationARN, InvalidNotificationEvent, ObjectNotInActiveTierError
from .models import s3_backend, get_canned_acl, FakeGrantee, FakeGrant, FakeAcl, FakeKey, FakeTagging, FakeTagSet, \ from .models import s3_backend, get_canned_acl, FakeGrantee, FakeGrant, FakeAcl, FakeKey, FakeTagging, FakeTagSet, \
FakeTag FakeTag
from .utils import bucket_name_from_url, clean_key_name, metadata_from_headers, parse_region_from_url from .utils import bucket_name_from_url, clean_key_name, metadata_from_headers, parse_region_from_url
@ -686,6 +686,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
keys = minidom.parseString(body).getElementsByTagName('Key') keys = minidom.parseString(body).getElementsByTagName('Key')
deleted_names = [] deleted_names = []
error_names = [] error_names = []
if len(keys) == 0:
raise MalformedXML()
for k in keys: for k in keys:
key_name = k.firstChild.nodeValue key_name = k.firstChild.nodeValue
@ -900,7 +902,11 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
src_version_id = parse_qs(src_key_parsed.query).get( src_version_id = parse_qs(src_key_parsed.query).get(
'versionId', [None])[0] 'versionId', [None])[0]
if self.backend.get_key(src_bucket, src_key, version_id=src_version_id): key = self.backend.get_key(src_bucket, src_key, version_id=src_version_id)
if key is not None:
if key.storage_class in ["GLACIER", "DEEP_ARCHIVE"]:
raise ObjectNotInActiveTierError(key)
self.backend.copy_key(src_bucket, src_key, bucket_name, key_name, self.backend.copy_key(src_bucket, src_key, bucket_name, key_name,
storage=storage_class, acl=acl, src_version_id=src_version_id) storage=storage_class, acl=acl, src_version_id=src_version_id)
else: else:
@ -940,13 +946,20 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
def _key_response_head(self, bucket_name, query, key_name, headers): def _key_response_head(self, bucket_name, query, key_name, headers):
response_headers = {} response_headers = {}
version_id = query.get('versionId', [None])[0] version_id = query.get('versionId', [None])[0]
part_number = query.get('partNumber', [None])[0]
if part_number:
part_number = int(part_number)
if_modified_since = headers.get('If-Modified-Since', None) if_modified_since = headers.get('If-Modified-Since', None)
if if_modified_since: if if_modified_since:
if_modified_since = str_to_rfc_1123_datetime(if_modified_since) if_modified_since = str_to_rfc_1123_datetime(if_modified_since)
key = self.backend.get_key( key = self.backend.get_key(
bucket_name, key_name, version_id=version_id) bucket_name,
key_name,
version_id=version_id,
part_number=part_number
)
if key: if key:
response_headers.update(key.metadata) response_headers.update(key.metadata)
response_headers.update(key.response_dict) response_headers.update(key.response_dict)

View File

@ -21,6 +21,16 @@ from moto.core.utils import convert_flask_to_httpretty_response
HTTP_METHODS = ["GET", "POST", "PUT", "DELETE", "HEAD", "PATCH"] HTTP_METHODS = ["GET", "POST", "PUT", "DELETE", "HEAD", "PATCH"]
DEFAULT_SERVICE_REGION = ('s3', 'us-east-1')
# Map of unsigned calls to service-region as per AWS API docs
# https://docs.aws.amazon.com/cognito/latest/developerguide/resource-permissions.html#amazon-cognito-signed-versus-unsigned-apis
UNSIGNED_REQUESTS = {
'AWSCognitoIdentityService': ('cognito-identity', 'us-east-1'),
'AWSCognitoIdentityProviderService': ('cognito-idp', 'us-east-1'),
}
class DomainDispatcherApplication(object): class DomainDispatcherApplication(object):
""" """
Dispatch requests to different applications based on the "Host:" header Dispatch requests to different applications based on the "Host:" header
@ -48,7 +58,45 @@ class DomainDispatcherApplication(object):
if re.match(url_base, 'http://%s' % host): if re.match(url_base, 'http://%s' % host):
return backend_name return backend_name
raise RuntimeError('Invalid host: "%s"' % host) def infer_service_region_host(self, environ):
auth = environ.get('HTTP_AUTHORIZATION')
if auth:
# Signed request
# Parse auth header to find service assuming a SigV4 request
# https://docs.aws.amazon.com/general/latest/gr/sigv4-signed-request-examples.html
# ['Credential=sdffdsa', '20170220', 'us-east-1', 'sns', 'aws4_request']
try:
credential_scope = auth.split(",")[0].split()[1]
_, _, region, service, _ = credential_scope.split("/")
except ValueError:
# Signature format does not match, this is exceptional and we can't
# infer a service-region. A reduced set of services still use
# the deprecated SigV2, ergo prefer S3 as most likely default.
# https://docs.aws.amazon.com/general/latest/gr/signature-version-2.html
service, region = DEFAULT_SERVICE_REGION
else:
# Unsigned request
target = environ.get('HTTP_X_AMZ_TARGET')
if target:
service, _ = target.split('.', 1)
service, region = UNSIGNED_REQUESTS.get(service, DEFAULT_SERVICE_REGION)
else:
# S3 is the last resort when the target is also unknown
service, region = DEFAULT_SERVICE_REGION
if service == 'dynamodb':
if environ['HTTP_X_AMZ_TARGET'].startswith('DynamoDBStreams'):
host = 'dynamodbstreams'
else:
dynamo_api_version = environ['HTTP_X_AMZ_TARGET'].split("_")[1].split(".")[0]
# If Newer API version, use dynamodb2
if dynamo_api_version > "20111205":
host = "dynamodb2"
else:
host = "{service}.{region}.amazonaws.com".format(
service=service, region=region)
return host
def get_application(self, environ): def get_application(self, environ):
path_info = environ.get('PATH_INFO', '') path_info = environ.get('PATH_INFO', '')
@ -65,34 +113,14 @@ class DomainDispatcherApplication(object):
host = "instance_metadata" host = "instance_metadata"
else: else:
host = environ['HTTP_HOST'].split(':')[0] host = environ['HTTP_HOST'].split(':')[0]
if host in {'localhost', 'motoserver'} or host.startswith("192.168."):
# Fall back to parsing auth header to find service
# ['Credential=sdffdsa', '20170220', 'us-east-1', 'sns', 'aws4_request']
try:
_, _, region, service, _ = environ['HTTP_AUTHORIZATION'].split(",")[0].split()[
1].split("/")
except (KeyError, ValueError):
# Some cognito-idp endpoints (e.g. change password) do not receive an auth header.
if environ.get('HTTP_X_AMZ_TARGET', '').startswith('AWSCognitoIdentityProviderService'):
service = 'cognito-idp'
else:
service = 's3'
region = 'us-east-1'
if service == 'dynamodb':
if environ['HTTP_X_AMZ_TARGET'].startswith('DynamoDBStreams'):
host = 'dynamodbstreams'
else:
dynamo_api_version = environ['HTTP_X_AMZ_TARGET'].split("_")[1].split(".")[0]
# If Newer API version, use dynamodb2
if dynamo_api_version > "20111205":
host = "dynamodb2"
else:
host = "{service}.{region}.amazonaws.com".format(
service=service, region=region)
with self.lock: with self.lock:
backend = self.get_backend_for_host(host) backend = self.get_backend_for_host(host)
if not backend:
# No regular backend found; try parsing other headers
host = self.infer_service_region_host(environ)
backend = self.get_backend_for_host(host)
app = self.app_instances.get(backend, None) app = self.app_instances.get(backend, None)
if app is None: if app is None:
app = self.create_app(backend) app = self.create_app(backend)

View File

@ -379,6 +379,7 @@ class SQSBackend(BaseBackend):
def reset(self): def reset(self):
region_name = self.region_name region_name = self.region_name
self._reset_model_refs()
self.__dict__ = {} self.__dict__ = {}
self.__init__(region_name) self.__init__(region_name)

15
moto/sts/exceptions.py Normal file
View File

@ -0,0 +1,15 @@
from __future__ import unicode_literals
from moto.core.exceptions import RESTError
class STSClientError(RESTError):
code = 400
class STSValidationError(STSClientError):
def __init__(self, *args, **kwargs):
super(STSValidationError, self).__init__(
"ValidationError",
*args, **kwargs
)

View File

@ -65,5 +65,8 @@ class STSBackend(BaseBackend):
return assumed_role return assumed_role
return None return None
def assume_role_with_web_identity(self, **kwargs):
return self.assume_role(**kwargs)
sts_backend = STSBackend() sts_backend = STSBackend()

View File

@ -3,8 +3,11 @@ from __future__ import unicode_literals
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.iam.models import ACCOUNT_ID from moto.iam.models import ACCOUNT_ID
from moto.iam import iam_backend from moto.iam import iam_backend
from .exceptions import STSValidationError
from .models import sts_backend from .models import sts_backend
MAX_FEDERATION_TOKEN_POLICY_LENGTH = 2048
class TokenResponse(BaseResponse): class TokenResponse(BaseResponse):
@ -17,6 +20,15 @@ class TokenResponse(BaseResponse):
def get_federation_token(self): def get_federation_token(self):
duration = int(self.querystring.get('DurationSeconds', [43200])[0]) duration = int(self.querystring.get('DurationSeconds', [43200])[0])
policy = self.querystring.get('Policy', [None])[0] policy = self.querystring.get('Policy', [None])[0]
if policy is not None and len(policy) > MAX_FEDERATION_TOKEN_POLICY_LENGTH:
raise STSValidationError(
"1 validation error detected: Value "
"'{\"Version\": \"2012-10-17\", \"Statement\": [...]}' "
"at 'policy' failed to satisfy constraint: Member must have length less than or "
" equal to %s" % MAX_FEDERATION_TOKEN_POLICY_LENGTH
)
name = self.querystring.get('Name')[0] name = self.querystring.get('Name')[0]
token = sts_backend.get_federation_token( token = sts_backend.get_federation_token(
duration=duration, name=name, policy=policy) duration=duration, name=name, policy=policy)
@ -41,6 +53,24 @@ class TokenResponse(BaseResponse):
template = self.response_template(ASSUME_ROLE_RESPONSE) template = self.response_template(ASSUME_ROLE_RESPONSE)
return template.render(role=role) return template.render(role=role)
def assume_role_with_web_identity(self):
role_session_name = self.querystring.get('RoleSessionName')[0]
role_arn = self.querystring.get('RoleArn')[0]
policy = self.querystring.get('Policy', [None])[0]
duration = int(self.querystring.get('DurationSeconds', [3600])[0])
external_id = self.querystring.get('ExternalId', [None])[0]
role = sts_backend.assume_role_with_web_identity(
role_session_name=role_session_name,
role_arn=role_arn,
policy=policy,
duration=duration,
external_id=external_id,
)
template = self.response_template(ASSUME_ROLE_WITH_WEB_IDENTITY_RESPONSE)
return template.render(role=role)
def get_caller_identity(self): def get_caller_identity(self):
template = self.response_template(GET_CALLER_IDENTITY_RESPONSE) template = self.response_template(GET_CALLER_IDENTITY_RESPONSE)
@ -118,6 +148,27 @@ ASSUME_ROLE_RESPONSE = """<AssumeRoleResponse xmlns="https://sts.amazonaws.com/d
</ResponseMetadata> </ResponseMetadata>
</AssumeRoleResponse>""" </AssumeRoleResponse>"""
ASSUME_ROLE_WITH_WEB_IDENTITY_RESPONSE = """<AssumeRoleWithWebIdentityResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<AssumeRoleWithWebIdentityResult>
<Credentials>
<SessionToken>{{ role.session_token }}</SessionToken>
<SecretAccessKey>{{ role.secret_access_key }}</SecretAccessKey>
<Expiration>{{ role.expiration_ISO8601 }}</Expiration>
<AccessKeyId>{{ role.access_key_id }}</AccessKeyId>
</Credentials>
<AssumedRoleUser>
<Arn>{{ role.arn }}</Arn>
<AssumedRoleId>ARO123EXAMPLE123:{{ role.session_name }}</AssumedRoleId>
</AssumedRoleUser>
<PackedPolicySize>6</PackedPolicySize>
</AssumeRoleWithWebIdentityResult>
<ResponseMetadata>
<RequestId>c6104cbe-af31-11e0-8154-cbc7ccf896c7</RequestId>
</ResponseMetadata>
</AssumeRoleWithWebIdentityResponse>"""
GET_CALLER_IDENTITY_RESPONSE = """<GetCallerIdentityResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/"> GET_CALLER_IDENTITY_RESPONSE = """<GetCallerIdentityResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<GetCallerIdentityResult> <GetCallerIdentityResult>
<Arn>{{ arn }}</Arn> <Arn>{{ arn }}</Arn>

View File

@ -30,10 +30,9 @@ def get_version():
install_requires = [ install_requires = [
"Jinja2>=2.10.1", "Jinja2>=2.10.1",
"boto>=2.36.0", "boto>=2.36.0",
"boto3>=1.9.86", "boto3>=1.9.201",
"botocore>=1.12.86", "botocore>=1.12.201",
"cryptography>=2.3.0", "cryptography>=2.3.0",
"datetime",
"requests>=2.5", "requests>=2.5",
"xmltodict", "xmltodict",
"six>1.9", "six>1.9",
@ -48,7 +47,7 @@ install_requires = [
"aws-xray-sdk!=0.96,>=0.93", "aws-xray-sdk!=0.96,>=0.93",
"responses>=0.9.0", "responses>=0.9.0",
"idna<2.9,>=2.5", "idna<2.9,>=2.5",
"cfn-lint", "cfn-lint>=0.4.0",
"sshpubkeys>=3.1.0,<4.0" "sshpubkeys>=3.1.0,<4.0"
] ]
@ -89,7 +88,6 @@ setup(
"Programming Language :: Python :: 2", "Programming Language :: Python :: 2",
"Programming Language :: Python :: 2.7", "Programming Language :: Python :: 2.7",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.4",
"Programming Language :: Python :: 3.5", "Programming Language :: Python :: 3.5",
"Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.7",

View File

@ -988,13 +988,30 @@ def test_api_keys():
apikey['name'].should.equal(apikey_name) apikey['name'].should.equal(apikey_name)
len(apikey['value']).should.equal(40) len(apikey['value']).should.equal(40)
apikey_name = 'TESTKEY3'
payload = {'name': apikey_name }
response = client.create_api_key(**payload)
apikey_id = response['id']
patch_operations = [
{'op': 'replace', 'path': '/name', 'value': 'TESTKEY3_CHANGE'},
{'op': 'replace', 'path': '/customerId', 'value': '12345'},
{'op': 'replace', 'path': '/description', 'value': 'APIKEY UPDATE TEST'},
{'op': 'replace', 'path': '/enabled', 'value': 'false'},
]
response = client.update_api_key(apiKey=apikey_id, patchOperations=patch_operations)
response['name'].should.equal('TESTKEY3_CHANGE')
response['customerId'].should.equal('12345')
response['description'].should.equal('APIKEY UPDATE TEST')
response['enabled'].should.equal(False)
response = client.get_api_keys() response = client.get_api_keys()
len(response['items']).should.equal(2) len(response['items']).should.equal(3)
client.delete_api_key(apiKey=apikey_id) client.delete_api_key(apiKey=apikey_id)
response = client.get_api_keys() response = client.get_api_keys()
len(response['items']).should.equal(1) len(response['items']).should.equal(2)
@mock_apigateway @mock_apigateway
def test_usage_plans(): def test_usage_plans():

View File

@ -7,11 +7,13 @@ from boto.ec2.autoscale.group import AutoScalingGroup
from boto.ec2.autoscale import Tag from boto.ec2.autoscale import Tag
import boto.ec2.elb import boto.ec2.elb
import sure # noqa import sure # noqa
from botocore.exceptions import ClientError
from nose.tools import assert_raises
from moto import mock_autoscaling, mock_ec2_deprecated, mock_elb_deprecated, mock_elb, mock_autoscaling_deprecated, mock_ec2 from moto import mock_autoscaling, mock_ec2_deprecated, mock_elb_deprecated, mock_elb, mock_autoscaling_deprecated, mock_ec2
from tests.helpers import requires_boto_gte from tests.helpers import requires_boto_gte
from utils import setup_networking, setup_networking_deprecated from utils import setup_networking, setup_networking_deprecated, setup_instance_with_networking
@mock_autoscaling_deprecated @mock_autoscaling_deprecated
@ -724,6 +726,67 @@ def test_create_autoscaling_group_boto3():
response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) response['ResponseMetadata']['HTTPStatusCode'].should.equal(200)
@mock_autoscaling
def test_create_autoscaling_group_from_instance():
autoscaling_group_name = 'test_asg'
image_id = 'ami-0cc293023f983ed53'
instance_type = 't2.micro'
mocked_instance_with_networking = setup_instance_with_networking(image_id, instance_type)
client = boto3.client('autoscaling', region_name='us-east-1')
response = client.create_auto_scaling_group(
AutoScalingGroupName=autoscaling_group_name,
InstanceId=mocked_instance_with_networking['instance'],
MinSize=1,
MaxSize=3,
DesiredCapacity=2,
Tags=[
{'ResourceId': 'test_asg',
'ResourceType': 'auto-scaling-group',
'Key': 'propogated-tag-key',
'Value': 'propogate-tag-value',
'PropagateAtLaunch': True
},
{'ResourceId': 'test_asg',
'ResourceType': 'auto-scaling-group',
'Key': 'not-propogated-tag-key',
'Value': 'not-propogate-tag-value',
'PropagateAtLaunch': False
}],
VPCZoneIdentifier=mocked_instance_with_networking['subnet1'],
NewInstancesProtectedFromScaleIn=False,
)
response['ResponseMetadata']['HTTPStatusCode'].should.equal(200)
describe_launch_configurations_response = client.describe_launch_configurations()
describe_launch_configurations_response['LaunchConfigurations'].should.have.length_of(1)
launch_configuration_from_instance = describe_launch_configurations_response['LaunchConfigurations'][0]
launch_configuration_from_instance['LaunchConfigurationName'].should.equal('test_asg')
launch_configuration_from_instance['ImageId'].should.equal(image_id)
launch_configuration_from_instance['InstanceType'].should.equal(instance_type)
@mock_autoscaling
def test_create_autoscaling_group_from_invalid_instance_id():
invalid_instance_id = 'invalid_instance'
mocked_networking = setup_networking()
client = boto3.client('autoscaling', region_name='us-east-1')
with assert_raises(ClientError) as ex:
client.create_auto_scaling_group(
AutoScalingGroupName='test_asg',
InstanceId=invalid_instance_id,
MinSize=9,
MaxSize=15,
DesiredCapacity=12,
VPCZoneIdentifier=mocked_networking['subnet1'],
NewInstancesProtectedFromScaleIn=False,
)
ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400)
ex.exception.response['Error']['Code'].should.equal('ValidationError')
ex.exception.response['Error']['Message'].should.equal('Instance [{0}] is invalid.'.format(invalid_instance_id))
@mock_autoscaling @mock_autoscaling
def test_describe_autoscaling_groups_boto3(): def test_describe_autoscaling_groups_boto3():
mocked_networking = setup_networking() mocked_networking = setup_networking()
@ -823,6 +886,62 @@ def test_update_autoscaling_group_boto3():
group['NewInstancesProtectedFromScaleIn'].should.equal(False) group['NewInstancesProtectedFromScaleIn'].should.equal(False)
@mock_autoscaling
def test_update_autoscaling_group_min_size_desired_capacity_change():
mocked_networking = setup_networking()
client = boto3.client('autoscaling', region_name='us-east-1')
client.create_launch_configuration(
LaunchConfigurationName='test_launch_configuration'
)
client.create_auto_scaling_group(
AutoScalingGroupName='test_asg',
LaunchConfigurationName='test_launch_configuration',
MinSize=2,
MaxSize=20,
DesiredCapacity=3,
VPCZoneIdentifier=mocked_networking['subnet1'],
)
client.update_auto_scaling_group(
AutoScalingGroupName='test_asg',
MinSize=5,
)
response = client.describe_auto_scaling_groups(
AutoScalingGroupNames=['test_asg'])
group = response['AutoScalingGroups'][0]
group['DesiredCapacity'].should.equal(5)
group['MinSize'].should.equal(5)
group['Instances'].should.have.length_of(5)
@mock_autoscaling
def test_update_autoscaling_group_max_size_desired_capacity_change():
mocked_networking = setup_networking()
client = boto3.client('autoscaling', region_name='us-east-1')
client.create_launch_configuration(
LaunchConfigurationName='test_launch_configuration'
)
client.create_auto_scaling_group(
AutoScalingGroupName='test_asg',
LaunchConfigurationName='test_launch_configuration',
MinSize=2,
MaxSize=20,
DesiredCapacity=10,
VPCZoneIdentifier=mocked_networking['subnet1'],
)
client.update_auto_scaling_group(
AutoScalingGroupName='test_asg',
MaxSize=5,
)
response = client.describe_auto_scaling_groups(
AutoScalingGroupNames=['test_asg'])
group = response['AutoScalingGroups'][0]
group['DesiredCapacity'].should.equal(5)
group['MaxSize'].should.equal(5)
group['Instances'].should.have.length_of(5)
@mock_autoscaling @mock_autoscaling
def test_autoscaling_taqs_update_boto3(): def test_autoscaling_taqs_update_boto3():
mocked_networking = setup_networking() mocked_networking = setup_networking()
@ -1269,3 +1388,36 @@ def test_set_desired_capacity_down_boto3():
instance_ids = {instance['InstanceId'] for instance in group['Instances']} instance_ids = {instance['InstanceId'] for instance in group['Instances']}
set(protected).should.equal(instance_ids) set(protected).should.equal(instance_ids)
set(unprotected).should_not.be.within(instance_ids) # only unprotected killed set(unprotected).should_not.be.within(instance_ids) # only unprotected killed
@mock_autoscaling
@mock_ec2
def test_terminate_instance_in_autoscaling_group():
mocked_networking = setup_networking()
client = boto3.client('autoscaling', region_name='us-east-1')
_ = client.create_launch_configuration(
LaunchConfigurationName='test_launch_configuration'
)
_ = client.create_auto_scaling_group(
AutoScalingGroupName='test_asg',
LaunchConfigurationName='test_launch_configuration',
MinSize=1,
MaxSize=20,
VPCZoneIdentifier=mocked_networking['subnet1'],
NewInstancesProtectedFromScaleIn=False
)
response = client.describe_auto_scaling_groups(AutoScalingGroupNames=['test_asg'])
original_instance_id = next(
instance['InstanceId']
for instance in response['AutoScalingGroups'][0]['Instances']
)
ec2_client = boto3.client('ec2', region_name='us-east-1')
ec2_client.terminate_instances(InstanceIds=[original_instance_id])
response = client.describe_auto_scaling_groups(AutoScalingGroupNames=['test_asg'])
replaced_instance_id = next(
instance['InstanceId']
for instance in response['AutoScalingGroups'][0]['Instances']
)
replaced_instance_id.should_not.equal(original_instance_id)

View File

@ -31,3 +31,18 @@ def setup_networking_deprecated():
"10.11.2.0/24", "10.11.2.0/24",
availability_zone='us-east-1b') availability_zone='us-east-1b')
return {'vpc': vpc.id, 'subnet1': subnet1.id, 'subnet2': subnet2.id} return {'vpc': vpc.id, 'subnet1': subnet1.id, 'subnet2': subnet2.id}
@mock_ec2
def setup_instance_with_networking(image_id, instance_type):
mock_data = setup_networking()
ec2 = boto3.resource('ec2', region_name='us-east-1')
instances = ec2.create_instances(
ImageId=image_id,
InstanceType=instance_type,
MaxCount=1,
MinCount=1,
SubnetId=mock_data['subnet1']
)
mock_data['instance'] = instances[0].id
return mock_data

View File

@ -642,6 +642,87 @@ def test_describe_task_definition():
len(resp['jobDefinitions']).should.equal(3) len(resp['jobDefinitions']).should.equal(3)
@mock_logs
@mock_ec2
@mock_ecs
@mock_iam
@mock_batch
def test_submit_job_by_name():
ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients()
vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client)
compute_name = 'test_compute_env'
resp = batch_client.create_compute_environment(
computeEnvironmentName=compute_name,
type='UNMANAGED',
state='ENABLED',
serviceRole=iam_arn
)
arn = resp['computeEnvironmentArn']
resp = batch_client.create_job_queue(
jobQueueName='test_job_queue',
state='ENABLED',
priority=123,
computeEnvironmentOrder=[
{
'order': 123,
'computeEnvironment': arn
},
]
)
queue_arn = resp['jobQueueArn']
job_definition_name = 'sleep10'
batch_client.register_job_definition(
jobDefinitionName=job_definition_name,
type='container',
containerProperties={
'image': 'busybox',
'vcpus': 1,
'memory': 128,
'command': ['sleep', '10']
}
)
batch_client.register_job_definition(
jobDefinitionName=job_definition_name,
type='container',
containerProperties={
'image': 'busybox',
'vcpus': 1,
'memory': 256,
'command': ['sleep', '10']
}
)
resp = batch_client.register_job_definition(
jobDefinitionName=job_definition_name,
type='container',
containerProperties={
'image': 'busybox',
'vcpus': 1,
'memory': 512,
'command': ['sleep', '10']
}
)
job_definition_arn = resp['jobDefinitionArn']
resp = batch_client.submit_job(
jobName='test1',
jobQueue=queue_arn,
jobDefinition=job_definition_name
)
job_id = resp['jobId']
resp_jobs = batch_client.describe_jobs(jobs=[job_id])
# batch_client.terminate_job(jobId=job_id)
len(resp_jobs['jobs']).should.equal(1)
resp_jobs['jobs'][0]['jobId'].should.equal(job_id)
resp_jobs['jobs'][0]['jobQueue'].should.equal(queue_arn)
resp_jobs['jobs'][0]['jobDefinition'].should.equal(job_definition_arn)
# SLOW TESTS # SLOW TESTS
@expected_failure @expected_failure
@mock_logs @mock_logs

View File

@ -68,7 +68,7 @@ def test_get_open_id_token_for_developer_identity():
}, },
TokenDuration=123 TokenDuration=123
) )
assert len(result['Token']) assert len(result['Token']) > 0
assert result['IdentityId'] == '12345' assert result['IdentityId'] == '12345'
@mock_cognitoidentity @mock_cognitoidentity
@ -83,3 +83,15 @@ def test_get_open_id_token_for_developer_identity_when_no_explicit_identity_id()
) )
assert len(result['Token']) > 0 assert len(result['Token']) > 0
assert len(result['IdentityId']) > 0 assert len(result['IdentityId']) > 0
@mock_cognitoidentity
def test_get_open_id_token():
conn = boto3.client('cognito-identity', 'us-west-2')
result = conn.get_open_id_token(
IdentityId='12345',
Logins={
'someurl': '12345'
}
)
assert len(result['Token']) > 0
assert result['IdentityId'] == '12345'

View File

@ -133,6 +133,22 @@ def test_create_user_pool_domain():
result["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) result["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
@mock_cognitoidp
def test_create_user_pool_domain_custom_domain_config():
conn = boto3.client("cognito-idp", "us-west-2")
domain = str(uuid.uuid4())
custom_domain_config = {
"CertificateArn": "arn:aws:acm:us-east-1:123456789012:certificate/123456789012",
}
user_pool_id = conn.create_user_pool(PoolName=str(uuid.uuid4()))["UserPool"]["Id"]
result = conn.create_user_pool_domain(
UserPoolId=user_pool_id, Domain=domain, CustomDomainConfig=custom_domain_config
)
result["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
result["CloudFrontDomain"].should.equal("e2c343b3293ee505.cloudfront.net")
@mock_cognitoidp @mock_cognitoidp
def test_describe_user_pool_domain(): def test_describe_user_pool_domain():
conn = boto3.client("cognito-idp", "us-west-2") conn = boto3.client("cognito-idp", "us-west-2")
@ -162,6 +178,23 @@ def test_delete_user_pool_domain():
result["DomainDescription"].keys().should.have.length_of(0) result["DomainDescription"].keys().should.have.length_of(0)
@mock_cognitoidp
def test_update_user_pool_domain():
conn = boto3.client("cognito-idp", "us-west-2")
domain = str(uuid.uuid4())
custom_domain_config = {
"CertificateArn": "arn:aws:acm:us-east-1:123456789012:certificate/123456789012",
}
user_pool_id = conn.create_user_pool(PoolName=str(uuid.uuid4()))["UserPool"]["Id"]
conn.create_user_pool_domain(UserPoolId=user_pool_id, Domain=domain)
result = conn.update_user_pool_domain(
UserPoolId=user_pool_id, Domain=domain, CustomDomainConfig=custom_domain_config
)
result["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
result["CloudFrontDomain"].should.equal("e2c343b3293ee505.cloudfront.net")
@mock_cognitoidp @mock_cognitoidp
def test_create_user_pool_client(): def test_create_user_pool_client():
conn = boto3.client("cognito-idp", "us-west-2") conn = boto3.client("cognito-idp", "us-west-2")

View File

@ -123,6 +123,526 @@ def test_put_configuration_recorder():
assert "maximum number of configuration recorders: 1 is reached." in ce.exception.response['Error']['Message'] assert "maximum number of configuration recorders: 1 is reached." in ce.exception.response['Error']['Message']
@mock_config
def test_put_configuration_aggregator():
client = boto3.client('config', region_name='us-west-2')
# With too many aggregation sources:
with assert_raises(ClientError) as ce:
client.put_configuration_aggregator(
ConfigurationAggregatorName='testing',
AccountAggregationSources=[
{
'AccountIds': [
'012345678910',
'111111111111',
'222222222222'
],
'AwsRegions': [
'us-east-1',
'us-west-2'
]
},
{
'AccountIds': [
'012345678910',
'111111111111',
'222222222222'
],
'AwsRegions': [
'us-east-1',
'us-west-2'
]
}
]
)
assert 'Member must have length less than or equal to 1' in ce.exception.response['Error']['Message']
assert ce.exception.response['Error']['Code'] == 'ValidationException'
# With an invalid region config (no regions defined):
with assert_raises(ClientError) as ce:
client.put_configuration_aggregator(
ConfigurationAggregatorName='testing',
AccountAggregationSources=[
{
'AccountIds': [
'012345678910',
'111111111111',
'222222222222'
],
'AllAwsRegions': False
}
]
)
assert 'Your request does not specify any regions' in ce.exception.response['Error']['Message']
assert ce.exception.response['Error']['Code'] == 'InvalidParameterValueException'
with assert_raises(ClientError) as ce:
client.put_configuration_aggregator(
ConfigurationAggregatorName='testing',
OrganizationAggregationSource={
'RoleArn': 'arn:aws:iam::012345678910:role/SomeRole'
}
)
assert 'Your request does not specify any regions' in ce.exception.response['Error']['Message']
assert ce.exception.response['Error']['Code'] == 'InvalidParameterValueException'
# With both region flags defined:
with assert_raises(ClientError) as ce:
client.put_configuration_aggregator(
ConfigurationAggregatorName='testing',
AccountAggregationSources=[
{
'AccountIds': [
'012345678910',
'111111111111',
'222222222222'
],
'AwsRegions': [
'us-east-1',
'us-west-2'
],
'AllAwsRegions': True
}
]
)
assert 'You must choose one of these options' in ce.exception.response['Error']['Message']
assert ce.exception.response['Error']['Code'] == 'InvalidParameterValueException'
with assert_raises(ClientError) as ce:
client.put_configuration_aggregator(
ConfigurationAggregatorName='testing',
OrganizationAggregationSource={
'RoleArn': 'arn:aws:iam::012345678910:role/SomeRole',
'AwsRegions': [
'us-east-1',
'us-west-2'
],
'AllAwsRegions': True
}
)
assert 'You must choose one of these options' in ce.exception.response['Error']['Message']
assert ce.exception.response['Error']['Code'] == 'InvalidParameterValueException'
# Name too long:
with assert_raises(ClientError) as ce:
client.put_configuration_aggregator(
ConfigurationAggregatorName='a' * 257,
AccountAggregationSources=[
{
'AccountIds': [
'012345678910',
],
'AllAwsRegions': True
}
]
)
assert 'configurationAggregatorName' in ce.exception.response['Error']['Message']
assert ce.exception.response['Error']['Code'] == 'ValidationException'
# Too many tags (>50):
with assert_raises(ClientError) as ce:
client.put_configuration_aggregator(
ConfigurationAggregatorName='testing',
AccountAggregationSources=[
{
'AccountIds': [
'012345678910',
],
'AllAwsRegions': True
}
],
Tags=[{'Key': '{}'.format(x), 'Value': '{}'.format(x)} for x in range(0, 51)]
)
assert 'Member must have length less than or equal to 50' in ce.exception.response['Error']['Message']
assert ce.exception.response['Error']['Code'] == 'ValidationException'
# Tag key is too big (>128 chars):
with assert_raises(ClientError) as ce:
client.put_configuration_aggregator(
ConfigurationAggregatorName='testing',
AccountAggregationSources=[
{
'AccountIds': [
'012345678910',
],
'AllAwsRegions': True
}
],
Tags=[{'Key': 'a' * 129, 'Value': 'a'}]
)
assert 'Member must have length less than or equal to 128' in ce.exception.response['Error']['Message']
assert ce.exception.response['Error']['Code'] == 'ValidationException'
# Tag value is too big (>256 chars):
with assert_raises(ClientError) as ce:
client.put_configuration_aggregator(
ConfigurationAggregatorName='testing',
AccountAggregationSources=[
{
'AccountIds': [
'012345678910',
],
'AllAwsRegions': True
}
],
Tags=[{'Key': 'tag', 'Value': 'a' * 257}]
)
assert 'Member must have length less than or equal to 256' in ce.exception.response['Error']['Message']
assert ce.exception.response['Error']['Code'] == 'ValidationException'
# Duplicate Tags:
with assert_raises(ClientError) as ce:
client.put_configuration_aggregator(
ConfigurationAggregatorName='testing',
AccountAggregationSources=[
{
'AccountIds': [
'012345678910',
],
'AllAwsRegions': True
}
],
Tags=[{'Key': 'a', 'Value': 'a'}, {'Key': 'a', 'Value': 'a'}]
)
assert 'Duplicate tag keys found.' in ce.exception.response['Error']['Message']
assert ce.exception.response['Error']['Code'] == 'InvalidInput'
# Invalid characters in the tag key:
with assert_raises(ClientError) as ce:
client.put_configuration_aggregator(
ConfigurationAggregatorName='testing',
AccountAggregationSources=[
{
'AccountIds': [
'012345678910',
],
'AllAwsRegions': True
}
],
Tags=[{'Key': '!', 'Value': 'a'}]
)
assert 'Member must satisfy regular expression pattern:' in ce.exception.response['Error']['Message']
assert ce.exception.response['Error']['Code'] == 'ValidationException'
# If it contains both the AccountAggregationSources and the OrganizationAggregationSource
with assert_raises(ClientError) as ce:
client.put_configuration_aggregator(
ConfigurationAggregatorName='testing',
AccountAggregationSources=[
{
'AccountIds': [
'012345678910',
],
'AllAwsRegions': False
}
],
OrganizationAggregationSource={
'RoleArn': 'arn:aws:iam::012345678910:role/SomeRole',
'AllAwsRegions': False
}
)
assert 'AccountAggregationSource and the OrganizationAggregationSource' in ce.exception.response['Error']['Message']
assert ce.exception.response['Error']['Code'] == 'InvalidParameterValueException'
# If it contains neither:
with assert_raises(ClientError) as ce:
client.put_configuration_aggregator(
ConfigurationAggregatorName='testing',
)
assert 'AccountAggregationSource or the OrganizationAggregationSource' in ce.exception.response['Error']['Message']
assert ce.exception.response['Error']['Code'] == 'InvalidParameterValueException'
# Just make one:
account_aggregation_source = {
'AccountIds': [
'012345678910',
'111111111111',
'222222222222'
],
'AwsRegions': [
'us-east-1',
'us-west-2'
],
'AllAwsRegions': False
}
result = client.put_configuration_aggregator(
ConfigurationAggregatorName='testing',
AccountAggregationSources=[account_aggregation_source],
)
assert result['ConfigurationAggregator']['ConfigurationAggregatorName'] == 'testing'
assert result['ConfigurationAggregator']['AccountAggregationSources'] == [account_aggregation_source]
assert 'arn:aws:config:us-west-2:123456789012:config-aggregator/config-aggregator-' in \
result['ConfigurationAggregator']['ConfigurationAggregatorArn']
assert result['ConfigurationAggregator']['CreationTime'] == result['ConfigurationAggregator']['LastUpdatedTime']
# Update the existing one:
original_arn = result['ConfigurationAggregator']['ConfigurationAggregatorArn']
account_aggregation_source.pop('AwsRegions')
account_aggregation_source['AllAwsRegions'] = True
result = client.put_configuration_aggregator(
ConfigurationAggregatorName='testing',
AccountAggregationSources=[account_aggregation_source]
)
assert result['ConfigurationAggregator']['ConfigurationAggregatorName'] == 'testing'
assert result['ConfigurationAggregator']['AccountAggregationSources'] == [account_aggregation_source]
assert result['ConfigurationAggregator']['ConfigurationAggregatorArn'] == original_arn
# Make an org one:
result = client.put_configuration_aggregator(
ConfigurationAggregatorName='testingOrg',
OrganizationAggregationSource={
'RoleArn': 'arn:aws:iam::012345678910:role/SomeRole',
'AwsRegions': ['us-east-1', 'us-west-2']
}
)
assert result['ConfigurationAggregator']['ConfigurationAggregatorName'] == 'testingOrg'
assert result['ConfigurationAggregator']['OrganizationAggregationSource'] == {
'RoleArn': 'arn:aws:iam::012345678910:role/SomeRole',
'AwsRegions': [
'us-east-1',
'us-west-2'
],
'AllAwsRegions': False
}
@mock_config
def test_describe_configuration_aggregators():
client = boto3.client('config', region_name='us-west-2')
# Without any config aggregators:
assert not client.describe_configuration_aggregators()['ConfigurationAggregators']
# Make 10 config aggregators:
for x in range(0, 10):
client.put_configuration_aggregator(
ConfigurationAggregatorName='testing{}'.format(x),
AccountAggregationSources=[
{
'AccountIds': [
'012345678910',
],
'AllAwsRegions': True
}
]
)
# Describe with an incorrect name:
with assert_raises(ClientError) as ce:
client.describe_configuration_aggregators(ConfigurationAggregatorNames=['DoesNotExist'])
assert 'The configuration aggregator does not exist.' in ce.exception.response['Error']['Message']
assert ce.exception.response['Error']['Code'] == 'NoSuchConfigurationAggregatorException'
# Error describe with more than 1 item in the list:
with assert_raises(ClientError) as ce:
client.describe_configuration_aggregators(ConfigurationAggregatorNames=['testing0', 'DoesNotExist'])
assert 'At least one of the configuration aggregators does not exist.' in ce.exception.response['Error']['Message']
assert ce.exception.response['Error']['Code'] == 'NoSuchConfigurationAggregatorException'
# Get the normal list:
result = client.describe_configuration_aggregators()
assert not result.get('NextToken')
assert len(result['ConfigurationAggregators']) == 10
# Test filtered list:
agg_names = ['testing0', 'testing1', 'testing2']
result = client.describe_configuration_aggregators(ConfigurationAggregatorNames=agg_names)
assert not result.get('NextToken')
assert len(result['ConfigurationAggregators']) == 3
assert [agg['ConfigurationAggregatorName'] for agg in result['ConfigurationAggregators']] == agg_names
# Test Pagination:
result = client.describe_configuration_aggregators(Limit=4)
assert len(result['ConfigurationAggregators']) == 4
assert result['NextToken'] == 'testing4'
assert [agg['ConfigurationAggregatorName'] for agg in result['ConfigurationAggregators']] == \
['testing{}'.format(x) for x in range(0, 4)]
result = client.describe_configuration_aggregators(Limit=4, NextToken='testing4')
assert len(result['ConfigurationAggregators']) == 4
assert result['NextToken'] == 'testing8'
assert [agg['ConfigurationAggregatorName'] for agg in result['ConfigurationAggregators']] == \
['testing{}'.format(x) for x in range(4, 8)]
result = client.describe_configuration_aggregators(Limit=4, NextToken='testing8')
assert len(result['ConfigurationAggregators']) == 2
assert not result.get('NextToken')
assert [agg['ConfigurationAggregatorName'] for agg in result['ConfigurationAggregators']] == \
['testing{}'.format(x) for x in range(8, 10)]
# Test Pagination with Filtering:
result = client.describe_configuration_aggregators(ConfigurationAggregatorNames=['testing2', 'testing4'], Limit=1)
assert len(result['ConfigurationAggregators']) == 1
assert result['NextToken'] == 'testing4'
assert result['ConfigurationAggregators'][0]['ConfigurationAggregatorName'] == 'testing2'
result = client.describe_configuration_aggregators(ConfigurationAggregatorNames=['testing2', 'testing4'], Limit=1, NextToken='testing4')
assert not result.get('NextToken')
assert result['ConfigurationAggregators'][0]['ConfigurationAggregatorName'] == 'testing4'
# Test with an invalid filter:
with assert_raises(ClientError) as ce:
client.describe_configuration_aggregators(NextToken='WRONG')
assert 'The nextToken provided is invalid' == ce.exception.response['Error']['Message']
assert ce.exception.response['Error']['Code'] == 'InvalidNextTokenException'
@mock_config
def test_put_aggregation_authorization():
client = boto3.client('config', region_name='us-west-2')
# Too many tags (>50):
with assert_raises(ClientError) as ce:
client.put_aggregation_authorization(
AuthorizedAccountId='012345678910',
AuthorizedAwsRegion='us-west-2',
Tags=[{'Key': '{}'.format(x), 'Value': '{}'.format(x)} for x in range(0, 51)]
)
assert 'Member must have length less than or equal to 50' in ce.exception.response['Error']['Message']
assert ce.exception.response['Error']['Code'] == 'ValidationException'
# Tag key is too big (>128 chars):
with assert_raises(ClientError) as ce:
client.put_aggregation_authorization(
AuthorizedAccountId='012345678910',
AuthorizedAwsRegion='us-west-2',
Tags=[{'Key': 'a' * 129, 'Value': 'a'}]
)
assert 'Member must have length less than or equal to 128' in ce.exception.response['Error']['Message']
assert ce.exception.response['Error']['Code'] == 'ValidationException'
# Tag value is too big (>256 chars):
with assert_raises(ClientError) as ce:
client.put_aggregation_authorization(
AuthorizedAccountId='012345678910',
AuthorizedAwsRegion='us-west-2',
Tags=[{'Key': 'tag', 'Value': 'a' * 257}]
)
assert 'Member must have length less than or equal to 256' in ce.exception.response['Error']['Message']
assert ce.exception.response['Error']['Code'] == 'ValidationException'
# Duplicate Tags:
with assert_raises(ClientError) as ce:
client.put_aggregation_authorization(
AuthorizedAccountId='012345678910',
AuthorizedAwsRegion='us-west-2',
Tags=[{'Key': 'a', 'Value': 'a'}, {'Key': 'a', 'Value': 'a'}]
)
assert 'Duplicate tag keys found.' in ce.exception.response['Error']['Message']
assert ce.exception.response['Error']['Code'] == 'InvalidInput'
# Invalid characters in the tag key:
with assert_raises(ClientError) as ce:
client.put_aggregation_authorization(
AuthorizedAccountId='012345678910',
AuthorizedAwsRegion='us-west-2',
Tags=[{'Key': '!', 'Value': 'a'}]
)
assert 'Member must satisfy regular expression pattern:' in ce.exception.response['Error']['Message']
assert ce.exception.response['Error']['Code'] == 'ValidationException'
# Put a normal one there:
result = client.put_aggregation_authorization(AuthorizedAccountId='012345678910', AuthorizedAwsRegion='us-east-1',
Tags=[{'Key': 'tag', 'Value': 'a'}])
assert result['AggregationAuthorization']['AggregationAuthorizationArn'] == 'arn:aws:config:us-west-2:123456789012:' \
'aggregation-authorization/012345678910/us-east-1'
assert result['AggregationAuthorization']['AuthorizedAccountId'] == '012345678910'
assert result['AggregationAuthorization']['AuthorizedAwsRegion'] == 'us-east-1'
assert isinstance(result['AggregationAuthorization']['CreationTime'], datetime)
creation_date = result['AggregationAuthorization']['CreationTime']
# And again:
result = client.put_aggregation_authorization(AuthorizedAccountId='012345678910', AuthorizedAwsRegion='us-east-1')
assert result['AggregationAuthorization']['AggregationAuthorizationArn'] == 'arn:aws:config:us-west-2:123456789012:' \
'aggregation-authorization/012345678910/us-east-1'
assert result['AggregationAuthorization']['AuthorizedAccountId'] == '012345678910'
assert result['AggregationAuthorization']['AuthorizedAwsRegion'] == 'us-east-1'
assert result['AggregationAuthorization']['CreationTime'] == creation_date
@mock_config
def test_describe_aggregation_authorizations():
client = boto3.client('config', region_name='us-west-2')
# With no aggregation authorizations:
assert not client.describe_aggregation_authorizations()['AggregationAuthorizations']
# Make 10 account authorizations:
for i in range(0, 10):
client.put_aggregation_authorization(AuthorizedAccountId='{}'.format(str(i) * 12), AuthorizedAwsRegion='us-west-2')
result = client.describe_aggregation_authorizations()
assert len(result['AggregationAuthorizations']) == 10
assert not result.get('NextToken')
for i in range(0, 10):
assert result['AggregationAuthorizations'][i]['AuthorizedAccountId'] == str(i) * 12
# Test Pagination:
result = client.describe_aggregation_authorizations(Limit=4)
assert len(result['AggregationAuthorizations']) == 4
assert result['NextToken'] == ('4' * 12) + '/us-west-2'
assert [auth['AuthorizedAccountId'] for auth in result['AggregationAuthorizations']] == ['{}'.format(str(x) * 12) for x in range(0, 4)]
result = client.describe_aggregation_authorizations(Limit=4, NextToken=('4' * 12) + '/us-west-2')
assert len(result['AggregationAuthorizations']) == 4
assert result['NextToken'] == ('8' * 12) + '/us-west-2'
assert [auth['AuthorizedAccountId'] for auth in result['AggregationAuthorizations']] == ['{}'.format(str(x) * 12) for x in range(4, 8)]
result = client.describe_aggregation_authorizations(Limit=4, NextToken=('8' * 12) + '/us-west-2')
assert len(result['AggregationAuthorizations']) == 2
assert not result.get('NextToken')
assert [auth['AuthorizedAccountId'] for auth in result['AggregationAuthorizations']] == ['{}'.format(str(x) * 12) for x in range(8, 10)]
# Test with an invalid filter:
with assert_raises(ClientError) as ce:
client.describe_aggregation_authorizations(NextToken='WRONG')
assert 'The nextToken provided is invalid' == ce.exception.response['Error']['Message']
assert ce.exception.response['Error']['Code'] == 'InvalidNextTokenException'
@mock_config
def test_delete_aggregation_authorization():
client = boto3.client('config', region_name='us-west-2')
client.put_aggregation_authorization(AuthorizedAccountId='012345678910', AuthorizedAwsRegion='us-west-2')
# Delete it:
client.delete_aggregation_authorization(AuthorizedAccountId='012345678910', AuthorizedAwsRegion='us-west-2')
# Verify that none are there:
assert not client.describe_aggregation_authorizations()['AggregationAuthorizations']
# Try it again -- nothing should happen:
client.delete_aggregation_authorization(AuthorizedAccountId='012345678910', AuthorizedAwsRegion='us-west-2')
@mock_config
def test_delete_configuration_aggregator():
client = boto3.client('config', region_name='us-west-2')
client.put_configuration_aggregator(
ConfigurationAggregatorName='testing',
AccountAggregationSources=[
{
'AccountIds': [
'012345678910',
],
'AllAwsRegions': True
}
]
)
client.delete_configuration_aggregator(ConfigurationAggregatorName='testing')
# And again to confirm that it's deleted:
with assert_raises(ClientError) as ce:
client.delete_configuration_aggregator(ConfigurationAggregatorName='testing')
assert 'The configuration aggregator does not exist.' in ce.exception.response['Error']['Message']
assert ce.exception.response['Error']['Code'] == 'NoSuchConfigurationAggregatorException'
@mock_config @mock_config
def test_describe_configurations(): def test_describe_configurations():
client = boto3.client('config', region_name='us-west-2') client = boto3.client('config', region_name='us-west-2')

View File

@ -0,0 +1,12 @@
import sure # noqa
import boto3
from moto import mock_sqs, settings
def test_context_manager_returns_mock():
with mock_sqs() as sqs_mock:
conn = boto3.client("sqs", region_name='us-west-1')
conn.create_queue(QueueName="queue1")
if not settings.TEST_SERVER_MODE:
list(sqs_mock.backends['us-west-1'].queues.keys()).should.equal(['queue1'])

View File

@ -38,12 +38,6 @@ def test_domain_dispatched():
keys[0].should.equal('EmailResponse.dispatch') keys[0].should.equal('EmailResponse.dispatch')
def test_domain_without_matches():
dispatcher = DomainDispatcherApplication(create_backend_app)
dispatcher.get_application.when.called_with(
{"HTTP_HOST": "not-matching-anything.com"}).should.throw(RuntimeError)
def test_domain_dispatched_with_service(): def test_domain_dispatched_with_service():
# If we pass a particular service, always return that. # If we pass a particular service, always return that.
dispatcher = DomainDispatcherApplication(create_backend_app, service="s3") dispatcher = DomainDispatcherApplication(create_backend_app, service="s3")

View File

@ -1342,6 +1342,46 @@ def test_query_missing_expr_names():
resp['Items'][0]['client']['S'].should.equal('test2') resp['Items'][0]['client']['S'].should.equal('test2')
# https://github.com/spulec/moto/issues/2328
@mock_dynamodb2
def test_update_item_with_list():
dynamodb = boto3.resource('dynamodb', region_name='us-east-1')
# Create the DynamoDB table.
dynamodb.create_table(
TableName='Table',
KeySchema=[
{
'AttributeName': 'key',
'KeyType': 'HASH'
}
],
AttributeDefinitions=[
{
'AttributeName': 'key',
'AttributeType': 'S'
},
],
ProvisionedThroughput={
'ReadCapacityUnits': 1,
'WriteCapacityUnits': 1
}
)
table = dynamodb.Table('Table')
table.update_item(
Key={'key': 'the-key'},
AttributeUpdates={
'list': {'Value': [1, 2], 'Action': 'PUT'}
}
)
resp = table.get_item(Key={'key': 'the-key'})
resp['Item'].should.equal({
'key': 'the-key',
'list': [1, 2]
})
# https://github.com/spulec/moto/issues/1342 # https://github.com/spulec/moto/issues/1342
@mock_dynamodb2 @mock_dynamodb2
def test_update_item_on_map(): def test_update_item_on_map():
@ -1964,6 +2004,36 @@ def test_condition_expression__attr_doesnt_exist():
update_if_attr_doesnt_exist() update_if_attr_doesnt_exist()
@mock_dynamodb2
def test_condition_expression__or_order():
client = boto3.client('dynamodb', region_name='us-east-1')
client.create_table(
TableName='test',
KeySchema=[{'AttributeName': 'forum_name', 'KeyType': 'HASH'}],
AttributeDefinitions=[
{'AttributeName': 'forum_name', 'AttributeType': 'S'},
],
ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1},
)
# ensure that the RHS of the OR expression is not evaluated if the LHS
# returns true (as it would result an error)
client.update_item(
TableName='test',
Key={
'forum_name': {'S': 'the-key'},
},
UpdateExpression='set #ttl=:ttl',
ConditionExpression='attribute_not_exists(#ttl) OR #ttl <= :old_ttl',
ExpressionAttributeNames={'#ttl': 'ttl'},
ExpressionAttributeValues={
':ttl': {'N': '6'},
':old_ttl': {'N': '5'},
}
)
@mock_dynamodb2 @mock_dynamodb2
def test_query_gsi_with_range_key(): def test_query_gsi_with_range_key():
dynamodb = boto3.client('dynamodb', region_name='us-east-1') dynamodb = boto3.client('dynamodb', region_name='us-east-1')

View File

@ -10,7 +10,7 @@ from nose.tools import assert_raises
import sure # noqa import sure # noqa
from moto import mock_ec2_deprecated, mock_ec2 from moto import mock_ec2_deprecated, mock_ec2
from moto.ec2.models import AMIS from moto.ec2.models import AMIS, OWNER_ID
from tests.helpers import requires_boto_gte from tests.helpers import requires_boto_gte
@ -152,6 +152,29 @@ def test_ami_copy():
cm.exception.request_id.should_not.be.none cm.exception.request_id.should_not.be.none
@mock_ec2
def test_copy_image_changes_owner_id():
conn = boto3.client('ec2', region_name='us-east-1')
# this source AMI ID is from moto/ec2/resources/amis.json
source_ami_id = "ami-03cf127a"
# confirm the source ami owner id is different from the default owner id.
# if they're ever the same it means this test is invalid.
check_resp = conn.describe_images(ImageIds=[source_ami_id])
check_resp["Images"][0]["OwnerId"].should_not.equal(OWNER_ID)
copy_resp = conn.copy_image(
SourceImageId=source_ami_id,
Name="new-image",
Description="a copy of an image",
SourceRegion="us-east-1")
describe_resp = conn.describe_images(Owners=["self"])
describe_resp["Images"][0]["OwnerId"].should.equal(OWNER_ID)
describe_resp["Images"][0]["ImageId"].should.equal(copy_resp["ImageId"])
@mock_ec2_deprecated @mock_ec2_deprecated
def test_ami_tagging(): def test_ami_tagging():
conn = boto.connect_vpc('the_key', 'the_secret') conn = boto.connect_vpc('the_key', 'the_secret')

View File

@ -12,6 +12,7 @@ from freezegun import freeze_time
import sure # noqa import sure # noqa
from moto import mock_ec2_deprecated, mock_ec2 from moto import mock_ec2_deprecated, mock_ec2
from moto.ec2.models import OWNER_ID
@mock_ec2_deprecated @mock_ec2_deprecated
@ -395,7 +396,7 @@ def test_snapshot_filters():
).should.equal({snapshot3.id}) ).should.equal({snapshot3.id})
snapshots_by_owner_id = conn.get_all_snapshots( snapshots_by_owner_id = conn.get_all_snapshots(
filters={'owner-id': '123456789012'}) filters={'owner-id': OWNER_ID})
set([snap.id for snap in snapshots_by_owner_id] set([snap.id for snap in snapshots_by_owner_id]
).should.equal({snapshot1.id, snapshot2.id, snapshot3.id}) ).should.equal({snapshot1.id, snapshot2.id, snapshot3.id})

View File

@ -161,7 +161,7 @@ def test_elastic_network_interfaces_filtering():
subnet.id, groups=[security_group1.id, security_group2.id]) subnet.id, groups=[security_group1.id, security_group2.id])
eni2 = conn.create_network_interface( eni2 = conn.create_network_interface(
subnet.id, groups=[security_group1.id]) subnet.id, groups=[security_group1.id])
eni3 = conn.create_network_interface(subnet.id) eni3 = conn.create_network_interface(subnet.id, description='test description')
all_enis = conn.get_all_network_interfaces() all_enis = conn.get_all_network_interfaces()
all_enis.should.have.length_of(3) all_enis.should.have.length_of(3)
@ -189,6 +189,12 @@ def test_elastic_network_interfaces_filtering():
enis_by_group.should.have.length_of(1) enis_by_group.should.have.length_of(1)
set([eni.id for eni in enis_by_group]).should.equal(set([eni1.id])) set([eni.id for eni in enis_by_group]).should.equal(set([eni1.id]))
# Filter by Description
enis_by_description = conn.get_all_network_interfaces(
filters={'description': eni3.description })
enis_by_description.should.have.length_of(1)
enis_by_description[0].description.should.equal(eni3.description)
# Unsupported filter # Unsupported filter
conn.get_all_network_interfaces.when.called_with( conn.get_all_network_interfaces.when.called_with(
filters={'not-implemented-filter': 'foobar'}).should.throw(NotImplementedError) filters={'not-implemented-filter': 'foobar'}).should.throw(NotImplementedError)
@ -343,6 +349,106 @@ def test_elastic_network_interfaces_get_by_subnet_id():
enis.should.have.length_of(0) enis.should.have.length_of(0)
@mock_ec2
def test_elastic_network_interfaces_get_by_description():
ec2 = boto3.resource('ec2', region_name='us-west-2')
ec2_client = boto3.client('ec2', region_name='us-west-2')
vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16')
subnet = ec2.create_subnet(
VpcId=vpc.id, CidrBlock='10.0.0.0/24', AvailabilityZone='us-west-2a')
eni1 = ec2.create_network_interface(
SubnetId=subnet.id, PrivateIpAddress='10.0.10.5', Description='test interface')
# The status of the new interface should be 'available'
waiter = ec2_client.get_waiter('network_interface_available')
waiter.wait(NetworkInterfaceIds=[eni1.id])
filters = [{'Name': 'description', 'Values': [eni1.description]}]
enis = list(ec2.network_interfaces.filter(Filters=filters))
enis.should.have.length_of(1)
filters = [{'Name': 'description', 'Values': ['bad description']}]
enis = list(ec2.network_interfaces.filter(Filters=filters))
enis.should.have.length_of(0)
@mock_ec2
def test_elastic_network_interfaces_describe_network_interfaces_with_filter():
ec2 = boto3.resource('ec2', region_name='us-west-2')
ec2_client = boto3.client('ec2', region_name='us-west-2')
vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16')
subnet = ec2.create_subnet(
VpcId=vpc.id, CidrBlock='10.0.0.0/24', AvailabilityZone='us-west-2a')
eni1 = ec2.create_network_interface(
SubnetId=subnet.id, PrivateIpAddress='10.0.10.5', Description='test interface')
# The status of the new interface should be 'available'
waiter = ec2_client.get_waiter('network_interface_available')
waiter.wait(NetworkInterfaceIds=[eni1.id])
# Filter by network-interface-id
response = ec2_client.describe_network_interfaces(
Filters=[{'Name': 'network-interface-id', 'Values': [eni1.id]}])
response['NetworkInterfaces'].should.have.length_of(1)
response['NetworkInterfaces'][0]['NetworkInterfaceId'].should.equal(eni1.id)
response['NetworkInterfaces'][0]['PrivateIpAddress'].should.equal(eni1.private_ip_address)
response['NetworkInterfaces'][0]['Description'].should.equal(eni1.description)
response = ec2_client.describe_network_interfaces(
Filters=[{'Name': 'network-interface-id', 'Values': ['bad-id']}])
response['NetworkInterfaces'].should.have.length_of(0)
# Filter by private-ip-address
response = ec2_client.describe_network_interfaces(
Filters=[{'Name': 'private-ip-address', 'Values': [eni1.private_ip_address]}])
response['NetworkInterfaces'].should.have.length_of(1)
response['NetworkInterfaces'][0]['NetworkInterfaceId'].should.equal(eni1.id)
response['NetworkInterfaces'][0]['PrivateIpAddress'].should.equal(eni1.private_ip_address)
response['NetworkInterfaces'][0]['Description'].should.equal(eni1.description)
response = ec2_client.describe_network_interfaces(
Filters=[{'Name': 'private-ip-address', 'Values': ['11.11.11.11']}])
response['NetworkInterfaces'].should.have.length_of(0)
# Filter by sunet-id
response = ec2_client.describe_network_interfaces(
Filters=[{'Name': 'subnet-id', 'Values': [eni1.subnet.id]}])
response['NetworkInterfaces'].should.have.length_of(1)
response['NetworkInterfaces'][0]['NetworkInterfaceId'].should.equal(eni1.id)
response['NetworkInterfaces'][0]['PrivateIpAddress'].should.equal(eni1.private_ip_address)
response['NetworkInterfaces'][0]['Description'].should.equal(eni1.description)
response = ec2_client.describe_network_interfaces(
Filters=[{'Name': 'subnet-id', 'Values': ['sn-bad-id']}])
response['NetworkInterfaces'].should.have.length_of(0)
# Filter by description
response = ec2_client.describe_network_interfaces(
Filters=[{'Name': 'description', 'Values': [eni1.description]}])
response['NetworkInterfaces'].should.have.length_of(1)
response['NetworkInterfaces'][0]['NetworkInterfaceId'].should.equal(eni1.id)
response['NetworkInterfaces'][0]['PrivateIpAddress'].should.equal(eni1.private_ip_address)
response['NetworkInterfaces'][0]['Description'].should.equal(eni1.description)
response = ec2_client.describe_network_interfaces(
Filters=[{'Name': 'description', 'Values': ['bad description']}])
response['NetworkInterfaces'].should.have.length_of(0)
# Filter by multiple filters
response = ec2_client.describe_network_interfaces(
Filters=[{'Name': 'private-ip-address', 'Values': [eni1.private_ip_address]},
{'Name': 'network-interface-id', 'Values': [eni1.id]},
{'Name': 'subnet-id', 'Values': [eni1.subnet.id]}])
response['NetworkInterfaces'].should.have.length_of(1)
response['NetworkInterfaces'][0]['NetworkInterfaceId'].should.equal(eni1.id)
response['NetworkInterfaces'][0]['PrivateIpAddress'].should.equal(eni1.private_ip_address)
response['NetworkInterfaces'][0]['Description'].should.equal(eni1.description)
@mock_ec2_deprecated @mock_ec2_deprecated
@mock_cloudformation_deprecated @mock_cloudformation_deprecated
def test_elastic_network_interfaces_cloudformation(): def test_elastic_network_interfaces_cloudformation():

View File

@ -1,4 +1,5 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from datetime import datetime
from copy import deepcopy from copy import deepcopy
@ -94,6 +95,10 @@ def test_register_task_definition():
}], }],
'logConfiguration': {'logDriver': 'json-file'} 'logConfiguration': {'logDriver': 'json-file'}
} }
],
tags=[
{'key': 'createdBy', 'value': 'moto-unittest'},
{'key': 'foo', 'value': 'bar'},
] ]
) )
type(response['taskDefinition']).should.be(dict) type(response['taskDefinition']).should.be(dict)
@ -473,6 +478,8 @@ def test_describe_services():
response['services'][0]['deployments'][0]['pendingCount'].should.equal(2) response['services'][0]['deployments'][0]['pendingCount'].should.equal(2)
response['services'][0]['deployments'][0]['runningCount'].should.equal(0) response['services'][0]['deployments'][0]['runningCount'].should.equal(0)
response['services'][0]['deployments'][0]['status'].should.equal('PRIMARY') response['services'][0]['deployments'][0]['status'].should.equal('PRIMARY')
(datetime.now() - response['services'][0]['deployments'][0]["createdAt"].replace(tzinfo=None)).seconds.should.be.within(0, 10)
(datetime.now() - response['services'][0]['deployments'][0]["updatedAt"].replace(tzinfo=None)).seconds.should.be.within(0, 10)
@mock_ecs @mock_ecs
@ -2304,3 +2311,52 @@ def test_create_service_load_balancing():
response['service']['status'].should.equal('ACTIVE') response['service']['status'].should.equal('ACTIVE')
response['service']['taskDefinition'].should.equal( response['service']['taskDefinition'].should.equal(
'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1') 'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1')
@mock_ecs
def test_list_tags_for_resource():
client = boto3.client('ecs', region_name='us-east-1')
response = client.register_task_definition(
family='test_ecs_task',
containerDefinitions=[
{
'name': 'hello_world',
'image': 'docker/hello-world:latest',
'cpu': 1024,
'memory': 400,
'essential': True,
'environment': [{
'name': 'AWS_ACCESS_KEY_ID',
'value': 'SOME_ACCESS_KEY'
}],
'logConfiguration': {'logDriver': 'json-file'}
}
],
tags=[
{'key': 'createdBy', 'value': 'moto-unittest'},
{'key': 'foo', 'value': 'bar'},
]
)
type(response['taskDefinition']).should.be(dict)
response['taskDefinition']['revision'].should.equal(1)
response['taskDefinition']['taskDefinitionArn'].should.equal(
'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1')
task_definition_arn = response['taskDefinition']['taskDefinitionArn']
response = client.list_tags_for_resource(resourceArn=task_definition_arn)
type(response['tags']).should.be(list)
response['tags'].should.equal([
{'key': 'createdBy', 'value': 'moto-unittest'},
{'key': 'foo', 'value': 'bar'},
])
@mock_ecs
def test_list_tags_for_resource_unknown():
client = boto3.client('ecs', region_name='us-east-1')
task_definition_arn = 'arn:aws:ecs:us-east-1:012345678910:task-definition/unknown:1'
try:
client.list_tags_for_resource(resourceArn=task_definition_arn)
except ClientError as err:
err.response['Error']['Code'].should.equal('ClientException')

View File

@ -667,6 +667,91 @@ def test_register_targets():
response.get('TargetHealthDescriptions').should.have.length_of(1) response.get('TargetHealthDescriptions').should.have.length_of(1)
@mock_ec2
@mock_elbv2
def test_stopped_instance_target():
target_group_port = 8080
conn = boto3.client('elbv2', region_name='us-east-1')
ec2 = boto3.resource('ec2', region_name='us-east-1')
security_group = ec2.create_security_group(
GroupName='a-security-group', Description='First One')
vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default')
subnet1 = ec2.create_subnet(
VpcId=vpc.id,
CidrBlock='172.28.7.192/26',
AvailabilityZone='us-east-1a')
subnet2 = ec2.create_subnet(
VpcId=vpc.id,
CidrBlock='172.28.7.0/26',
AvailabilityZone='us-east-1b')
conn.create_load_balancer(
Name='my-lb',
Subnets=[subnet1.id, subnet2.id],
SecurityGroups=[security_group.id],
Scheme='internal',
Tags=[{'Key': 'key_name', 'Value': 'a_value'}])
response = conn.create_target_group(
Name='a-target',
Protocol='HTTP',
Port=target_group_port,
VpcId=vpc.id,
HealthCheckProtocol='HTTP',
HealthCheckPath='/',
HealthCheckIntervalSeconds=5,
HealthCheckTimeoutSeconds=5,
HealthyThresholdCount=5,
UnhealthyThresholdCount=2,
Matcher={'HttpCode': '200'})
target_group = response.get('TargetGroups')[0]
# No targets registered yet
response = conn.describe_target_health(
TargetGroupArn=target_group.get('TargetGroupArn'))
response.get('TargetHealthDescriptions').should.have.length_of(0)
response = ec2.create_instances(
ImageId='ami-1234abcd', MinCount=1, MaxCount=1)
instance = response[0]
target_dict = {
'Id': instance.id,
'Port': 500
}
response = conn.register_targets(
TargetGroupArn=target_group.get('TargetGroupArn'),
Targets=[target_dict])
response = conn.describe_target_health(
TargetGroupArn=target_group.get('TargetGroupArn'))
response.get('TargetHealthDescriptions').should.have.length_of(1)
target_health_description = response.get('TargetHealthDescriptions')[0]
target_health_description['Target'].should.equal(target_dict)
target_health_description['HealthCheckPort'].should.equal(str(target_group_port))
target_health_description['TargetHealth'].should.equal({
'State': 'healthy'
})
instance.stop()
response = conn.describe_target_health(
TargetGroupArn=target_group.get('TargetGroupArn'))
response.get('TargetHealthDescriptions').should.have.length_of(1)
target_health_description = response.get('TargetHealthDescriptions')[0]
target_health_description['Target'].should.equal(target_dict)
target_health_description['HealthCheckPort'].should.equal(str(target_group_port))
target_health_description['TargetHealth'].should.equal({
'State': 'unused',
'Reason': 'Target.InvalidState',
'Description': 'Target is in the stopped state'
})
@mock_ec2 @mock_ec2
@mock_elbv2 @mock_elbv2
def test_target_group_attributes(): def test_target_group_attributes():
@ -1726,3 +1811,132 @@ def test_redirect_action_listener_rule_cloudformation():
'Port': '443', 'Protocol': 'HTTPS', 'StatusCode': 'HTTP_301', 'Port': '443', 'Protocol': 'HTTPS', 'StatusCode': 'HTTP_301',
} }
},]) },])
@mock_elbv2
@mock_ec2
def test_cognito_action_listener_rule():
conn = boto3.client('elbv2', region_name='us-east-1')
ec2 = boto3.resource('ec2', region_name='us-east-1')
security_group = ec2.create_security_group(
GroupName='a-security-group', Description='First One')
vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default')
subnet1 = ec2.create_subnet(
VpcId=vpc.id,
CidrBlock='172.28.7.192/26',
AvailabilityZone='us-east-1a')
subnet2 = ec2.create_subnet(
VpcId=vpc.id,
CidrBlock='172.28.7.128/26',
AvailabilityZone='us-east-1b')
response = conn.create_load_balancer(
Name='my-lb',
Subnets=[subnet1.id, subnet2.id],
SecurityGroups=[security_group.id],
Scheme='internal',
Tags=[{'Key': 'key_name', 'Value': 'a_value'}])
load_balancer_arn = response.get('LoadBalancers')[0].get('LoadBalancerArn')
action = {
'Type': 'authenticate-cognito',
'AuthenticateCognitoConfig': {
'UserPoolArn': 'arn:aws:cognito-idp:us-east-1:123456789012:userpool/us-east-1_ABCD1234',
'UserPoolClientId': 'abcd1234abcd',
'UserPoolDomain': 'testpool',
}
}
response = conn.create_listener(LoadBalancerArn=load_balancer_arn,
Protocol='HTTP',
Port=80,
DefaultActions=[action])
listener = response.get('Listeners')[0]
listener.get('DefaultActions')[0].should.equal(action)
listener_arn = listener.get('ListenerArn')
describe_rules_response = conn.describe_rules(ListenerArn=listener_arn)
describe_rules_response['Rules'][0]['Actions'][0].should.equal(action)
describe_listener_response = conn.describe_listeners(ListenerArns=[listener_arn, ])
describe_listener_actions = describe_listener_response['Listeners'][0]['DefaultActions'][0]
describe_listener_actions.should.equal(action)
@mock_elbv2
@mock_cloudformation
def test_cognito_action_listener_rule_cloudformation():
cnf_conn = boto3.client('cloudformation', region_name='us-east-1')
elbv2_client = boto3.client('elbv2', region_name='us-east-1')
template = {
"AWSTemplateFormatVersion": "2010-09-09",
"Description": "ECS Cluster Test CloudFormation",
"Resources": {
"testVPC": {
"Type": "AWS::EC2::VPC",
"Properties": {
"CidrBlock": "10.0.0.0/16",
},
},
"subnet1": {
"Type": "AWS::EC2::Subnet",
"Properties": {
"CidrBlock": "10.0.0.0/24",
"VpcId": {"Ref": "testVPC"},
"AvalabilityZone": "us-east-1b",
},
},
"subnet2": {
"Type": "AWS::EC2::Subnet",
"Properties": {
"CidrBlock": "10.0.1.0/24",
"VpcId": {"Ref": "testVPC"},
"AvalabilityZone": "us-east-1b",
},
},
"testLb": {
"Type": "AWS::ElasticLoadBalancingV2::LoadBalancer",
"Properties": {
"Name": "my-lb",
"Subnets": [{"Ref": "subnet1"}, {"Ref": "subnet2"}],
"Type": "application",
"SecurityGroups": [],
}
},
"testListener": {
"Type": "AWS::ElasticLoadBalancingV2::Listener",
"Properties": {
"LoadBalancerArn": {"Ref": "testLb"},
"Port": 80,
"Protocol": "HTTP",
"DefaultActions": [{
"Type": "authenticate-cognito",
"AuthenticateCognitoConfig": {
'UserPoolArn': 'arn:aws:cognito-idp:us-east-1:123456789012:userpool/us-east-1_ABCD1234',
'UserPoolClientId': 'abcd1234abcd',
'UserPoolDomain': 'testpool',
}
}]
}
}
}
}
template_json = json.dumps(template)
cnf_conn.create_stack(StackName="test-stack", TemplateBody=template_json)
describe_load_balancers_response = elbv2_client.describe_load_balancers(Names=['my-lb',])
load_balancer_arn = describe_load_balancers_response['LoadBalancers'][0]['LoadBalancerArn']
describe_listeners_response = elbv2_client.describe_listeners(LoadBalancerArn=load_balancer_arn)
describe_listeners_response['Listeners'].should.have.length_of(1)
describe_listeners_response['Listeners'][0]['DefaultActions'].should.equal([{
'Type': 'authenticate-cognito',
"AuthenticateCognitoConfig": {
'UserPoolArn': 'arn:aws:cognito-idp:us-east-1:123456789012:userpool/us-east-1_ABCD1234',
'UserPoolClientId': 'abcd1234abcd',
'UserPoolDomain': 'testpool',
}
},])

View File

@ -419,6 +419,63 @@ def test_get_partition():
partition['Values'].should.equal(values[1]) partition['Values'].should.equal(values[1])
@mock_glue
def test_batch_get_partition():
client = boto3.client('glue', region_name='us-east-1')
database_name = 'myspecialdatabase'
table_name = 'myfirsttable'
helpers.create_database(client, database_name)
helpers.create_table(client, database_name, table_name)
values = [['2018-10-01'], ['2018-09-01']]
helpers.create_partition(client, database_name, table_name, values=values[0])
helpers.create_partition(client, database_name, table_name, values=values[1])
partitions_to_get = [
{'Values': values[0]},
{'Values': values[1]},
]
response = client.batch_get_partition(DatabaseName=database_name, TableName=table_name, PartitionsToGet=partitions_to_get)
partitions = response['Partitions']
partitions.should.have.length_of(2)
partition = partitions[1]
partition['TableName'].should.equal(table_name)
partition['Values'].should.equal(values[1])
@mock_glue
def test_batch_get_partition_missing_partition():
client = boto3.client('glue', region_name='us-east-1')
database_name = 'myspecialdatabase'
table_name = 'myfirsttable'
helpers.create_database(client, database_name)
helpers.create_table(client, database_name, table_name)
values = [['2018-10-01'], ['2018-09-01'], ['2018-08-01']]
helpers.create_partition(client, database_name, table_name, values=values[0])
helpers.create_partition(client, database_name, table_name, values=values[2])
partitions_to_get = [
{'Values': values[0]},
{'Values': values[1]},
{'Values': values[2]},
]
response = client.batch_get_partition(DatabaseName=database_name, TableName=table_name, PartitionsToGet=partitions_to_get)
partitions = response['Partitions']
partitions.should.have.length_of(2)
partitions[0]['Values'].should.equal(values[0])
partitions[1]['Values'].should.equal(values[2])
@mock_glue @mock_glue
def test_update_partition_not_found_moving(): def test_update_partition_not_found_moving():
client = boto3.client('glue', region_name='us-east-1') client = boto3.client('glue', region_name='us-east-1')

View File

@ -11,21 +11,29 @@ import sure # noqa
from moto import mock_kms, mock_kms_deprecated from moto import mock_kms, mock_kms_deprecated
from nose.tools import assert_raises from nose.tools import assert_raises
from freezegun import freeze_time from freezegun import freeze_time
from datetime import date
from datetime import datetime from datetime import datetime
from dateutil.tz import tzutc from dateutil.tz import tzutc
@mock_kms_deprecated @mock_kms
def test_create_key(): def test_create_key():
conn = boto.kms.connect_to_region("us-west-2") conn = boto3.client('kms', region_name='us-east-1')
with freeze_time("2015-01-01 00:00:00"): with freeze_time("2015-01-01 00:00:00"):
key = conn.create_key(policy="my policy", key = conn.create_key(Policy="my policy",
description="my key", key_usage='ENCRYPT_DECRYPT') Description="my key",
KeyUsage='ENCRYPT_DECRYPT',
Tags=[
{
'TagKey': 'project',
'TagValue': 'moto',
},
])
key['KeyMetadata']['Description'].should.equal("my key") key['KeyMetadata']['Description'].should.equal("my key")
key['KeyMetadata']['KeyUsage'].should.equal("ENCRYPT_DECRYPT") key['KeyMetadata']['KeyUsage'].should.equal("ENCRYPT_DECRYPT")
key['KeyMetadata']['Enabled'].should.equal(True) key['KeyMetadata']['Enabled'].should.equal(True)
key['KeyMetadata']['CreationDate'].should.equal("1420070400") key['KeyMetadata']['CreationDate'].should.be.a(date)
@mock_kms_deprecated @mock_kms_deprecated
@ -183,6 +191,7 @@ def test_decrypt():
conn = boto.kms.connect_to_region('us-west-2') conn = boto.kms.connect_to_region('us-west-2')
response = conn.decrypt('ZW5jcnlwdG1l'.encode('utf-8')) response = conn.decrypt('ZW5jcnlwdG1l'.encode('utf-8'))
response['Plaintext'].should.equal(b'encryptme') response['Plaintext'].should.equal(b'encryptme')
response['KeyId'].should.equal('key_id')
@mock_kms_deprecated @mock_kms_deprecated

View File

@ -162,3 +162,63 @@ def test_delete_retention_policy():
response = conn.delete_log_group(logGroupName=log_group_name) response = conn.delete_log_group(logGroupName=log_group_name)
@mock_logs
def test_get_log_events():
conn = boto3.client('logs', 'us-west-2')
log_group_name = 'test'
log_stream_name = 'stream'
conn.create_log_group(logGroupName=log_group_name)
conn.create_log_stream(
logGroupName=log_group_name,
logStreamName=log_stream_name
)
events = [{'timestamp': x, 'message': str(x)} for x in range(20)]
conn.put_log_events(
logGroupName=log_group_name,
logStreamName=log_stream_name,
logEvents=events
)
resp = conn.get_log_events(
logGroupName=log_group_name,
logStreamName=log_stream_name,
limit=10)
resp['events'].should.have.length_of(10)
resp.should.have.key('nextForwardToken')
resp.should.have.key('nextBackwardToken')
for i in range(10):
resp['events'][i]['timestamp'].should.equal(i)
resp['events'][i]['message'].should.equal(str(i))
next_token = resp['nextForwardToken']
resp = conn.get_log_events(
logGroupName=log_group_name,
logStreamName=log_stream_name,
nextToken=next_token,
limit=10)
resp['events'].should.have.length_of(10)
resp.should.have.key('nextForwardToken')
resp.should.have.key('nextBackwardToken')
resp['nextForwardToken'].should.equal(next_token)
for i in range(10):
resp['events'][i]['timestamp'].should.equal(i+10)
resp['events'][i]['message'].should.equal(str(i+10))
resp = conn.get_log_events(
logGroupName=log_group_name,
logStreamName=log_stream_name,
nextToken=resp['nextBackwardToken'],
limit=10)
resp['events'].should.have.length_of(10)
resp.should.have.key('nextForwardToken')
resp.should.have.key('nextBackwardToken')
for i in range(10):
resp['events'][i]['timestamp'].should.equal(i)
resp['events'][i]['message'].should.equal(str(i))

View File

@ -1,7 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import six import six
import sure # noqa
import datetime import datetime
from moto.organizations import utils from moto.organizations import utils

View File

@ -3,7 +3,6 @@ from __future__ import unicode_literals
import boto3 import boto3
import json import json
import six import six
import sure # noqa
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
from nose.tools import assert_raises from nose.tools import assert_raises
@ -27,6 +26,25 @@ def test_create_organization():
validate_organization(response) validate_organization(response)
response['Organization']['FeatureSet'].should.equal('ALL') response['Organization']['FeatureSet'].should.equal('ALL')
response = client.list_accounts()
len(response['Accounts']).should.equal(1)
response['Accounts'][0]['Name'].should.equal('master')
response['Accounts'][0]['Id'].should.equal(utils.MASTER_ACCOUNT_ID)
response['Accounts'][0]['Email'].should.equal(utils.MASTER_ACCOUNT_EMAIL)
response = client.list_policies(Filter='SERVICE_CONTROL_POLICY')
len(response['Policies']).should.equal(1)
response['Policies'][0]['Name'].should.equal('FullAWSAccess')
response['Policies'][0]['Id'].should.equal(utils.DEFAULT_POLICY_ID)
response['Policies'][0]['AwsManaged'].should.equal(True)
response = client.list_targets_for_policy(PolicyId=utils.DEFAULT_POLICY_ID)
len(response['Targets']).should.equal(2)
root_ou = [t for t in response['Targets'] if t['Type'] == 'ROOT'][0]
root_ou['Name'].should.equal('Root')
master_account = [t for t in response['Targets'] if t['Type'] == 'ACCOUNT'][0]
master_account['Name'].should.equal('master')
@mock_organizations @mock_organizations
def test_describe_organization(): def test_describe_organization():
@ -177,11 +195,11 @@ def test_list_accounts():
response = client.list_accounts() response = client.list_accounts()
response.should.have.key('Accounts') response.should.have.key('Accounts')
accounts = response['Accounts'] accounts = response['Accounts']
len(accounts).should.equal(5) len(accounts).should.equal(6)
for account in accounts: for account in accounts:
validate_account(org, account) validate_account(org, account)
accounts[3]['Name'].should.equal(mockname + '3') accounts[4]['Name'].should.equal(mockname + '3')
accounts[2]['Email'].should.equal(mockname + '2' + '@' + mockdomain) accounts[3]['Email'].should.equal(mockname + '2' + '@' + mockdomain)
@mock_organizations @mock_organizations
@ -291,8 +309,10 @@ def test_list_children():
response02 = client.list_children(ParentId=root_id, ChildType='ORGANIZATIONAL_UNIT') response02 = client.list_children(ParentId=root_id, ChildType='ORGANIZATIONAL_UNIT')
response03 = client.list_children(ParentId=ou01_id, ChildType='ACCOUNT') response03 = client.list_children(ParentId=ou01_id, ChildType='ACCOUNT')
response04 = client.list_children(ParentId=ou01_id, ChildType='ORGANIZATIONAL_UNIT') response04 = client.list_children(ParentId=ou01_id, ChildType='ORGANIZATIONAL_UNIT')
response01['Children'][0]['Id'].should.equal(account01_id) response01['Children'][0]['Id'].should.equal(utils.MASTER_ACCOUNT_ID)
response01['Children'][0]['Type'].should.equal('ACCOUNT') response01['Children'][0]['Type'].should.equal('ACCOUNT')
response01['Children'][1]['Id'].should.equal(account01_id)
response01['Children'][1]['Type'].should.equal('ACCOUNT')
response02['Children'][0]['Id'].should.equal(ou01_id) response02['Children'][0]['Id'].should.equal(ou01_id)
response02['Children'][0]['Type'].should.equal('ORGANIZATIONAL_UNIT') response02['Children'][0]['Type'].should.equal('ORGANIZATIONAL_UNIT')
response03['Children'][0]['Id'].should.equal(account02_id) response03['Children'][0]['Id'].should.equal(account02_id)
@ -591,4 +611,3 @@ def test_list_targets_for_policy_exception():
ex.operation_name.should.equal('ListTargetsForPolicy') ex.operation_name.should.equal('ListTargetsForPolicy')
ex.response['Error']['Code'].should.equal('400') ex.response['Error']['Code'].should.equal('400')
ex.response['Error']['Message'].should.contain('InvalidInputException') ex.response['Error']['Message'].should.contain('InvalidInputException')

View File

@ -18,7 +18,8 @@ def test_create_database():
MasterUsername='root', MasterUsername='root',
MasterUserPassword='hunter2', MasterUserPassword='hunter2',
Port=1234, Port=1234,
DBSecurityGroups=["my_sg"]) DBSecurityGroups=["my_sg"],
VpcSecurityGroupIds=['sg-123456'])
db_instance = database['DBInstance'] db_instance = database['DBInstance']
db_instance['AllocatedStorage'].should.equal(10) db_instance['AllocatedStorage'].should.equal(10)
db_instance['DBInstanceClass'].should.equal("db.m1.small") db_instance['DBInstanceClass'].should.equal("db.m1.small")
@ -35,6 +36,7 @@ def test_create_database():
db_instance['DbiResourceId'].should.contain("db-") db_instance['DbiResourceId'].should.contain("db-")
db_instance['CopyTagsToSnapshot'].should.equal(False) db_instance['CopyTagsToSnapshot'].should.equal(False)
db_instance['InstanceCreateTime'].should.be.a("datetime.datetime") db_instance['InstanceCreateTime'].should.be.a("datetime.datetime")
db_instance['VpcSecurityGroups'][0]['VpcSecurityGroupId'].should.equal('sg-123456')
@mock_rds2 @mock_rds2
@ -260,9 +262,11 @@ def test_modify_db_instance():
instances['DBInstances'][0]['AllocatedStorage'].should.equal(10) instances['DBInstances'][0]['AllocatedStorage'].should.equal(10)
conn.modify_db_instance(DBInstanceIdentifier='db-master-1', conn.modify_db_instance(DBInstanceIdentifier='db-master-1',
AllocatedStorage=20, AllocatedStorage=20,
ApplyImmediately=True) ApplyImmediately=True,
VpcSecurityGroupIds=['sg-123456'])
instances = conn.describe_db_instances(DBInstanceIdentifier='db-master-1') instances = conn.describe_db_instances(DBInstanceIdentifier='db-master-1')
instances['DBInstances'][0]['AllocatedStorage'].should.equal(20) instances['DBInstances'][0]['AllocatedStorage'].should.equal(20)
instances['DBInstances'][0]['VpcSecurityGroups'][0]['VpcSecurityGroupId'].should.equal('sg-123456')
@mock_rds2 @mock_rds2

View File

@ -36,6 +36,7 @@ def test_create_cluster_boto3():
response['Cluster']['NodeType'].should.equal('ds2.xlarge') response['Cluster']['NodeType'].should.equal('ds2.xlarge')
create_time = response['Cluster']['ClusterCreateTime'] create_time = response['Cluster']['ClusterCreateTime']
create_time.should.be.lower_than(datetime.datetime.now(create_time.tzinfo)) create_time.should.be.lower_than(datetime.datetime.now(create_time.tzinfo))
create_time.should.be.greater_than(datetime.datetime.now(create_time.tzinfo) - datetime.timedelta(minutes=1))
@mock_redshift @mock_redshift

View File

@ -2,7 +2,11 @@ from __future__ import unicode_literals
import boto3 import boto3
import sure # noqa import sure # noqa
from moto import mock_resourcegroupstaggingapi, mock_s3, mock_ec2, mock_elbv2 from moto import mock_ec2
from moto import mock_elbv2
from moto import mock_kms
from moto import mock_resourcegroupstaggingapi
from moto import mock_s3
@mock_s3 @mock_s3
@ -225,10 +229,12 @@ def test_get_tag_values_ec2():
@mock_ec2 @mock_ec2
@mock_elbv2 @mock_elbv2
@mock_kms
@mock_resourcegroupstaggingapi @mock_resourcegroupstaggingapi
def test_get_resources_elbv2(): def test_get_many_resources():
conn = boto3.client('elbv2', region_name='us-east-1') elbv2 = boto3.client('elbv2', region_name='us-east-1')
ec2 = boto3.resource('ec2', region_name='us-east-1') ec2 = boto3.resource('ec2', region_name='us-east-1')
kms = boto3.client('kms', region_name='us-east-1')
security_group = ec2.create_security_group( security_group = ec2.create_security_group(
GroupName='a-security-group', Description='First One') GroupName='a-security-group', Description='First One')
@ -242,7 +248,7 @@ def test_get_resources_elbv2():
CidrBlock='172.28.7.0/26', CidrBlock='172.28.7.0/26',
AvailabilityZone='us-east-1b') AvailabilityZone='us-east-1b')
conn.create_load_balancer( elbv2.create_load_balancer(
Name='my-lb', Name='my-lb',
Subnets=[subnet1.id, subnet2.id], Subnets=[subnet1.id, subnet2.id],
SecurityGroups=[security_group.id], SecurityGroups=[security_group.id],
@ -259,13 +265,27 @@ def test_get_resources_elbv2():
] ]
) )
conn.create_load_balancer( elbv2.create_load_balancer(
Name='my-other-lb', Name='my-other-lb',
Subnets=[subnet1.id, subnet2.id], Subnets=[subnet1.id, subnet2.id],
SecurityGroups=[security_group.id], SecurityGroups=[security_group.id],
Scheme='internal', Scheme='internal',
) )
kms.create_key(
KeyUsage='ENCRYPT_DECRYPT',
Tags=[
{
'TagKey': 'key_name',
'TagValue': 'a_value'
},
{
'TagKey': 'key_2',
'TagValue': 'val2'
}
]
)
rtapi = boto3.client('resourcegroupstaggingapi', region_name='us-east-1') rtapi = boto3.client('resourcegroupstaggingapi', region_name='us-east-1')
resp = rtapi.get_resources(ResourceTypeFilters=['elasticloadbalancer:loadbalancer']) resp = rtapi.get_resources(ResourceTypeFilters=['elasticloadbalancer:loadbalancer'])

View File

@ -652,6 +652,114 @@ def test_change_resource_record_sets_crud_valid():
response = conn.list_resource_record_sets(HostedZoneId=hosted_zone_id) response = conn.list_resource_record_sets(HostedZoneId=hosted_zone_id)
len(response['ResourceRecordSets']).should.equal(0) len(response['ResourceRecordSets']).should.equal(0)
@mock_route53
def test_change_weighted_resource_record_sets():
conn = boto3.client('route53', region_name='us-east-2')
conn.create_hosted_zone(
Name='test.vpc.internal.',
CallerReference=str(hash('test'))
)
zones = conn.list_hosted_zones_by_name(
DNSName='test.vpc.internal.'
)
hosted_zone_id = zones['HostedZones'][0]['Id']
#Create 2 weighted records
conn.change_resource_record_sets(
HostedZoneId=hosted_zone_id,
ChangeBatch={
'Changes': [
{
'Action': 'CREATE',
'ResourceRecordSet': {
'Name': 'test.vpc.internal',
'Type': 'A',
'SetIdentifier': 'test1',
'Weight': 50,
'AliasTarget': {
'HostedZoneId': 'Z3AADJGX6KTTL2',
'DNSName': 'internal-test1lb-447688172.us-east-2.elb.amazonaws.com.',
'EvaluateTargetHealth': True
}
}
},
{
'Action': 'CREATE',
'ResourceRecordSet': {
'Name': 'test.vpc.internal',
'Type': 'A',
'SetIdentifier': 'test2',
'Weight': 50,
'AliasTarget': {
'HostedZoneId': 'Z3AADJGX6KTTL2',
'DNSName': 'internal-testlb2-1116641781.us-east-2.elb.amazonaws.com.',
'EvaluateTargetHealth': True
}
}
}
]
}
)
response = conn.list_resource_record_sets(HostedZoneId=hosted_zone_id)
record = response['ResourceRecordSets'][0]
#Update the first record to have a weight of 90
conn.change_resource_record_sets(
HostedZoneId=hosted_zone_id,
ChangeBatch={
'Changes' : [
{
'Action' : 'UPSERT',
'ResourceRecordSet' : {
'Name' : record['Name'],
'Type' : record['Type'],
'SetIdentifier' : record['SetIdentifier'],
'Weight' : 90,
'AliasTarget' : {
'HostedZoneId' : record['AliasTarget']['HostedZoneId'],
'DNSName' : record['AliasTarget']['DNSName'],
'EvaluateTargetHealth' : record['AliasTarget']['EvaluateTargetHealth']
}
}
},
]
}
)
record = response['ResourceRecordSets'][1]
#Update the second record to have a weight of 10
conn.change_resource_record_sets(
HostedZoneId=hosted_zone_id,
ChangeBatch={
'Changes' : [
{
'Action' : 'UPSERT',
'ResourceRecordSet' : {
'Name' : record['Name'],
'Type' : record['Type'],
'SetIdentifier' : record['SetIdentifier'],
'Weight' : 10,
'AliasTarget' : {
'HostedZoneId' : record['AliasTarget']['HostedZoneId'],
'DNSName' : record['AliasTarget']['DNSName'],
'EvaluateTargetHealth' : record['AliasTarget']['EvaluateTargetHealth']
}
}
},
]
}
)
response = conn.list_resource_record_sets(HostedZoneId=hosted_zone_id)
for record in response['ResourceRecordSets']:
if record['SetIdentifier'] == 'test1':
record['Weight'].should.equal(90)
if record['SetIdentifier'] == 'test2':
record['Weight'].should.equal(10)
@mock_route53 @mock_route53
def test_change_resource_record_invalid(): def test_change_resource_record_invalid():

View File

@ -639,7 +639,7 @@ def test_delete_keys():
@mock_s3_deprecated @mock_s3_deprecated
def test_delete_keys_with_invalid(): def test_delete_keys_invalid():
conn = boto.connect_s3('the_key', 'the_secret') conn = boto.connect_s3('the_key', 'the_secret')
bucket = conn.create_bucket('foobar') bucket = conn.create_bucket('foobar')
@ -648,6 +648,7 @@ def test_delete_keys_with_invalid():
Key(bucket=bucket, name='file3').set_contents_from_string('abc') Key(bucket=bucket, name='file3').set_contents_from_string('abc')
Key(bucket=bucket, name='file4').set_contents_from_string('abc') Key(bucket=bucket, name='file4').set_contents_from_string('abc')
# non-existing key case
result = bucket.delete_keys(['abc', 'file3']) result = bucket.delete_keys(['abc', 'file3'])
result.deleted.should.have.length_of(1) result.deleted.should.have.length_of(1)
@ -656,6 +657,18 @@ def test_delete_keys_with_invalid():
keys.should.have.length_of(3) keys.should.have.length_of(3)
keys[0].name.should.equal('file1') keys[0].name.should.equal('file1')
# empty keys
result = bucket.delete_keys([])
result.deleted.should.have.length_of(0)
result.errors.should.have.length_of(0)
@mock_s3
def test_boto3_delete_empty_keys_list():
with assert_raises(ClientError) as err:
boto3.client('s3').delete_objects(Bucket='foobar', Delete={'Objects': []})
assert err.exception.response["Error"]["Code"] == "MalformedXML"
@mock_s3_deprecated @mock_s3_deprecated
def test_bucket_name_with_dot(): def test_bucket_name_with_dot():
@ -1671,6 +1684,42 @@ def test_boto3_multipart_etag():
resp['ETag'].should.equal(EXPECTED_ETAG) resp['ETag'].should.equal(EXPECTED_ETAG)
@mock_s3
@reduced_min_part_size
def test_boto3_multipart_part_size():
s3 = boto3.client('s3', region_name='us-east-1')
s3.create_bucket(Bucket='mybucket')
mpu = s3.create_multipart_upload(Bucket='mybucket', Key='the-key')
mpu_id = mpu["UploadId"]
parts = []
n_parts = 10
for i in range(1, n_parts + 1):
part_size = REDUCED_PART_SIZE + i
body = b'1' * part_size
part = s3.upload_part(
Bucket='mybucket',
Key='the-key',
PartNumber=i,
UploadId=mpu_id,
Body=body,
ContentLength=len(body),
)
parts.append({"PartNumber": i, "ETag": part["ETag"]})
s3.complete_multipart_upload(
Bucket='mybucket',
Key='the-key',
UploadId=mpu_id,
MultipartUpload={"Parts": parts},
)
for i in range(1, n_parts + 1):
obj = s3.head_object(Bucket='mybucket', Key='the-key', PartNumber=i)
assert obj["ContentLength"] == REDUCED_PART_SIZE + i
@mock_s3 @mock_s3
def test_boto3_put_object_with_tagging(): def test_boto3_put_object_with_tagging():
s3 = boto3.client('s3', region_name='us-east-1') s3 = boto3.client('s3', region_name='us-east-1')

View File

@ -1,16 +1,12 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import boto
import boto3 import boto3
from boto.exception import S3CreateError, S3ResponseError
from boto.s3.lifecycle import Lifecycle, Transition, Expiration, Rule
import sure # noqa import sure # noqa
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
from datetime import datetime
from nose.tools import assert_raises from nose.tools import assert_raises
from moto import mock_s3_deprecated, mock_s3 from moto import mock_s3
@mock_s3 @mock_s3
@ -41,6 +37,18 @@ def test_s3_storage_class_infrequent_access():
D['Contents'][0]["StorageClass"].should.equal("STANDARD_IA") D['Contents'][0]["StorageClass"].should.equal("STANDARD_IA")
@mock_s3
def test_s3_storage_class_intelligent_tiering():
s3 = boto3.client("s3")
s3.create_bucket(Bucket="Bucket")
s3.put_object(Bucket="Bucket", Key="my_key_infrequent", Body="my_value_infrequent", StorageClass="INTELLIGENT_TIERING")
objects = s3.list_objects(Bucket="Bucket")
objects['Contents'][0]["StorageClass"].should.equal("INTELLIGENT_TIERING")
@mock_s3 @mock_s3
def test_s3_storage_class_copy(): def test_s3_storage_class_copy():
s3 = boto3.client("s3") s3 = boto3.client("s3")
@ -90,6 +98,7 @@ def test_s3_invalid_storage_class():
e.response["Error"]["Code"].should.equal("InvalidStorageClass") e.response["Error"]["Code"].should.equal("InvalidStorageClass")
e.response["Error"]["Message"].should.equal("The storage class you specified is not valid") e.response["Error"]["Message"].should.equal("The storage class you specified is not valid")
@mock_s3 @mock_s3
def test_s3_default_storage_class(): def test_s3_default_storage_class():
s3 = boto3.client("s3") s3 = boto3.client("s3")
@ -103,4 +112,27 @@ def test_s3_default_storage_class():
list_of_objects["Contents"][0]["StorageClass"].should.equal("STANDARD") list_of_objects["Contents"][0]["StorageClass"].should.equal("STANDARD")
@mock_s3
def test_s3_copy_object_error_for_glacier_storage_class():
s3 = boto3.client("s3")
s3.create_bucket(Bucket="Bucket")
s3.put_object(Bucket="Bucket", Key="First_Object", Body="Body", StorageClass="GLACIER")
with assert_raises(ClientError) as exc:
s3.copy_object(CopySource={"Bucket": "Bucket", "Key": "First_Object"}, Bucket="Bucket", Key="Second_Object")
exc.exception.response["Error"]["Code"].should.equal("ObjectNotInActiveTierError")
@mock_s3
def test_s3_copy_object_error_for_deep_archive_storage_class():
s3 = boto3.client("s3")
s3.create_bucket(Bucket="Bucket")
s3.put_object(Bucket="Bucket", Key="First_Object", Body="Body", StorageClass="DEEP_ARCHIVE")
with assert_raises(ClientError) as exc:
s3.copy_object(CopySource={"Bucket": "Bucket", "Key": "First_Object"}, Bucket="Bucket", Key="Second_Object")
exc.exception.response["Error"]["Code"].should.equal("ObjectNotInActiveTierError")

View File

@ -3,11 +3,15 @@ import json
import boto import boto
import boto3 import boto3
from botocore.client import ClientError
from freezegun import freeze_time from freezegun import freeze_time
from nose.tools import assert_raises
import sure # noqa import sure # noqa
from moto import mock_sts, mock_sts_deprecated, mock_iam from moto import mock_sts, mock_sts_deprecated, mock_iam
from moto.iam.models import ACCOUNT_ID from moto.iam.models import ACCOUNT_ID
from moto.sts.responses import MAX_FEDERATION_TOKEN_POLICY_LENGTH
@freeze_time("2012-01-01 12:00:00") @freeze_time("2012-01-01 12:00:00")
@ -80,6 +84,41 @@ def test_assume_role():
assume_role_response['AssumedRoleUser']['AssumedRoleId'].should.have.length_of(21 + 1 + len(session_name)) assume_role_response['AssumedRoleUser']['AssumedRoleId'].should.have.length_of(21 + 1 + len(session_name))
@freeze_time("2012-01-01 12:00:00")
@mock_sts_deprecated
def test_assume_role_with_web_identity():
conn = boto.connect_sts()
policy = json.dumps({
"Statement": [
{
"Sid": "Stmt13690092345534",
"Action": [
"S3:ListBucket"
],
"Effect": "Allow",
"Resource": [
"arn:aws:s3:::foobar-tester"
]
},
]
})
s3_role = "arn:aws:iam::123456789012:role/test-role"
role = conn.assume_role_with_web_identity(
s3_role, "session-name", policy, duration_seconds=123)
credentials = role.credentials
credentials.expiration.should.equal('2012-01-01T12:02:03.000Z')
credentials.session_token.should.have.length_of(356)
assert credentials.session_token.startswith("FQoGZXIvYXdzE")
credentials.access_key.should.have.length_of(20)
assert credentials.access_key.startswith("ASIA")
credentials.secret_key.should.have.length_of(40)
role.user.arn.should.equal("arn:aws:iam::123456789012:role/test-role")
role.user.assume_role_id.should.contain("session-name")
@mock_sts @mock_sts
def test_get_caller_identity_with_default_credentials(): def test_get_caller_identity_with_default_credentials():
identity = boto3.client( identity = boto3.client(
@ -137,3 +176,32 @@ def test_get_caller_identity_with_assumed_role_credentials():
identity['Arn'].should.equal(assumed_role['AssumedRoleUser']['Arn']) identity['Arn'].should.equal(assumed_role['AssumedRoleUser']['Arn'])
identity['UserId'].should.equal(assumed_role['AssumedRoleUser']['AssumedRoleId']) identity['UserId'].should.equal(assumed_role['AssumedRoleUser']['AssumedRoleId'])
identity['Account'].should.equal(str(ACCOUNT_ID)) identity['Account'].should.equal(str(ACCOUNT_ID))
@mock_sts
def test_federation_token_with_too_long_policy():
"Trying to get a federation token with a policy longer than 2048 character should fail"
cli = boto3.client("sts", region_name='us-east-1')
resource_tmpl = 'arn:aws:s3:::yyyy-xxxxx-cloud-default/my_default_folder/folder-name-%s/*'
statements = []
for num in range(30):
statements.append(
{
'Effect': 'Allow',
'Action': ['s3:*'],
'Resource': resource_tmpl % str(num)
}
)
policy = {
'Version': '2012-10-17',
'Statement': statements
}
json_policy = json.dumps(policy)
assert len(json_policy) > MAX_FEDERATION_TOKEN_POLICY_LENGTH
with assert_raises(ClientError) as exc:
cli.get_federation_token(Name='foo', DurationSeconds=3600, Policy=json_policy)
exc.exception.response['Error']['Code'].should.equal('ValidationError')
exc.exception.response['Error']['Message'].should.contain(
str(MAX_FEDERATION_TOKEN_POLICY_LENGTH)
)

View File

@ -1,5 +1,5 @@
[tox] [tox]
envlist = py27, py36 envlist = py27, py36, py37
[testenv] [testenv]
setenv = setenv =

View File

@ -74,9 +74,9 @@ def prerelease_version():
ver, commits_since, githash = get_git_version_info() ver, commits_since, githash = get_git_version_info()
initpy_ver = get_version() initpy_ver = get_version()
assert len(initpy_ver.split('.')) in [3, 4], 'moto/__init__.py version should be like 0.0.2 or 0.0.2.dev' assert len(initpy_ver.split('.')) in [3, 4], 'moto/__init__.py version should be like 0.0.2.dev'
assert initpy_ver > ver, 'the moto/__init__.py version should be newer than the last tagged release.' assert initpy_ver > ver, 'the moto/__init__.py version should be newer than the last tagged release.'
return '{initpy_ver}.dev{commits_since}'.format(initpy_ver=initpy_ver, commits_since=commits_since) return '{initpy_ver}.{commits_since}'.format(initpy_ver=initpy_ver, commits_since=commits_since)
def read(*parts): def read(*parts):
""" Reads in file from *parts. """ Reads in file from *parts.
@ -108,8 +108,10 @@ def release_version_correct():
new_version = prerelease_version() new_version = prerelease_version()
print('updating version in __init__.py to {new_version}'.format(new_version=new_version)) print('updating version in __init__.py to {new_version}'.format(new_version=new_version))
assert len(new_version.split('.')) >= 4, 'moto/__init__.py version should be like 0.0.2.dev'
migrate_version(initpy, new_version) migrate_version(initpy, new_version)
else: else:
assert False, "No non-master deployments yet"
# check that we are a tag with the same version as in __init__.py # check that we are a tag with the same version as in __init__.py
assert get_version() == git_tag_name(), 'git tag/branch name not the same as moto/__init__.py __verion__' assert get_version() == git_tag_name(), 'git tag/branch name not the same as moto/__init__.py __verion__'