Merge branch 'master' of https://github.com/spulec/moto into spulec-master

This commit is contained in:
Stephan Huber 2018-10-16 15:29:56 +02:00
parent 1c5c5036e3
commit 0ba213ffcc
76 changed files with 7929 additions and 4971 deletions

View File

@ -8,6 +8,19 @@ python:
env: env:
- TEST_SERVER_MODE=false - TEST_SERVER_MODE=false
- TEST_SERVER_MODE=true - TEST_SERVER_MODE=true
# Due to incomplete Python 3.7 support on Travis CI (
# https://github.com/travis-ci/travis-ci/issues/9815),
# using a matrix is necessary
matrix:
include:
- python: 3.7
env: TEST_SERVER_MODE=false
dist: xenial
sudo: true
- python: 3.7
env: TEST_SERVER_MODE=true
dist: xenial
sudo: true
before_install: before_install:
- export BOTO_CONFIG=/dev/null - export BOTO_CONFIG=/dev/null
install: install:

View File

@ -53,3 +53,4 @@ Moto is written by Steve Pulec with contributions from:
* [Jim Shields](https://github.com/jimjshields) * [Jim Shields](https://github.com/jimjshields)
* [William Richard](https://github.com/william-richard) * [William Richard](https://github.com/william-richard)
* [Alex Casalboni](https://github.com/alexcasalboni) * [Alex Casalboni](https://github.com/alexcasalboni)
* [Jon Beilke](https://github.com/jrbeilke)

View File

@ -1,6 +1,11 @@
Moto Changelog Moto Changelog
=================== ===================
1.3.6
-----
* Fix boto3 pinning.
1.3.5 1.3.5
----- -----

File diff suppressed because it is too large Load Diff

View File

@ -112,6 +112,8 @@ It gets even better! Moto isn't just for Python code and it isn't just for S3. L
|------------------------------------------------------------------------------| |------------------------------------------------------------------------------|
| KMS | @mock_kms | basic endpoints done | | KMS | @mock_kms | basic endpoints done |
|------------------------------------------------------------------------------| |------------------------------------------------------------------------------|
| Organizations | @mock_organizations | some core endpoints done |
|------------------------------------------------------------------------------|
| Polly | @mock_polly | all endpoints done | | Polly | @mock_polly | all endpoints done |
|------------------------------------------------------------------------------| |------------------------------------------------------------------------------|
| RDS | @mock_rds | core endpoints done | | RDS | @mock_rds | core endpoints done |

View File

@ -34,11 +34,11 @@ Currently implemented Services:
| - DynamoDB2 | - @mock_dynamodb2 | - core endpoints + partial indexes| | - DynamoDB2 | - @mock_dynamodb2 | - core endpoints + partial indexes|
+-----------------------+---------------------+-----------------------------------+ +-----------------------+---------------------+-----------------------------------+
| EC2 | @mock_ec2 | core endpoints done | | EC2 | @mock_ec2 | core endpoints done |
| - AMI | | core endpoints done | | - AMI | | - core endpoints done |
| - EBS | | core endpoints done | | - EBS | | - core endpoints done |
| - Instances | | all endpoints done | | - Instances | | - all endpoints done |
| - Security Groups | | core endpoints done | | - Security Groups | | - core endpoints done |
| - Tags | | all endpoints done | | - Tags | | - all endpoints done |
+-----------------------+---------------------+-----------------------------------+ +-----------------------+---------------------+-----------------------------------+
| ECS | @mock_ecs | basic endpoints done | | ECS | @mock_ecs | basic endpoints done |
+-----------------------+---------------------+-----------------------------------+ +-----------------------+---------------------+-----------------------------------+

View File

@ -3,7 +3,7 @@ import logging
# logging.getLogger('boto').setLevel(logging.CRITICAL) # logging.getLogger('boto').setLevel(logging.CRITICAL)
__title__ = 'moto' __title__ = 'moto'
__version__ = '1.3.5' __version__ = '1.3.6'
from .acm import mock_acm # flake8: noqa from .acm import mock_acm # flake8: noqa
from .apigateway import mock_apigateway, mock_apigateway_deprecated # flake8: noqa from .apigateway import mock_apigateway, mock_apigateway_deprecated # flake8: noqa
@ -28,6 +28,7 @@ from .glue import mock_glue # flake8: noqa
from .iam import mock_iam, mock_iam_deprecated # flake8: noqa from .iam import mock_iam, mock_iam_deprecated # flake8: noqa
from .kinesis import mock_kinesis, mock_kinesis_deprecated # flake8: noqa from .kinesis import mock_kinesis, mock_kinesis_deprecated # flake8: noqa
from .kms import mock_kms, mock_kms_deprecated # flake8: noqa from .kms import mock_kms, mock_kms_deprecated # flake8: noqa
from .organizations import mock_organizations # flake8: noqa
from .opsworks import mock_opsworks, mock_opsworks_deprecated # flake8: noqa from .opsworks import mock_opsworks, mock_opsworks_deprecated # flake8: noqa
from .polly import mock_polly # flake8: noqa from .polly import mock_polly # flake8: noqa
from .rds import mock_rds, mock_rds_deprecated # flake8: noqa from .rds import mock_rds, mock_rds_deprecated # flake8: noqa

View File

@ -27,6 +27,7 @@ from moto.kinesis import kinesis_backends
from moto.kms import kms_backends from moto.kms import kms_backends
from moto.logs import logs_backends from moto.logs import logs_backends
from moto.opsworks import opsworks_backends from moto.opsworks import opsworks_backends
from moto.organizations import organizations_backends
from moto.polly import polly_backends from moto.polly import polly_backends
from moto.rds2 import rds2_backends from moto.rds2 import rds2_backends
from moto.redshift import redshift_backends from moto.redshift import redshift_backends
@ -74,6 +75,7 @@ BACKENDS = {
'kinesis': kinesis_backends, 'kinesis': kinesis_backends,
'kms': kms_backends, 'kms': kms_backends,
'opsworks': opsworks_backends, 'opsworks': opsworks_backends,
'organizations': organizations_backends,
'polly': polly_backends, 'polly': polly_backends,
'redshift': redshift_backends, 'redshift': redshift_backends,
'rds': rds2_backends, 'rds': rds2_backends,

View File

@ -387,6 +387,7 @@ class ResourceMap(collections.Mapping):
"AWS::StackName": stack_name, "AWS::StackName": stack_name,
"AWS::URLSuffix": "amazonaws.com", "AWS::URLSuffix": "amazonaws.com",
"AWS::NoValue": None, "AWS::NoValue": None,
"AWS::Partition": "aws",
} }
def __getitem__(self, key): def __getitem__(self, key):

View File

@ -3,6 +3,7 @@ from __future__ import unicode_literals
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import cognitoidentity_backends from .models import cognitoidentity_backends
from .utils import get_random_identity_id
class CognitoIdentityResponse(BaseResponse): class CognitoIdentityResponse(BaseResponse):
@ -31,4 +32,6 @@ class CognitoIdentityResponse(BaseResponse):
return cognitoidentity_backends[self.region].get_credentials_for_identity(self._get_param('IdentityId')) return cognitoidentity_backends[self.region].get_credentials_for_identity(self._get_param('IdentityId'))
def get_open_id_token_for_developer_identity(self): def get_open_id_token_for_developer_identity(self):
return cognitoidentity_backends[self.region].get_open_id_token_for_developer_identity(self._get_param('IdentityId')) return cognitoidentity_backends[self.region].get_open_id_token_for_developer_identity(
self._get_param('IdentityId') or get_random_identity_id(self.region)
)

View File

@ -2,4 +2,4 @@ from moto.core.utils import get_random_hex
def get_random_identity_id(region): def get_random_identity_id(region):
return "{0}:{0}".format(region, get_random_hex(length=19)) return "{0}:{1}".format(region, get_random_hex(length=19))

View File

@ -24,7 +24,7 @@ class CognitoIdpUserPool(BaseModel):
def __init__(self, region, name, extended_config): def __init__(self, region, name, extended_config):
self.region = region self.region = region
self.id = str(uuid.uuid4()) self.id = "{}_{}".format(self.region, str(uuid.uuid4().hex))
self.name = name self.name = name
self.status = None self.status = None
self.extended_config = extended_config or {} self.extended_config = extended_config or {}
@ -84,7 +84,11 @@ class CognitoIdpUserPool(BaseModel):
return refresh_token return refresh_token
def create_access_token(self, client_id, username): def create_access_token(self, client_id, username):
access_token, expires_in = self.create_jwt(client_id, username) extra_data = self.get_user_extra_data_by_client_id(
client_id, username
)
access_token, expires_in = self.create_jwt(client_id, username,
extra_data=extra_data)
self.access_tokens[access_token] = (client_id, username) self.access_tokens[access_token] = (client_id, username)
return access_token, expires_in return access_token, expires_in
@ -97,6 +101,21 @@ class CognitoIdpUserPool(BaseModel):
id_token, _ = self.create_id_token(client_id, username) id_token, _ = self.create_id_token(client_id, username)
return access_token, id_token, expires_in return access_token, id_token, expires_in
def get_user_extra_data_by_client_id(self, client_id, username):
extra_data = {}
current_client = self.clients.get(client_id, None)
if current_client:
for readable_field in current_client.get_readable_fields():
attribute = list(filter(
lambda f: f['Name'] == readable_field,
self.users.get(username).attributes
))
if len(attribute) > 0:
extra_data.update({
attribute[0]['Name']: attribute[0]['Value']
})
return extra_data
class CognitoIdpUserPoolDomain(BaseModel): class CognitoIdpUserPoolDomain(BaseModel):
@ -138,6 +157,9 @@ class CognitoIdpUserPoolClient(BaseModel):
return user_pool_client_json return user_pool_client_json
def get_readable_fields(self):
return self.extended_config.get('ReadAttributes', [])
class CognitoIdpIdentityProvider(BaseModel): class CognitoIdpIdentityProvider(BaseModel):

View File

@ -89,6 +89,17 @@ class BaseMockAWS(object):
if inspect.ismethod(attr_value) and attr_value.__self__ is klass: if inspect.ismethod(attr_value) and attr_value.__self__ is klass:
continue continue
# Check if this is a staticmethod. If so, skip patching
for cls in inspect.getmro(klass):
if attr_value.__name__ not in cls.__dict__:
continue
bound_attr_value = cls.__dict__[attr_value.__name__]
if not isinstance(bound_attr_value, staticmethod):
break
else:
# It is a staticmethod, skip patching
continue
try: try:
setattr(klass, attr, self(attr_value, reset=False)) setattr(klass, attr, self(attr_value, reset=False))
except TypeError: except TypeError:

View File

@ -154,7 +154,7 @@ class Item(BaseModel):
# If not exists, changes value to a default if needed, else its the same as it was # If not exists, changes value to a default if needed, else its the same as it was
if value.startswith('if_not_exists'): if value.startswith('if_not_exists'):
# Function signature # Function signature
match = re.match(r'.*if_not_exists\((?P<path>.+),\s*(?P<default>.+)\).*', value) match = re.match(r'.*if_not_exists\s*\((?P<path>.+),\s*(?P<default>.+)\).*', value)
if not match: if not match:
raise TypeError raise TypeError
@ -162,8 +162,9 @@ class Item(BaseModel):
# If it already exists, get its value so we dont overwrite it # If it already exists, get its value so we dont overwrite it
if path in self.attrs: if path in self.attrs:
value = self.attrs[path].cast_value value = self.attrs[path]
if type(value) != DynamoType:
if value in expression_attribute_values: if value in expression_attribute_values:
value = DynamoType(expression_attribute_values[value]) value = DynamoType(expression_attribute_values[value])
else: else:

View File

@ -20,6 +20,17 @@ def has_empty_keys_or_values(_dict):
) )
def get_empty_str_error():
er = 'com.amazonaws.dynamodb.v20111205#ValidationException'
return (400,
{'server': 'amazon.com'},
dynamo_json_dump({'__type': er,
'message': ('One or more parameter values were '
'invalid: An AttributeValue may not '
'contain an empty string')}
))
class DynamoHandler(BaseResponse): class DynamoHandler(BaseResponse):
def get_endpoint_name(self, headers): def get_endpoint_name(self, headers):
@ -174,14 +185,7 @@ class DynamoHandler(BaseResponse):
item = self.body['Item'] item = self.body['Item']
if has_empty_keys_or_values(item): if has_empty_keys_or_values(item):
er = 'com.amazonaws.dynamodb.v20111205#ValidationException' return get_empty_str_error()
return (400,
{'server': 'amazon.com'},
dynamo_json_dump({'__type': er,
'message': ('One or more parameter values were '
'invalid: An AttributeValue may not '
'contain an empty string')}
))
overwrite = 'Expected' not in self.body overwrite = 'Expected' not in self.body
if not overwrite: if not overwrite:
@ -200,9 +204,9 @@ class DynamoHandler(BaseResponse):
if cond_items: if cond_items:
expected = {} expected = {}
overwrite = False overwrite = False
exists_re = re.compile('^attribute_exists\((.*)\)$') exists_re = re.compile('^attribute_exists\s*\((.*)\)$')
not_exists_re = re.compile( not_exists_re = re.compile(
'^attribute_not_exists\((.*)\)$') '^attribute_not_exists\s*\((.*)\)$')
for cond in cond_items: for cond in cond_items:
exists_m = exists_re.match(cond) exists_m = exists_re.match(cond)
@ -523,6 +527,7 @@ class DynamoHandler(BaseResponse):
return dynamo_json_dump(item_dict) return dynamo_json_dump(item_dict)
def update_item(self): def update_item(self):
name = self.body['TableName'] name = self.body['TableName']
key = self.body['Key'] key = self.body['Key']
update_expression = self.body.get('UpdateExpression') update_expression = self.body.get('UpdateExpression')
@ -533,6 +538,9 @@ class DynamoHandler(BaseResponse):
'ExpressionAttributeValues', {}) 'ExpressionAttributeValues', {})
existing_item = self.dynamodb_backend.get_item(name, key) existing_item = self.dynamodb_backend.get_item(name, key)
if has_empty_keys_or_values(expression_attribute_values):
return get_empty_str_error()
if 'Expected' in self.body: if 'Expected' in self.body:
expected = self.body['Expected'] expected = self.body['Expected']
else: else:
@ -548,9 +556,9 @@ class DynamoHandler(BaseResponse):
if cond_items: if cond_items:
expected = {} expected = {}
exists_re = re.compile('^attribute_exists\((.*)\)$') exists_re = re.compile('^attribute_exists\s*\((.*)\)$')
not_exists_re = re.compile( not_exists_re = re.compile(
'^attribute_not_exists\((.*)\)$') '^attribute_not_exists\s*\((.*)\)$')
for cond in cond_items: for cond in cond_items:
exists_m = exists_re.match(cond) exists_m = exists_re.match(cond)

View File

@ -13,6 +13,7 @@ from pkg_resources import resource_filename
import boto.ec2 import boto.ec2
from collections import defaultdict from collections import defaultdict
import weakref
from datetime import datetime from datetime import datetime
from boto.ec2.instance import Instance as BotoInstance, Reservation from boto.ec2.instance import Instance as BotoInstance, Reservation
from boto.ec2.blockdevicemapping import BlockDeviceMapping, BlockDeviceType from boto.ec2.blockdevicemapping import BlockDeviceMapping, BlockDeviceType
@ -2115,10 +2116,20 @@ class VPC(TaggedEC2Resource):
class VPCBackend(object): class VPCBackend(object):
__refs__ = defaultdict(list)
def __init__(self): def __init__(self):
self.vpcs = {} self.vpcs = {}
self.__refs__[self.__class__].append(weakref.ref(self))
super(VPCBackend, self).__init__() super(VPCBackend, self).__init__()
@classmethod
def get_instances(cls):
for inst_ref in cls.__refs__[cls]:
inst = inst_ref()
if inst is not None:
yield inst
def create_vpc(self, cidr_block, instance_tenancy='default', amazon_provided_ipv6_cidr_block=False): def create_vpc(self, cidr_block, instance_tenancy='default', amazon_provided_ipv6_cidr_block=False):
vpc_id = random_vpc_id() vpc_id = random_vpc_id()
vpc = VPC(self, vpc_id, cidr_block, len(self.vpcs) == 0, instance_tenancy, amazon_provided_ipv6_cidr_block) vpc = VPC(self, vpc_id, cidr_block, len(self.vpcs) == 0, instance_tenancy, amazon_provided_ipv6_cidr_block)
@ -2142,6 +2153,13 @@ class VPCBackend(object):
raise InvalidVPCIdError(vpc_id) raise InvalidVPCIdError(vpc_id)
return self.vpcs.get(vpc_id) return self.vpcs.get(vpc_id)
# get vpc by vpc id and aws region
def get_cross_vpc(self, vpc_id, peer_region):
for vpcs in self.get_instances():
if vpcs.region_name == peer_region:
match_vpc = vpcs.get_vpc(vpc_id)
return match_vpc
def get_all_vpcs(self, vpc_ids=None, filters=None): def get_all_vpcs(self, vpc_ids=None, filters=None):
matches = self.vpcs.values() matches = self.vpcs.values()
if vpc_ids: if vpc_ids:

View File

@ -5,8 +5,12 @@ from moto.core.responses import BaseResponse
class VPCPeeringConnections(BaseResponse): class VPCPeeringConnections(BaseResponse):
def create_vpc_peering_connection(self): def create_vpc_peering_connection(self):
vpc = self.ec2_backend.get_vpc(self._get_param('VpcId')) peer_region = self._get_param('PeerRegion')
if peer_region == self.region or peer_region is None:
peer_vpc = self.ec2_backend.get_vpc(self._get_param('PeerVpcId')) peer_vpc = self.ec2_backend.get_vpc(self._get_param('PeerVpcId'))
else:
peer_vpc = self.ec2_backend.get_cross_vpc(self._get_param('PeerVpcId'), peer_region)
vpc = self.ec2_backend.get_vpc(self._get_param('VpcId'))
vpc_pcx = self.ec2_backend.create_vpc_peering_connection(vpc, peer_vpc) vpc_pcx = self.ec2_backend.create_vpc_peering_connection(vpc, peer_vpc)
template = self.response_template( template = self.response_template(
CREATE_VPC_PEERING_CONNECTION_RESPONSE) CREATE_VPC_PEERING_CONNECTION_RESPONSE)
@ -41,7 +45,7 @@ class VPCPeeringConnections(BaseResponse):
CREATE_VPC_PEERING_CONNECTION_RESPONSE = """ CREATE_VPC_PEERING_CONNECTION_RESPONSE = """
<CreateVpcPeeringConnectionResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/"> <CreateVpcPeeringConnectionResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</requestId> <requestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</requestId>
<vpcPeeringConnection> <vpcPeeringConnection>
<vpcPeeringConnectionId>{{ vpc_pcx.id }}</vpcPeeringConnectionId> <vpcPeeringConnectionId>{{ vpc_pcx.id }}</vpcPeeringConnectionId>
@ -49,6 +53,11 @@ CREATE_VPC_PEERING_CONNECTION_RESPONSE = """
<ownerId>777788889999</ownerId> <ownerId>777788889999</ownerId>
<vpcId>{{ vpc_pcx.vpc.id }}</vpcId> <vpcId>{{ vpc_pcx.vpc.id }}</vpcId>
<cidrBlock>{{ vpc_pcx.vpc.cidr_block }}</cidrBlock> <cidrBlock>{{ vpc_pcx.vpc.cidr_block }}</cidrBlock>
<peeringOptions>
<allowEgressFromLocalClassicLinkToRemoteVpc>false</allowEgressFromLocalClassicLinkToRemoteVpc>
<allowEgressFromLocalVpcToRemoteClassicLink>false</allowEgressFromLocalVpcToRemoteClassicLink>
<allowDnsResolutionFromRemoteVpc>false</allowDnsResolutionFromRemoteVpc>
</peeringOptions>
</requesterVpcInfo> </requesterVpcInfo>
<accepterVpcInfo> <accepterVpcInfo>
<ownerId>123456789012</ownerId> <ownerId>123456789012</ownerId>
@ -56,7 +65,7 @@ CREATE_VPC_PEERING_CONNECTION_RESPONSE = """
</accepterVpcInfo> </accepterVpcInfo>
<status> <status>
<code>initiating-request</code> <code>initiating-request</code>
<message>Initiating request to {accepter ID}.</message> <message>Initiating Request to {accepter ID}</message>
</status> </status>
<expirationTime>2014-02-18T14:37:25.000Z</expirationTime> <expirationTime>2014-02-18T14:37:25.000Z</expirationTime>
<tagSet/> <tagSet/>

View File

@ -179,7 +179,7 @@ class Task(BaseObject):
class Service(BaseObject): class Service(BaseObject):
def __init__(self, cluster, service_name, task_definition, desired_count, load_balancers=None): def __init__(self, cluster, service_name, task_definition, desired_count, load_balancers=None, scheduling_strategy=None):
self.cluster_arn = cluster.arn self.cluster_arn = cluster.arn
self.arn = 'arn:aws:ecs:us-east-1:012345678910:service/{0}'.format( self.arn = 'arn:aws:ecs:us-east-1:012345678910:service/{0}'.format(
service_name) service_name)
@ -202,6 +202,7 @@ class Service(BaseObject):
} }
] ]
self.load_balancers = load_balancers if load_balancers is not None else [] self.load_balancers = load_balancers if load_balancers is not None else []
self.scheduling_strategy = scheduling_strategy if scheduling_strategy is not None else 'REPLICA'
self.pending_count = 0 self.pending_count = 0
@property @property
@ -214,6 +215,7 @@ class Service(BaseObject):
del response_object['name'], response_object['arn'] del response_object['name'], response_object['arn']
response_object['serviceName'] = self.name response_object['serviceName'] = self.name
response_object['serviceArn'] = self.arn response_object['serviceArn'] = self.arn
response_object['schedulingStrategy'] = self.scheduling_strategy
for deployment in response_object['deployments']: for deployment in response_object['deployments']:
if isinstance(deployment['createdAt'], datetime): if isinstance(deployment['createdAt'], datetime):
@ -655,7 +657,7 @@ class EC2ContainerServiceBackend(BaseBackend):
raise Exception("Could not find task {} on cluster {}".format( raise Exception("Could not find task {} on cluster {}".format(
task_str, cluster_name)) task_str, cluster_name))
def create_service(self, cluster_str, service_name, task_definition_str, desired_count, load_balancers=None): def create_service(self, cluster_str, service_name, task_definition_str, desired_count, load_balancers=None, scheduling_strategy=None):
cluster_name = cluster_str.split('/')[-1] cluster_name = cluster_str.split('/')[-1]
if cluster_name in self.clusters: if cluster_name in self.clusters:
cluster = self.clusters[cluster_name] cluster = self.clusters[cluster_name]
@ -665,7 +667,7 @@ class EC2ContainerServiceBackend(BaseBackend):
desired_count = desired_count if desired_count is not None else 0 desired_count = desired_count if desired_count is not None else 0
service = Service(cluster, service_name, service = Service(cluster, service_name,
task_definition, desired_count, load_balancers) task_definition, desired_count, load_balancers, scheduling_strategy)
cluster_service_pair = '{0}:{1}'.format(cluster_name, service_name) cluster_service_pair = '{0}:{1}'.format(cluster_name, service_name)
self.services[cluster_service_pair] = service self.services[cluster_service_pair] = service

View File

@ -154,8 +154,9 @@ class EC2ContainerServiceResponse(BaseResponse):
task_definition_str = self._get_param('taskDefinition') task_definition_str = self._get_param('taskDefinition')
desired_count = self._get_int_param('desiredCount') desired_count = self._get_int_param('desiredCount')
load_balancers = self._get_param('loadBalancers') load_balancers = self._get_param('loadBalancers')
scheduling_strategy = self._get_param('schedulingStrategy')
service = self.ecs_backend.create_service( service = self.ecs_backend.create_service(
cluster_str, service_name, task_definition_str, desired_count, load_balancers) cluster_str, service_name, task_definition_str, desired_count, load_balancers, scheduling_strategy)
return json.dumps({ return json.dumps({
'service': service.response_object 'service': service.response_object
}) })

View File

@ -259,12 +259,22 @@ class ELBResponse(BaseResponse):
def describe_instance_health(self): def describe_instance_health(self):
load_balancer_name = self._get_param('LoadBalancerName') load_balancer_name = self._get_param('LoadBalancerName')
instance_ids = [list(param.values())[0] for param in self._get_list_prefix('Instances.member')] provided_instance_ids = [
if len(instance_ids) == 0: list(param.values())[0]
instance_ids = self.elb_backend.get_load_balancer( for param in self._get_list_prefix('Instances.member')
]
registered_instances_id = self.elb_backend.get_load_balancer(
load_balancer_name).instance_ids load_balancer_name).instance_ids
if len(provided_instance_ids) == 0:
provided_instance_ids = registered_instances_id
template = self.response_template(DESCRIBE_INSTANCE_HEALTH_TEMPLATE) template = self.response_template(DESCRIBE_INSTANCE_HEALTH_TEMPLATE)
return template.render(instance_ids=instance_ids) instances = []
for instance_id in provided_instance_ids:
state = "InService" \
if instance_id in registered_instances_id\
else "Unknown"
instances.append({"InstanceId": instance_id, "State": state})
return template.render(instances=instances)
def add_tags(self): def add_tags(self):
@ -689,11 +699,11 @@ SET_LOAD_BALANCER_POLICIES_FOR_BACKEND_SERVER_TEMPLATE = """<SetLoadBalancerPoli
DESCRIBE_INSTANCE_HEALTH_TEMPLATE = """<DescribeInstanceHealthResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2012-06-01/"> DESCRIBE_INSTANCE_HEALTH_TEMPLATE = """<DescribeInstanceHealthResponse xmlns="http://elasticloadbalancing.amazonaws.com/doc/2012-06-01/">
<DescribeInstanceHealthResult> <DescribeInstanceHealthResult>
<InstanceStates> <InstanceStates>
{% for instance_id in instance_ids %} {% for instance in instances %}
<member> <member>
<Description>N/A</Description> <Description>N/A</Description>
<InstanceId>{{ instance_id }}</InstanceId> <InstanceId>{{ instance['InstanceId'] }}</InstanceId>
<State>InService</State> <State>{{ instance['State'] }}</State>
<ReasonCode>N/A</ReasonCode> <ReasonCode>N/A</ReasonCode>
</member> </member>
{% endfor %} {% endfor %}

View File

@ -6,19 +6,56 @@ class GlueClientError(JsonRESTError):
code = 400 code = 400
class DatabaseAlreadyExistsException(GlueClientError): class AlreadyExistsException(GlueClientError):
def __init__(self): def __init__(self, typ):
self.code = 400 super(GlueClientError, self).__init__(
super(DatabaseAlreadyExistsException, self).__init__( 'AlreadyExistsException',
'DatabaseAlreadyExistsException', '%s already exists.' % (typ),
'Database already exists.'
) )
class TableAlreadyExistsException(GlueClientError): class DatabaseAlreadyExistsException(AlreadyExistsException):
def __init__(self): def __init__(self):
self.code = 400 super(DatabaseAlreadyExistsException, self).__init__('Database')
super(TableAlreadyExistsException, self).__init__(
'TableAlreadyExistsException',
'Table already exists.' class TableAlreadyExistsException(AlreadyExistsException):
def __init__(self):
super(TableAlreadyExistsException, self).__init__('Table')
class PartitionAlreadyExistsException(AlreadyExistsException):
def __init__(self):
super(PartitionAlreadyExistsException, self).__init__('Partition')
class EntityNotFoundException(GlueClientError):
def __init__(self, msg):
super(GlueClientError, self).__init__(
'EntityNotFoundException',
msg,
) )
class DatabaseNotFoundException(EntityNotFoundException):
def __init__(self, db):
super(DatabaseNotFoundException, self).__init__(
'Database %s not found.' % db,
)
class TableNotFoundException(EntityNotFoundException):
def __init__(self, tbl):
super(TableNotFoundException, self).__init__(
'Table %s not found.' % tbl,
)
class PartitionNotFoundException(EntityNotFoundException):
def __init__(self):
super(PartitionNotFoundException, self).__init__("Cannot find partition.")
class VersionNotFoundException(EntityNotFoundException):
def __init__(self):
super(VersionNotFoundException, self).__init__("Version not found.")

View File

@ -1,8 +1,19 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import time
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.compat import OrderedDict from moto.compat import OrderedDict
from.exceptions import DatabaseAlreadyExistsException, TableAlreadyExistsException from.exceptions import (
JsonRESTError,
DatabaseAlreadyExistsException,
DatabaseNotFoundException,
TableAlreadyExistsException,
TableNotFoundException,
PartitionAlreadyExistsException,
PartitionNotFoundException,
VersionNotFoundException,
)
class GlueBackend(BaseBackend): class GlueBackend(BaseBackend):
@ -19,7 +30,10 @@ class GlueBackend(BaseBackend):
return database return database
def get_database(self, database_name): def get_database(self, database_name):
try:
return self.databases[database_name] return self.databases[database_name]
except KeyError:
raise DatabaseNotFoundException(database_name)
def create_table(self, database_name, table_name, table_input): def create_table(self, database_name, table_name, table_input):
database = self.get_database(database_name) database = self.get_database(database_name)
@ -33,7 +47,10 @@ class GlueBackend(BaseBackend):
def get_table(self, database_name, table_name): def get_table(self, database_name, table_name):
database = self.get_database(database_name) database = self.get_database(database_name)
try:
return database.tables[table_name] return database.tables[table_name]
except KeyError:
raise TableNotFoundException(table_name)
def get_tables(self, database_name): def get_tables(self, database_name):
database = self.get_database(database_name) database = self.get_database(database_name)
@ -52,9 +69,84 @@ class FakeTable(BaseModel):
def __init__(self, database_name, table_name, table_input): def __init__(self, database_name, table_name, table_input):
self.database_name = database_name self.database_name = database_name
self.name = table_name self.name = table_name
self.table_input = table_input self.partitions = OrderedDict()
self.storage_descriptor = self.table_input.get('StorageDescriptor', {}) self.versions = []
self.partition_keys = self.table_input.get('PartitionKeys', []) self.update(table_input)
def update(self, table_input):
self.versions.append(table_input)
def get_version(self, ver):
try:
if not isinstance(ver, int):
# "1" goes to [0]
ver = int(ver) - 1
except ValueError as e:
raise JsonRESTError("InvalidInputException", str(e))
try:
return self.versions[ver]
except IndexError:
raise VersionNotFoundException()
def as_dict(self, version=-1):
obj = {
'DatabaseName': self.database_name,
'Name': self.name,
}
obj.update(self.get_version(version))
return obj
def create_partition(self, partiton_input):
partition = FakePartition(self.database_name, self.name, partiton_input)
key = str(partition.values)
if key in self.partitions:
raise PartitionAlreadyExistsException()
self.partitions[str(partition.values)] = partition
def get_partitions(self):
return [p for str_part_values, p in self.partitions.items()]
def get_partition(self, values):
try:
return self.partitions[str(values)]
except KeyError:
raise PartitionNotFoundException()
def update_partition(self, old_values, partiton_input):
partition = FakePartition(self.database_name, self.name, partiton_input)
key = str(partition.values)
if old_values == partiton_input['Values']:
# Altering a partition in place. Don't remove it so the order of
# returned partitions doesn't change
if key not in self.partitions:
raise PartitionNotFoundException()
else:
removed = self.partitions.pop(str(old_values), None)
if removed is None:
raise PartitionNotFoundException()
if key in self.partitions:
# Trying to update to overwrite a partition that exists
raise PartitionAlreadyExistsException()
self.partitions[key] = partition
class FakePartition(BaseModel):
def __init__(self, database_name, table_name, partiton_input):
self.creation_time = time.time()
self.database_name = database_name
self.table_name = table_name
self.partition_input = partiton_input
self.values = self.partition_input.get('Values', [])
def as_dict(self):
obj = {
'DatabaseName': self.database_name,
'TableName': self.table_name,
'CreationTime': self.creation_time,
}
obj.update(self.partition_input)
return obj
glue_backend = GlueBackend() glue_backend = GlueBackend()

View File

@ -37,27 +37,94 @@ class GlueResponse(BaseResponse):
database_name = self.parameters.get('DatabaseName') database_name = self.parameters.get('DatabaseName')
table_name = self.parameters.get('Name') table_name = self.parameters.get('Name')
table = self.glue_backend.get_table(database_name, table_name) table = self.glue_backend.get_table(database_name, table_name)
return json.dumps({'Table': table.as_dict()})
def update_table(self):
database_name = self.parameters.get('DatabaseName')
table_input = self.parameters.get('TableInput')
table_name = table_input.get('Name')
table = self.glue_backend.get_table(database_name, table_name)
table.update(table_input)
return ""
def get_table_versions(self):
database_name = self.parameters.get('DatabaseName')
table_name = self.parameters.get('TableName')
table = self.glue_backend.get_table(database_name, table_name)
return json.dumps({ return json.dumps({
'Table': { "TableVersions": [
'DatabaseName': table.database_name, {
'Name': table.name, "Table": table.as_dict(version=n),
'PartitionKeys': table.partition_keys, "VersionId": str(n + 1),
'StorageDescriptor': table.storage_descriptor } for n in range(len(table.versions))
} ],
})
def get_table_version(self):
database_name = self.parameters.get('DatabaseName')
table_name = self.parameters.get('TableName')
table = self.glue_backend.get_table(database_name, table_name)
ver_id = self.parameters.get('VersionId')
return json.dumps({
"TableVersion": {
"Table": table.as_dict(version=ver_id),
"VersionId": ver_id,
},
}) })
def get_tables(self): def get_tables(self):
database_name = self.parameters.get('DatabaseName') database_name = self.parameters.get('DatabaseName')
tables = self.glue_backend.get_tables(database_name) tables = self.glue_backend.get_tables(database_name)
return json.dumps( return json.dumps({
{
'TableList': [ 'TableList': [
{ table.as_dict() for table in tables
'DatabaseName': table.database_name,
'Name': table.name,
'PartitionKeys': table.partition_keys,
'StorageDescriptor': table.storage_descriptor
} for table in tables
] ]
} })
)
def get_partitions(self):
database_name = self.parameters.get('DatabaseName')
table_name = self.parameters.get('TableName')
if 'Expression' in self.parameters:
raise NotImplementedError("Expression filtering in get_partitions is not implemented in moto")
table = self.glue_backend.get_table(database_name, table_name)
return json.dumps({
'Partitions': [
p.as_dict() for p in table.get_partitions()
]
})
def get_partition(self):
database_name = self.parameters.get('DatabaseName')
table_name = self.parameters.get('TableName')
values = self.parameters.get('PartitionValues')
table = self.glue_backend.get_table(database_name, table_name)
p = table.get_partition(values)
return json.dumps({'Partition': p.as_dict()})
def create_partition(self):
database_name = self.parameters.get('DatabaseName')
table_name = self.parameters.get('TableName')
part_input = self.parameters.get('PartitionInput')
table = self.glue_backend.get_table(database_name, table_name)
table.create_partition(part_input)
return ""
def update_partition(self):
database_name = self.parameters.get('DatabaseName')
table_name = self.parameters.get('TableName')
part_input = self.parameters.get('PartitionInput')
part_to_update = self.parameters.get('PartitionValueList')
table = self.glue_backend.get_table(database_name, table_name)
table.update_partition(part_to_update, part_input)
return ""

View File

@ -37,7 +37,6 @@ class Policy(BaseModel):
description=None, description=None,
document=None, document=None,
path=None): path=None):
self.document = document or {}
self.name = name self.name = name
self.attachment_count = 0 self.attachment_count = 0
@ -45,7 +44,7 @@ class Policy(BaseModel):
self.id = random_policy_id() self.id = random_policy_id()
self.path = path or '/' self.path = path or '/'
self.default_version_id = default_version_id or 'v1' self.default_version_id = default_version_id or 'v1'
self.versions = [] self.versions = [PolicyVersion(self.arn, document, True)]
self.create_datetime = datetime.now(pytz.utc) self.create_datetime = datetime.now(pytz.utc)
self.update_datetime = datetime.now(pytz.utc) self.update_datetime = datetime.now(pytz.utc)
@ -72,11 +71,11 @@ class ManagedPolicy(Policy):
def attach_to(self, obj): def attach_to(self, obj):
self.attachment_count += 1 self.attachment_count += 1
obj.managed_policies[self.name] = self obj.managed_policies[self.arn] = self
def detach_from(self, obj): def detach_from(self, obj):
self.attachment_count -= 1 self.attachment_count -= 1
del obj.managed_policies[self.name] del obj.managed_policies[self.arn]
@property @property
def arn(self): def arn(self):
@ -477,11 +476,13 @@ class IAMBackend(BaseBackend):
document=policy_document, document=policy_document,
path=path, path=path,
) )
self.managed_policies[policy.name] = policy self.managed_policies[policy.arn] = policy
return policy return policy
def get_policy(self, policy_name): def get_policy(self, policy_arn):
return self.managed_policies.get(policy_name) if policy_arn not in self.managed_policies:
raise IAMNotFoundException("Policy {0} not found".format(policy_arn))
return self.managed_policies.get(policy_arn)
def list_attached_role_policies(self, role_name, marker=None, max_items=100, path_prefix='/'): def list_attached_role_policies(self, role_name, marker=None, max_items=100, path_prefix='/'):
policies = self.get_role(role_name).managed_policies.values() policies = self.get_role(role_name).managed_policies.values()
@ -575,21 +576,18 @@ class IAMBackend(BaseBackend):
return role.policies.keys() return role.policies.keys()
def create_policy_version(self, policy_arn, policy_document, set_as_default): def create_policy_version(self, policy_arn, policy_document, set_as_default):
policy_name = policy_arn.split(':')[-1] policy = self.get_policy(policy_arn)
policy_name = policy_name.split('/')[1]
policy = self.get_policy(policy_name)
if not policy: if not policy:
raise IAMNotFoundException("Policy not found") raise IAMNotFoundException("Policy not found")
version = PolicyVersion(policy_arn, policy_document, set_as_default) version = PolicyVersion(policy_arn, policy_document, set_as_default)
policy.versions.append(version) policy.versions.append(version)
version.version_id = 'v{0}'.format(len(policy.versions))
if set_as_default: if set_as_default:
policy.default_version_id = version.version_id policy.default_version_id = version.version_id
return version return version
def get_policy_version(self, policy_arn, version_id): def get_policy_version(self, policy_arn, version_id):
policy_name = policy_arn.split(':')[-1] policy = self.get_policy(policy_arn)
policy_name = policy_name.split('/')[1]
policy = self.get_policy(policy_name)
if not policy: if not policy:
raise IAMNotFoundException("Policy not found") raise IAMNotFoundException("Policy not found")
for version in policy.versions: for version in policy.versions:
@ -598,19 +596,18 @@ class IAMBackend(BaseBackend):
raise IAMNotFoundException("Policy version not found") raise IAMNotFoundException("Policy version not found")
def list_policy_versions(self, policy_arn): def list_policy_versions(self, policy_arn):
policy_name = policy_arn.split(':')[-1] policy = self.get_policy(policy_arn)
policy_name = policy_name.split('/')[1]
policy = self.get_policy(policy_name)
if not policy: if not policy:
raise IAMNotFoundException("Policy not found") raise IAMNotFoundException("Policy not found")
return policy.versions return policy.versions
def delete_policy_version(self, policy_arn, version_id): def delete_policy_version(self, policy_arn, version_id):
policy_name = policy_arn.split(':')[-1] policy = self.get_policy(policy_arn)
policy_name = policy_name.split('/')[1]
policy = self.get_policy(policy_name)
if not policy: if not policy:
raise IAMNotFoundException("Policy not found") raise IAMNotFoundException("Policy not found")
if version_id == policy.default_version_id:
raise IAMConflictException(
"Cannot delete the default version of a policy")
for i, v in enumerate(policy.versions): for i, v in enumerate(policy.versions):
if v.version_id == version_id: if v.version_id == version_id:
del policy.versions[i] del policy.versions[i]

View File

@ -58,6 +58,12 @@ class IamResponse(BaseResponse):
template = self.response_template(CREATE_POLICY_TEMPLATE) template = self.response_template(CREATE_POLICY_TEMPLATE)
return template.render(policy=policy) return template.render(policy=policy)
def get_policy(self):
policy_arn = self._get_param('PolicyArn')
policy = iam_backend.get_policy(policy_arn)
template = self.response_template(GET_POLICY_TEMPLATE)
return template.render(policy=policy)
def list_attached_role_policies(self): def list_attached_role_policies(self):
marker = self._get_param('Marker') marker = self._get_param('Marker')
max_items = self._get_int_param('MaxItems', 100) max_items = self._get_int_param('MaxItems', 100)
@ -601,6 +607,25 @@ CREATE_POLICY_TEMPLATE = """<CreatePolicyResponse>
</ResponseMetadata> </ResponseMetadata>
</CreatePolicyResponse>""" </CreatePolicyResponse>"""
GET_POLICY_TEMPLATE = """<GetPolicyResponse>
<GetPolicyResult>
<Policy>
<PolicyName>{{ policy.name }}</PolicyName>
<Description>{{ policy.description }}</Description>
<DefaultVersionId>{{ policy.default_version_id }}</DefaultVersionId>
<PolicyId>{{ policy.id }}</PolicyId>
<Path>{{ policy.path }}</Path>
<Arn>{{ policy.arn }}</Arn>
<AttachmentCount>{{ policy.attachment_count }}</AttachmentCount>
<CreateDate>{{ policy.create_datetime.isoformat() }}</CreateDate>
<UpdateDate>{{ policy.update_datetime.isoformat() }}</UpdateDate>
</Policy>
</GetPolicyResult>
<ResponseMetadata>
<RequestId>684f0917-3d22-11e4-a4a0-cffb9EXAMPLE</RequestId>
</ResponseMetadata>
</GetPolicyResponse>"""
LIST_ATTACHED_ROLE_POLICIES_TEMPLATE = """<ListAttachedRolePoliciesResponse> LIST_ATTACHED_ROLE_POLICIES_TEMPLATE = """<ListAttachedRolePoliciesResponse>
<ListAttachedRolePoliciesResult> <ListAttachedRolePoliciesResult>
{% if marker is none %} {% if marker is none %}

View File

@ -2,8 +2,10 @@ from __future__ import unicode_literals
import boto.kms import boto.kms
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_without_milliseconds
from .utils import generate_key_id from .utils import generate_key_id
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timedelta
class Key(BaseModel): class Key(BaseModel):
@ -12,11 +14,13 @@ class Key(BaseModel):
self.id = generate_key_id() self.id = generate_key_id()
self.policy = policy self.policy = policy
self.key_usage = key_usage self.key_usage = key_usage
self.key_state = "Enabled"
self.description = description self.description = description
self.enabled = True self.enabled = True
self.region = region self.region = region
self.account_id = "0123456789012" self.account_id = "0123456789012"
self.key_rotation_status = False self.key_rotation_status = False
self.deletion_date = None
@property @property
def physical_resource_id(self): def physical_resource_id(self):
@ -27,7 +31,7 @@ class Key(BaseModel):
return "arn:aws:kms:{0}:{1}:key/{2}".format(self.region, self.account_id, self.id) return "arn:aws:kms:{0}:{1}:key/{2}".format(self.region, self.account_id, self.id)
def to_dict(self): def to_dict(self):
return { key_dict = {
"KeyMetadata": { "KeyMetadata": {
"AWSAccountId": self.account_id, "AWSAccountId": self.account_id,
"Arn": self.arn, "Arn": self.arn,
@ -36,8 +40,12 @@ class Key(BaseModel):
"Enabled": self.enabled, "Enabled": self.enabled,
"KeyId": self.id, "KeyId": self.id,
"KeyUsage": self.key_usage, "KeyUsage": self.key_usage,
"KeyState": self.key_state,
} }
} }
if self.key_state == 'PendingDeletion':
key_dict['KeyMetadata']['DeletionDate'] = iso_8601_datetime_without_milliseconds(self.deletion_date)
return key_dict
def delete(self, region_name): def delete(self, region_name):
kms_backends[region_name].delete_key(self.id) kms_backends[region_name].delete_key(self.id)
@ -138,6 +146,29 @@ class KmsBackend(BaseBackend):
def get_key_policy(self, key_id): def get_key_policy(self, key_id):
return self.keys[self.get_key_id(key_id)].policy return self.keys[self.get_key_id(key_id)].policy
def disable_key(self, key_id):
if key_id in self.keys:
self.keys[key_id].enabled = False
self.keys[key_id].key_state = 'Disabled'
def enable_key(self, key_id):
if key_id in self.keys:
self.keys[key_id].enabled = True
self.keys[key_id].key_state = 'Enabled'
def cancel_key_deletion(self, key_id):
if key_id in self.keys:
self.keys[key_id].key_state = 'Disabled'
self.keys[key_id].deletion_date = None
def schedule_key_deletion(self, key_id, pending_window_in_days):
if key_id in self.keys:
if 7 <= pending_window_in_days <= 30:
self.keys[key_id].enabled = False
self.keys[key_id].key_state = 'PendingDeletion'
self.keys[key_id].deletion_date = datetime.now() + timedelta(days=pending_window_in_days)
return iso_8601_datetime_without_milliseconds(self.keys[key_id].deletion_date)
kms_backends = {} kms_backends = {}
for region in boto.kms.regions(): for region in boto.kms.regions():

View File

@ -233,6 +233,56 @@ class KmsResponse(BaseResponse):
value = self.parameters.get("CiphertextBlob") value = self.parameters.get("CiphertextBlob")
return json.dumps({"Plaintext": base64.b64decode(value).decode("utf-8")}) return json.dumps({"Plaintext": base64.b64decode(value).decode("utf-8")})
def disable_key(self):
key_id = self.parameters.get('KeyId')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
try:
self.kms_backend.disable_key(key_id)
except KeyError:
raise JSONResponseError(404, 'Not Found', body={
'message': "Key 'arn:aws:kms:{region}:012345678912:key/{key_id}' does not exist".format(region=self.region, key_id=key_id),
'__type': 'NotFoundException'})
return json.dumps(None)
def enable_key(self):
key_id = self.parameters.get('KeyId')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
try:
self.kms_backend.enable_key(key_id)
except KeyError:
raise JSONResponseError(404, 'Not Found', body={
'message': "Key 'arn:aws:kms:{region}:012345678912:key/{key_id}' does not exist".format(region=self.region, key_id=key_id),
'__type': 'NotFoundException'})
return json.dumps(None)
def cancel_key_deletion(self):
key_id = self.parameters.get('KeyId')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
try:
self.kms_backend.cancel_key_deletion(key_id)
except KeyError:
raise JSONResponseError(404, 'Not Found', body={
'message': "Key 'arn:aws:kms:{region}:012345678912:key/{key_id}' does not exist".format(region=self.region, key_id=key_id),
'__type': 'NotFoundException'})
return json.dumps({'KeyId': key_id})
def schedule_key_deletion(self):
key_id = self.parameters.get('KeyId')
if self.parameters.get('PendingWindowInDays') is None:
pending_window_in_days = 30
else:
pending_window_in_days = self.parameters.get('PendingWindowInDays')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
try:
return json.dumps({
'KeyId': key_id,
'DeletionDate': self.kms_backend.schedule_key_deletion(key_id, pending_window_in_days)
})
except KeyError:
raise JSONResponseError(404, 'Not Found', body={
'message': "Key 'arn:aws:kms:{region}:012345678912:key/{key_id}' does not exist".format(region=self.region, key_id=key_id),
'__type': 'NotFoundException'})
def _assert_valid_key_id(key_id): def _assert_valid_key_id(key_id):
if not re.match(r'^[A-F0-9]{8}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{12}$', key_id, re.IGNORECASE): if not re.match(r'^[A-F0-9]{8}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{12}$', key_id, re.IGNORECASE):

View File

@ -29,5 +29,5 @@ class ResourceAlreadyExistsException(LogsClientError):
self.code = 400 self.code = 400
super(ResourceAlreadyExistsException, self).__init__( super(ResourceAlreadyExistsException, self).__init__(
'ResourceAlreadyExistsException', 'ResourceAlreadyExistsException',
'The specified resource already exists.' 'The specified log group already exists'
) )

View File

@ -19,7 +19,7 @@ class LogEvent:
def to_filter_dict(self): def to_filter_dict(self):
return { return {
"eventId": self.eventId, "eventId": str(self.eventId),
"ingestionTime": self.ingestionTime, "ingestionTime": self.ingestionTime,
# "logStreamName": # "logStreamName":
"message": self.message, "message": self.message,
@ -86,7 +86,7 @@ class LogStream:
self.events += [LogEvent(self.lastIngestionTime, log_event) for log_event in log_events] self.events += [LogEvent(self.lastIngestionTime, log_event) for log_event in log_events]
self.uploadSequenceToken += 1 self.uploadSequenceToken += 1
return self.uploadSequenceToken return '{:056d}'.format(self.uploadSequenceToken)
def get_log_events(self, log_group_name, log_stream_name, start_time, end_time, limit, next_token, start_from_head): def get_log_events(self, log_group_name, log_stream_name, start_time, end_time, limit, next_token, start_from_head):
def filter_func(event): def filter_func(event):

View File

@ -0,0 +1,6 @@
from __future__ import unicode_literals
from .models import organizations_backend
from ..core.models import base_decorator
organizations_backends = {"global": organizations_backend}
mock_organizations = base_decorator(organizations_backends)

View File

@ -0,0 +1,296 @@
from __future__ import unicode_literals
import datetime
import re
from moto.core import BaseBackend, BaseModel
from moto.core.exceptions import RESTError
from moto.core.utils import unix_time
from moto.organizations import utils
class FakeOrganization(BaseModel):
def __init__(self, feature_set):
self.id = utils.make_random_org_id()
self.root_id = utils.make_random_root_id()
self.feature_set = feature_set
self.master_account_id = utils.MASTER_ACCOUNT_ID
self.master_account_email = utils.MASTER_ACCOUNT_EMAIL
self.available_policy_types = [{
'Type': 'SERVICE_CONTROL_POLICY',
'Status': 'ENABLED'
}]
@property
def arn(self):
return utils.ORGANIZATION_ARN_FORMAT.format(self.master_account_id, self.id)
@property
def master_account_arn(self):
return utils.MASTER_ACCOUNT_ARN_FORMAT.format(self.master_account_id, self.id)
def describe(self):
return {
'Organization': {
'Id': self.id,
'Arn': self.arn,
'FeatureSet': self.feature_set,
'MasterAccountArn': self.master_account_arn,
'MasterAccountId': self.master_account_id,
'MasterAccountEmail': self.master_account_email,
'AvailablePolicyTypes': self.available_policy_types,
}
}
class FakeAccount(BaseModel):
def __init__(self, organization, **kwargs):
self.organization_id = organization.id
self.master_account_id = organization.master_account_id
self.create_account_status_id = utils.make_random_create_account_status_id()
self.id = utils.make_random_account_id()
self.name = kwargs['AccountName']
self.email = kwargs['Email']
self.create_time = datetime.datetime.utcnow()
self.status = 'ACTIVE'
self.joined_method = 'CREATED'
self.parent_id = organization.root_id
@property
def arn(self):
return utils.ACCOUNT_ARN_FORMAT.format(
self.master_account_id,
self.organization_id,
self.id
)
@property
def create_account_status(self):
return {
'CreateAccountStatus': {
'Id': self.create_account_status_id,
'AccountName': self.name,
'State': 'SUCCEEDED',
'RequestedTimestamp': unix_time(self.create_time),
'CompletedTimestamp': unix_time(self.create_time),
'AccountId': self.id,
}
}
def describe(self):
return {
'Account': {
'Id': self.id,
'Arn': self.arn,
'Email': self.email,
'Name': self.name,
'Status': self.status,
'JoinedMethod': self.joined_method,
'JoinedTimestamp': unix_time(self.create_time),
}
}
class FakeOrganizationalUnit(BaseModel):
def __init__(self, organization, **kwargs):
self.type = 'ORGANIZATIONAL_UNIT'
self.organization_id = organization.id
self.master_account_id = organization.master_account_id
self.id = utils.make_random_ou_id(organization.root_id)
self.name = kwargs.get('Name')
self.parent_id = kwargs.get('ParentId')
self._arn_format = utils.OU_ARN_FORMAT
@property
def arn(self):
return self._arn_format.format(
self.master_account_id,
self.organization_id,
self.id
)
def describe(self):
return {
'OrganizationalUnit': {
'Id': self.id,
'Arn': self.arn,
'Name': self.name,
}
}
class FakeRoot(FakeOrganizationalUnit):
def __init__(self, organization, **kwargs):
super(FakeRoot, self).__init__(organization, **kwargs)
self.type = 'ROOT'
self.id = organization.root_id
self.name = 'Root'
self.policy_types = [{
'Type': 'SERVICE_CONTROL_POLICY',
'Status': 'ENABLED'
}]
self._arn_format = utils.ROOT_ARN_FORMAT
def describe(self):
return {
'Id': self.id,
'Arn': self.arn,
'Name': self.name,
'PolicyTypes': self.policy_types
}
class OrganizationsBackend(BaseBackend):
def __init__(self):
self.org = None
self.accounts = []
self.ou = []
def create_organization(self, **kwargs):
self.org = FakeOrganization(kwargs['FeatureSet'])
self.ou.append(FakeRoot(self.org))
return self.org.describe()
def describe_organization(self):
if not self.org:
raise RESTError(
'AWSOrganizationsNotInUseException',
"Your account is not a member of an organization."
)
return self.org.describe()
def list_roots(self):
return dict(
Roots=[ou.describe() for ou in self.ou if isinstance(ou, FakeRoot)]
)
def create_organizational_unit(self, **kwargs):
new_ou = FakeOrganizationalUnit(self.org, **kwargs)
self.ou.append(new_ou)
return new_ou.describe()
def get_organizational_unit_by_id(self, ou_id):
ou = next((ou for ou in self.ou if ou.id == ou_id), None)
if ou is None:
raise RESTError(
'OrganizationalUnitNotFoundException',
"You specified an organizational unit that doesn't exist."
)
return ou
def validate_parent_id(self, parent_id):
try:
self.get_organizational_unit_by_id(parent_id)
except RESTError:
raise RESTError(
'ParentNotFoundException',
"You specified parent that doesn't exist."
)
return parent_id
def describe_organizational_unit(self, **kwargs):
ou = self.get_organizational_unit_by_id(kwargs['OrganizationalUnitId'])
return ou.describe()
def list_organizational_units_for_parent(self, **kwargs):
parent_id = self.validate_parent_id(kwargs['ParentId'])
return dict(
OrganizationalUnits=[
{
'Id': ou.id,
'Arn': ou.arn,
'Name': ou.name,
}
for ou in self.ou
if ou.parent_id == parent_id
]
)
def create_account(self, **kwargs):
new_account = FakeAccount(self.org, **kwargs)
self.accounts.append(new_account)
return new_account.create_account_status
def get_account_by_id(self, account_id):
account = next((
account for account in self.accounts
if account.id == account_id
), None)
if account is None:
raise RESTError(
'AccountNotFoundException',
"You specified an account that doesn't exist."
)
return account
def describe_account(self, **kwargs):
account = self.get_account_by_id(kwargs['AccountId'])
return account.describe()
def list_accounts(self):
return dict(
Accounts=[account.describe()['Account'] for account in self.accounts]
)
def list_accounts_for_parent(self, **kwargs):
parent_id = self.validate_parent_id(kwargs['ParentId'])
return dict(
Accounts=[
account.describe()['Account']
for account in self.accounts
if account.parent_id == parent_id
]
)
def move_account(self, **kwargs):
new_parent_id = self.validate_parent_id(kwargs['DestinationParentId'])
self.validate_parent_id(kwargs['SourceParentId'])
account = self.get_account_by_id(kwargs['AccountId'])
index = self.accounts.index(account)
self.accounts[index].parent_id = new_parent_id
def list_parents(self, **kwargs):
if re.compile(r'[0-9]{12}').match(kwargs['ChildId']):
child_object = self.get_account_by_id(kwargs['ChildId'])
else:
child_object = self.get_organizational_unit_by_id(kwargs['ChildId'])
return dict(
Parents=[
{
'Id': ou.id,
'Type': ou.type,
}
for ou in self.ou
if ou.id == child_object.parent_id
]
)
def list_children(self, **kwargs):
parent_id = self.validate_parent_id(kwargs['ParentId'])
if kwargs['ChildType'] == 'ACCOUNT':
obj_list = self.accounts
elif kwargs['ChildType'] == 'ORGANIZATIONAL_UNIT':
obj_list = self.ou
else:
raise RESTError(
'InvalidInputException',
'You specified an invalid value.'
)
return dict(
Children=[
{
'Id': obj.id,
'Type': kwargs['ChildType'],
}
for obj in obj_list
if obj.parent_id == parent_id
]
)
organizations_backend = OrganizationsBackend()

View File

@ -0,0 +1,87 @@
from __future__ import unicode_literals
import json
from moto.core.responses import BaseResponse
from .models import organizations_backend
class OrganizationsResponse(BaseResponse):
@property
def organizations_backend(self):
return organizations_backend
@property
def request_params(self):
try:
return json.loads(self.body)
except ValueError:
return {}
def _get_param(self, param, default=None):
return self.request_params.get(param, default)
def create_organization(self):
return json.dumps(
self.organizations_backend.create_organization(**self.request_params)
)
def describe_organization(self):
return json.dumps(
self.organizations_backend.describe_organization()
)
def list_roots(self):
return json.dumps(
self.organizations_backend.list_roots()
)
def create_organizational_unit(self):
return json.dumps(
self.organizations_backend.create_organizational_unit(**self.request_params)
)
def describe_organizational_unit(self):
return json.dumps(
self.organizations_backend.describe_organizational_unit(**self.request_params)
)
def list_organizational_units_for_parent(self):
return json.dumps(
self.organizations_backend.list_organizational_units_for_parent(**self.request_params)
)
def list_parents(self):
return json.dumps(
self.organizations_backend.list_parents(**self.request_params)
)
def create_account(self):
return json.dumps(
self.organizations_backend.create_account(**self.request_params)
)
def describe_account(self):
return json.dumps(
self.organizations_backend.describe_account(**self.request_params)
)
def list_accounts(self):
return json.dumps(
self.organizations_backend.list_accounts()
)
def list_accounts_for_parent(self):
return json.dumps(
self.organizations_backend.list_accounts_for_parent(**self.request_params)
)
def move_account(self):
return json.dumps(
self.organizations_backend.move_account(**self.request_params)
)
def list_children(self):
return json.dumps(
self.organizations_backend.list_children(**self.request_params)
)

View File

@ -0,0 +1,10 @@
from __future__ import unicode_literals
from .responses import OrganizationsResponse
url_bases = [
"https?://organizations.(.+).amazonaws.com",
]
url_paths = {
'{0}/$': OrganizationsResponse.dispatch,
}

View File

@ -0,0 +1,59 @@
from __future__ import unicode_literals
import random
import string
MASTER_ACCOUNT_ID = '123456789012'
MASTER_ACCOUNT_EMAIL = 'fakeorg@moto-example.com'
ORGANIZATION_ARN_FORMAT = 'arn:aws:organizations::{0}:organization/{1}'
MASTER_ACCOUNT_ARN_FORMAT = 'arn:aws:organizations::{0}:account/{1}/{0}'
ACCOUNT_ARN_FORMAT = 'arn:aws:organizations::{0}:account/{1}/{2}'
ROOT_ARN_FORMAT = 'arn:aws:organizations::{0}:root/{1}/{2}'
OU_ARN_FORMAT = 'arn:aws:organizations::{0}:ou/{1}/{2}'
CHARSET = string.ascii_lowercase + string.digits
ORG_ID_SIZE = 10
ROOT_ID_SIZE = 4
ACCOUNT_ID_SIZE = 12
OU_ID_SUFFIX_SIZE = 8
CREATE_ACCOUNT_STATUS_ID_SIZE = 8
def make_random_org_id():
# The regex pattern for an organization ID string requires "o-"
# followed by from 10 to 32 lower-case letters or digits.
# e.g. 'o-vipjnq5z86'
return 'o-' + ''.join(random.choice(CHARSET) for x in range(ORG_ID_SIZE))
def make_random_root_id():
# The regex pattern for a root ID string requires "r-" followed by
# from 4 to 32 lower-case letters or digits.
# e.g. 'r-3zwx'
return 'r-' + ''.join(random.choice(CHARSET) for x in range(ROOT_ID_SIZE))
def make_random_ou_id(root_id):
# The regex pattern for an organizational unit ID string requires "ou-"
# followed by from 4 to 32 lower-case letters or digits (the ID of the root
# that contains the OU) followed by a second "-" dash and from 8 to 32
# additional lower-case letters or digits.
# e.g. ou-g8sd-5oe3bjaw
return '-'.join([
'ou',
root_id.partition('-')[2],
''.join(random.choice(CHARSET) for x in range(OU_ID_SUFFIX_SIZE)),
])
def make_random_account_id():
# The regex pattern for an account ID string requires exactly 12 digits.
# e.g. '488633172133'
return ''.join([random.choice(string.digits) for n in range(ACCOUNT_ID_SIZE)])
def make_random_create_account_status_id():
# The regex pattern for an create account request ID string requires
# "car-" followed by from 8 to 32 lower-case letters or digits.
# e.g. 'car-35gxzwrp'
return 'car-' + ''.join(random.choice(CHARSET) for x in range(CREATE_ACCOUNT_STATUS_ID_SIZE))

View File

@ -85,6 +85,7 @@ old_socksocket = None
old_ssl_wrap_socket = None old_ssl_wrap_socket = None
old_sslwrap_simple = None old_sslwrap_simple = None
old_sslsocket = None old_sslsocket = None
old_sslcontext_wrap_socket = None
if PY3: # pragma: no cover if PY3: # pragma: no cover
basestring = (bytes, str) basestring = (bytes, str)
@ -100,6 +101,10 @@ try: # pragma: no cover
if not PY3: if not PY3:
old_sslwrap_simple = ssl.sslwrap_simple old_sslwrap_simple = ssl.sslwrap_simple
old_sslsocket = ssl.SSLSocket old_sslsocket = ssl.SSLSocket
try:
old_sslcontext_wrap_socket = ssl.SSLContext.wrap_socket
except AttributeError:
pass
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
ssl = None ssl = None
@ -281,7 +286,7 @@ class fakesock(object):
return { return {
'notAfter': shift.strftime('%b %d %H:%M:%S GMT'), 'notAfter': shift.strftime('%b %d %H:%M:%S GMT'),
'subjectAltName': ( 'subjectAltName': (
('DNS', '*%s' % self._host), ('DNS', '*.%s' % self._host),
('DNS', self._host), ('DNS', self._host),
('DNS', '*'), ('DNS', '*'),
), ),
@ -772,7 +777,7 @@ class URIMatcher(object):
def __init__(self, uri, entries, match_querystring=False): def __init__(self, uri, entries, match_querystring=False):
self._match_querystring = match_querystring self._match_querystring = match_querystring
if type(uri).__name__ == 'SRE_Pattern': if type(uri).__name__ in ('SRE_Pattern', 'Pattern'):
self.regex = uri self.regex = uri
result = urlsplit(uri.pattern) result = urlsplit(uri.pattern)
if result.scheme == 'https': if result.scheme == 'https':
@ -1012,6 +1017,10 @@ class httpretty(HttpBaseClass):
if ssl: if ssl:
ssl.wrap_socket = old_ssl_wrap_socket ssl.wrap_socket = old_ssl_wrap_socket
ssl.SSLSocket = old_sslsocket ssl.SSLSocket = old_sslsocket
try:
ssl.SSLContext.wrap_socket = old_sslcontext_wrap_socket
except AttributeError:
pass
ssl.__dict__['wrap_socket'] = old_ssl_wrap_socket ssl.__dict__['wrap_socket'] = old_ssl_wrap_socket
ssl.__dict__['SSLSocket'] = old_sslsocket ssl.__dict__['SSLSocket'] = old_sslsocket
@ -1058,6 +1067,14 @@ class httpretty(HttpBaseClass):
ssl.wrap_socket = fake_wrap_socket ssl.wrap_socket = fake_wrap_socket
ssl.SSLSocket = FakeSSLSocket ssl.SSLSocket = FakeSSLSocket
try:
def fake_sslcontext_wrap_socket(cls, *args, **kwargs):
return fake_wrap_socket(*args, **kwargs)
ssl.SSLContext.wrap_socket = fake_sslcontext_wrap_socket
except AttributeError:
pass
ssl.__dict__['wrap_socket'] = fake_wrap_socket ssl.__dict__['wrap_socket'] = fake_wrap_socket
ssl.__dict__['SSLSocket'] = FakeSSLSocket ssl.__dict__['SSLSocket'] = FakeSSLSocket

View File

@ -48,6 +48,10 @@ class Database(BaseModel):
if self.publicly_accessible is None: if self.publicly_accessible is None:
self.publicly_accessible = True self.publicly_accessible = True
self.copy_tags_to_snapshot = kwargs.get("copy_tags_to_snapshot")
if self.copy_tags_to_snapshot is None:
self.copy_tags_to_snapshot = False
self.backup_retention_period = kwargs.get("backup_retention_period") self.backup_retention_period = kwargs.get("backup_retention_period")
if self.backup_retention_period is None: if self.backup_retention_period is None:
self.backup_retention_period = 1 self.backup_retention_period = 1
@ -137,6 +141,7 @@ class Database(BaseModel):
"multi_az": properties.get("MultiAZ"), "multi_az": properties.get("MultiAZ"),
"port": properties.get('Port', 3306), "port": properties.get('Port', 3306),
"publicly_accessible": properties.get("PubliclyAccessible"), "publicly_accessible": properties.get("PubliclyAccessible"),
"copy_tags_to_snapshot": properties.get("CopyTagsToSnapshot"),
"region": region_name, "region": region_name,
"security_groups": security_groups, "security_groups": security_groups,
"storage_encrypted": properties.get("StorageEncrypted"), "storage_encrypted": properties.get("StorageEncrypted"),
@ -217,6 +222,7 @@ class Database(BaseModel):
</DBSubnetGroup> </DBSubnetGroup>
{% endif %} {% endif %}
<PubliclyAccessible>{{ database.publicly_accessible }}</PubliclyAccessible> <PubliclyAccessible>{{ database.publicly_accessible }}</PubliclyAccessible>
<CopyTagsToSnapshot>{{ database.copy_tags_to_snapshot }}</CopyTagsToSnapshot>
<AutoMinorVersionUpgrade>{{ database.auto_minor_version_upgrade }}</AutoMinorVersionUpgrade> <AutoMinorVersionUpgrade>{{ database.auto_minor_version_upgrade }}</AutoMinorVersionUpgrade>
<AllocatedStorage>{{ database.allocated_storage }}</AllocatedStorage> <AllocatedStorage>{{ database.allocated_storage }}</AllocatedStorage>
<StorageEncrypted>{{ database.storage_encrypted }}</StorageEncrypted> <StorageEncrypted>{{ database.storage_encrypted }}</StorageEncrypted>

View File

@ -73,6 +73,9 @@ class Database(BaseModel):
self.publicly_accessible = kwargs.get("publicly_accessible") self.publicly_accessible = kwargs.get("publicly_accessible")
if self.publicly_accessible is None: if self.publicly_accessible is None:
self.publicly_accessible = True self.publicly_accessible = True
self.copy_tags_to_snapshot = kwargs.get("copy_tags_to_snapshot")
if self.copy_tags_to_snapshot is None:
self.copy_tags_to_snapshot = False
self.backup_retention_period = kwargs.get("backup_retention_period") self.backup_retention_period = kwargs.get("backup_retention_period")
if self.backup_retention_period is None: if self.backup_retention_period is None:
self.backup_retention_period = 1 self.backup_retention_period = 1
@ -208,6 +211,7 @@ class Database(BaseModel):
</DBSubnetGroup> </DBSubnetGroup>
{% endif %} {% endif %}
<PubliclyAccessible>{{ database.publicly_accessible }}</PubliclyAccessible> <PubliclyAccessible>{{ database.publicly_accessible }}</PubliclyAccessible>
<CopyTagsToSnapshot>{{ database.copy_tags_to_snapshot }}</CopyTagsToSnapshot>
<AutoMinorVersionUpgrade>{{ database.auto_minor_version_upgrade }}</AutoMinorVersionUpgrade> <AutoMinorVersionUpgrade>{{ database.auto_minor_version_upgrade }}</AutoMinorVersionUpgrade>
<AllocatedStorage>{{ database.allocated_storage }}</AllocatedStorage> <AllocatedStorage>{{ database.allocated_storage }}</AllocatedStorage>
<StorageEncrypted>{{ database.storage_encrypted }}</StorageEncrypted> <StorageEncrypted>{{ database.storage_encrypted }}</StorageEncrypted>
@ -304,6 +308,7 @@ class Database(BaseModel):
"db_parameter_group_name": properties.get('DBParameterGroupName'), "db_parameter_group_name": properties.get('DBParameterGroupName'),
"port": properties.get('Port', 3306), "port": properties.get('Port', 3306),
"publicly_accessible": properties.get("PubliclyAccessible"), "publicly_accessible": properties.get("PubliclyAccessible"),
"copy_tags_to_snapshot": properties.get("CopyTagsToSnapshot"),
"region": region_name, "region": region_name,
"security_groups": security_groups, "security_groups": security_groups,
"storage_encrypted": properties.get("StorageEncrypted"), "storage_encrypted": properties.get("StorageEncrypted"),
@ -362,6 +367,7 @@ class Database(BaseModel):
"PreferredBackupWindow": "{{ database.preferred_backup_window }}", "PreferredBackupWindow": "{{ database.preferred_backup_window }}",
"PreferredMaintenanceWindow": "{{ database.preferred_maintenance_window }}", "PreferredMaintenanceWindow": "{{ database.preferred_maintenance_window }}",
"PubliclyAccessible": "{{ database.publicly_accessible }}", "PubliclyAccessible": "{{ database.publicly_accessible }}",
"CopyTagsToSnapshot": "{{ database.copy_tags_to_snapshot }}",
"AllocatedStorage": "{{ database.allocated_storage }}", "AllocatedStorage": "{{ database.allocated_storage }}",
"Endpoint": { "Endpoint": {
"Address": "{{ database.address }}", "Address": "{{ database.address }}",
@ -411,10 +417,10 @@ class Database(BaseModel):
class Snapshot(BaseModel): class Snapshot(BaseModel):
def __init__(self, database, snapshot_id, tags=None): def __init__(self, database, snapshot_id, tags):
self.database = database self.database = database
self.snapshot_id = snapshot_id self.snapshot_id = snapshot_id
self.tags = tags or [] self.tags = tags
self.created_at = iso_8601_datetime_with_milliseconds(datetime.datetime.now()) self.created_at = iso_8601_datetime_with_milliseconds(datetime.datetime.now())
@property @property
@ -456,6 +462,20 @@ class Snapshot(BaseModel):
</DBSnapshot>""") </DBSnapshot>""")
return template.render(snapshot=self, database=self.database) return template.render(snapshot=self, database=self.database)
def get_tags(self):
return self.tags
def add_tags(self, tags):
new_keys = [tag_set['Key'] for tag_set in tags]
self.tags = [tag_set for tag_set in self.tags if tag_set[
'Key'] not in new_keys]
self.tags.extend(tags)
return self.tags
def remove_tags(self, tag_keys):
self.tags = [tag_set for tag_set in self.tags if tag_set[
'Key'] not in tag_keys]
class SecurityGroup(BaseModel): class SecurityGroup(BaseModel):
@ -691,6 +711,10 @@ class RDS2Backend(BaseBackend):
raise DBSnapshotAlreadyExistsError(db_snapshot_identifier) raise DBSnapshotAlreadyExistsError(db_snapshot_identifier)
if len(self.snapshots) >= int(os.environ.get('MOTO_RDS_SNAPSHOT_LIMIT', '100')): if len(self.snapshots) >= int(os.environ.get('MOTO_RDS_SNAPSHOT_LIMIT', '100')):
raise SnapshotQuotaExceededError() raise SnapshotQuotaExceededError()
if tags is None:
tags = list()
if database.copy_tags_to_snapshot and not tags:
tags = database.get_tags()
snapshot = Snapshot(database, db_snapshot_identifier, tags) snapshot = Snapshot(database, db_snapshot_identifier, tags)
self.snapshots[db_snapshot_identifier] = snapshot self.snapshots[db_snapshot_identifier] = snapshot
return snapshot return snapshot
@ -787,13 +811,13 @@ class RDS2Backend(BaseBackend):
def delete_database(self, db_instance_identifier, db_snapshot_name=None): def delete_database(self, db_instance_identifier, db_snapshot_name=None):
if db_instance_identifier in self.databases: if db_instance_identifier in self.databases:
if db_snapshot_name:
self.create_snapshot(db_instance_identifier, db_snapshot_name)
database = self.databases.pop(db_instance_identifier) database = self.databases.pop(db_instance_identifier)
if database.is_replica: if database.is_replica:
primary = self.find_db_from_id(database.source_db_identifier) primary = self.find_db_from_id(database.source_db_identifier)
primary.remove_replica(database) primary.remove_replica(database)
database.status = 'deleting' database.status = 'deleting'
if db_snapshot_name:
self.snapshots[db_snapshot_name] = Snapshot(database, db_snapshot_name)
return database return database
else: else:
raise DBInstanceNotFoundError(db_instance_identifier) raise DBInstanceNotFoundError(db_instance_identifier)
@ -1028,8 +1052,8 @@ class RDS2Backend(BaseBackend):
if resource_name in self.security_groups: if resource_name in self.security_groups:
return self.security_groups[resource_name].get_tags() return self.security_groups[resource_name].get_tags()
elif resource_type == 'snapshot': # DB Snapshot elif resource_type == 'snapshot': # DB Snapshot
# TODO: Complete call to tags on resource type DB Snapshot if resource_name in self.snapshots:
return [] return self.snapshots[resource_name].get_tags()
elif resource_type == 'subgrp': # DB subnet group elif resource_type == 'subgrp': # DB subnet group
if resource_name in self.subnet_groups: if resource_name in self.subnet_groups:
return self.subnet_groups[resource_name].get_tags() return self.subnet_groups[resource_name].get_tags()
@ -1059,7 +1083,8 @@ class RDS2Backend(BaseBackend):
if resource_name in self.security_groups: if resource_name in self.security_groups:
return self.security_groups[resource_name].remove_tags(tag_keys) return self.security_groups[resource_name].remove_tags(tag_keys)
elif resource_type == 'snapshot': # DB Snapshot elif resource_type == 'snapshot': # DB Snapshot
return None if resource_name in self.snapshots:
return self.snapshots[resource_name].remove_tags(tag_keys)
elif resource_type == 'subgrp': # DB subnet group elif resource_type == 'subgrp': # DB subnet group
if resource_name in self.subnet_groups: if resource_name in self.subnet_groups:
return self.subnet_groups[resource_name].remove_tags(tag_keys) return self.subnet_groups[resource_name].remove_tags(tag_keys)
@ -1088,7 +1113,8 @@ class RDS2Backend(BaseBackend):
if resource_name in self.security_groups: if resource_name in self.security_groups:
return self.security_groups[resource_name].add_tags(tags) return self.security_groups[resource_name].add_tags(tags)
elif resource_type == 'snapshot': # DB Snapshot elif resource_type == 'snapshot': # DB Snapshot
return [] if resource_name in self.snapshots:
return self.snapshots[resource_name].add_tags(tags)
elif resource_type == 'subgrp': # DB subnet group elif resource_type == 'subgrp': # DB subnet group
if resource_name in self.subnet_groups: if resource_name in self.subnet_groups:
return self.subnet_groups[resource_name].add_tags(tags) return self.subnet_groups[resource_name].add_tags(tags)

View File

@ -19,6 +19,7 @@ class RDS2Response(BaseResponse):
"allocated_storage": self._get_int_param('AllocatedStorage'), "allocated_storage": self._get_int_param('AllocatedStorage'),
"availability_zone": self._get_param("AvailabilityZone"), "availability_zone": self._get_param("AvailabilityZone"),
"backup_retention_period": self._get_param("BackupRetentionPeriod"), "backup_retention_period": self._get_param("BackupRetentionPeriod"),
"copy_tags_to_snapshot": self._get_param("CopyTagsToSnapshot"),
"db_instance_class": self._get_param('DBInstanceClass'), "db_instance_class": self._get_param('DBInstanceClass'),
"db_instance_identifier": self._get_param('DBInstanceIdentifier'), "db_instance_identifier": self._get_param('DBInstanceIdentifier'),
"db_name": self._get_param("DBName"), "db_name": self._get_param("DBName"),
@ -159,7 +160,7 @@ class RDS2Response(BaseResponse):
def create_db_snapshot(self): def create_db_snapshot(self):
db_instance_identifier = self._get_param('DBInstanceIdentifier') db_instance_identifier = self._get_param('DBInstanceIdentifier')
db_snapshot_identifier = self._get_param('DBSnapshotIdentifier') db_snapshot_identifier = self._get_param('DBSnapshotIdentifier')
tags = self._get_param('Tags', []) tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value'))
snapshot = self.backend.create_snapshot(db_instance_identifier, db_snapshot_identifier, tags) snapshot = self.backend.create_snapshot(db_instance_identifier, db_snapshot_identifier, tags)
template = self.response_template(CREATE_SNAPSHOT_TEMPLATE) template = self.response_template(CREATE_SNAPSHOT_TEMPLATE)
return template.render(snapshot=snapshot) return template.render(snapshot=snapshot)

View File

@ -78,6 +78,7 @@ class Cluster(TaggableResourceMixin, BaseModel):
super(Cluster, self).__init__(region_name, tags) super(Cluster, self).__init__(region_name, tags)
self.redshift_backend = redshift_backend self.redshift_backend = redshift_backend
self.cluster_identifier = cluster_identifier self.cluster_identifier = cluster_identifier
self.create_time = iso_8601_datetime_with_milliseconds(datetime.datetime.now())
self.status = 'available' self.status = 'available'
self.node_type = node_type self.node_type = node_type
self.master_username = master_username self.master_username = master_username
@ -237,6 +238,7 @@ class Cluster(TaggableResourceMixin, BaseModel):
"Address": self.endpoint, "Address": self.endpoint,
"Port": self.port "Port": self.port
}, },
'ClusterCreateTime': self.create_time,
"PendingModifiedValues": [], "PendingModifiedValues": [],
"Tags": self.tags, "Tags": self.tags,
"IamRoles": [{ "IamRoles": [{

View File

@ -27,8 +27,14 @@ class FakeDeleteMarker(BaseModel):
def __init__(self, key): def __init__(self, key):
self.key = key self.key = key
self.name = key.name
self.last_modified = datetime.datetime.utcnow()
self._version_id = key.version_id + 1 self._version_id = key.version_id + 1
@property
def last_modified_ISO8601(self):
return iso_8601_datetime_with_milliseconds(self.last_modified)
@property @property
def version_id(self): def version_id(self):
return self._version_id return self._version_id
@ -335,8 +341,9 @@ class LifecycleAndFilter(BaseModel):
class LifecycleRule(BaseModel): class LifecycleRule(BaseModel):
def __init__(self, id=None, prefix=None, lc_filter=None, status=None, expiration_days=None, def __init__(self, id=None, prefix=None, lc_filter=None, status=None, expiration_days=None,
expiration_date=None, transition_days=None, expired_object_delete_marker=None, expiration_date=None, transition_days=None, transition_date=None, storage_class=None,
transition_date=None, storage_class=None): expired_object_delete_marker=None, nve_noncurrent_days=None, nvt_noncurrent_days=None,
nvt_storage_class=None, aimu_days=None):
self.id = id self.id = id
self.prefix = prefix self.prefix = prefix
self.filter = lc_filter self.filter = lc_filter
@ -345,8 +352,12 @@ class LifecycleRule(BaseModel):
self.expiration_date = expiration_date self.expiration_date = expiration_date
self.transition_days = transition_days self.transition_days = transition_days
self.transition_date = transition_date self.transition_date = transition_date
self.expired_object_delete_marker = expired_object_delete_marker
self.storage_class = storage_class self.storage_class = storage_class
self.expired_object_delete_marker = expired_object_delete_marker
self.nve_noncurrent_days = nve_noncurrent_days
self.nvt_noncurrent_days = nvt_noncurrent_days
self.nvt_storage_class = nvt_storage_class
self.aimu_days = aimu_days
class CorsRule(BaseModel): class CorsRule(BaseModel):
@ -408,9 +419,32 @@ class FakeBucket(BaseModel):
def set_lifecycle(self, rules): def set_lifecycle(self, rules):
self.rules = [] self.rules = []
for rule in rules: for rule in rules:
# Extract and validate actions from Lifecycle rule
expiration = rule.get('Expiration') expiration = rule.get('Expiration')
transition = rule.get('Transition') transition = rule.get('Transition')
nve_noncurrent_days = None
if rule.get('NoncurrentVersionExpiration') is not None:
if rule["NoncurrentVersionExpiration"].get('NoncurrentDays') is None:
raise MalformedXML()
nve_noncurrent_days = rule["NoncurrentVersionExpiration"]["NoncurrentDays"]
nvt_noncurrent_days = None
nvt_storage_class = None
if rule.get('NoncurrentVersionTransition') is not None:
if rule["NoncurrentVersionTransition"].get('NoncurrentDays') is None:
raise MalformedXML()
if rule["NoncurrentVersionTransition"].get('StorageClass') is None:
raise MalformedXML()
nvt_noncurrent_days = rule["NoncurrentVersionTransition"]["NoncurrentDays"]
nvt_storage_class = rule["NoncurrentVersionTransition"]["StorageClass"]
aimu_days = None
if rule.get('AbortIncompleteMultipartUpload') is not None:
if rule["AbortIncompleteMultipartUpload"].get('DaysAfterInitiation') is None:
raise MalformedXML()
aimu_days = rule["AbortIncompleteMultipartUpload"]["DaysAfterInitiation"]
eodm = None eodm = None
if expiration and expiration.get("ExpiredObjectDeleteMarker") is not None: if expiration and expiration.get("ExpiredObjectDeleteMarker") is not None:
# This cannot be set if Date or Days is set: # This cannot be set if Date or Days is set:
@ -453,11 +487,14 @@ class FakeBucket(BaseModel):
status=rule['Status'], status=rule['Status'],
expiration_days=expiration.get('Days') if expiration else None, expiration_days=expiration.get('Days') if expiration else None,
expiration_date=expiration.get('Date') if expiration else None, expiration_date=expiration.get('Date') if expiration else None,
expired_object_delete_marker=eodm,
transition_days=transition.get('Days') if transition else None, transition_days=transition.get('Days') if transition else None,
transition_date=transition.get('Date') if transition else None, transition_date=transition.get('Date') if transition else None,
storage_class=transition[ storage_class=transition.get('StorageClass') if transition else None,
'StorageClass'] if transition else None, expired_object_delete_marker=eodm,
nve_noncurrent_days=nve_noncurrent_days,
nvt_noncurrent_days=nvt_noncurrent_days,
nvt_storage_class=nvt_storage_class,
aimu_days=aimu_days,
)) ))
def delete_lifecycle(self): def delete_lifecycle(self):
@ -630,9 +667,6 @@ class S3Backend(BaseBackend):
latest_versions = {} latest_versions = {}
for version in versions: for version in versions:
if isinstance(version, FakeDeleteMarker):
name = version.key.name
else:
name = version.name name = version.name
version_id = version.version_id version_id = version.version_id
maximum_version_per_key[name] = max( maximum_version_per_key[name] = max(

View File

@ -1228,6 +1228,22 @@ S3_BUCKET_LIFECYCLE_CONFIGURATION = """<?xml version="1.0" encoding="UTF-8"?>
{% endif %} {% endif %}
</Expiration> </Expiration>
{% endif %} {% endif %}
{% if rule.nvt_noncurrent_days and rule.nvt_storage_class %}
<NoncurrentVersionTransition>
<NoncurrentDays>{{ rule.nvt_noncurrent_days }}</NoncurrentDays>
<StorageClass>{{ rule.nvt_storage_class }}</StorageClass>
</NoncurrentVersionTransition>
{% endif %}
{% if rule.nve_noncurrent_days %}
<NoncurrentVersionExpiration>
<NoncurrentDays>{{ rule.nve_noncurrent_days }}</NoncurrentDays>
</NoncurrentVersionExpiration>
{% endif %}
{% if rule.aimu_days %}
<AbortIncompleteMultipartUpload>
<DaysAfterInitiation>{{ rule.aimu_days }}</DaysAfterInitiation>
</AbortIncompleteMultipartUpload>
{% endif %}
</Rule> </Rule>
{% endfor %} {% endfor %}
</LifecycleConfiguration> </LifecycleConfiguration>
@ -1273,10 +1289,10 @@ S3_BUCKET_GET_VERSIONS = """<?xml version="1.0" encoding="UTF-8"?>
{% endfor %} {% endfor %}
{% for marker in delete_marker_list %} {% for marker in delete_marker_list %}
<DeleteMarker> <DeleteMarker>
<Key>{{ marker.key.name }}</Key> <Key>{{ marker.name }}</Key>
<VersionId>{{ marker.version_id }}</VersionId> <VersionId>{{ marker.version_id }}</VersionId>
<IsLatest>{% if latest_versions[marker.key.name] == marker.version_id %}true{% else %}false{% endif %}</IsLatest> <IsLatest>{% if latest_versions[marker.name] == marker.version_id %}true{% else %}false{% endif %}</IsLatest>
<LastModified>{{ marker.key.last_modified_ISO8601 }}</LastModified> <LastModified>{{ marker.last_modified_ISO8601 }}</LastModified>
<Owner> <Owner>
<ID>75aa57f09aa0c8caeab4f8c24e99d10f8e7faeebf76c078efc7c6caea54ba06a</ID> <ID>75aa57f09aa0c8caeab4f8c24e99d10f8e7faeebf76c078efc7c6caea54ba06a</ID>
<DisplayName>webfile</DisplayName> <DisplayName>webfile</DisplayName>
@ -1433,7 +1449,7 @@ S3_MULTIPART_LIST_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
</Owner> </Owner>
<StorageClass>STANDARD</StorageClass> <StorageClass>STANDARD</StorageClass>
<PartNumberMarker>1</PartNumberMarker> <PartNumberMarker>1</PartNumberMarker>
<NextPartNumberMarker>{{ count }} </NextPartNumberMarker> <NextPartNumberMarker>{{ count }}</NextPartNumberMarker>
<MaxParts>{{ count }}</MaxParts> <MaxParts>{{ count }}</MaxParts>
<IsTruncated>false</IsTruncated> <IsTruncated>false</IsTruncated>
{% for part in parts %} {% for part in parts %}

View File

@ -36,6 +36,7 @@ class SecretsManagerBackend(BaseBackend):
self.rotation_enabled = False self.rotation_enabled = False
self.rotation_lambda_arn = '' self.rotation_lambda_arn = ''
self.auto_rotate_after_days = 0 self.auto_rotate_after_days = 0
self.version_id = ''
def reset(self): def reset(self):
region_name = self.region region_name = self.region
@ -105,6 +106,56 @@ class SecretsManagerBackend(BaseBackend):
return response return response
def rotate_secret(self, secret_id, client_request_token=None,
rotation_lambda_arn=None, rotation_rules=None):
rotation_days = 'AutomaticallyAfterDays'
if not self._is_valid_identifier(secret_id):
raise ResourceNotFoundException
if client_request_token:
token_length = len(client_request_token)
if token_length < 32 or token_length > 64:
msg = (
'ClientRequestToken '
'must be 32-64 characters long.'
)
raise InvalidParameterException(msg)
if rotation_lambda_arn:
if len(rotation_lambda_arn) > 2048:
msg = (
'RotationLambdaARN '
'must <= 2048 characters long.'
)
raise InvalidParameterException(msg)
if rotation_rules:
if rotation_days in rotation_rules:
rotation_period = rotation_rules[rotation_days]
if rotation_period < 1 or rotation_period > 1000:
msg = (
'RotationRules.AutomaticallyAfterDays '
'must be within 1-1000.'
)
raise InvalidParameterException(msg)
self.version_id = client_request_token or ''
self.rotation_lambda_arn = rotation_lambda_arn or ''
if rotation_rules:
self.auto_rotate_after_days = rotation_rules.get(rotation_days, 0)
if self.auto_rotate_after_days > 0:
self.rotation_enabled = True
response = json.dumps({
"ARN": secret_arn(self.region, self.secret_id),
"Name": self.name,
"VersionId": self.version_id
})
return response
def get_random_password(self, password_length, def get_random_password(self, password_length,
exclude_characters, exclude_numbers, exclude_characters, exclude_numbers,
exclude_punctuation, exclude_uppercase, exclude_punctuation, exclude_uppercase,

View File

@ -50,3 +50,15 @@ class SecretsManagerResponse(BaseResponse):
return secretsmanager_backends[self.region].describe_secret( return secretsmanager_backends[self.region].describe_secret(
secret_id=secret_id secret_id=secret_id
) )
def rotate_secret(self):
client_request_token = self._get_param('ClientRequestToken')
rotation_lambda_arn = self._get_param('RotationLambdaARN')
rotation_rules = self._get_param('RotationRules')
secret_id = self._get_param('SecretId')
return secretsmanager_backends[self.region].rotate_secret(
secret_id=secret_id,
client_request_token=client_request_token,
rotation_lambda_arn=rotation_lambda_arn,
rotation_rules=rotation_rules
)

View File

@ -34,6 +34,9 @@ class DomainDispatcherApplication(object):
self.service = service self.service = service
def get_backend_for_host(self, host): def get_backend_for_host(self, host):
if host == 'moto_api':
return host
if self.service: if self.service:
return self.service return self.service

View File

@ -49,7 +49,8 @@ class SESBackend(BaseBackend):
self.sent_messages = [] self.sent_messages = []
self.sent_message_count = 0 self.sent_message_count = 0
def _is_verified_address(self, address): def _is_verified_address(self, source):
_, address = parseaddr(source)
if address in self.addresses: if address in self.addresses:
return True return True
user, host = address.split('@', 1) user, host = address.split('@', 1)

View File

@ -385,9 +385,21 @@ class SQSBackend(BaseBackend):
def create_queue(self, name, **kwargs): def create_queue(self, name, **kwargs):
queue = self.queues.get(name) queue = self.queues.get(name)
if queue: if queue:
# Queue already exist. If attributes don't match, throw error try:
for key, value in kwargs.items(): kwargs.pop('region')
if getattr(queue, camelcase_to_underscores(key)) != value: except KeyError:
pass
new_queue = Queue(name, region=self.region_name, **kwargs)
queue_attributes = queue.attributes
new_queue_attributes = new_queue.attributes
for key in ['CreatedTimestamp', 'LastModifiedTimestamp']:
queue_attributes.pop(key)
new_queue_attributes.pop(key)
if queue_attributes != new_queue_attributes:
raise QueueAlreadyExists("The specified queue already exists.") raise QueueAlreadyExists("The specified queue already exists.")
else: else:
try: try:

View File

@ -336,7 +336,7 @@ class SQSResponse(BaseResponse):
try: try:
wait_time = int(self.querystring.get("WaitTimeSeconds")[0]) wait_time = int(self.querystring.get("WaitTimeSeconds")[0])
except TypeError: except TypeError:
wait_time = queue.receive_message_wait_time_seconds wait_time = int(queue.receive_message_wait_time_seconds)
if wait_time < 0 or wait_time > 20: if wait_time < 0 or wait_time > 20:
return self._error( return self._error(

View File

@ -5,10 +5,12 @@ from collections import defaultdict
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.exceptions import RESTError from moto.core.exceptions import RESTError
from moto.ec2 import ec2_backends from moto.ec2 import ec2_backends
from moto.cloudformation import cloudformation_backends
import datetime import datetime
import time import time
import uuid import uuid
import itertools
class Parameter(BaseModel): class Parameter(BaseModel):
@ -67,7 +69,7 @@ class Command(BaseModel):
instance_ids=None, max_concurrency='', max_errors='', instance_ids=None, max_concurrency='', max_errors='',
notification_config=None, output_s3_bucket_name='', notification_config=None, output_s3_bucket_name='',
output_s3_key_prefix='', output_s3_region='', parameters=None, output_s3_key_prefix='', output_s3_region='', parameters=None,
service_role_arn='', targets=None): service_role_arn='', targets=None, backend_region='us-east-1'):
if instance_ids is None: if instance_ids is None:
instance_ids = [] instance_ids = []
@ -88,9 +90,9 @@ class Command(BaseModel):
self.status = 'Success' self.status = 'Success'
self.status_details = 'Details placeholder' self.status_details = 'Details placeholder'
now = datetime.datetime.now() self.requested_date_time = datetime.datetime.now()
self.requested_date_time = now.isoformat() self.requested_date_time_iso = self.requested_date_time.isoformat()
expires_after = now + datetime.timedelta(0, timeout_seconds) expires_after = self.requested_date_time + datetime.timedelta(0, timeout_seconds)
self.expires_after = expires_after.isoformat() self.expires_after = expires_after.isoformat()
self.comment = comment self.comment = comment
@ -105,6 +107,32 @@ class Command(BaseModel):
self.parameters = parameters self.parameters = parameters
self.service_role_arn = service_role_arn self.service_role_arn = service_role_arn
self.targets = targets self.targets = targets
self.backend_region = backend_region
# Get instance ids from a cloud formation stack target.
stack_instance_ids = [self.get_instance_ids_by_stack_ids(target['Values']) for
target in self.targets if
target['Key'] == 'tag:aws:cloudformation:stack-name']
self.instance_ids += list(itertools.chain.from_iterable(stack_instance_ids))
# Create invocations with a single run command plugin.
self.invocations = []
for instance_id in self.instance_ids:
self.invocations.append(
self.invocation_response(instance_id, "aws:runShellScript"))
def get_instance_ids_by_stack_ids(self, stack_ids):
instance_ids = []
cloudformation_backend = cloudformation_backends[self.backend_region]
for stack_id in stack_ids:
stack_resources = cloudformation_backend.list_stack_resources(stack_id)
instance_resources = [
instance.id for instance in stack_resources
if instance.type == "AWS::EC2::Instance"]
instance_ids.extend(instance_resources)
return instance_ids
def response_object(self): def response_object(self):
r = { r = {
@ -122,7 +150,7 @@ class Command(BaseModel):
'OutputS3BucketName': self.output_s3_bucket_name, 'OutputS3BucketName': self.output_s3_bucket_name,
'OutputS3KeyPrefix': self.output_s3_key_prefix, 'OutputS3KeyPrefix': self.output_s3_key_prefix,
'Parameters': self.parameters, 'Parameters': self.parameters,
'RequestedDateTime': self.requested_date_time, 'RequestedDateTime': self.requested_date_time_iso,
'ServiceRole': self.service_role_arn, 'ServiceRole': self.service_role_arn,
'Status': self.status, 'Status': self.status,
'StatusDetails': self.status_details, 'StatusDetails': self.status_details,
@ -132,6 +160,50 @@ class Command(BaseModel):
return r return r
def invocation_response(self, instance_id, plugin_name):
# Calculate elapsed time from requested time and now. Use a hardcoded
# elapsed time since there is no easy way to convert a timedelta to
# an ISO 8601 duration string.
elapsed_time_iso = "PT5M"
elapsed_time_delta = datetime.timedelta(minutes=5)
end_time = self.requested_date_time + elapsed_time_delta
r = {
'CommandId': self.command_id,
'InstanceId': instance_id,
'Comment': self.comment,
'DocumentName': self.document_name,
'PluginName': plugin_name,
'ResponseCode': 0,
'ExecutionStartDateTime': self.requested_date_time_iso,
'ExecutionElapsedTime': elapsed_time_iso,
'ExecutionEndDateTime': end_time.isoformat(),
'Status': 'Success',
'StatusDetails': 'Success',
'StandardOutputContent': '',
'StandardOutputUrl': '',
'StandardErrorContent': '',
}
return r
def get_invocation(self, instance_id, plugin_name):
invocation = next(
(invocation for invocation in self.invocations
if invocation['InstanceId'] == instance_id), None)
if invocation is None:
raise RESTError(
'InvocationDoesNotExist',
'An error occurred (InvocationDoesNotExist) when calling the GetCommandInvocation operation')
if plugin_name is not None and invocation['PluginName'] != plugin_name:
raise RESTError(
'InvocationDoesNotExist',
'An error occurred (InvocationDoesNotExist) when calling the GetCommandInvocation operation')
return invocation
class SimpleSystemManagerBackend(BaseBackend): class SimpleSystemManagerBackend(BaseBackend):
@ -140,6 +212,11 @@ class SimpleSystemManagerBackend(BaseBackend):
self._resource_tags = defaultdict(lambda: defaultdict(dict)) self._resource_tags = defaultdict(lambda: defaultdict(dict))
self._commands = [] self._commands = []
# figure out what region we're in
for region, backend in ssm_backends.items():
if backend == self:
self._region = region
def delete_parameter(self, name): def delete_parameter(self, name):
try: try:
del self._parameters[name] del self._parameters[name]
@ -260,7 +337,8 @@ class SimpleSystemManagerBackend(BaseBackend):
output_s3_region=kwargs.get('OutputS3Region', ''), output_s3_region=kwargs.get('OutputS3Region', ''),
parameters=kwargs.get('Parameters', {}), parameters=kwargs.get('Parameters', {}),
service_role_arn=kwargs.get('ServiceRoleArn', ''), service_role_arn=kwargs.get('ServiceRoleArn', ''),
targets=kwargs.get('Targets', [])) targets=kwargs.get('Targets', []),
backend_region=self._region)
self._commands.append(command) self._commands.append(command)
return { return {
@ -298,6 +376,18 @@ class SimpleSystemManagerBackend(BaseBackend):
command for command in self._commands command for command in self._commands
if instance_id in command.instance_ids] if instance_id in command.instance_ids]
def get_command_invocation(self, **kwargs):
"""
https://docs.aws.amazon.com/systems-manager/latest/APIReference/API_GetCommandInvocation.html
"""
command_id = kwargs.get('CommandId')
instance_id = kwargs.get('InstanceId')
plugin_name = kwargs.get('PluginName', None)
command = self.get_command_by_id(command_id)
return command.get_invocation(instance_id, plugin_name)
ssm_backends = {} ssm_backends = {}
for region, ec2_backend in ec2_backends.items(): for region, ec2_backend in ec2_backends.items():

View File

@ -210,3 +210,8 @@ class SimpleSystemManagerResponse(BaseResponse):
return json.dumps( return json.dumps(
self.ssm_backend.list_commands(**self.request_params) self.ssm_backend.list_commands(**self.request_params)
) )
def get_command_invocation(self):
return json.dumps(
self.ssm_backend.get_command_invocation(**self.request_params)
)

View File

@ -1,7 +1,7 @@
-r requirements.txt -r requirements.txt
mock mock
nose nose
sure==1.2.24 sure==1.4.11
coverage coverage
flake8==3.5.0 flake8==3.5.0
freezegun freezegun
@ -13,5 +13,5 @@ six>=1.9
prompt-toolkit==1.0.14 prompt-toolkit==1.0.14
click==6.7 click==6.7
inflection==0.3.1 inflection==0.3.1
lxml==4.0.0 lxml==4.2.3
beautifulsoup4==4.6.0 beautifulsoup4==4.6.0

View File

@ -8,10 +8,9 @@ import sys
install_requires = [ install_requires = [
"Jinja2>=2.7.3", "Jinja2>=2.7.3",
"boto>=2.36.0", "boto>=2.36.0",
"boto3>=1.6.16", "boto3>=1.6.16,<1.8",
"botocore>=1.9.16,<1.11", "botocore>=1.9.16,<1.11",
"cookies", "cryptography>=2.3.0",
"cryptography>=2.0.0",
"requests>=2.5", "requests>=2.5",
"xmltodict", "xmltodict",
"six>1.9", "six>1.9",
@ -41,7 +40,7 @@ else:
setup( setup(
name='moto', name='moto',
version='1.3.5', version='1.3.6',
description='A library that allows your python tests to easily' description='A library that allows your python tests to easily'
' mock out the boto library', ' mock out the boto library',
author='Steve Pulec', author='Steve Pulec',

View File

@ -31,6 +31,7 @@ def test_create_identity_pool():
# testing a helper function # testing a helper function
def test_get_random_identity_id(): def test_get_random_identity_id():
assert len(get_random_identity_id('us-west-2')) > 0 assert len(get_random_identity_id('us-west-2')) > 0
assert len(get_random_identity_id('us-west-2').split(':')[1]) == 19
@mock_cognitoidentity @mock_cognitoidentity
@ -69,3 +70,16 @@ def test_get_open_id_token_for_developer_identity():
) )
assert len(result['Token']) assert len(result['Token'])
assert result['IdentityId'] == '12345' assert result['IdentityId'] == '12345'
@mock_cognitoidentity
def test_get_open_id_token_for_developer_identity_when_no_explicit_identity_id():
conn = boto3.client('cognito-identity', 'us-west-2')
result = conn.get_open_id_token_for_developer_identity(
IdentityPoolId='us-west-2:12345',
Logins={
'someurl': '12345'
},
TokenDuration=123
)
assert len(result['Token']) > 0
assert len(result['IdentityId']) > 0

View File

@ -6,6 +6,7 @@ import os
import uuid import uuid
from jose import jws from jose import jws
from moto import mock_cognitoidp from moto import mock_cognitoidp
import sure # noqa import sure # noqa
@ -24,6 +25,7 @@ def test_create_user_pool():
) )
result["UserPool"]["Id"].should_not.be.none result["UserPool"]["Id"].should_not.be.none
result["UserPool"]["Id"].should.match(r'[\w-]+_[0-9a-zA-Z]+')
result["UserPool"]["Name"].should.equal(name) result["UserPool"]["Name"].should.equal(name)
result["UserPool"]["LambdaConfig"]["PreSignUp"].should.equal(value) result["UserPool"]["LambdaConfig"]["PreSignUp"].should.equal(value)
@ -399,15 +401,22 @@ def authentication_flow(conn):
username = str(uuid.uuid4()) username = str(uuid.uuid4())
temporary_password = str(uuid.uuid4()) temporary_password = str(uuid.uuid4())
user_pool_id = conn.create_user_pool(PoolName=str(uuid.uuid4()))["UserPool"]["Id"] user_pool_id = conn.create_user_pool(PoolName=str(uuid.uuid4()))["UserPool"]["Id"]
user_attribute_name = str(uuid.uuid4())
user_attribute_value = str(uuid.uuid4())
client_id = conn.create_user_pool_client( client_id = conn.create_user_pool_client(
UserPoolId=user_pool_id, UserPoolId=user_pool_id,
ClientName=str(uuid.uuid4()), ClientName=str(uuid.uuid4()),
ReadAttributes=[user_attribute_name]
)["UserPoolClient"]["ClientId"] )["UserPoolClient"]["ClientId"]
conn.admin_create_user( conn.admin_create_user(
UserPoolId=user_pool_id, UserPoolId=user_pool_id,
Username=username, Username=username,
TemporaryPassword=temporary_password, TemporaryPassword=temporary_password,
UserAttributes=[{
'Name': user_attribute_name,
'Value': user_attribute_value
}]
) )
result = conn.admin_initiate_auth( result = conn.admin_initiate_auth(
@ -446,6 +455,9 @@ def authentication_flow(conn):
"access_token": result["AuthenticationResult"]["AccessToken"], "access_token": result["AuthenticationResult"]["AccessToken"],
"username": username, "username": username,
"password": new_password, "password": new_password,
"additional_fields": {
user_attribute_name: user_attribute_value
}
} }
@ -475,6 +487,8 @@ def test_token_legitimacy():
access_claims = json.loads(jws.verify(access_token, json_web_key, "RS256")) access_claims = json.loads(jws.verify(access_token, json_web_key, "RS256"))
access_claims["iss"].should.equal(issuer) access_claims["iss"].should.equal(issuer)
access_claims["aud"].should.equal(client_id) access_claims["aud"].should.equal(client_id)
for k, v in outputs["additional_fields"].items():
access_claims[k].should.equal(v)
@mock_cognitoidp @mock_cognitoidp

View File

@ -85,3 +85,14 @@ class TesterWithSetup(unittest.TestCase):
def test_still_the_same(self): def test_still_the_same(self):
bucket = self.conn.get_bucket('mybucket') bucket = self.conn.get_bucket('mybucket')
bucket.name.should.equal("mybucket") bucket.name.should.equal("mybucket")
@mock_s3_deprecated
class TesterWithStaticmethod(object):
@staticmethod
def static(*args):
assert not args or not isinstance(args[0], TesterWithStaticmethod)
def test_no_instance_sent_to_staticmethod(self):
self.static()

View File

@ -201,6 +201,48 @@ def test_item_add_empty_string_exception():
) )
@requires_boto_gte("2.9")
@mock_dynamodb2
def test_update_item_with_empty_string_exception():
name = 'TestTable'
conn = boto3.client('dynamodb',
region_name='us-west-2',
aws_access_key_id="ak",
aws_secret_access_key="sk")
conn.create_table(TableName=name,
KeySchema=[{'AttributeName':'forum_name','KeyType':'HASH'}],
AttributeDefinitions=[{'AttributeName':'forum_name','AttributeType':'S'}],
ProvisionedThroughput={'ReadCapacityUnits':5,'WriteCapacityUnits':5})
conn.put_item(
TableName=name,
Item={
'forum_name': { 'S': 'LOLCat Forum' },
'subject': { 'S': 'Check this out!' },
'Body': { 'S': 'http://url_to_lolcat.gif'},
'SentBy': { 'S': "test" },
'ReceivedTime': { 'S': '12/9/2011 11:36:03 PM'},
}
)
with assert_raises(ClientError) as ex:
conn.update_item(
TableName=name,
Key={
'forum_name': { 'S': 'LOLCat Forum'},
},
UpdateExpression='set Body=:Body',
ExpressionAttributeValues={
':Body': {'S': ''}
})
ex.exception.response['Error']['Code'].should.equal('ValidationException')
ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400)
ex.exception.response['Error']['Message'].should.equal(
'One or more parameter values were invalid: An AttributeValue may not contain an empty string'
)
@requires_boto_gte("2.9") @requires_boto_gte("2.9")
@mock_dynamodb2 @mock_dynamodb2
def test_query_invalid_table(): def test_query_invalid_table():
@ -658,8 +700,8 @@ def test_filter_expression():
filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id IN :v0', {}, {':v0': {'NS': [7, 8, 9]}}) filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id IN :v0', {}, {':v0': {'NS': [7, 8, 9]}})
filter_expr.expr(row1).should.be(True) filter_expr.expr(row1).should.be(True)
# attribute function tests # attribute function tests (with extra spaces)
filter_expr = moto.dynamodb2.comparisons.get_filter_expression('attribute_exists(Id) AND attribute_not_exists(User)', {}, {}) filter_expr = moto.dynamodb2.comparisons.get_filter_expression('attribute_exists(Id) AND attribute_not_exists (User)', {}, {})
filter_expr.expr(row1).should.be(True) filter_expr.expr(row1).should.be(True)
filter_expr = moto.dynamodb2.comparisons.get_filter_expression('attribute_type(Id, N)', {}, {}) filter_expr = moto.dynamodb2.comparisons.get_filter_expression('attribute_type(Id, N)', {}, {})
@ -1178,7 +1220,8 @@ def test_update_if_not_exists():
'forum_name': 'the-key', 'forum_name': 'the-key',
'subject': '123' 'subject': '123'
}, },
UpdateExpression='SET created_at = if_not_exists(created_at, :created_at)', # if_not_exists without space
UpdateExpression='SET created_at=if_not_exists(created_at,:created_at)',
ExpressionAttributeValues={ ExpressionAttributeValues={
':created_at': 123 ':created_at': 123
} }
@ -1191,7 +1234,8 @@ def test_update_if_not_exists():
'forum_name': 'the-key', 'forum_name': 'the-key',
'subject': '123' 'subject': '123'
}, },
UpdateExpression='SET created_at = if_not_exists(created_at, :created_at)', # if_not_exists with space
UpdateExpression='SET created_at = if_not_exists (created_at, :created_at)',
ExpressionAttributeValues={ ExpressionAttributeValues={
':created_at': 456 ':created_at': 456
} }

View File

@ -615,8 +615,8 @@ def test_copy_snapshot():
dest = dest_ec2.Snapshot(copy_snapshot_response['SnapshotId']) dest = dest_ec2.Snapshot(copy_snapshot_response['SnapshotId'])
attribs = ['data_encryption_key_id', 'encrypted', attribs = ['data_encryption_key_id', 'encrypted',
'kms_key_id', 'owner_alias', 'owner_id', 'progress', 'kms_key_id', 'owner_alias', 'owner_id',
'start_time', 'state', 'state_message', 'progress', 'state', 'state_message',
'tags', 'volume_id', 'volume_size'] 'tags', 'volume_id', 'volume_size']
for attrib in attribs: for attrib in attribs:

View File

@ -2,12 +2,15 @@ from __future__ import unicode_literals
# Ensure 'assert_raises' context manager support for Python 2.6 # Ensure 'assert_raises' context manager support for Python 2.6
import tests.backport_assert_raises import tests.backport_assert_raises
from nose.tools import assert_raises from nose.tools import assert_raises
from moto.ec2.exceptions import EC2ClientError
from botocore.exceptions import ClientError
import boto3
import boto import boto
from boto.exception import EC2ResponseError from boto.exception import EC2ResponseError
import sure # noqa import sure # noqa
from moto import mock_ec2_deprecated from moto import mock_ec2, mock_ec2_deprecated
from tests.helpers import requires_boto_gte from tests.helpers import requires_boto_gte
@ -93,3 +96,37 @@ def test_vpc_peering_connections_delete():
cm.exception.code.should.equal('InvalidVpcPeeringConnectionId.NotFound') cm.exception.code.should.equal('InvalidVpcPeeringConnectionId.NotFound')
cm.exception.status.should.equal(400) cm.exception.status.should.equal(400)
cm.exception.request_id.should_not.be.none cm.exception.request_id.should_not.be.none
@mock_ec2
def test_vpc_peering_connections_cross_region():
# create vpc in us-west-1 and ap-northeast-1
ec2_usw1 = boto3.resource('ec2', region_name='us-west-1')
vpc_usw1 = ec2_usw1.create_vpc(CidrBlock='10.90.0.0/16')
ec2_apn1 = boto3.resource('ec2', region_name='ap-northeast-1')
vpc_apn1 = ec2_apn1.create_vpc(CidrBlock='10.20.0.0/16')
# create peering
vpc_pcx = ec2_usw1.create_vpc_peering_connection(
VpcId=vpc_usw1.id,
PeerVpcId=vpc_apn1.id,
PeerRegion='ap-northeast-1',
)
vpc_pcx.status['Code'].should.equal('initiating-request')
vpc_pcx.requester_vpc.id.should.equal(vpc_usw1.id)
vpc_pcx.accepter_vpc.id.should.equal(vpc_apn1.id)
@mock_ec2
def test_vpc_peering_connections_cross_region_fail():
# create vpc in us-west-1 and ap-northeast-1
ec2_usw1 = boto3.resource('ec2', region_name='us-west-1')
vpc_usw1 = ec2_usw1.create_vpc(CidrBlock='10.90.0.0/16')
ec2_apn1 = boto3.resource('ec2', region_name='ap-northeast-1')
vpc_apn1 = ec2_apn1.create_vpc(CidrBlock='10.20.0.0/16')
# create peering wrong region with no vpc
with assert_raises(ClientError) as cm:
ec2_usw1.create_vpc_peering_connection(
VpcId=vpc_usw1.id,
PeerVpcId=vpc_apn1.id,
PeerRegion='ap-northeast-2')
cm.exception.response['Error']['Code'].should.equal('InvalidVpcID.NotFound')

View File

@ -304,6 +304,52 @@ def test_create_service():
response['service']['status'].should.equal('ACTIVE') response['service']['status'].should.equal('ACTIVE')
response['service']['taskDefinition'].should.equal( response['service']['taskDefinition'].should.equal(
'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1') 'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1')
response['service']['schedulingStrategy'].should.equal('REPLICA')
@mock_ecs
def test_create_service_scheduling_strategy():
client = boto3.client('ecs', region_name='us-east-1')
_ = client.create_cluster(
clusterName='test_ecs_cluster'
)
_ = client.register_task_definition(
family='test_ecs_task',
containerDefinitions=[
{
'name': 'hello_world',
'image': 'docker/hello-world:latest',
'cpu': 1024,
'memory': 400,
'essential': True,
'environment': [{
'name': 'AWS_ACCESS_KEY_ID',
'value': 'SOME_ACCESS_KEY'
}],
'logConfiguration': {'logDriver': 'json-file'}
}
]
)
response = client.create_service(
cluster='test_ecs_cluster',
serviceName='test_ecs_service',
taskDefinition='test_ecs_task',
desiredCount=2,
schedulingStrategy='DAEMON',
)
response['service']['clusterArn'].should.equal(
'arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster')
response['service']['desiredCount'].should.equal(2)
len(response['service']['events']).should.equal(0)
len(response['service']['loadBalancers']).should.equal(0)
response['service']['pendingCount'].should.equal(0)
response['service']['runningCount'].should.equal(0)
response['service']['serviceArn'].should.equal(
'arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service')
response['service']['serviceName'].should.equal('test_ecs_service')
response['service']['status'].should.equal('ACTIVE')
response['service']['taskDefinition'].should.equal(
'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1')
response['service']['schedulingStrategy'].should.equal('DAEMON')
@mock_ecs @mock_ecs
@ -411,6 +457,72 @@ def test_describe_services():
response['services'][0]['deployments'][0]['status'].should.equal('PRIMARY') response['services'][0]['deployments'][0]['status'].should.equal('PRIMARY')
@mock_ecs
def test_describe_services_scheduling_strategy():
client = boto3.client('ecs', region_name='us-east-1')
_ = client.create_cluster(
clusterName='test_ecs_cluster'
)
_ = client.register_task_definition(
family='test_ecs_task',
containerDefinitions=[
{
'name': 'hello_world',
'image': 'docker/hello-world:latest',
'cpu': 1024,
'memory': 400,
'essential': True,
'environment': [{
'name': 'AWS_ACCESS_KEY_ID',
'value': 'SOME_ACCESS_KEY'
}],
'logConfiguration': {'logDriver': 'json-file'}
}
]
)
_ = client.create_service(
cluster='test_ecs_cluster',
serviceName='test_ecs_service1',
taskDefinition='test_ecs_task',
desiredCount=2
)
_ = client.create_service(
cluster='test_ecs_cluster',
serviceName='test_ecs_service2',
taskDefinition='test_ecs_task',
desiredCount=2,
schedulingStrategy='DAEMON'
)
_ = client.create_service(
cluster='test_ecs_cluster',
serviceName='test_ecs_service3',
taskDefinition='test_ecs_task',
desiredCount=2
)
response = client.describe_services(
cluster='test_ecs_cluster',
services=['test_ecs_service1',
'arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service2',
'test_ecs_service3']
)
len(response['services']).should.equal(3)
response['services'][0]['serviceArn'].should.equal(
'arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service1')
response['services'][0]['serviceName'].should.equal('test_ecs_service1')
response['services'][1]['serviceArn'].should.equal(
'arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service2')
response['services'][1]['serviceName'].should.equal('test_ecs_service2')
response['services'][0]['deployments'][0]['desiredCount'].should.equal(2)
response['services'][0]['deployments'][0]['pendingCount'].should.equal(2)
response['services'][0]['deployments'][0]['runningCount'].should.equal(0)
response['services'][0]['deployments'][0]['status'].should.equal('PRIMARY')
response['services'][0]['schedulingStrategy'].should.equal('REPLICA')
response['services'][1]['schedulingStrategy'].should.equal('DAEMON')
response['services'][2]['schedulingStrategy'].should.equal('REPLICA')
@mock_ecs @mock_ecs
def test_update_service(): def test_update_service():
client = boto3.client('ecs', region_name='us-east-1') client = boto3.client('ecs', region_name='us-east-1')
@ -449,6 +561,7 @@ def test_update_service():
desiredCount=0 desiredCount=0
) )
response['service']['desiredCount'].should.equal(0) response['service']['desiredCount'].should.equal(0)
response['service']['schedulingStrategy'].should.equal('REPLICA')
@mock_ecs @mock_ecs
@ -515,10 +628,12 @@ def test_delete_service():
'arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service') 'arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service')
response['service']['serviceName'].should.equal('test_ecs_service') response['service']['serviceName'].should.equal('test_ecs_service')
response['service']['status'].should.equal('ACTIVE') response['service']['status'].should.equal('ACTIVE')
response['service']['schedulingStrategy'].should.equal('REPLICA')
response['service']['taskDefinition'].should.equal( response['service']['taskDefinition'].should.equal(
'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1') 'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1')
@mock_ec2 @mock_ec2
@mock_ecs @mock_ecs
def test_register_container_instance(): def test_register_container_instance():

View File

@ -723,6 +723,40 @@ def test_describe_instance_health():
instances_health[0].state.should.equal('InService') instances_health[0].state.should.equal('InService')
@mock_ec2
@mock_elb
def test_describe_instance_health_boto3():
elb = boto3.client('elb', region_name="us-east-1")
ec2 = boto3.client('ec2', region_name="us-east-1")
instances = ec2.run_instances(MinCount=2, MaxCount=2)['Instances']
lb_name = "my_load_balancer"
elb.create_load_balancer(
Listeners=[{
'InstancePort': 80,
'LoadBalancerPort': 8080,
'Protocol': 'HTTP'
}],
LoadBalancerName=lb_name,
)
elb.register_instances_with_load_balancer(
LoadBalancerName=lb_name,
Instances=[{'InstanceId': instances[0]['InstanceId']}]
)
instances_health = elb.describe_instance_health(
LoadBalancerName=lb_name,
Instances=[{'InstanceId': instance['InstanceId']} for instance in instances]
)
instances_health['InstanceStates'].should.have.length_of(2)
instances_health['InstanceStates'][0]['InstanceId'].\
should.equal(instances[0]['InstanceId'])
instances_health['InstanceStates'][0]['State'].\
should.equal('InService')
instances_health['InstanceStates'][1]['InstanceId'].\
should.equal(instances[1]['InstanceId'])
instances_health['InstanceStates'][1]['State'].\
should.equal('Unknown')
@mock_elb @mock_elb
def test_add_remove_tags(): def test_add_remove_tags():
client = boto3.client('elb', region_name='us-east-1') client = boto3.client('elb', region_name='us-east-1')

View File

@ -29,3 +29,28 @@ TABLE_INPUT = {
}, },
'TableType': 'EXTERNAL_TABLE', 'TableType': 'EXTERNAL_TABLE',
} }
PARTITION_INPUT = {
# 'DatabaseName': 'dbname',
'StorageDescriptor': {
'BucketColumns': [],
'Columns': [],
'Compressed': False,
'InputFormat': 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat',
'Location': 's3://.../partition=value',
'NumberOfBuckets': -1,
'OutputFormat': 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat',
'Parameters': {},
'SerdeInfo': {
'Parameters': {'path': 's3://...', 'serialization.format': '1'},
'SerializationLibrary': 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'},
'SkewedInfo': {'SkewedColumnNames': [],
'SkewedColumnValueLocationMaps': {},
'SkewedColumnValues': []},
'SortColumns': [],
'StoredAsSubDirectories': False,
},
# 'TableName': 'source_table',
# 'Values': ['2018-06-26'],
}

View File

@ -2,7 +2,7 @@ from __future__ import unicode_literals
import copy import copy
from .fixtures.datacatalog import TABLE_INPUT from .fixtures.datacatalog import TABLE_INPUT, PARTITION_INPUT
def create_database(client, database_name): def create_database(client, database_name):
@ -17,22 +17,38 @@ def get_database(client, database_name):
return client.get_database(Name=database_name) return client.get_database(Name=database_name)
def create_table_input(table_name, s3_location, columns=[], partition_keys=[]): def create_table_input(database_name, table_name, columns=[], partition_keys=[]):
table_input = copy.deepcopy(TABLE_INPUT) table_input = copy.deepcopy(TABLE_INPUT)
table_input['Name'] = table_name table_input['Name'] = table_name
table_input['PartitionKeys'] = partition_keys table_input['PartitionKeys'] = partition_keys
table_input['StorageDescriptor']['Columns'] = columns table_input['StorageDescriptor']['Columns'] = columns
table_input['StorageDescriptor']['Location'] = s3_location table_input['StorageDescriptor']['Location'] = 's3://my-bucket/{database_name}/{table_name}'.format(
database_name=database_name,
table_name=table_name
)
return table_input return table_input
def create_table(client, database_name, table_name, table_input): def create_table(client, database_name, table_name, table_input=None, **kwargs):
if table_input is None:
table_input = create_table_input(database_name, table_name, **kwargs)
return client.create_table( return client.create_table(
DatabaseName=database_name, DatabaseName=database_name,
TableInput=table_input TableInput=table_input
) )
def update_table(client, database_name, table_name, table_input=None, **kwargs):
if table_input is None:
table_input = create_table_input(database_name, table_name, **kwargs)
return client.update_table(
DatabaseName=database_name,
TableInput=table_input,
)
def get_table(client, database_name, table_name): def get_table(client, database_name, table_name):
return client.get_table( return client.get_table(
DatabaseName=database_name, DatabaseName=database_name,
@ -44,3 +60,60 @@ def get_tables(client, database_name):
return client.get_tables( return client.get_tables(
DatabaseName=database_name DatabaseName=database_name
) )
def get_table_versions(client, database_name, table_name):
return client.get_table_versions(
DatabaseName=database_name,
TableName=table_name
)
def get_table_version(client, database_name, table_name, version_id):
return client.get_table_version(
DatabaseName=database_name,
TableName=table_name,
VersionId=version_id,
)
def create_partition_input(database_name, table_name, values=[], columns=[]):
root_path = 's3://my-bucket/{database_name}/{table_name}'.format(
database_name=database_name,
table_name=table_name
)
part_input = copy.deepcopy(PARTITION_INPUT)
part_input['Values'] = values
part_input['StorageDescriptor']['Columns'] = columns
part_input['StorageDescriptor']['SerdeInfo']['Parameters']['path'] = root_path
return part_input
def create_partition(client, database_name, table_name, partiton_input=None, **kwargs):
if partiton_input is None:
partiton_input = create_partition_input(database_name, table_name, **kwargs)
return client.create_partition(
DatabaseName=database_name,
TableName=table_name,
PartitionInput=partiton_input
)
def update_partition(client, database_name, table_name, old_values=[], partiton_input=None, **kwargs):
if partiton_input is None:
partiton_input = create_partition_input(database_name, table_name, **kwargs)
return client.update_partition(
DatabaseName=database_name,
TableName=table_name,
PartitionInput=partiton_input,
PartitionValueList=old_values,
)
def get_partition(client, database_name, table_name, values):
return client.get_partition(
DatabaseName=database_name,
TableName=table_name,
PartitionValues=values,
)

View File

@ -1,10 +1,15 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import sure # noqa import sure # noqa
import re
from nose.tools import assert_raises from nose.tools import assert_raises
import boto3 import boto3
from botocore.client import ClientError from botocore.client import ClientError
from datetime import datetime
import pytz
from moto import mock_glue from moto import mock_glue
from . import helpers from . import helpers
@ -30,7 +35,19 @@ def test_create_database_already_exists():
with assert_raises(ClientError) as exc: with assert_raises(ClientError) as exc:
helpers.create_database(client, database_name) helpers.create_database(client, database_name)
exc.exception.response['Error']['Code'].should.equal('DatabaseAlreadyExistsException') exc.exception.response['Error']['Code'].should.equal('AlreadyExistsException')
@mock_glue
def test_get_database_not_exits():
client = boto3.client('glue', region_name='us-east-1')
database_name = 'nosuchdatabase'
with assert_raises(ClientError) as exc:
helpers.get_database(client, database_name)
exc.exception.response['Error']['Code'].should.equal('EntityNotFoundException')
exc.exception.response['Error']['Message'].should.match('Database nosuchdatabase not found')
@mock_glue @mock_glue
@ -40,12 +57,7 @@ def test_create_table():
helpers.create_database(client, database_name) helpers.create_database(client, database_name)
table_name = 'myspecialtable' table_name = 'myspecialtable'
s3_location = 's3://my-bucket/{database_name}/{table_name}'.format( table_input = helpers.create_table_input(database_name, table_name)
database_name=database_name,
table_name=table_name
)
table_input = helpers.create_table_input(table_name, s3_location)
helpers.create_table(client, database_name, table_name, table_input) helpers.create_table(client, database_name, table_name, table_input)
response = helpers.get_table(client, database_name, table_name) response = helpers.get_table(client, database_name, table_name)
@ -63,18 +75,12 @@ def test_create_table_already_exists():
helpers.create_database(client, database_name) helpers.create_database(client, database_name)
table_name = 'cantcreatethistabletwice' table_name = 'cantcreatethistabletwice'
s3_location = 's3://my-bucket/{database_name}/{table_name}'.format( helpers.create_table(client, database_name, table_name)
database_name=database_name,
table_name=table_name
)
table_input = helpers.create_table_input(table_name, s3_location)
helpers.create_table(client, database_name, table_name, table_input)
with assert_raises(ClientError) as exc: with assert_raises(ClientError) as exc:
helpers.create_table(client, database_name, table_name, table_input) helpers.create_table(client, database_name, table_name)
exc.exception.response['Error']['Code'].should.equal('TableAlreadyExistsException') exc.exception.response['Error']['Code'].should.equal('AlreadyExistsException')
@mock_glue @mock_glue
@ -87,11 +93,7 @@ def test_get_tables():
table_inputs = {} table_inputs = {}
for table_name in table_names: for table_name in table_names:
s3_location = 's3://my-bucket/{database_name}/{table_name}'.format( table_input = helpers.create_table_input(database_name, table_name)
database_name=database_name,
table_name=table_name
)
table_input = helpers.create_table_input(table_name, s3_location)
table_inputs[table_name] = table_input table_inputs[table_name] = table_input
helpers.create_table(client, database_name, table_name, table_input) helpers.create_table(client, database_name, table_name, table_input)
@ -99,10 +101,326 @@ def test_get_tables():
tables = response['TableList'] tables = response['TableList']
assert len(tables) == 3 tables.should.have.length_of(3)
for table in tables: for table in tables:
table_name = table['Name'] table_name = table['Name']
table_name.should.equal(table_inputs[table_name]['Name']) table_name.should.equal(table_inputs[table_name]['Name'])
table['StorageDescriptor'].should.equal(table_inputs[table_name]['StorageDescriptor']) table['StorageDescriptor'].should.equal(table_inputs[table_name]['StorageDescriptor'])
table['PartitionKeys'].should.equal(table_inputs[table_name]['PartitionKeys']) table['PartitionKeys'].should.equal(table_inputs[table_name]['PartitionKeys'])
@mock_glue
def test_get_table_versions():
client = boto3.client('glue', region_name='us-east-1')
database_name = 'myspecialdatabase'
helpers.create_database(client, database_name)
table_name = 'myfirsttable'
version_inputs = {}
table_input = helpers.create_table_input(database_name, table_name)
helpers.create_table(client, database_name, table_name, table_input)
version_inputs["1"] = table_input
columns = [{'Name': 'country', 'Type': 'string'}]
table_input = helpers.create_table_input(database_name, table_name, columns=columns)
helpers.update_table(client, database_name, table_name, table_input)
version_inputs["2"] = table_input
# Updateing with an indentical input should still create a new version
helpers.update_table(client, database_name, table_name, table_input)
version_inputs["3"] = table_input
response = helpers.get_table_versions(client, database_name, table_name)
vers = response['TableVersions']
vers.should.have.length_of(3)
vers[0]['Table']['StorageDescriptor']['Columns'].should.equal([])
vers[-1]['Table']['StorageDescriptor']['Columns'].should.equal(columns)
for n, ver in enumerate(vers):
n = str(n + 1)
ver['VersionId'].should.equal(n)
ver['Table']['Name'].should.equal(table_name)
ver['Table']['StorageDescriptor'].should.equal(version_inputs[n]['StorageDescriptor'])
ver['Table']['PartitionKeys'].should.equal(version_inputs[n]['PartitionKeys'])
response = helpers.get_table_version(client, database_name, table_name, "3")
ver = response['TableVersion']
ver['VersionId'].should.equal("3")
ver['Table']['Name'].should.equal(table_name)
ver['Table']['StorageDescriptor']['Columns'].should.equal(columns)
@mock_glue
def test_get_table_version_not_found():
client = boto3.client('glue', region_name='us-east-1')
database_name = 'myspecialdatabase'
table_name = 'myfirsttable'
helpers.create_database(client, database_name)
helpers.create_table(client, database_name, table_name)
with assert_raises(ClientError) as exc:
helpers.get_table_version(client, database_name, 'myfirsttable', "20")
exc.exception.response['Error']['Code'].should.equal('EntityNotFoundException')
exc.exception.response['Error']['Message'].should.match('version', re.I)
@mock_glue
def test_get_table_version_invalid_input():
client = boto3.client('glue', region_name='us-east-1')
database_name = 'myspecialdatabase'
table_name = 'myfirsttable'
helpers.create_database(client, database_name)
helpers.create_table(client, database_name, table_name)
with assert_raises(ClientError) as exc:
helpers.get_table_version(client, database_name, 'myfirsttable', "10not-an-int")
exc.exception.response['Error']['Code'].should.equal('InvalidInputException')
@mock_glue
def test_get_table_not_exits():
client = boto3.client('glue', region_name='us-east-1')
database_name = 'myspecialdatabase'
helpers.create_database(client, database_name)
with assert_raises(ClientError) as exc:
helpers.get_table(client, database_name, 'myfirsttable')
exc.exception.response['Error']['Code'].should.equal('EntityNotFoundException')
exc.exception.response['Error']['Message'].should.match('Table myfirsttable not found')
@mock_glue
def test_get_table_when_database_not_exits():
client = boto3.client('glue', region_name='us-east-1')
database_name = 'nosuchdatabase'
with assert_raises(ClientError) as exc:
helpers.get_table(client, database_name, 'myfirsttable')
exc.exception.response['Error']['Code'].should.equal('EntityNotFoundException')
exc.exception.response['Error']['Message'].should.match('Database nosuchdatabase not found')
@mock_glue
def test_get_partitions_empty():
client = boto3.client('glue', region_name='us-east-1')
database_name = 'myspecialdatabase'
table_name = 'myfirsttable'
helpers.create_database(client, database_name)
helpers.create_table(client, database_name, table_name)
response = client.get_partitions(DatabaseName=database_name, TableName=table_name)
response['Partitions'].should.have.length_of(0)
@mock_glue
def test_create_partition():
client = boto3.client('glue', region_name='us-east-1')
database_name = 'myspecialdatabase'
table_name = 'myfirsttable'
values = ['2018-10-01']
helpers.create_database(client, database_name)
helpers.create_table(client, database_name, table_name)
before = datetime.now(pytz.utc)
part_input = helpers.create_partition_input(database_name, table_name, values=values)
helpers.create_partition(client, database_name, table_name, part_input)
after = datetime.now(pytz.utc)
response = client.get_partitions(DatabaseName=database_name, TableName=table_name)
partitions = response['Partitions']
partitions.should.have.length_of(1)
partition = partitions[0]
partition['TableName'].should.equal(table_name)
partition['StorageDescriptor'].should.equal(part_input['StorageDescriptor'])
partition['Values'].should.equal(values)
partition['CreationTime'].should.be.greater_than(before)
partition['CreationTime'].should.be.lower_than(after)
@mock_glue
def test_create_partition_already_exist():
client = boto3.client('glue', region_name='us-east-1')
database_name = 'myspecialdatabase'
table_name = 'myfirsttable'
values = ['2018-10-01']
helpers.create_database(client, database_name)
helpers.create_table(client, database_name, table_name)
helpers.create_partition(client, database_name, table_name, values=values)
with assert_raises(ClientError) as exc:
helpers.create_partition(client, database_name, table_name, values=values)
exc.exception.response['Error']['Code'].should.equal('AlreadyExistsException')
@mock_glue
def test_get_partition_not_found():
client = boto3.client('glue', region_name='us-east-1')
database_name = 'myspecialdatabase'
table_name = 'myfirsttable'
values = ['2018-10-01']
helpers.create_database(client, database_name)
helpers.create_table(client, database_name, table_name)
with assert_raises(ClientError) as exc:
helpers.get_partition(client, database_name, table_name, values)
exc.exception.response['Error']['Code'].should.equal('EntityNotFoundException')
exc.exception.response['Error']['Message'].should.match('partition')
@mock_glue
def test_get_partition():
client = boto3.client('glue', region_name='us-east-1')
database_name = 'myspecialdatabase'
table_name = 'myfirsttable'
helpers.create_database(client, database_name)
helpers.create_table(client, database_name, table_name)
values = [['2018-10-01'], ['2018-09-01']]
helpers.create_partition(client, database_name, table_name, values=values[0])
helpers.create_partition(client, database_name, table_name, values=values[1])
response = client.get_partition(DatabaseName=database_name, TableName=table_name, PartitionValues=values[1])
partition = response['Partition']
partition['TableName'].should.equal(table_name)
partition['Values'].should.equal(values[1])
@mock_glue
def test_update_partition_not_found_moving():
client = boto3.client('glue', region_name='us-east-1')
database_name = 'myspecialdatabase'
table_name = 'myfirsttable'
helpers.create_database(client, database_name)
helpers.create_table(client, database_name, table_name)
with assert_raises(ClientError) as exc:
helpers.update_partition(client, database_name, table_name, old_values=['0000-00-00'], values=['2018-10-02'])
exc.exception.response['Error']['Code'].should.equal('EntityNotFoundException')
exc.exception.response['Error']['Message'].should.match('partition')
@mock_glue
def test_update_partition_not_found_change_in_place():
client = boto3.client('glue', region_name='us-east-1')
database_name = 'myspecialdatabase'
table_name = 'myfirsttable'
values = ['2018-10-01']
helpers.create_database(client, database_name)
helpers.create_table(client, database_name, table_name)
with assert_raises(ClientError) as exc:
helpers.update_partition(client, database_name, table_name, old_values=values, values=values)
exc.exception.response['Error']['Code'].should.equal('EntityNotFoundException')
exc.exception.response['Error']['Message'].should.match('partition')
@mock_glue
def test_update_partition_cannot_overwrite():
client = boto3.client('glue', region_name='us-east-1')
database_name = 'myspecialdatabase'
table_name = 'myfirsttable'
helpers.create_database(client, database_name)
helpers.create_table(client, database_name, table_name)
values = [['2018-10-01'], ['2018-09-01']]
helpers.create_partition(client, database_name, table_name, values=values[0])
helpers.create_partition(client, database_name, table_name, values=values[1])
with assert_raises(ClientError) as exc:
helpers.update_partition(client, database_name, table_name, old_values=values[0], values=values[1])
exc.exception.response['Error']['Code'].should.equal('AlreadyExistsException')
@mock_glue
def test_update_partition():
client = boto3.client('glue', region_name='us-east-1')
database_name = 'myspecialdatabase'
table_name = 'myfirsttable'
values = ['2018-10-01']
helpers.create_database(client, database_name)
helpers.create_table(client, database_name, table_name)
helpers.create_partition(client, database_name, table_name, values=values)
response = helpers.update_partition(
client,
database_name,
table_name,
old_values=values,
values=values,
columns=[{'Name': 'country', 'Type': 'string'}],
)
response = client.get_partition(DatabaseName=database_name, TableName=table_name, PartitionValues=values)
partition = response['Partition']
partition['TableName'].should.equal(table_name)
partition['StorageDescriptor']['Columns'].should.equal([{'Name': 'country', 'Type': 'string'}])
@mock_glue
def test_update_partition_move():
client = boto3.client('glue', region_name='us-east-1')
database_name = 'myspecialdatabase'
table_name = 'myfirsttable'
values = ['2018-10-01']
new_values = ['2018-09-01']
helpers.create_database(client, database_name)
helpers.create_table(client, database_name, table_name)
helpers.create_partition(client, database_name, table_name, values=values)
response = helpers.update_partition(
client,
database_name,
table_name,
old_values=values,
values=new_values,
columns=[{'Name': 'country', 'Type': 'string'}],
)
with assert_raises(ClientError) as exc:
helpers.get_partition(client, database_name, table_name, values)
# Old partition shouldn't exist anymore
exc.exception.response['Error']['Code'].should.equal('EntityNotFoundException')
response = client.get_partition(DatabaseName=database_name, TableName=table_name, PartitionValues=new_values)
partition = response['Partition']
partition['TableName'].should.equal(table_name)
partition['StorageDescriptor']['Columns'].should.equal([{'Name': 'country', 'Type': 'string'}])

View File

@ -286,6 +286,16 @@ def test_create_policy_versions():
PolicyDocument='{"some":"policy"}') PolicyDocument='{"some":"policy"}')
version.get('PolicyVersion').get('Document').should.equal({'some': 'policy'}) version.get('PolicyVersion').get('Document').should.equal({'some': 'policy'})
@mock_iam
def test_get_policy():
conn = boto3.client('iam', region_name='us-east-1')
response = conn.create_policy(
PolicyName="TestGetPolicy",
PolicyDocument='{"some":"policy"}')
policy = conn.get_policy(
PolicyArn="arn:aws:iam::123456789012:policy/TestGetPolicy")
response['Policy']['Arn'].should.equal("arn:aws:iam::123456789012:policy/TestGetPolicy")
@mock_iam @mock_iam
def test_get_policy_version(): def test_get_policy_version():
@ -314,17 +324,22 @@ def test_list_policy_versions():
PolicyArn="arn:aws:iam::123456789012:policy/TestListPolicyVersions") PolicyArn="arn:aws:iam::123456789012:policy/TestListPolicyVersions")
conn.create_policy( conn.create_policy(
PolicyName="TestListPolicyVersions", PolicyName="TestListPolicyVersions",
PolicyDocument='{"some":"policy"}')
conn.create_policy_version(
PolicyArn="arn:aws:iam::123456789012:policy/TestListPolicyVersions",
PolicyDocument='{"first":"policy"}') PolicyDocument='{"first":"policy"}')
versions = conn.list_policy_versions(
PolicyArn="arn:aws:iam::123456789012:policy/TestListPolicyVersions")
versions.get('Versions')[0].get('VersionId').should.equal('v1')
conn.create_policy_version( conn.create_policy_version(
PolicyArn="arn:aws:iam::123456789012:policy/TestListPolicyVersions", PolicyArn="arn:aws:iam::123456789012:policy/TestListPolicyVersions",
PolicyDocument='{"second":"policy"}') PolicyDocument='{"second":"policy"}')
conn.create_policy_version(
PolicyArn="arn:aws:iam::123456789012:policy/TestListPolicyVersions",
PolicyDocument='{"third":"policy"}')
versions = conn.list_policy_versions( versions = conn.list_policy_versions(
PolicyArn="arn:aws:iam::123456789012:policy/TestListPolicyVersions") PolicyArn="arn:aws:iam::123456789012:policy/TestListPolicyVersions")
versions.get('Versions')[0].get('Document').should.equal({'first': 'policy'}) print(versions.get('Versions'))
versions.get('Versions')[1].get('Document').should.equal({'second': 'policy'}) versions.get('Versions')[1].get('Document').should.equal({'second': 'policy'})
versions.get('Versions')[2].get('Document').should.equal({'third': 'policy'})
@mock_iam @mock_iam
@ -332,20 +347,20 @@ def test_delete_policy_version():
conn = boto3.client('iam', region_name='us-east-1') conn = boto3.client('iam', region_name='us-east-1')
conn.create_policy( conn.create_policy(
PolicyName="TestDeletePolicyVersion", PolicyName="TestDeletePolicyVersion",
PolicyDocument='{"some":"policy"}') PolicyDocument='{"first":"policy"}')
conn.create_policy_version( conn.create_policy_version(
PolicyArn="arn:aws:iam::123456789012:policy/TestDeletePolicyVersion", PolicyArn="arn:aws:iam::123456789012:policy/TestDeletePolicyVersion",
PolicyDocument='{"first":"policy"}') PolicyDocument='{"second":"policy"}')
with assert_raises(ClientError): with assert_raises(ClientError):
conn.delete_policy_version( conn.delete_policy_version(
PolicyArn="arn:aws:iam::123456789012:policy/TestDeletePolicyVersion", PolicyArn="arn:aws:iam::123456789012:policy/TestDeletePolicyVersion",
VersionId='v2-nope-this-does-not-exist') VersionId='v2-nope-this-does-not-exist')
conn.delete_policy_version( conn.delete_policy_version(
PolicyArn="arn:aws:iam::123456789012:policy/TestDeletePolicyVersion", PolicyArn="arn:aws:iam::123456789012:policy/TestDeletePolicyVersion",
VersionId='v1') VersionId='v2')
versions = conn.list_policy_versions( versions = conn.list_policy_versions(
PolicyArn="arn:aws:iam::123456789012:policy/TestDeletePolicyVersion") PolicyArn="arn:aws:iam::123456789012:policy/TestDeletePolicyVersion")
len(versions.get('Versions')).should.equal(0) len(versions.get('Versions')).should.equal(1)
@mock_iam_deprecated() @mock_iam_deprecated()

View File

@ -1,5 +1,5 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import re import os, re
import boto3 import boto3
import boto.kms import boto.kms
@ -8,6 +8,9 @@ from boto.kms.exceptions import AlreadyExistsException, NotFoundException
import sure # noqa import sure # noqa
from moto import mock_kms, mock_kms_deprecated from moto import mock_kms, mock_kms_deprecated
from nose.tools import assert_raises from nose.tools import assert_raises
from freezegun import freeze_time
from datetime import datetime, timedelta
from dateutil.tz import tzlocal
@mock_kms_deprecated @mock_kms_deprecated
@ -617,3 +620,100 @@ def test_kms_encrypt_boto3():
response = client.decrypt(CiphertextBlob=response['CiphertextBlob']) response = client.decrypt(CiphertextBlob=response['CiphertextBlob'])
response['Plaintext'].should.equal(b'bar') response['Plaintext'].should.equal(b'bar')
@mock_kms
def test_disable_key():
client = boto3.client('kms', region_name='us-east-1')
key = client.create_key(Description='disable-key')
client.disable_key(
KeyId=key['KeyMetadata']['KeyId']
)
result = client.describe_key(KeyId=key['KeyMetadata']['KeyId'])
assert result["KeyMetadata"]["Enabled"] == False
assert result["KeyMetadata"]["KeyState"] == 'Disabled'
@mock_kms
def test_enable_key():
client = boto3.client('kms', region_name='us-east-1')
key = client.create_key(Description='enable-key')
client.disable_key(
KeyId=key['KeyMetadata']['KeyId']
)
client.enable_key(
KeyId=key['KeyMetadata']['KeyId']
)
result = client.describe_key(KeyId=key['KeyMetadata']['KeyId'])
assert result["KeyMetadata"]["Enabled"] == True
assert result["KeyMetadata"]["KeyState"] == 'Enabled'
@mock_kms
def test_schedule_key_deletion():
client = boto3.client('kms', region_name='us-east-1')
key = client.create_key(Description='schedule-key-deletion')
if os.environ.get('TEST_SERVER_MODE', 'false').lower() == 'false':
with freeze_time("2015-01-01 12:00:00"):
response = client.schedule_key_deletion(
KeyId=key['KeyMetadata']['KeyId']
)
assert response['KeyId'] == key['KeyMetadata']['KeyId']
assert response['DeletionDate'] == datetime(2015, 1, 31, 12, 0, tzinfo=tzlocal())
else:
# Can't manipulate time in server mode
response = client.schedule_key_deletion(
KeyId=key['KeyMetadata']['KeyId']
)
assert response['KeyId'] == key['KeyMetadata']['KeyId']
result = client.describe_key(KeyId=key['KeyMetadata']['KeyId'])
assert result["KeyMetadata"]["Enabled"] == False
assert result["KeyMetadata"]["KeyState"] == 'PendingDeletion'
assert 'DeletionDate' in result["KeyMetadata"]
@mock_kms
def test_schedule_key_deletion_custom():
client = boto3.client('kms', region_name='us-east-1')
key = client.create_key(Description='schedule-key-deletion')
if os.environ.get('TEST_SERVER_MODE', 'false').lower() == 'false':
with freeze_time("2015-01-01 12:00:00"):
response = client.schedule_key_deletion(
KeyId=key['KeyMetadata']['KeyId'],
PendingWindowInDays=7
)
assert response['KeyId'] == key['KeyMetadata']['KeyId']
assert response['DeletionDate'] == datetime(2015, 1, 8, 12, 0, tzinfo=tzlocal())
else:
# Can't manipulate time in server mode
response = client.schedule_key_deletion(
KeyId=key['KeyMetadata']['KeyId'],
PendingWindowInDays=7
)
assert response['KeyId'] == key['KeyMetadata']['KeyId']
result = client.describe_key(KeyId=key['KeyMetadata']['KeyId'])
assert result["KeyMetadata"]["Enabled"] == False
assert result["KeyMetadata"]["KeyState"] == 'PendingDeletion'
assert 'DeletionDate' in result["KeyMetadata"]
@mock_kms
def test_cancel_key_deletion():
client = boto3.client('kms', region_name='us-east-1')
key = client.create_key(Description='cancel-key-deletion')
client.schedule_key_deletion(
KeyId=key['KeyMetadata']['KeyId']
)
response = client.cancel_key_deletion(
KeyId=key['KeyMetadata']['KeyId']
)
assert response['KeyId'] == key['KeyMetadata']['KeyId']
result = client.describe_key(KeyId=key['KeyMetadata']['KeyId'])
assert result["KeyMetadata"]["Enabled"] == False
assert result["KeyMetadata"]["KeyState"] == 'Disabled'
assert 'DeletionDate' not in result["KeyMetadata"]

View File

@ -1,5 +1,6 @@
import boto3 import boto3
import sure # noqa import sure # noqa
import six
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
from moto import mock_logs, settings from moto import mock_logs, settings
@ -79,7 +80,7 @@ def test_put_logs():
{'timestamp': 0, 'message': 'hello'}, {'timestamp': 0, 'message': 'hello'},
{'timestamp': 0, 'message': 'world'} {'timestamp': 0, 'message': 'world'}
] ]
conn.put_log_events( putRes = conn.put_log_events(
logGroupName=log_group_name, logGroupName=log_group_name,
logStreamName=log_stream_name, logStreamName=log_stream_name,
logEvents=messages logEvents=messages
@ -89,6 +90,9 @@ def test_put_logs():
logStreamName=log_stream_name logStreamName=log_stream_name
) )
events = res['events'] events = res['events']
nextSequenceToken = putRes['nextSequenceToken']
assert isinstance(nextSequenceToken, six.string_types) == True
assert len(nextSequenceToken) == 56
events.should.have.length_of(2) events.should.have.length_of(2)
@ -117,4 +121,8 @@ def test_filter_logs_interleaved():
interleaved=True, interleaved=True,
) )
events = res['events'] events = res['events']
events.should.have.length_of(2) for original_message, resulting_event in zip(messages, events):
resulting_event['eventId'].should.equal(str(resulting_event['eventId']))
resulting_event['timestamp'].should.equal(original_message['timestamp'])
resulting_event['message'].should.equal(original_message['message'])

View File

View File

@ -0,0 +1,136 @@
from __future__ import unicode_literals
import six
import sure # noqa
import datetime
from moto.organizations import utils
EMAIL_REGEX = "^.+@[a-zA-Z0-9-.]+.[a-zA-Z]{2,3}|[0-9]{1,3}$"
ORG_ID_REGEX = r'o-[a-z0-9]{%s}' % utils.ORG_ID_SIZE
ROOT_ID_REGEX = r'r-[a-z0-9]{%s}' % utils.ROOT_ID_SIZE
OU_ID_REGEX = r'ou-[a-z0-9]{%s}-[a-z0-9]{%s}' % (utils.ROOT_ID_SIZE, utils.OU_ID_SUFFIX_SIZE)
ACCOUNT_ID_REGEX = r'[0-9]{%s}' % utils.ACCOUNT_ID_SIZE
CREATE_ACCOUNT_STATUS_ID_REGEX = r'car-[a-z0-9]{%s}' % utils.CREATE_ACCOUNT_STATUS_ID_SIZE
def test_make_random_org_id():
org_id = utils.make_random_org_id()
org_id.should.match(ORG_ID_REGEX)
def test_make_random_root_id():
root_id = utils.make_random_root_id()
root_id.should.match(ROOT_ID_REGEX)
def test_make_random_ou_id():
root_id = utils.make_random_root_id()
ou_id = utils.make_random_ou_id(root_id)
ou_id.should.match(OU_ID_REGEX)
def test_make_random_account_id():
account_id = utils.make_random_account_id()
account_id.should.match(ACCOUNT_ID_REGEX)
def test_make_random_create_account_status_id():
create_account_status_id = utils.make_random_create_account_status_id()
create_account_status_id.should.match(CREATE_ACCOUNT_STATUS_ID_REGEX)
def validate_organization(response):
org = response['Organization']
sorted(org.keys()).should.equal([
'Arn',
'AvailablePolicyTypes',
'FeatureSet',
'Id',
'MasterAccountArn',
'MasterAccountEmail',
'MasterAccountId',
])
org['Id'].should.match(ORG_ID_REGEX)
org['MasterAccountId'].should.equal(utils.MASTER_ACCOUNT_ID)
org['MasterAccountArn'].should.equal(utils.MASTER_ACCOUNT_ARN_FORMAT.format(
org['MasterAccountId'],
org['Id'],
))
org['Arn'].should.equal(utils.ORGANIZATION_ARN_FORMAT.format(
org['MasterAccountId'],
org['Id'],
))
org['MasterAccountEmail'].should.equal(utils.MASTER_ACCOUNT_EMAIL)
org['FeatureSet'].should.be.within(['ALL', 'CONSOLIDATED_BILLING'])
org['AvailablePolicyTypes'].should.equal([{
'Type': 'SERVICE_CONTROL_POLICY',
'Status': 'ENABLED'
}])
def validate_roots(org, response):
response.should.have.key('Roots').should.be.a(list)
response['Roots'].should_not.be.empty
root = response['Roots'][0]
root.should.have.key('Id').should.match(ROOT_ID_REGEX)
root.should.have.key('Arn').should.equal(utils.ROOT_ARN_FORMAT.format(
org['MasterAccountId'],
org['Id'],
root['Id'],
))
root.should.have.key('Name').should.be.a(six.string_types)
root.should.have.key('PolicyTypes').should.be.a(list)
root['PolicyTypes'][0].should.have.key('Type').should.equal('SERVICE_CONTROL_POLICY')
root['PolicyTypes'][0].should.have.key('Status').should.equal('ENABLED')
def validate_organizational_unit(org, response):
response.should.have.key('OrganizationalUnit').should.be.a(dict)
ou = response['OrganizationalUnit']
ou.should.have.key('Id').should.match(OU_ID_REGEX)
ou.should.have.key('Arn').should.equal(utils.OU_ARN_FORMAT.format(
org['MasterAccountId'],
org['Id'],
ou['Id'],
))
ou.should.have.key('Name').should.be.a(six.string_types)
def validate_account(org, account):
sorted(account.keys()).should.equal([
'Arn',
'Email',
'Id',
'JoinedMethod',
'JoinedTimestamp',
'Name',
'Status',
])
account['Id'].should.match(ACCOUNT_ID_REGEX)
account['Arn'].should.equal(utils.ACCOUNT_ARN_FORMAT.format(
org['MasterAccountId'],
org['Id'],
account['Id'],
))
account['Email'].should.match(EMAIL_REGEX)
account['JoinedMethod'].should.be.within(['INVITED', 'CREATED'])
account['Status'].should.be.within(['ACTIVE', 'SUSPENDED'])
account['Name'].should.be.a(six.string_types)
account['JoinedTimestamp'].should.be.a(datetime.datetime)
def validate_create_account_status(create_status):
sorted(create_status.keys()).should.equal([
'AccountId',
'AccountName',
'CompletedTimestamp',
'Id',
'RequestedTimestamp',
'State',
])
create_status['Id'].should.match(CREATE_ACCOUNT_STATUS_ID_REGEX)
create_status['AccountId'].should.match(ACCOUNT_ID_REGEX)
create_status['AccountName'].should.be.a(six.string_types)
create_status['State'].should.equal('SUCCEEDED')
create_status['RequestedTimestamp'].should.be.a(datetime.datetime)
create_status['CompletedTimestamp'].should.be.a(datetime.datetime)

View File

@ -0,0 +1,322 @@
from __future__ import unicode_literals
import boto3
import sure # noqa
from botocore.exceptions import ClientError
from nose.tools import assert_raises
from moto import mock_organizations
from moto.organizations import utils
from .organizations_test_utils import (
validate_organization,
validate_roots,
validate_organizational_unit,
validate_account,
validate_create_account_status,
)
@mock_organizations
def test_create_organization():
client = boto3.client('organizations', region_name='us-east-1')
response = client.create_organization(FeatureSet='ALL')
validate_organization(response)
response['Organization']['FeatureSet'].should.equal('ALL')
@mock_organizations
def test_describe_organization():
client = boto3.client('organizations', region_name='us-east-1')
client.create_organization(FeatureSet='ALL')
response = client.describe_organization()
validate_organization(response)
@mock_organizations
def test_describe_organization_exception():
client = boto3.client('organizations', region_name='us-east-1')
with assert_raises(ClientError) as e:
response = client.describe_organization()
ex = e.exception
ex.operation_name.should.equal('DescribeOrganization')
ex.response['Error']['Code'].should.equal('400')
ex.response['Error']['Message'].should.contain('AWSOrganizationsNotInUseException')
# Organizational Units
@mock_organizations
def test_list_roots():
client = boto3.client('organizations', region_name='us-east-1')
org = client.create_organization(FeatureSet='ALL')['Organization']
response = client.list_roots()
validate_roots(org, response)
@mock_organizations
def test_create_organizational_unit():
client = boto3.client('organizations', region_name='us-east-1')
org = client.create_organization(FeatureSet='ALL')['Organization']
root_id = client.list_roots()['Roots'][0]['Id']
ou_name = 'ou01'
response = client.create_organizational_unit(
ParentId=root_id,
Name=ou_name,
)
validate_organizational_unit(org, response)
response['OrganizationalUnit']['Name'].should.equal(ou_name)
@mock_organizations
def test_describe_organizational_unit():
client = boto3.client('organizations', region_name='us-east-1')
org = client.create_organization(FeatureSet='ALL')['Organization']
root_id = client.list_roots()['Roots'][0]['Id']
ou_id = client.create_organizational_unit(
ParentId=root_id,
Name='ou01',
)['OrganizationalUnit']['Id']
response = client.describe_organizational_unit(OrganizationalUnitId=ou_id)
validate_organizational_unit(org, response)
@mock_organizations
def test_describe_organizational_unit_exception():
client = boto3.client('organizations', region_name='us-east-1')
org = client.create_organization(FeatureSet='ALL')['Organization']
with assert_raises(ClientError) as e:
response = client.describe_organizational_unit(
OrganizationalUnitId=utils.make_random_root_id()
)
ex = e.exception
ex.operation_name.should.equal('DescribeOrganizationalUnit')
ex.response['Error']['Code'].should.equal('400')
ex.response['Error']['Message'].should.contain('OrganizationalUnitNotFoundException')
@mock_organizations
def test_list_organizational_units_for_parent():
client = boto3.client('organizations', region_name='us-east-1')
org = client.create_organization(FeatureSet='ALL')['Organization']
root_id = client.list_roots()['Roots'][0]['Id']
client.create_organizational_unit(ParentId=root_id, Name='ou01')
client.create_organizational_unit(ParentId=root_id, Name='ou02')
client.create_organizational_unit(ParentId=root_id, Name='ou03')
response = client.list_organizational_units_for_parent(ParentId=root_id)
response.should.have.key('OrganizationalUnits').should.be.a(list)
for ou in response['OrganizationalUnits']:
validate_organizational_unit(org, dict(OrganizationalUnit=ou))
@mock_organizations
def test_list_organizational_units_for_parent_exception():
client = boto3.client('organizations', region_name='us-east-1')
with assert_raises(ClientError) as e:
response = client.list_organizational_units_for_parent(
ParentId=utils.make_random_root_id()
)
ex = e.exception
ex.operation_name.should.equal('ListOrganizationalUnitsForParent')
ex.response['Error']['Code'].should.equal('400')
ex.response['Error']['Message'].should.contain('ParentNotFoundException')
# Accounts
mockname = 'mock-account'
mockdomain = 'moto-example.org'
mockemail = '@'.join([mockname, mockdomain])
@mock_organizations
def test_create_account():
client = boto3.client('organizations', region_name='us-east-1')
client.create_organization(FeatureSet='ALL')
create_status = client.create_account(
AccountName=mockname, Email=mockemail
)['CreateAccountStatus']
validate_create_account_status(create_status)
create_status['AccountName'].should.equal(mockname)
@mock_organizations
def test_describe_account():
client = boto3.client('organizations', region_name='us-east-1')
org = client.create_organization(FeatureSet='ALL')['Organization']
account_id = client.create_account(
AccountName=mockname, Email=mockemail
)['CreateAccountStatus']['AccountId']
response = client.describe_account(AccountId=account_id)
validate_account(org, response['Account'])
response['Account']['Name'].should.equal(mockname)
response['Account']['Email'].should.equal(mockemail)
@mock_organizations
def test_describe_account_exception():
client = boto3.client('organizations', region_name='us-east-1')
with assert_raises(ClientError) as e:
response = client.describe_account(AccountId=utils.make_random_account_id())
ex = e.exception
ex.operation_name.should.equal('DescribeAccount')
ex.response['Error']['Code'].should.equal('400')
ex.response['Error']['Message'].should.contain('AccountNotFoundException')
@mock_organizations
def test_list_accounts():
client = boto3.client('organizations', region_name='us-east-1')
org = client.create_organization(FeatureSet='ALL')['Organization']
for i in range(5):
name = mockname + str(i)
email = name + '@' + mockdomain
client.create_account(AccountName=name, Email=email)
response = client.list_accounts()
response.should.have.key('Accounts')
accounts = response['Accounts']
len(accounts).should.equal(5)
for account in accounts:
validate_account(org, account)
accounts[3]['Name'].should.equal(mockname + '3')
accounts[2]['Email'].should.equal(mockname + '2' + '@' + mockdomain)
@mock_organizations
def test_list_accounts_for_parent():
client = boto3.client('organizations', region_name='us-east-1')
org = client.create_organization(FeatureSet='ALL')['Organization']
root_id = client.list_roots()['Roots'][0]['Id']
account_id = client.create_account(
AccountName=mockname,
Email=mockemail,
)['CreateAccountStatus']['AccountId']
response = client.list_accounts_for_parent(ParentId=root_id)
account_id.should.be.within([account['Id'] for account in response['Accounts']])
@mock_organizations
def test_move_account():
client = boto3.client('organizations', region_name='us-east-1')
org = client.create_organization(FeatureSet='ALL')['Organization']
root_id = client.list_roots()['Roots'][0]['Id']
account_id = client.create_account(
AccountName=mockname, Email=mockemail
)['CreateAccountStatus']['AccountId']
ou01 = client.create_organizational_unit(ParentId=root_id, Name='ou01')
ou01_id = ou01['OrganizationalUnit']['Id']
client.move_account(
AccountId=account_id,
SourceParentId=root_id,
DestinationParentId=ou01_id,
)
response = client.list_accounts_for_parent(ParentId=ou01_id)
account_id.should.be.within([account['Id'] for account in response['Accounts']])
@mock_organizations
def test_list_parents_for_ou():
client = boto3.client('organizations', region_name='us-east-1')
org = client.create_organization(FeatureSet='ALL')['Organization']
root_id = client.list_roots()['Roots'][0]['Id']
ou01 = client.create_organizational_unit(ParentId=root_id, Name='ou01')
ou01_id = ou01['OrganizationalUnit']['Id']
response01 = client.list_parents(ChildId=ou01_id)
response01.should.have.key('Parents').should.be.a(list)
response01['Parents'][0].should.have.key('Id').should.equal(root_id)
response01['Parents'][0].should.have.key('Type').should.equal('ROOT')
ou02 = client.create_organizational_unit(ParentId=ou01_id, Name='ou02')
ou02_id = ou02['OrganizationalUnit']['Id']
response02 = client.list_parents(ChildId=ou02_id)
response02.should.have.key('Parents').should.be.a(list)
response02['Parents'][0].should.have.key('Id').should.equal(ou01_id)
response02['Parents'][0].should.have.key('Type').should.equal('ORGANIZATIONAL_UNIT')
@mock_organizations
def test_list_parents_for_accounts():
client = boto3.client('organizations', region_name='us-east-1')
org = client.create_organization(FeatureSet='ALL')['Organization']
root_id = client.list_roots()['Roots'][0]['Id']
ou01 = client.create_organizational_unit(ParentId=root_id, Name='ou01')
ou01_id = ou01['OrganizationalUnit']['Id']
account01_id = client.create_account(
AccountName='account01',
Email='account01@moto-example.org'
)['CreateAccountStatus']['AccountId']
account02_id = client.create_account(
AccountName='account02',
Email='account02@moto-example.org'
)['CreateAccountStatus']['AccountId']
client.move_account(
AccountId=account02_id,
SourceParentId=root_id,
DestinationParentId=ou01_id,
)
response01 = client.list_parents(ChildId=account01_id)
response01.should.have.key('Parents').should.be.a(list)
response01['Parents'][0].should.have.key('Id').should.equal(root_id)
response01['Parents'][0].should.have.key('Type').should.equal('ROOT')
response02 = client.list_parents(ChildId=account02_id)
response02.should.have.key('Parents').should.be.a(list)
response02['Parents'][0].should.have.key('Id').should.equal(ou01_id)
response02['Parents'][0].should.have.key('Type').should.equal('ORGANIZATIONAL_UNIT')
@mock_organizations
def test_list_children():
client = boto3.client('organizations', region_name='us-east-1')
org = client.create_organization(FeatureSet='ALL')['Organization']
root_id = client.list_roots()['Roots'][0]['Id']
ou01 = client.create_organizational_unit(ParentId=root_id, Name='ou01')
ou01_id = ou01['OrganizationalUnit']['Id']
ou02 = client.create_organizational_unit(ParentId=ou01_id, Name='ou02')
ou02_id = ou02['OrganizationalUnit']['Id']
account01_id = client.create_account(
AccountName='account01',
Email='account01@moto-example.org'
)['CreateAccountStatus']['AccountId']
account02_id = client.create_account(
AccountName='account02',
Email='account02@moto-example.org'
)['CreateAccountStatus']['AccountId']
client.move_account(
AccountId=account02_id,
SourceParentId=root_id,
DestinationParentId=ou01_id,
)
response01 = client.list_children(ParentId=root_id, ChildType='ACCOUNT')
response02 = client.list_children(ParentId=root_id, ChildType='ORGANIZATIONAL_UNIT')
response03 = client.list_children(ParentId=ou01_id, ChildType='ACCOUNT')
response04 = client.list_children(ParentId=ou01_id, ChildType='ORGANIZATIONAL_UNIT')
response01['Children'][0]['Id'].should.equal(account01_id)
response01['Children'][0]['Type'].should.equal('ACCOUNT')
response02['Children'][0]['Id'].should.equal(ou01_id)
response02['Children'][0]['Type'].should.equal('ORGANIZATIONAL_UNIT')
response03['Children'][0]['Id'].should.equal(account02_id)
response03['Children'][0]['Type'].should.equal('ACCOUNT')
response04['Children'][0]['Id'].should.equal(ou02_id)
response04['Children'][0]['Type'].should.equal('ORGANIZATIONAL_UNIT')
@mock_organizations
def test_list_children_exception():
client = boto3.client('organizations', region_name='us-east-1')
org = client.create_organization(FeatureSet='ALL')['Organization']
root_id = client.list_roots()['Roots'][0]['Id']
with assert_raises(ClientError) as e:
response = client.list_children(
ParentId=utils.make_random_root_id(),
ChildType='ACCOUNT'
)
ex = e.exception
ex.operation_name.should.equal('ListChildren')
ex.response['Error']['Code'].should.equal('400')
ex.response['Error']['Message'].should.contain('ParentNotFoundException')
with assert_raises(ClientError) as e:
response = client.list_children(
ParentId=root_id,
ChildType='BLEE'
)
ex = e.exception
ex.operation_name.should.equal('ListChildren')
ex.response['Error']['Code'].should.equal('400')
ex.response['Error']['Message'].should.contain('InvalidInputException')

View File

@ -33,6 +33,7 @@ def test_create_database():
db_instance['DBInstanceIdentifier'].should.equal("db-master-1") db_instance['DBInstanceIdentifier'].should.equal("db-master-1")
db_instance['IAMDatabaseAuthenticationEnabled'].should.equal(False) db_instance['IAMDatabaseAuthenticationEnabled'].should.equal(False)
db_instance['DbiResourceId'].should.contain("db-") db_instance['DbiResourceId'].should.contain("db-")
db_instance['CopyTagsToSnapshot'].should.equal(False)
@mock_rds2 @mock_rds2
@ -339,6 +340,49 @@ def test_create_db_snapshots():
snapshot.get('Engine').should.equal('postgres') snapshot.get('Engine').should.equal('postgres')
snapshot.get('DBInstanceIdentifier').should.equal('db-primary-1') snapshot.get('DBInstanceIdentifier').should.equal('db-primary-1')
snapshot.get('DBSnapshotIdentifier').should.equal('g-1') snapshot.get('DBSnapshotIdentifier').should.equal('g-1')
result = conn.list_tags_for_resource(ResourceName=snapshot['DBSnapshotArn'])
result['TagList'].should.equal([])
@mock_rds2
def test_create_db_snapshots_copy_tags():
conn = boto3.client('rds', region_name='us-west-2')
conn.create_db_snapshot.when.called_with(
DBInstanceIdentifier='db-primary-1',
DBSnapshotIdentifier='snapshot-1').should.throw(ClientError)
conn.create_db_instance(DBInstanceIdentifier='db-primary-1',
AllocatedStorage=10,
Engine='postgres',
DBName='staging-postgres',
DBInstanceClass='db.m1.small',
MasterUsername='root',
MasterUserPassword='hunter2',
Port=1234,
DBSecurityGroups=["my_sg"],
CopyTagsToSnapshot=True,
Tags=[
{
'Key': 'foo',
'Value': 'bar',
},
{
'Key': 'foo1',
'Value': 'bar1',
},
])
snapshot = conn.create_db_snapshot(DBInstanceIdentifier='db-primary-1',
DBSnapshotIdentifier='g-1').get('DBSnapshot')
snapshot.get('Engine').should.equal('postgres')
snapshot.get('DBInstanceIdentifier').should.equal('db-primary-1')
snapshot.get('DBSnapshotIdentifier').should.equal('g-1')
result = conn.list_tags_for_resource(ResourceName=snapshot['DBSnapshotArn'])
result['TagList'].should.equal([{'Value': 'bar',
'Key': 'foo'},
{'Value': 'bar1',
'Key': 'foo1'}])
@mock_rds2 @mock_rds2
@ -656,6 +700,117 @@ def test_remove_tags_db():
len(result['TagList']).should.equal(1) len(result['TagList']).should.equal(1)
@mock_rds2
def test_list_tags_snapshot():
conn = boto3.client('rds', region_name='us-west-2')
result = conn.list_tags_for_resource(
ResourceName='arn:aws:rds:us-west-2:1234567890:snapshot:foo')
result['TagList'].should.equal([])
conn.create_db_instance(DBInstanceIdentifier='db-primary-1',
AllocatedStorage=10,
Engine='postgres',
DBName='staging-postgres',
DBInstanceClass='db.m1.small',
MasterUsername='root',
MasterUserPassword='hunter2',
Port=1234,
DBSecurityGroups=["my_sg"])
snapshot = conn.create_db_snapshot(DBInstanceIdentifier='db-primary-1',
DBSnapshotIdentifier='snapshot-with-tags',
Tags=[
{
'Key': 'foo',
'Value': 'bar',
},
{
'Key': 'foo1',
'Value': 'bar1',
},
])
result = conn.list_tags_for_resource(ResourceName=snapshot['DBSnapshot']['DBSnapshotArn'])
result['TagList'].should.equal([{'Value': 'bar',
'Key': 'foo'},
{'Value': 'bar1',
'Key': 'foo1'}])
@mock_rds2
def test_add_tags_snapshot():
conn = boto3.client('rds', region_name='us-west-2')
conn.create_db_instance(DBInstanceIdentifier='db-primary-1',
AllocatedStorage=10,
Engine='postgres',
DBName='staging-postgres',
DBInstanceClass='db.m1.small',
MasterUsername='root',
MasterUserPassword='hunter2',
Port=1234,
DBSecurityGroups=["my_sg"])
snapshot = conn.create_db_snapshot(DBInstanceIdentifier='db-primary-1',
DBSnapshotIdentifier='snapshot-without-tags',
Tags=[
{
'Key': 'foo',
'Value': 'bar',
},
{
'Key': 'foo1',
'Value': 'bar1',
},
])
result = conn.list_tags_for_resource(
ResourceName='arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-without-tags')
list(result['TagList']).should.have.length_of(2)
conn.add_tags_to_resource(ResourceName='arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-without-tags',
Tags=[
{
'Key': 'foo',
'Value': 'fish',
},
{
'Key': 'foo2',
'Value': 'bar2',
},
])
result = conn.list_tags_for_resource(
ResourceName='arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-without-tags')
list(result['TagList']).should.have.length_of(3)
@mock_rds2
def test_remove_tags_snapshot():
conn = boto3.client('rds', region_name='us-west-2')
conn.create_db_instance(DBInstanceIdentifier='db-primary-1',
AllocatedStorage=10,
Engine='postgres',
DBName='staging-postgres',
DBInstanceClass='db.m1.small',
MasterUsername='root',
MasterUserPassword='hunter2',
Port=1234,
DBSecurityGroups=["my_sg"])
snapshot = conn.create_db_snapshot(DBInstanceIdentifier='db-primary-1',
DBSnapshotIdentifier='snapshot-with-tags',
Tags=[
{
'Key': 'foo',
'Value': 'bar',
},
{
'Key': 'foo1',
'Value': 'bar1',
},
])
result = conn.list_tags_for_resource(
ResourceName='arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-with-tags')
list(result['TagList']).should.have.length_of(2)
conn.remove_tags_from_resource(
ResourceName='arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-with-tags', TagKeys=['foo'])
result = conn.list_tags_for_resource(
ResourceName='arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-with-tags')
len(result['TagList']).should.equal(1)
@mock_rds2 @mock_rds2
def test_add_tags_option_group(): def test_add_tags_option_group():
conn = boto3.client('rds', region_name='us-west-2') conn = boto3.client('rds', region_name='us-west-2')

View File

@ -1,5 +1,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import datetime
import boto import boto
import boto3 import boto3
from boto.redshift.exceptions import ( from boto.redshift.exceptions import (
@ -32,6 +34,8 @@ def test_create_cluster_boto3():
MasterUserPassword='password', MasterUserPassword='password',
) )
response['Cluster']['NodeType'].should.equal('ds2.xlarge') response['Cluster']['NodeType'].should.equal('ds2.xlarge')
create_time = response['Cluster']['ClusterCreateTime']
create_time.should.be.lower_than(datetime.datetime.now(create_time.tzinfo))
@mock_redshift @mock_redshift

View File

@ -2471,6 +2471,72 @@ def test_boto3_delete_markers():
oldest['Key'].should.equal('key-with-versions-and-unicode-ó') oldest['Key'].should.equal('key-with-versions-and-unicode-ó')
@mock_s3
def test_boto3_multiple_delete_markers():
s3 = boto3.client('s3', region_name='us-east-1')
bucket_name = 'mybucket'
key = u'key-with-versions-and-unicode-ó'
s3.create_bucket(Bucket=bucket_name)
s3.put_bucket_versioning(
Bucket=bucket_name,
VersioningConfiguration={
'Status': 'Enabled'
}
)
items = (six.b('v1'), six.b('v2'))
for body in items:
s3.put_object(
Bucket=bucket_name,
Key=key,
Body=body
)
# Delete the object twice to add multiple delete markers
s3.delete_object(Bucket=bucket_name, Key=key)
s3.delete_object(Bucket=bucket_name, Key=key)
response = s3.list_object_versions(Bucket=bucket_name)
response['DeleteMarkers'].should.have.length_of(2)
with assert_raises(ClientError) as e:
s3.get_object(
Bucket=bucket_name,
Key=key
)
e.response['Error']['Code'].should.equal('404')
# Remove both delete markers to restore the object
s3.delete_object(
Bucket=bucket_name,
Key=key,
VersionId='2'
)
s3.delete_object(
Bucket=bucket_name,
Key=key,
VersionId='3'
)
response = s3.get_object(
Bucket=bucket_name,
Key=key
)
response['Body'].read().should.equal(items[-1])
response = s3.list_object_versions(Bucket=bucket_name)
response['Versions'].should.have.length_of(2)
# We've asserted there is only 2 records so one is newest, one is oldest
latest = list(filter(lambda item: item['IsLatest'], response['Versions']))[0]
oldest = list(filter(lambda item: not item['IsLatest'], response['Versions']))[0]
# Double check ordering of version ID's
latest['VersionId'].should.equal('1')
oldest['VersionId'].should.equal('0')
# Double check the name is still unicode
latest['Key'].should.equal('key-with-versions-and-unicode-ó')
oldest['Key'].should.equal('key-with-versions-and-unicode-ó')
@mock_s3 @mock_s3
def test_get_stream_gzipped(): def test_get_stream_gzipped():
payload = b"this is some stuff here" payload = b"this is some stuff here"

View File

@ -191,6 +191,127 @@ def test_lifecycle_with_eodm():
assert err.exception.response["Error"]["Code"] == "MalformedXML" assert err.exception.response["Error"]["Code"] == "MalformedXML"
@mock_s3
def test_lifecycle_with_nve():
client = boto3.client("s3")
client.create_bucket(Bucket="bucket")
lfc = {
"Rules": [
{
"NoncurrentVersionExpiration": {
"NoncurrentDays": 30
},
"ID": "wholebucket",
"Filter": {
"Prefix": ""
},
"Status": "Enabled"
}
]
}
client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc)
result = client.get_bucket_lifecycle_configuration(Bucket="bucket")
assert len(result["Rules"]) == 1
assert result["Rules"][0]["NoncurrentVersionExpiration"]["NoncurrentDays"] == 30
# Change NoncurrentDays:
lfc["Rules"][0]["NoncurrentVersionExpiration"]["NoncurrentDays"] = 10
client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc)
result = client.get_bucket_lifecycle_configuration(Bucket="bucket")
assert len(result["Rules"]) == 1
assert result["Rules"][0]["NoncurrentVersionExpiration"]["NoncurrentDays"] == 10
# TODO: Add test for failures due to missing children
@mock_s3
def test_lifecycle_with_nvt():
client = boto3.client("s3")
client.create_bucket(Bucket="bucket")
lfc = {
"Rules": [
{
"NoncurrentVersionTransitions": [{
"NoncurrentDays": 30,
"StorageClass": "ONEZONE_IA"
}],
"ID": "wholebucket",
"Filter": {
"Prefix": ""
},
"Status": "Enabled"
}
]
}
client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc)
result = client.get_bucket_lifecycle_configuration(Bucket="bucket")
assert len(result["Rules"]) == 1
assert result["Rules"][0]["NoncurrentVersionTransitions"][0]["NoncurrentDays"] == 30
assert result["Rules"][0]["NoncurrentVersionTransitions"][0]["StorageClass"] == "ONEZONE_IA"
# Change NoncurrentDays:
lfc["Rules"][0]["NoncurrentVersionTransitions"][0]["NoncurrentDays"] = 10
client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc)
result = client.get_bucket_lifecycle_configuration(Bucket="bucket")
assert len(result["Rules"]) == 1
assert result["Rules"][0]["NoncurrentVersionTransitions"][0]["NoncurrentDays"] == 10
# Change StorageClass:
lfc["Rules"][0]["NoncurrentVersionTransitions"][0]["StorageClass"] = "GLACIER"
client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc)
result = client.get_bucket_lifecycle_configuration(Bucket="bucket")
assert len(result["Rules"]) == 1
assert result["Rules"][0]["NoncurrentVersionTransitions"][0]["StorageClass"] == "GLACIER"
# With failures for missing children:
del lfc["Rules"][0]["NoncurrentVersionTransitions"][0]["NoncurrentDays"]
with assert_raises(ClientError) as err:
client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc)
assert err.exception.response["Error"]["Code"] == "MalformedXML"
lfc["Rules"][0]["NoncurrentVersionTransitions"][0]["NoncurrentDays"] = 30
del lfc["Rules"][0]["NoncurrentVersionTransitions"][0]["StorageClass"]
with assert_raises(ClientError) as err:
client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc)
assert err.exception.response["Error"]["Code"] == "MalformedXML"
@mock_s3
def test_lifecycle_with_aimu():
client = boto3.client("s3")
client.create_bucket(Bucket="bucket")
lfc = {
"Rules": [
{
"AbortIncompleteMultipartUpload": {
"DaysAfterInitiation": 7
},
"ID": "wholebucket",
"Filter": {
"Prefix": ""
},
"Status": "Enabled"
}
]
}
client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc)
result = client.get_bucket_lifecycle_configuration(Bucket="bucket")
assert len(result["Rules"]) == 1
assert result["Rules"][0]["AbortIncompleteMultipartUpload"]["DaysAfterInitiation"] == 7
# Change DaysAfterInitiation:
lfc["Rules"][0]["AbortIncompleteMultipartUpload"]["DaysAfterInitiation"] = 30
client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc)
result = client.get_bucket_lifecycle_configuration(Bucket="bucket")
assert len(result["Rules"]) == 1
assert result["Rules"][0]["AbortIncompleteMultipartUpload"]["DaysAfterInitiation"] == 30
# TODO: Add test for failures due to missing children
@mock_s3_deprecated @mock_s3_deprecated
def test_lifecycle_with_glacier_transition(): def test_lifecycle_with_glacier_transition():
conn = boto.s3.connect_to_region("us-west-1") conn = boto.s3.connect_to_region("us-west-1")

View File

@ -26,13 +26,13 @@ def test_get_secret_that_does_not_exist():
result = conn.get_secret_value(SecretId='i-dont-exist') result = conn.get_secret_value(SecretId='i-dont-exist')
@mock_secretsmanager @mock_secretsmanager
def test_get_secret_with_mismatched_id(): def test_get_secret_that_does_not_match():
conn = boto3.client('secretsmanager', region_name='us-west-2') conn = boto3.client('secretsmanager', region_name='us-west-2')
create_secret = conn.create_secret(Name='java-util-test-password', create_secret = conn.create_secret(Name='java-util-test-password',
SecretString="foosecret") SecretString="foosecret")
with assert_raises(ClientError): with assert_raises(ClientError):
result = conn.get_secret_value(SecretId='i-dont-exist') result = conn.get_secret_value(SecretId='i-dont-match')
@mock_secretsmanager @mock_secretsmanager
def test_create_secret(): def test_create_secret():
@ -179,3 +179,108 @@ def test_describe_secret_that_does_not_match():
with assert_raises(ClientError): with assert_raises(ClientError):
result = conn.get_secret_value(SecretId='i-dont-match') result = conn.get_secret_value(SecretId='i-dont-match')
@mock_secretsmanager
def test_rotate_secret():
secret_name = 'test-secret'
conn = boto3.client('secretsmanager', region_name='us-west-2')
conn.create_secret(Name=secret_name,
SecretString='foosecret')
rotated_secret = conn.rotate_secret(SecretId=secret_name)
assert rotated_secret
assert rotated_secret['ARN'] == (
'arn:aws:secretsmanager:us-west-2:1234567890:secret:test-secret-rIjad'
)
assert rotated_secret['Name'] == secret_name
assert rotated_secret['VersionId'] != ''
@mock_secretsmanager
def test_rotate_secret_enable_rotation():
secret_name = 'test-secret'
conn = boto3.client('secretsmanager', region_name='us-west-2')
conn.create_secret(Name=secret_name,
SecretString='foosecret')
initial_description = conn.describe_secret(SecretId=secret_name)
assert initial_description
assert initial_description['RotationEnabled'] is False
assert initial_description['RotationRules']['AutomaticallyAfterDays'] == 0
conn.rotate_secret(SecretId=secret_name,
RotationRules={'AutomaticallyAfterDays': 42})
rotated_description = conn.describe_secret(SecretId=secret_name)
assert rotated_description
assert rotated_description['RotationEnabled'] is True
assert rotated_description['RotationRules']['AutomaticallyAfterDays'] == 42
@mock_secretsmanager
def test_rotate_secret_that_does_not_exist():
conn = boto3.client('secretsmanager', 'us-west-2')
with assert_raises(ClientError):
result = conn.rotate_secret(SecretId='i-dont-exist')
@mock_secretsmanager
def test_rotate_secret_that_does_not_match():
conn = boto3.client('secretsmanager', region_name='us-west-2')
conn.create_secret(Name='test-secret',
SecretString='foosecret')
with assert_raises(ClientError):
result = conn.rotate_secret(SecretId='i-dont-match')
@mock_secretsmanager
def test_rotate_secret_client_request_token_too_short():
# Test is intentionally empty. Boto3 catches too short ClientRequestToken
# and raises ParamValidationError before Moto can see it.
# test_server actually handles this error.
assert True
@mock_secretsmanager
def test_rotate_secret_client_request_token_too_long():
secret_name = 'test-secret'
conn = boto3.client('secretsmanager', region_name='us-west-2')
conn.create_secret(Name=secret_name,
SecretString='foosecret')
client_request_token = (
'ED9F8B6C-85B7-446A-B7E4-38F2A3BEB13C-'
'ED9F8B6C-85B7-446A-B7E4-38F2A3BEB13C'
)
with assert_raises(ClientError):
result = conn.rotate_secret(SecretId=secret_name,
ClientRequestToken=client_request_token)
@mock_secretsmanager
def test_rotate_secret_rotation_lambda_arn_too_long():
secret_name = 'test-secret'
conn = boto3.client('secretsmanager', region_name='us-west-2')
conn.create_secret(Name=secret_name,
SecretString='foosecret')
rotation_lambda_arn = '85B7-446A-B7E4' * 147 # == 2058 characters
with assert_raises(ClientError):
result = conn.rotate_secret(SecretId=secret_name,
RotationLambdaARN=rotation_lambda_arn)
@mock_secretsmanager
def test_rotate_secret_rotation_period_zero():
# Test is intentionally empty. Boto3 catches zero day rotation period
# and raises ParamValidationError before Moto can see it.
# test_server actually handles this error.
assert True
@mock_secretsmanager
def test_rotate_secret_rotation_period_too_long():
secret_name = 'test-secret'
conn = boto3.client('secretsmanager', region_name='us-west-2')
conn.create_secret(Name=secret_name,
SecretString='foosecret')
rotation_rules = {'AutomaticallyAfterDays': 1001}
with assert_raises(ClientError):
result = conn.rotate_secret(SecretId=secret_name,
RotationRules=rotation_rules)

View File

@ -49,6 +49,27 @@ def test_get_secret_that_does_not_exist():
assert json_data['message'] == "Secrets Manager can't find the specified secret" assert json_data['message'] == "Secrets Manager can't find the specified secret"
assert json_data['__type'] == 'ResourceNotFoundException' assert json_data['__type'] == 'ResourceNotFoundException'
@mock_secretsmanager
def test_get_secret_that_does_not_match():
backend = server.create_backend_app("secretsmanager")
test_client = backend.test_client()
create_secret = test_client.post('/',
data={"Name": "test-secret",
"SecretString": "foo-secret"},
headers={
"X-Amz-Target": "secretsmanager.CreateSecret"},
)
get_secret = test_client.post('/',
data={"SecretId": "i-dont-match",
"VersionStage": "AWSCURRENT"},
headers={
"X-Amz-Target": "secretsmanager.GetSecretValue"},
)
json_data = json.loads(get_secret.data.decode("utf-8"))
assert json_data['message'] == "Secrets Manager can't find the specified secret"
assert json_data['__type'] == 'ResourceNotFoundException'
@mock_secretsmanager @mock_secretsmanager
def test_create_secret(): def test_create_secret():
@ -133,3 +154,268 @@ def test_describe_secret_that_does_not_match():
json_data = json.loads(describe_secret.data.decode("utf-8")) json_data = json.loads(describe_secret.data.decode("utf-8"))
assert json_data['message'] == "Secrets Manager can't find the specified secret" assert json_data['message'] == "Secrets Manager can't find the specified secret"
assert json_data['__type'] == 'ResourceNotFoundException' assert json_data['__type'] == 'ResourceNotFoundException'
@mock_secretsmanager
def test_rotate_secret():
backend = server.create_backend_app('secretsmanager')
test_client = backend.test_client()
create_secret = test_client.post('/',
data={"Name": "test-secret",
"SecretString": "foosecret"},
headers={
"X-Amz-Target": "secretsmanager.CreateSecret"
},
)
client_request_token = "EXAMPLE2-90ab-cdef-fedc-ba987SECRET2"
rotate_secret = test_client.post('/',
data={"SecretId": "test-secret",
"ClientRequestToken": client_request_token},
headers={
"X-Amz-Target": "secretsmanager.RotateSecret"
},
)
json_data = json.loads(rotate_secret.data.decode("utf-8"))
assert json_data # Returned dict is not empty
assert json_data['ARN'] == (
'arn:aws:secretsmanager:us-east-1:1234567890:secret:test-secret-rIjad'
)
assert json_data['Name'] == 'test-secret'
assert json_data['VersionId'] == client_request_token
# @mock_secretsmanager
# def test_rotate_secret_enable_rotation():
# backend = server.create_backend_app('secretsmanager')
# test_client = backend.test_client()
# create_secret = test_client.post(
# '/',
# data={
# "Name": "test-secret",
# "SecretString": "foosecret"
# },
# headers={
# "X-Amz-Target": "secretsmanager.CreateSecret"
# },
# )
# initial_description = test_client.post(
# '/',
# data={
# "SecretId": "test-secret"
# },
# headers={
# "X-Amz-Target": "secretsmanager.DescribeSecret"
# },
# )
# json_data = json.loads(initial_description.data.decode("utf-8"))
# assert json_data # Returned dict is not empty
# assert json_data['RotationEnabled'] is False
# assert json_data['RotationRules']['AutomaticallyAfterDays'] == 0
# rotate_secret = test_client.post(
# '/',
# data={
# "SecretId": "test-secret",
# "RotationRules": {"AutomaticallyAfterDays": 42}
# },
# headers={
# "X-Amz-Target": "secretsmanager.RotateSecret"
# },
# )
# rotated_description = test_client.post(
# '/',
# data={
# "SecretId": "test-secret"
# },
# headers={
# "X-Amz-Target": "secretsmanager.DescribeSecret"
# },
# )
# json_data = json.loads(rotated_description.data.decode("utf-8"))
# assert json_data # Returned dict is not empty
# assert json_data['RotationEnabled'] is True
# assert json_data['RotationRules']['AutomaticallyAfterDays'] == 42
@mock_secretsmanager
def test_rotate_secret_that_does_not_exist():
backend = server.create_backend_app('secretsmanager')
test_client = backend.test_client()
rotate_secret = test_client.post('/',
data={"SecretId": "i-dont-exist"},
headers={
"X-Amz-Target": "secretsmanager.RotateSecret"
},
)
json_data = json.loads(rotate_secret.data.decode("utf-8"))
assert json_data['message'] == "Secrets Manager can't find the specified secret"
assert json_data['__type'] == 'ResourceNotFoundException'
@mock_secretsmanager
def test_rotate_secret_that_does_not_match():
backend = server.create_backend_app('secretsmanager')
test_client = backend.test_client()
create_secret = test_client.post('/',
data={"Name": "test-secret",
"SecretString": "foosecret"},
headers={
"X-Amz-Target": "secretsmanager.CreateSecret"
},
)
rotate_secret = test_client.post('/',
data={"SecretId": "i-dont-match"},
headers={
"X-Amz-Target": "secretsmanager.RotateSecret"
},
)
json_data = json.loads(rotate_secret.data.decode("utf-8"))
assert json_data['message'] == "Secrets Manager can't find the specified secret"
assert json_data['__type'] == 'ResourceNotFoundException'
@mock_secretsmanager
def test_rotate_secret_client_request_token_too_short():
backend = server.create_backend_app('secretsmanager')
test_client = backend.test_client()
create_secret = test_client.post('/',
data={"Name": "test-secret",
"SecretString": "foosecret"},
headers={
"X-Amz-Target": "secretsmanager.CreateSecret"
},
)
client_request_token = "ED9F8B6C-85B7-B7E4-38F2A3BEB13C"
rotate_secret = test_client.post('/',
data={"SecretId": "test-secret",
"ClientRequestToken": client_request_token},
headers={
"X-Amz-Target": "secretsmanager.RotateSecret"
},
)
json_data = json.loads(rotate_secret.data.decode("utf-8"))
assert json_data['message'] == "ClientRequestToken must be 32-64 characters long."
assert json_data['__type'] == 'InvalidParameterException'
@mock_secretsmanager
def test_rotate_secret_client_request_token_too_long():
backend = server.create_backend_app('secretsmanager')
test_client = backend.test_client()
create_secret = test_client.post('/',
data={"Name": "test-secret",
"SecretString": "foosecret"},
headers={
"X-Amz-Target": "secretsmanager.CreateSecret"
},
)
client_request_token = (
'ED9F8B6C-85B7-446A-B7E4-38F2A3BEB13C-'
'ED9F8B6C-85B7-446A-B7E4-38F2A3BEB13C'
)
rotate_secret = test_client.post('/',
data={"SecretId": "test-secret",
"ClientRequestToken": client_request_token},
headers={
"X-Amz-Target": "secretsmanager.RotateSecret"
},
)
json_data = json.loads(rotate_secret.data.decode("utf-8"))
assert json_data['message'] == "ClientRequestToken must be 32-64 characters long."
assert json_data['__type'] == 'InvalidParameterException'
@mock_secretsmanager
def test_rotate_secret_rotation_lambda_arn_too_long():
backend = server.create_backend_app('secretsmanager')
test_client = backend.test_client()
create_secret = test_client.post('/',
data={"Name": "test-secret",
"SecretString": "foosecret"},
headers={
"X-Amz-Target": "secretsmanager.CreateSecret"
},
)
rotation_lambda_arn = '85B7-446A-B7E4' * 147 # == 2058 characters
rotate_secret = test_client.post('/',
data={"SecretId": "test-secret",
"RotationLambdaARN": rotation_lambda_arn},
headers={
"X-Amz-Target": "secretsmanager.RotateSecret"
},
)
json_data = json.loads(rotate_secret.data.decode("utf-8"))
assert json_data['message'] == "RotationLambdaARN must <= 2048 characters long."
assert json_data['__type'] == 'InvalidParameterException'
#
# The following tests should work, but fail on the embedded dict in
# RotationRules. The error message suggests a problem deeper in the code, which
# needs further investigation.
#
# @mock_secretsmanager
# def test_rotate_secret_rotation_period_zero():
# backend = server.create_backend_app('secretsmanager')
# test_client = backend.test_client()
# create_secret = test_client.post('/',
# data={"Name": "test-secret",
# "SecretString": "foosecret"},
# headers={
# "X-Amz-Target": "secretsmanager.CreateSecret"
# },
# )
# rotate_secret = test_client.post('/',
# data={"SecretId": "test-secret",
# "RotationRules": {"AutomaticallyAfterDays": 0}},
# headers={
# "X-Amz-Target": "secretsmanager.RotateSecret"
# },
# )
# json_data = json.loads(rotate_secret.data.decode("utf-8"))
# assert json_data['message'] == "RotationRules.AutomaticallyAfterDays must be within 1-1000."
# assert json_data['__type'] == 'InvalidParameterException'
# @mock_secretsmanager
# def test_rotate_secret_rotation_period_too_long():
# backend = server.create_backend_app('secretsmanager')
# test_client = backend.test_client()
# create_secret = test_client.post('/',
# data={"Name": "test-secret",
# "SecretString": "foosecret"},
# headers={
# "X-Amz-Target": "secretsmanager.CreateSecret"
# },
# )
# rotate_secret = test_client.post('/',
# data={"SecretId": "test-secret",
# "RotationRules": {"AutomaticallyAfterDays": 1001}},
# headers={
# "X-Amz-Target": "secretsmanager.RotateSecret"
# },
# )
# json_data = json.loads(rotate_secret.data.decode("utf-8"))
# assert json_data['message'] == "RotationRules.AutomaticallyAfterDays must be within 1-1000."
# assert json_data['__type'] == 'InvalidParameterException'

View File

@ -40,6 +40,33 @@ def test_create_fifo_queue_fail():
raise RuntimeError('Should of raised InvalidParameterValue Exception') raise RuntimeError('Should of raised InvalidParameterValue Exception')
@mock_sqs
def test_create_queue_with_same_attributes():
sqs = boto3.client('sqs', region_name='us-east-1')
dlq_url = sqs.create_queue(QueueName='test-queue-dlq')['QueueUrl']
dlq_arn = sqs.get_queue_attributes(QueueUrl=dlq_url)['Attributes']['QueueArn']
attributes = {
'DelaySeconds': '900',
'MaximumMessageSize': '262144',
'MessageRetentionPeriod': '1209600',
'ReceiveMessageWaitTimeSeconds': '20',
'RedrivePolicy': '{"deadLetterTargetArn": "%s", "maxReceiveCount": 100}' % (dlq_arn),
'VisibilityTimeout': '43200'
}
sqs.create_queue(
QueueName='test-queue',
Attributes=attributes
)
sqs.create_queue(
QueueName='test-queue',
Attributes=attributes
)
@mock_sqs @mock_sqs
def test_create_queue_with_different_attributes_fail(): def test_create_queue_with_different_attributes_fail():
sqs = boto3.client('sqs', region_name='us-east-1') sqs = boto3.client('sqs', region_name='us-east-1')
@ -1195,3 +1222,16 @@ def test_receive_messages_with_message_group_id_on_visibility_timeout():
messages = queue.receive_messages() messages = queue.receive_messages()
messages.should.have.length_of(1) messages.should.have.length_of(1)
messages[0].message_id.should.equal(message.message_id) messages[0].message_id.should.equal(message.message_id)
@mock_sqs
def test_receive_message_for_queue_with_receive_message_wait_time_seconds_set():
sqs = boto3.resource('sqs', region_name='us-east-1')
queue = sqs.create_queue(
QueueName='test-queue',
Attributes={
'ReceiveMessageWaitTimeSeconds': '2',
}
)
queue.receive_messages()

View File

@ -5,11 +5,12 @@ import botocore.exceptions
import sure # noqa import sure # noqa
import datetime import datetime
import uuid import uuid
import json
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
from nose.tools import assert_raises from nose.tools import assert_raises
from moto import mock_ssm from moto import mock_ssm, mock_cloudformation
@mock_ssm @mock_ssm
@ -668,3 +669,118 @@ def test_list_commands():
with assert_raises(ClientError): with assert_raises(ClientError):
response = client.list_commands( response = client.list_commands(
CommandId=str(uuid.uuid4())) CommandId=str(uuid.uuid4()))
@mock_ssm
def test_get_command_invocation():
client = boto3.client('ssm', region_name='us-east-1')
ssm_document = 'AWS-RunShellScript'
params = {'commands': ['#!/bin/bash\necho \'hello world\'']}
response = client.send_command(
InstanceIds=['i-123456', 'i-234567', 'i-345678'],
DocumentName=ssm_document,
Parameters=params,
OutputS3Region='us-east-2',
OutputS3BucketName='the-bucket',
OutputS3KeyPrefix='pref')
cmd = response['Command']
cmd_id = cmd['CommandId']
instance_id = 'i-345678'
invocation_response = client.get_command_invocation(
CommandId=cmd_id,
InstanceId=instance_id,
PluginName='aws:runShellScript')
invocation_response['CommandId'].should.equal(cmd_id)
invocation_response['InstanceId'].should.equal(instance_id)
# test the error case for an invalid instance id
with assert_raises(ClientError):
invocation_response = client.get_command_invocation(
CommandId=cmd_id,
InstanceId='i-FAKE')
# test the error case for an invalid plugin name
with assert_raises(ClientError):
invocation_response = client.get_command_invocation(
CommandId=cmd_id,
InstanceId=instance_id,
PluginName='FAKE')
@mock_ssm
@mock_cloudformation
def test_get_command_invocations_from_stack():
stack_template = {
"AWSTemplateFormatVersion": "2010-09-09",
"Description": "Test Stack",
"Resources": {
"EC2Instance1": {
"Type": "AWS::EC2::Instance",
"Properties": {
"ImageId": "ami-test-image-id",
"KeyName": "test",
"InstanceType": "t2.micro",
"Tags": [
{
"Key": "Test Description",
"Value": "Test tag"
},
{
"Key": "Test Name",
"Value": "Name tag for tests"
}
]
}
}
},
"Outputs": {
"test": {
"Description": "Test Output",
"Value": "Test output value",
"Export": {
"Name": "Test value to export"
}
},
"PublicIP": {
"Value": "Test public ip"
}
}
}
cloudformation_client = boto3.client(
'cloudformation',
region_name='us-east-1')
stack_template_str = json.dumps(stack_template)
response = cloudformation_client.create_stack(
StackName='test_stack',
TemplateBody=stack_template_str,
Capabilities=('CAPABILITY_IAM', ))
client = boto3.client('ssm', region_name='us-east-1')
ssm_document = 'AWS-RunShellScript'
params = {'commands': ['#!/bin/bash\necho \'hello world\'']}
response = client.send_command(
Targets=[{
'Key': 'tag:aws:cloudformation:stack-name',
'Values': ('test_stack', )}],
DocumentName=ssm_document,
Parameters=params,
OutputS3Region='us-east-2',
OutputS3BucketName='the-bucket',
OutputS3KeyPrefix='pref')
cmd = response['Command']
cmd_id = cmd['CommandId']
instance_ids = cmd['InstanceIds']
invocation_response = client.get_command_invocation(
CommandId=cmd_id,
InstanceId=instance_ids[0],
PluginName='aws:runShellScript')