Merge remote-tracking branch 'spulec/master'

This commit is contained in:
eric-weaver 2017-07-15 18:00:42 -04:00
commit 38880c4c9e
22 changed files with 676 additions and 53 deletions

View File

@ -414,6 +414,9 @@ class _RecursiveDictRef(object):
def __getattr__(self, key): def __getattr__(self, key):
return self.dic.__getattr__(key) return self.dic.__getattr__(key)
def __getitem__(self, key):
return self.dic.__getitem__(key)
def set_reference(self, key, dic): def set_reference(self, key, dic):
"""Set the RecursiveDictRef object to keep reference to dict object """Set the RecursiveDictRef object to keep reference to dict object
(dic) at the key. (dic) at the key.

View File

@ -3,6 +3,7 @@ from collections import defaultdict
import datetime import datetime
import decimal import decimal
import json import json
import re
from moto.compat import OrderedDict from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
@ -115,28 +116,31 @@ class Item(BaseModel):
} }
def update(self, update_expression, expression_attribute_names, expression_attribute_values): def update(self, update_expression, expression_attribute_names, expression_attribute_values):
ACTION_VALUES = ['SET', 'set', 'REMOVE', 'remove'] # Update subexpressions are identifiable by the operator keyword, so split on that and
# get rid of the empty leading string.
action = None parts = [p for p in re.split(r'\b(SET|REMOVE|ADD|DELETE)\b', update_expression) if p]
for value in update_expression.split(): # make sure that we correctly found only operator/value pairs
if value in ACTION_VALUES: assert len(parts) % 2 == 0, "Mismatched operators and values in update expression: '{}'".format(update_expression)
# An action for action, valstr in zip(parts[:-1:2], parts[1::2]):
action = value values = valstr.split(',')
continue for value in values:
else:
# A Real value # A Real value
value = value.lstrip(":").rstrip(",") value = value.lstrip(":").rstrip(",").strip()
for k, v in expression_attribute_names.items(): for k, v in expression_attribute_names.items():
value = value.replace(k, v) value = re.sub(r'{0}\b'.format(k), v, value)
if action == "REMOVE" or action == 'remove':
self.attrs.pop(value, None) if action == "REMOVE":
elif action == 'SET' or action == 'set': self.attrs.pop(value, None)
key, value = value.split("=") elif action == 'SET':
if value in expression_attribute_values: key, value = value.split("=")
self.attrs[key] = DynamoType( key = key.strip()
expression_attribute_values[value]) value = value.strip()
if value in expression_attribute_values:
self.attrs[key] = DynamoType(expression_attribute_values[value])
else:
self.attrs[key] = DynamoType({"S": value})
else: else:
self.attrs[key] = DynamoType({"S": value}) raise NotImplementedError('{} update action not yet supported'.format(action))
def update_with_attribute_updates(self, attribute_updates): def update_with_attribute_updates(self, attribute_updates):
for attribute_name, update_action in attribute_updates.items(): for attribute_name, update_action in attribute_updates.items():
@ -345,7 +349,6 @@ class Table(BaseModel):
def query(self, hash_key, range_comparison, range_objs, limit, def query(self, hash_key, range_comparison, range_objs, limit,
exclusive_start_key, scan_index_forward, index_name=None, **filter_kwargs): exclusive_start_key, scan_index_forward, index_name=None, **filter_kwargs):
results = [] results = []
if index_name: if index_name:
all_indexes = (self.global_indexes or []) + (self.indexes or []) all_indexes = (self.global_indexes or []) + (self.indexes or [])
indexes_by_name = dict((i['IndexName'], i) for i in all_indexes) indexes_by_name = dict((i['IndexName'], i) for i in all_indexes)

View File

@ -316,24 +316,26 @@ class DynamoHandler(BaseResponse):
else: else:
index = table.schema index = table.schema
key_map = [column for _, column in sorted( reverse_attribute_lookup = dict((v, k) for k, v in
(k, v) for k, v in self.body['ExpressionAttributeNames'].items())] six.iteritems(self.body['ExpressionAttributeNames']))
if " AND " in key_condition_expression: if " AND " in key_condition_expression:
expressions = key_condition_expression.split(" AND ", 1) expressions = key_condition_expression.split(" AND ", 1)
index_hash_key = [ index_hash_key = [key for key in index if key['KeyType'] == 'HASH'][0]
key for key in index if key['KeyType'] == 'HASH'][0] hash_key_var = reverse_attribute_lookup.get(index_hash_key['AttributeName'],
hash_key_index_in_key_map = key_map.index( index_hash_key['AttributeName'])
index_hash_key['AttributeName']) hash_key_regex = r'(^|[\s(]){0}\b'.format(hash_key_var)
i, hash_key_expression = next((i, e) for i, e in enumerate(expressions)
if re.search(hash_key_regex, e))
hash_key_expression = hash_key_expression.strip('()')
expressions.pop(i)
hash_key_expression = expressions.pop( # TODO implement more than one range expression and OR operators
hash_key_index_in_key_map).strip('()')
# TODO implement more than one range expression and OR
# operators
range_key_expression = expressions[0].strip('()') range_key_expression = expressions[0].strip('()')
range_key_expression_components = range_key_expression.split() range_key_expression_components = range_key_expression.split()
range_comparison = range_key_expression_components[1] range_comparison = range_key_expression_components[1]
if 'AND' in range_key_expression: if 'AND' in range_key_expression:
range_comparison = 'BETWEEN' range_comparison = 'BETWEEN'
range_values = [ range_values = [

View File

@ -40,6 +40,15 @@ class BadHealthCheckDefinition(ELBClientError):
"HealthCheck Target must begin with one of HTTP, TCP, HTTPS, SSL") "HealthCheck Target must begin with one of HTTP, TCP, HTTPS, SSL")
class DuplicateListenerError(ELBClientError):
def __init__(self, name, port):
super(DuplicateListenerError, self).__init__(
"DuplicateListener",
"A listener already exists for {0} with LoadBalancerPort {1}, but with a different InstancePort, Protocol, or SSLCertificateId"
.format(name, port))
class DuplicateLoadBalancerName(ELBClientError): class DuplicateLoadBalancerName(ELBClientError):
def __init__(self, name): def __init__(self, name):

View File

@ -18,6 +18,7 @@ from moto.ec2.models import ec2_backends
from .exceptions import ( from .exceptions import (
BadHealthCheckDefinition, BadHealthCheckDefinition,
DuplicateLoadBalancerName, DuplicateLoadBalancerName,
DuplicateListenerError,
EmptyListenersError, EmptyListenersError,
LoadBalancerNotFoundError, LoadBalancerNotFoundError,
TooManyTagsError, TooManyTagsError,
@ -257,6 +258,12 @@ class ELBBackend(BaseBackend):
ssl_certificate_id = port.get('sslcertificate_id') ssl_certificate_id = port.get('sslcertificate_id')
for listener in balancer.listeners: for listener in balancer.listeners:
if lb_port == listener.load_balancer_port: if lb_port == listener.load_balancer_port:
if protocol != listener.protocol:
raise DuplicateListenerError(name, lb_port)
if instance_port != listener.instance_port:
raise DuplicateListenerError(name, lb_port)
if ssl_certificate_id != listener.ssl_certificate_id:
raise DuplicateListenerError(name, lb_port)
break break
else: else:
balancer.listeners.append(FakeListener( balancer.listeners.append(FakeListener(

View File

@ -182,7 +182,7 @@ class Database(BaseModel):
<ReadReplicaSourceDBInstanceIdentifier>{{ database.source_db_identifier }}</ReadReplicaSourceDBInstanceIdentifier> <ReadReplicaSourceDBInstanceIdentifier>{{ database.source_db_identifier }}</ReadReplicaSourceDBInstanceIdentifier>
{% endif %} {% endif %}
<Engine>{{ database.engine }}</Engine> <Engine>{{ database.engine }}</Engine>
<LicenseModel>general-public-license</LicenseModel> <LicenseModel>{{ database.license_model }}</LicenseModel>
<EngineVersion>{{ database.engine_version }}</EngineVersion> <EngineVersion>{{ database.engine_version }}</EngineVersion>
<DBParameterGroups> <DBParameterGroups>
</DBParameterGroups> </DBParameterGroups>

View File

@ -28,6 +28,14 @@ class DBInstanceNotFoundError(RDSClientError):
"Database {0} not found.".format(database_identifier)) "Database {0} not found.".format(database_identifier))
class DBSnapshotNotFoundError(RDSClientError):
def __init__(self):
super(DBSnapshotNotFoundError, self).__init__(
'DBSnapshotNotFound',
"DBSnapshotIdentifier does not refer to an existing DB snapshot.")
class DBSecurityGroupNotFoundError(RDSClientError): class DBSecurityGroupNotFoundError(RDSClientError):
def __init__(self, security_group_name): def __init__(self, security_group_name):

View File

@ -1,6 +1,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import copy import copy
import datetime
from collections import defaultdict from collections import defaultdict
import boto.rds2 import boto.rds2
@ -10,9 +11,11 @@ from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
from moto.compat import OrderedDict from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.utils import get_random_hex from moto.core.utils import get_random_hex
from moto.core.utils import iso_8601_datetime_with_milliseconds
from moto.ec2.models import ec2_backends from moto.ec2.models import ec2_backends
from .exceptions import (RDSClientError, from .exceptions import (RDSClientError,
DBInstanceNotFoundError, DBInstanceNotFoundError,
DBSnapshotNotFoundError,
DBSecurityGroupNotFoundError, DBSecurityGroupNotFoundError,
DBSubnetGroupNotFoundError, DBSubnetGroupNotFoundError,
DBParameterGroupNotFoundError) DBParameterGroupNotFoundError)
@ -86,8 +89,7 @@ class Database(BaseModel):
self.preferred_backup_window = kwargs.get( self.preferred_backup_window = kwargs.get(
'preferred_backup_window', '13:14-13:44') 'preferred_backup_window', '13:14-13:44')
self.license_model = kwargs.get( self.license_model = kwargs.get('license_model', 'general-public-license')
'license_model', 'general-public-license')
self.option_group_name = kwargs.get('option_group_name', None) self.option_group_name = kwargs.get('option_group_name', None)
self.default_option_groups = {"MySQL": "default.mysql5.6", self.default_option_groups = {"MySQL": "default.mysql5.6",
"mysql": "default.mysql5.6", "mysql": "default.mysql5.6",
@ -156,7 +158,7 @@ class Database(BaseModel):
<ReadReplicaSourceDBInstanceIdentifier>{{ database.source_db_identifier }}</ReadReplicaSourceDBInstanceIdentifier> <ReadReplicaSourceDBInstanceIdentifier>{{ database.source_db_identifier }}</ReadReplicaSourceDBInstanceIdentifier>
{% endif %} {% endif %}
<Engine>{{ database.engine }}</Engine> <Engine>{{ database.engine }}</Engine>
<LicenseModel>general-public-license</LicenseModel> <LicenseModel>{{ database.license_model }}</LicenseModel>
<EngineVersion>{{ database.engine_version }}</EngineVersion> <EngineVersion>{{ database.engine_version }}</EngineVersion>
<OptionGroupMemberships> <OptionGroupMemberships>
</OptionGroupMemberships> </OptionGroupMemberships>
@ -399,6 +401,53 @@ class Database(BaseModel):
backend.delete_database(self.db_instance_identifier) backend.delete_database(self.db_instance_identifier)
class Snapshot(BaseModel):
def __init__(self, database, snapshot_id, tags=None):
self.database = database
self.snapshot_id = snapshot_id
self.tags = tags or []
self.created_at = iso_8601_datetime_with_milliseconds(datetime.datetime.now())
@property
def snapshot_arn(self):
return "arn:aws:rds:{0}:1234567890:snapshot:{1}".format(self.database.region, self.snapshot_id)
def to_xml(self):
template = Template("""<DBSnapshot>
<DBSnapshotIdentifier>{{ snapshot.snapshot_id }}</DBSnapshotIdentifier>
<DBInstanceIdentifier>{{ database.db_instance_identifier }}</DBInstanceIdentifier>
<SnapshotCreateTime>{{ snapshot.created_at }}</SnapshotCreateTime>
<Engine>{{ database.engine }}</Engine>
<AllocatedStorage>{{ database.allocated_storage }}</AllocatedStorage>
<Status>available</Status>
<Port>{{ database.port }}</Port>
<AvailabilityZone>{{ database.availability_zone }}</AvailabilityZone>
<VpcId>{{ database.db_subnet_group.vpc_id }}</VpcId>
<InstanceCreateTime>{{ snapshot.created_at }}</InstanceCreateTime>
<MasterUsername>{{ database.master_username }}</MasterUsername>
<EngineVersion>{{ database.engine_version }}</EngineVersion>
<LicenseModel>{{ database.license_model }}</LicenseModel>
<SnapshotType>manual</SnapshotType>
{% if database.iops %}
<Iops>{{ database.iops }}</Iops>
<StorageType>io1</StorageType>
{% else %}
<StorageType>{{ database.storage_type }}</StorageType>
{% endif %}
<OptionGroupName>{{ database.option_group_name }}</OptionGroupName>
<PercentProgress>{{ 100 }}</PercentProgress>
<SourceRegion>{{ database.region }}</SourceRegion>
<SourceDBSnapshotIdentifier></SourceDBSnapshotIdentifier>
<TdeCredentialArn></TdeCredentialArn>
<Encrypted>{{ database.storage_encrypted }}</Encrypted>
<KmsKeyId>{{ database.kms_key_id }}</KmsKeyId>
<DBSnapshotArn>{{ snapshot.snapshot_arn }}</DBSnapshotArn>
<Timezone></Timezone>
<IAMDatabaseAuthenticationEnabled>false</IAMDatabaseAuthenticationEnabled>
</DBSnapshot>""")
return template.render(snapshot=self, database=self.database)
class SecurityGroup(BaseModel): class SecurityGroup(BaseModel):
def __init__(self, group_name, description, tags): def __init__(self, group_name, description, tags):
@ -607,6 +656,7 @@ class RDS2Backend(BaseBackend):
self.arn_regex = re_compile( self.arn_regex = re_compile(
r'^arn:aws:rds:.*:[0-9]*:(db|es|og|pg|ri|secgrp|snapshot|subgrp):.*$') r'^arn:aws:rds:.*:[0-9]*:(db|es|og|pg|ri|secgrp|snapshot|subgrp):.*$')
self.databases = OrderedDict() self.databases = OrderedDict()
self.snapshots = OrderedDict()
self.db_parameter_groups = {} self.db_parameter_groups = {}
self.option_groups = {} self.option_groups = {}
self.security_groups = {} self.security_groups = {}
@ -624,6 +674,20 @@ class RDS2Backend(BaseBackend):
self.databases[database_id] = database self.databases[database_id] = database
return database return database
def create_snapshot(self, db_instance_identifier, db_snapshot_identifier, tags):
database = self.databases.get(db_instance_identifier)
if not database:
raise DBInstanceNotFoundError(db_instance_identifier)
snapshot = Snapshot(database, db_snapshot_identifier, tags)
self.snapshots[db_snapshot_identifier] = snapshot
return snapshot
def delete_snapshot(self, db_snapshot_identifier):
if db_snapshot_identifier not in self.snapshots:
raise DBSnapshotNotFoundError()
return self.snapshots.pop(db_snapshot_identifier)
def create_database_replica(self, db_kwargs): def create_database_replica(self, db_kwargs):
database_id = db_kwargs['db_instance_identifier'] database_id = db_kwargs['db_instance_identifier']
source_database_id = db_kwargs['source_db_identifier'] source_database_id = db_kwargs['source_db_identifier']
@ -646,6 +710,20 @@ class RDS2Backend(BaseBackend):
raise DBInstanceNotFoundError(db_instance_identifier) raise DBInstanceNotFoundError(db_instance_identifier)
return self.databases.values() return self.databases.values()
def describe_snapshots(self, db_instance_identifier, db_snapshot_identifier):
if db_instance_identifier:
for snapshot in self.snapshots.values():
if snapshot.database.db_instance_identifier == db_instance_identifier:
return [snapshot]
raise DBSnapshotNotFoundError()
if db_snapshot_identifier:
if db_snapshot_identifier in self.snapshots:
return [self.snapshots[db_snapshot_identifier]]
raise DBSnapshotNotFoundError()
return self.snapshots.values()
def modify_database(self, db_instance_identifier, db_kwargs): def modify_database(self, db_instance_identifier, db_kwargs):
database = self.describe_databases(db_instance_identifier)[0] database = self.describe_databases(db_instance_identifier)[0]
database.update(db_kwargs) database.update(db_kwargs)
@ -667,13 +745,15 @@ class RDS2Backend(BaseBackend):
return backend.describe_databases(db_name)[0] return backend.describe_databases(db_name)[0]
def delete_database(self, db_instance_identifier): def delete_database(self, db_instance_identifier, db_snapshot_name=None):
if db_instance_identifier in self.databases: if db_instance_identifier in self.databases:
database = self.databases.pop(db_instance_identifier) database = self.databases.pop(db_instance_identifier)
if database.is_replica: if database.is_replica:
primary = self.find_db_from_id(database.source_db_identifier) primary = self.find_db_from_id(database.source_db_identifier)
primary.remove_replica(database) primary.remove_replica(database)
database.status = 'deleting' database.status = 'deleting'
if db_snapshot_name:
self.snapshots[db_snapshot_name] = Snapshot(database, db_snapshot_name)
return database return database
else: else:
raise DBInstanceNotFoundError(db_instance_identifier) raise DBInstanceNotFoundError(db_instance_identifier)

View File

@ -26,6 +26,7 @@ class RDS2Response(BaseResponse):
"db_subnet_group_name": self._get_param("DBSubnetGroupName"), "db_subnet_group_name": self._get_param("DBSubnetGroupName"),
"engine": self._get_param("Engine"), "engine": self._get_param("Engine"),
"engine_version": self._get_param("EngineVersion"), "engine_version": self._get_param("EngineVersion"),
"license_model": self._get_param("LicenseModel"),
"iops": self._get_int_param("Iops"), "iops": self._get_int_param("Iops"),
"kms_key_id": self._get_param("KmsKeyId"), "kms_key_id": self._get_param("KmsKeyId"),
"master_user_password": self._get_param('MasterUserPassword'), "master_user_password": self._get_param('MasterUserPassword'),
@ -39,7 +40,7 @@ class RDS2Response(BaseResponse):
"region": self.region, "region": self.region,
"security_groups": self._get_multi_param('DBSecurityGroups.DBSecurityGroupName'), "security_groups": self._get_multi_param('DBSecurityGroups.DBSecurityGroupName'),
"storage_encrypted": self._get_param("StorageEncrypted"), "storage_encrypted": self._get_param("StorageEncrypted"),
"storage_type": self._get_param("StorageType"), "storage_type": self._get_param("StorageType", 'standard'),
# VpcSecurityGroupIds.member.N # VpcSecurityGroupIds.member.N
"tags": list(), "tags": list(),
} }
@ -140,7 +141,8 @@ class RDS2Response(BaseResponse):
def delete_db_instance(self): def delete_db_instance(self):
db_instance_identifier = self._get_param('DBInstanceIdentifier') db_instance_identifier = self._get_param('DBInstanceIdentifier')
database = self.backend.delete_database(db_instance_identifier) db_snapshot_name = self._get_param('FinalDBSnapshotIdentifier')
database = self.backend.delete_database(db_instance_identifier, db_snapshot_name)
template = self.response_template(DELETE_DATABASE_TEMPLATE) template = self.response_template(DELETE_DATABASE_TEMPLATE)
return template.render(database=database) return template.render(database=database)
@ -150,6 +152,27 @@ class RDS2Response(BaseResponse):
template = self.response_template(REBOOT_DATABASE_TEMPLATE) template = self.response_template(REBOOT_DATABASE_TEMPLATE)
return template.render(database=database) return template.render(database=database)
def create_db_snapshot(self):
db_instance_identifier = self._get_param('DBInstanceIdentifier')
db_snapshot_identifier = self._get_param('DBSnapshotIdentifier')
tags = self._get_param('Tags', [])
snapshot = self.backend.create_snapshot(db_instance_identifier, db_snapshot_identifier, tags)
template = self.response_template(CREATE_SNAPSHOT_TEMPLATE)
return template.render(snapshot=snapshot)
def describe_db_snapshots(self):
db_instance_identifier = self._get_param('DBInstanceIdentifier')
db_snapshot_identifier = self._get_param('DBSnapshotIdentifier')
snapshots = self.backend.describe_snapshots(db_instance_identifier, db_snapshot_identifier)
template = self.response_template(DESCRIBE_SNAPSHOTS_TEMPLATE)
return template.render(snapshots=snapshots)
def delete_db_snapshot(self):
db_snapshot_identifier = self._get_param('DBSnapshotIdentifier')
snapshot = self.backend.delete_snapshot(db_snapshot_identifier)
template = self.response_template(DELETE_SNAPSHOT_TEMPLATE)
return template.render(snapshot=snapshot)
def list_tags_for_resource(self): def list_tags_for_resource(self):
arn = self._get_param('ResourceName') arn = self._get_param('ResourceName')
template = self.response_template(LIST_TAGS_FOR_RESOURCE_TEMPLATE) template = self.response_template(LIST_TAGS_FOR_RESOURCE_TEMPLATE)
@ -397,6 +420,42 @@ DELETE_DATABASE_TEMPLATE = """<DeleteDBInstanceResponse xmlns="http://rds.amazon
</ResponseMetadata> </ResponseMetadata>
</DeleteDBInstanceResponse>""" </DeleteDBInstanceResponse>"""
CREATE_SNAPSHOT_TEMPLATE = """<CreateDBSnapshotResponse xmlns="http://rds.amazonaws.com/doc/2014-09-01/">
<CreateDBSnapshotResult>
{{ snapshot.to_xml() }}
</CreateDBSnapshotResult>
<ResponseMetadata>
<RequestId>523e3218-afc7-11c3-90f5-f90431260ab4</RequestId>
</ResponseMetadata>
</CreateDBSnapshotResponse>
"""
DESCRIBE_SNAPSHOTS_TEMPLATE = """<DescribeDBSnapshotsResponse xmlns="http://rds.amazonaws.com/doc/2014-09-01/">
<DescribeDBSnapshotsResult>
<DBSnapshots>
{%- for snapshot in snapshots -%}
{{ snapshot.to_xml() }}
{%- endfor -%}
</DBSnapshots>
{% if marker %}
<Marker>{{ marker }}</Marker>
{% endif %}
</DescribeDBSnapshotsResult>
<ResponseMetadata>
<RequestId>523e3218-afc7-11c3-90f5-f90431260ab4</RequestId>
</ResponseMetadata>
</DescribeDBSnapshotsResponse>"""
DELETE_SNAPSHOT_TEMPLATE = """<DeleteDBSnapshotResponse xmlns="http://rds.amazonaws.com/doc/2014-09-01/">
<DeleteDBSnapshotResult>
{{ snapshot.to_xml() }}
</DeleteDBSnapshotResult>
<ResponseMetadata>
<RequestId>523e3218-afc7-11c3-90f5-f90431260ab4</RequestId>
</ResponseMetadata>
</DeleteDBSnapshotResponse>
"""
CREATE_SECURITY_GROUP_TEMPLATE = """<CreateDBSecurityGroupResponse xmlns="http://rds.amazonaws.com/doc/2014-09-01/"> CREATE_SECURITY_GROUP_TEMPLATE = """<CreateDBSecurityGroupResponse xmlns="http://rds.amazonaws.com/doc/2014-09-01/">
<CreateDBSecurityGroupResult> <CreateDBSecurityGroupResult>
{{ security_group.to_xml() }} {{ security_group.to_xml() }}

View File

@ -171,6 +171,12 @@ def main(argv=sys.argv[1:]):
help='Reload server on a file change', help='Reload server on a file change',
default=False default=False
) )
parser.add_argument(
'-s', '--ssl',
action='store_true',
help='Enable SSL encrypted connection (use https://... URL)',
default=False
)
args = parser.parse_args(argv) args = parser.parse_args(argv)
@ -180,7 +186,8 @@ def main(argv=sys.argv[1:]):
main_app.debug = True main_app.debug = True
run_simple(args.host, args.port, main_app, run_simple(args.host, args.port, main_app,
threaded=True, use_reloader=args.reload) threaded=True, use_reloader=args.reload,
ssl_context='adhoc' if args.ssl else None)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -36,6 +36,7 @@ class SESBackend(BaseBackend):
def __init__(self): def __init__(self):
self.addresses = [] self.addresses = []
self.email_addresses = []
self.domains = [] self.domains = []
self.sent_messages = [] self.sent_messages = []
self.sent_message_count = 0 self.sent_message_count = 0
@ -49,12 +50,18 @@ class SESBackend(BaseBackend):
def verify_email_identity(self, address): def verify_email_identity(self, address):
self.addresses.append(address) self.addresses.append(address)
def verify_email_address(self, address):
self.email_addresses.append(address)
def verify_domain(self, domain): def verify_domain(self, domain):
self.domains.append(domain) self.domains.append(domain)
def list_identities(self): def list_identities(self):
return self.domains + self.addresses return self.domains + self.addresses
def list_verified_email_addresses(self):
return self.email_addresses
def delete_identity(self, identity): def delete_identity(self, identity):
if '@' in identity: if '@' in identity:
self.addresses.remove(identity) self.addresses.remove(identity)

View File

@ -15,11 +15,22 @@ class EmailResponse(BaseResponse):
template = self.response_template(VERIFY_EMAIL_IDENTITY) template = self.response_template(VERIFY_EMAIL_IDENTITY)
return template.render() return template.render()
def verify_email_address(self):
address = self.querystring.get('EmailAddress')[0]
ses_backend.verify_email_address(address)
template = self.response_template(VERIFY_EMAIL_ADDRESS)
return template.render()
def list_identities(self): def list_identities(self):
identities = ses_backend.list_identities() identities = ses_backend.list_identities()
template = self.response_template(LIST_IDENTITIES_RESPONSE) template = self.response_template(LIST_IDENTITIES_RESPONSE)
return template.render(identities=identities) return template.render(identities=identities)
def list_verified_email_addresses(self):
email_addresses = ses_backend.list_verified_email_addresses()
template = self.response_template(LIST_VERIFIED_EMAIL_RESPONSE)
return template.render(email_addresses=email_addresses)
def verify_domain_dkim(self): def verify_domain_dkim(self):
domain = self.querystring.get('Domain')[0] domain = self.querystring.get('Domain')[0]
ses_backend.verify_domain(domain) ses_backend.verify_domain(domain)
@ -95,6 +106,13 @@ VERIFY_EMAIL_IDENTITY = """<VerifyEmailIdentityResponse xmlns="http://ses.amazon
</ResponseMetadata> </ResponseMetadata>
</VerifyEmailIdentityResponse>""" </VerifyEmailIdentityResponse>"""
VERIFY_EMAIL_ADDRESS = """<VerifyEmailAddressResponse xmlns="http://ses.amazonaws.com/doc/2010-12-01/">
<VerifyEmailAddressResult/>
<ResponseMetadata>
<RequestId>47e0ef1a-9bf2-11e1-9279-0100e8cf109a</RequestId>
</ResponseMetadata>
</VerifyEmailAddressResponse>"""
LIST_IDENTITIES_RESPONSE = """<ListIdentitiesResponse xmlns="http://ses.amazonaws.com/doc/2010-12-01/"> LIST_IDENTITIES_RESPONSE = """<ListIdentitiesResponse xmlns="http://ses.amazonaws.com/doc/2010-12-01/">
<ListIdentitiesResult> <ListIdentitiesResult>
<Identities> <Identities>
@ -108,6 +126,19 @@ LIST_IDENTITIES_RESPONSE = """<ListIdentitiesResponse xmlns="http://ses.amazonaw
</ResponseMetadata> </ResponseMetadata>
</ListIdentitiesResponse>""" </ListIdentitiesResponse>"""
LIST_VERIFIED_EMAIL_RESPONSE = """<ListVerifiedEmailAddressesResponse xmlns="http://ses.amazonaws.com/doc/2010-12-01/">
<ListVerifiedEmailAddressesResult>
<VerifiedEmailAddresses>
{% for email in email_addresses %}
<member>{{ email }}</member>
{% endfor %}
</VerifiedEmailAddresses>
</ListVerifiedEmailAddressesResult>
<ResponseMetadata>
<RequestId>cacecf23-9bf1-11e1-9279-0100e8cf109a</RequestId>
</ResponseMetadata>
</ListVerifiedEmailAddressesResponse>"""
VERIFY_DOMAIN_DKIM_RESPONSE = """<VerifyDomainDkimResponse xmlns="http://ses.amazonaws.com/doc/2010-12-01/"> VERIFY_DOMAIN_DKIM_RESPONSE = """<VerifyDomainDkimResponse xmlns="http://ses.amazonaws.com/doc/2010-12-01/">
<VerifyDomainDkimResult> <VerifyDomainDkimResult>
<DkimTokens> <DkimTokens>

View File

@ -96,7 +96,7 @@ class Message(BaseModel):
return escape(self._body) return escape(self._body)
def mark_sent(self, delay_seconds=None): def mark_sent(self, delay_seconds=None):
self.sent_timestamp = unix_time_millis() self.sent_timestamp = int(unix_time_millis())
if delay_seconds: if delay_seconds:
self.delay(delay_seconds=delay_seconds) self.delay(delay_seconds=delay_seconds)
@ -111,7 +111,7 @@ class Message(BaseModel):
visibility_timeout = 0 visibility_timeout = 0
if not self.approximate_first_receive_timestamp: if not self.approximate_first_receive_timestamp:
self.approximate_first_receive_timestamp = unix_time_millis() self.approximate_first_receive_timestamp = int(unix_time_millis())
self.approximate_receive_count += 1 self.approximate_receive_count += 1

View File

@ -339,7 +339,9 @@ SEND_MESSAGE_RESPONSE = """<SendMessageResponse>
<MD5OfMessageBody> <MD5OfMessageBody>
{{- message.body_md5 -}} {{- message.body_md5 -}}
</MD5OfMessageBody> </MD5OfMessageBody>
{% if message.message_attributes.items()|count > 0 %}
<MD5OfMessageAttributes>{{- message.attribute_md5 -}}</MD5OfMessageAttributes> <MD5OfMessageAttributes>{{- message.attribute_md5 -}}</MD5OfMessageAttributes>
{% endif %}
<MessageId> <MessageId>
{{- message.id -}} {{- message.id -}}
</MessageId> </MessageId>
@ -373,7 +375,9 @@ RECEIVE_MESSAGE_RESPONSE = """<ReceiveMessageResponse>
<Name>ApproximateFirstReceiveTimestamp</Name> <Name>ApproximateFirstReceiveTimestamp</Name>
<Value>{{ message.approximate_first_receive_timestamp }}</Value> <Value>{{ message.approximate_first_receive_timestamp }}</Value>
</Attribute> </Attribute>
{% if message.message_attributes.items()|count > 0 %}
<MD5OfMessageAttributes>{{- message.attribute_md5 -}}</MD5OfMessageAttributes> <MD5OfMessageAttributes>{{- message.attribute_md5 -}}</MD5OfMessageAttributes>
{% endif %}
{% for name, value in message.message_attributes.items() %} {% for name, value in message.message_attributes.items() %}
<MessageAttribute> <MessageAttribute>
<Name>{{ name }}</Name> <Name>{{ name }}</Name>
@ -402,7 +406,9 @@ SEND_MESSAGE_BATCH_RESPONSE = """<SendMessageBatchResponse>
<Id>{{ message.user_id }}</Id> <Id>{{ message.user_id }}</Id>
<MessageId>{{ message.id }}</MessageId> <MessageId>{{ message.id }}</MessageId>
<MD5OfMessageBody>{{ message.body_md5 }}</MD5OfMessageBody> <MD5OfMessageBody>{{ message.body_md5 }}</MD5OfMessageBody>
{% if message.message_attributes.items()|count > 0 %}
<MD5OfMessageAttributes>{{- message.attribute_md5 -}}</MD5OfMessageAttributes> <MD5OfMessageAttributes>{{- message.attribute_md5 -}}</MD5OfMessageAttributes>
{% endif %}
</SendMessageBatchResultEntry> </SendMessageBatchResultEntry>
{% endfor %} {% endfor %}
</SendMessageBatchResult> </SendMessageBatchResult>

View File

@ -28,11 +28,14 @@ class Parameter(BaseModel):
return value[len(prefix):] return value[len(prefix):]
def response_object(self, decrypt=False): def response_object(self, decrypt=False):
return { r = {
'Name': self.name, 'Name': self.name,
'Type': self.type, 'Type': self.type,
'Value': self.decrypt(self.value) if decrypt else self.value 'Value': self.decrypt(self.value) if decrypt else self.value
} }
if self.keyid:
r['KeyId'] = self.keyid
return r
class SimpleSystemManagerBackend(BaseBackend): class SimpleSystemManagerBackend(BaseBackend):
@ -46,6 +49,12 @@ class SimpleSystemManagerBackend(BaseBackend):
except KeyError: except KeyError:
pass pass
def get_all_parameters(self):
result = []
for k, _ in self._parameters.items():
result.append(self._parameters[k])
return result
def get_parameters(self, names, with_decryption): def get_parameters(self, names, with_decryption):
result = [] result = []
for name in names: for name in names:

View File

@ -43,6 +43,60 @@ class SimpleSystemManagerResponse(BaseResponse):
return json.dumps(response) return json.dumps(response)
def describe_parameters(self):
page_size = 10
filters = self._get_param('Filters')
token = self._get_param('NextToken')
if hasattr(token, 'strip'):
token = token.strip()
if not token:
token = '0'
token = int(token)
result = self.ssm_backend.get_all_parameters()
response = {
'Parameters': [],
}
end = token + page_size
for parameter in result[token:]:
param_data = parameter.response_object(False)
add = False
if filters:
for filter in filters:
if filter['Key'] == 'Name':
k = param_data['Name']
for v in filter['Values']:
if k.startswith(v):
add = True
break
elif filter['Key'] == 'Type':
k = param_data['Type']
for v in filter['Values']:
if k == v:
add = True
break
elif filter['Key'] == 'KeyId':
k = param_data.get('KeyId')
if k:
for v in filter['Values']:
if k == v:
add = True
break
else:
add = True
if add:
response['Parameters'].append(param_data)
token = token + 1
if len(response['Parameters']) == page_size:
response['NextToken'] = str(end)
break
return json.dumps(response)
def put_parameter(self): def put_parameter(self):
name = self._get_param('Name') name = self._get_param('Name')
description = self._get_param('Description') description = self._get_param('Description')

View File

@ -214,6 +214,21 @@ def test_create_and_delete_listener_boto3_support():
balancer['ListenerDescriptions'][1]['Listener'][ balancer['ListenerDescriptions'][1]['Listener'][
'InstancePort'].should.equal(8443) 'InstancePort'].should.equal(8443)
# Creating this listener with an conflicting definition throws error
with assert_raises(ClientError):
client.create_load_balancer_listeners(
LoadBalancerName='my-lb',
Listeners=[
{'Protocol': 'tcp', 'LoadBalancerPort': 443, 'InstancePort': 1234}]
)
client.delete_load_balancer_listeners(
LoadBalancerName='my-lb',
LoadBalancerPorts=[443])
balancer = client.describe_load_balancers()['LoadBalancerDescriptions'][0]
list(balancer['ListenerDescriptions']).should.have.length_of(1)
@mock_elb_deprecated @mock_elb_deprecated
def test_set_sslcertificate(): def test_set_sslcertificate():

View File

@ -64,7 +64,18 @@ def test_describe_cluster():
args['Configurations'] = [ args['Configurations'] = [
{'Classification': 'yarn-site', {'Classification': 'yarn-site',
'Properties': {'someproperty': 'somevalue', 'Properties': {'someproperty': 'somevalue',
'someotherproperty': 'someothervalue'}}] 'someotherproperty': 'someothervalue'}},
{'Classification': 'nested-configs',
'Properties': {},
'Configurations': [
{
'Classification': 'nested-config',
'Properties': {
'nested-property': 'nested-value'
}
}
]}
]
args['Instances']['AdditionalMasterSecurityGroups'] = ['additional-master'] args['Instances']['AdditionalMasterSecurityGroups'] = ['additional-master']
args['Instances']['AdditionalSlaveSecurityGroups'] = ['additional-slave'] args['Instances']['AdditionalSlaveSecurityGroups'] = ['additional-slave']
args['Instances']['Ec2KeyName'] = 'mykey' args['Instances']['Ec2KeyName'] = 'mykey'
@ -87,6 +98,10 @@ def test_describe_cluster():
config['Classification'].should.equal('yarn-site') config['Classification'].should.equal('yarn-site')
config['Properties'].should.equal(args['Configurations'][0]['Properties']) config['Properties'].should.equal(args['Configurations'][0]['Properties'])
nested_config = cl['Configurations'][1]
nested_config['Classification'].should.equal('nested-configs')
nested_config['Properties'].should.equal(args['Configurations'][1]['Properties'])
attrs = cl['Ec2InstanceAttributes'] attrs = cl['Ec2InstanceAttributes']
attrs['AdditionalMasterSecurityGroups'].should.equal( attrs['AdditionalMasterSecurityGroups'].should.equal(
args['Instances']['AdditionalMasterSecurityGroups']) args['Instances']['AdditionalMasterSecurityGroups'])

View File

@ -14,6 +14,7 @@ def test_create_database():
Engine='postgres', Engine='postgres',
DBName='staging-postgres', DBName='staging-postgres',
DBInstanceClass='db.m1.small', DBInstanceClass='db.m1.small',
LicenseModel='license-included',
MasterUsername='root', MasterUsername='root',
MasterUserPassword='hunter2', MasterUserPassword='hunter2',
Port=1234, Port=1234,
@ -23,6 +24,7 @@ def test_create_database():
database['DBInstance']['DBInstanceIdentifier'].should.equal("db-master-1") database['DBInstance']['DBInstanceIdentifier'].should.equal("db-master-1")
database['DBInstance']['AllocatedStorage'].should.equal(10) database['DBInstance']['AllocatedStorage'].should.equal(10)
database['DBInstance']['DBInstanceClass'].should.equal("db.m1.small") database['DBInstance']['DBInstanceClass'].should.equal("db.m1.small")
database['DBInstance']['LicenseModel'].should.equal("license-included")
database['DBInstance']['MasterUsername'].should.equal("root") database['DBInstance']['MasterUsername'].should.equal("root")
database['DBInstance']['DBSecurityGroups'][0][ database['DBInstance']['DBSecurityGroups'][0][
'DBSecurityGroupName'].should.equal('my_sg') 'DBSecurityGroupName'].should.equal('my_sg')
@ -145,10 +147,10 @@ def test_delete_database():
conn = boto3.client('rds', region_name='us-west-2') conn = boto3.client('rds', region_name='us-west-2')
instances = conn.describe_db_instances() instances = conn.describe_db_instances()
list(instances['DBInstances']).should.have.length_of(0) list(instances['DBInstances']).should.have.length_of(0)
conn.create_db_instance(DBInstanceIdentifier='db-master-1', conn.create_db_instance(DBInstanceIdentifier='db-primary-1',
AllocatedStorage=10, AllocatedStorage=10,
DBInstanceClass='postgres', Engine='postgres',
Engine='db.m1.small', DBInstanceClass='db.m1.small',
MasterUsername='root', MasterUsername='root',
MasterUserPassword='hunter2', MasterUserPassword='hunter2',
Port=1234, Port=1234,
@ -156,10 +158,16 @@ def test_delete_database():
instances = conn.describe_db_instances() instances = conn.describe_db_instances()
list(instances['DBInstances']).should.have.length_of(1) list(instances['DBInstances']).should.have.length_of(1)
conn.delete_db_instance(DBInstanceIdentifier="db-master-1") conn.delete_db_instance(DBInstanceIdentifier="db-primary-1",
FinalDBSnapshotIdentifier='primary-1-snapshot')
instances = conn.describe_db_instances() instances = conn.describe_db_instances()
list(instances['DBInstances']).should.have.length_of(0) list(instances['DBInstances']).should.have.length_of(0)
# Saved the snapshot
snapshots = conn.describe_db_snapshots(DBInstanceIdentifier="db-primary-1").get('DBSnapshots')
snapshots[0].get('Engine').should.equal('postgres')
@mock_rds2 @mock_rds2
def test_delete_non_existant_database(): def test_delete_non_existant_database():
@ -168,6 +176,81 @@ def test_delete_non_existant_database():
DBInstanceIdentifier="not-a-db").should.throw(ClientError) DBInstanceIdentifier="not-a-db").should.throw(ClientError)
@mock_rds2
def test_create_db_snapshots():
conn = boto3.client('rds', region_name='us-west-2')
conn.create_db_snapshot.when.called_with(
DBInstanceIdentifier='db-primary-1',
DBSnapshotIdentifier='snapshot-1').should.throw(ClientError)
conn.create_db_instance(DBInstanceIdentifier='db-primary-1',
AllocatedStorage=10,
Engine='postgres',
DBName='staging-postgres',
DBInstanceClass='db.m1.small',
MasterUsername='root',
MasterUserPassword='hunter2',
Port=1234,
DBSecurityGroups=["my_sg"])
snapshot = conn.create_db_snapshot(DBInstanceIdentifier='db-primary-1',
DBSnapshotIdentifier='g-1').get('DBSnapshot')
snapshot.get('Engine').should.equal('postgres')
snapshot.get('DBInstanceIdentifier').should.equal('db-primary-1')
snapshot.get('DBSnapshotIdentifier').should.equal('g-1')
@mock_rds2
def test_describe_db_snapshots():
conn = boto3.client('rds', region_name='us-west-2')
conn.create_db_instance(DBInstanceIdentifier='db-primary-1',
AllocatedStorage=10,
Engine='postgres',
DBName='staging-postgres',
DBInstanceClass='db.m1.small',
MasterUsername='root',
MasterUserPassword='hunter2',
Port=1234,
DBSecurityGroups=["my_sg"])
conn.describe_db_snapshots.when.called_with(
DBInstanceIdentifier="db-primary-1").should.throw(ClientError)
created = conn.create_db_snapshot(DBInstanceIdentifier='db-primary-1',
DBSnapshotIdentifier='snapshot-1').get('DBSnapshot')
created.get('Engine').should.equal('postgres')
by_database_id = conn.describe_db_snapshots(DBInstanceIdentifier='db-primary-1').get('DBSnapshots')
by_snapshot_id = conn.describe_db_snapshots(DBSnapshotIdentifier='snapshot-1').get('DBSnapshots')
by_snapshot_id.should.equal(by_database_id)
snapshot = by_snapshot_id[0]
snapshot.should.equal(created)
snapshot.get('Engine').should.equal('postgres')
@mock_rds2
def test_delete_db_snapshot():
conn = boto3.client('rds', region_name='us-west-2')
conn.create_db_instance(DBInstanceIdentifier='db-primary-1',
AllocatedStorage=10,
Engine='postgres',
DBName='staging-postgres',
DBInstanceClass='db.m1.small',
MasterUsername='root',
MasterUserPassword='hunter2',
Port=1234,
DBSecurityGroups=["my_sg"])
conn.create_db_snapshot(DBInstanceIdentifier='db-primary-1',
DBSnapshotIdentifier='snapshot-1')
conn.describe_db_snapshots(DBSnapshotIdentifier='snapshot-1').get('DBSnapshots')[0]
conn.delete_db_snapshot(DBSnapshotIdentifier='snapshot-1')
conn.describe_db_snapshots.when.called_with(
DBSnapshotIdentifier='snapshot-1').should.throw(ClientError)
@mock_rds2 @mock_rds2
def test_create_option_group(): def test_create_option_group():
conn = boto3.client('rds', region_name='us-west-2') conn = boto3.client('rds', region_name='us-west-2')

View File

@ -19,6 +19,13 @@ def test_verify_email_identity():
address = identities['Identities'][0] address = identities['Identities'][0]
address.should.equal('test@example.com') address.should.equal('test@example.com')
@mock_ses
def test_verify_email_address():
conn = boto3.client('ses', region_name='us-east-1')
conn.verify_email_address(EmailAddress="test@example.com")
email_addresses = conn.list_verified_email_addresses()
email = email_addresses['VerifiedEmailAddresses'][0]
email.should.equal('test@example.com')
@mock_ses @mock_ses
def test_domain_verify(): def test_domain_verify():

View File

@ -39,9 +39,25 @@ def test_get_inexistent_queue():
sqs.get_queue_by_name.when.called_with( sqs.get_queue_by_name.when.called_with(
QueueName='nonexisting-queue').should.throw(botocore.exceptions.ClientError) QueueName='nonexisting-queue').should.throw(botocore.exceptions.ClientError)
@mock_sqs
def test_message_send_without_attributes():
sqs = boto3.resource('sqs', region_name='us-east-1')
queue = sqs.create_queue(QueueName="blah")
msg = queue.send_message(
MessageBody="derp"
)
msg.get('MD5OfMessageBody').should.equal(
'58fd9edd83341c29f1aebba81c31e257')
msg.shouldnt.have.key('MD5OfMessageAttributes')
msg.get('ResponseMetadata', {}).get('RequestId').should.equal(
'27daac76-34dd-47df-bd01-1f6e873584a0')
msg.get('MessageId').should_not.contain(' \n')
messages = queue.receive_messages()
messages.should.have.length_of(1)
@mock_sqs @mock_sqs
def test_message_send(): def test_message_send_with_attributes():
sqs = boto3.resource('sqs', region_name='us-east-1') sqs = boto3.resource('sqs', region_name='us-east-1')
queue = sqs.create_queue(QueueName="blah") queue = sqs.create_queue(QueueName="blah")
msg = queue.send_message( msg = queue.send_message(
@ -189,7 +205,7 @@ def test_set_queue_attribute():
@mock_sqs @mock_sqs
def test_send_message(): def test_send_receive_message_without_attributes():
sqs = boto3.resource('sqs', region_name='us-east-1') sqs = boto3.resource('sqs', region_name='us-east-1')
conn = boto3.client("sqs", region_name='us-east-1') conn = boto3.client("sqs", region_name='us-east-1')
conn.create_queue(QueueName="test-queue") conn.create_queue(QueueName="test-queue")
@ -198,14 +214,81 @@ def test_send_message():
body_one = 'this is a test message' body_one = 'this is a test message'
body_two = 'this is another test message' body_two = 'this is another test message'
response = queue.send_message(MessageBody=body_one) queue.send_message(MessageBody=body_one)
response = queue.send_message(MessageBody=body_two) queue.send_message(MessageBody=body_two)
messages = conn.receive_message( messages = conn.receive_message(
QueueUrl=queue.url, MaxNumberOfMessages=2)['Messages'] QueueUrl=queue.url, MaxNumberOfMessages=2)['Messages']
messages[0]['Body'].should.equal(body_one) message1 = messages[0]
messages[1]['Body'].should.equal(body_two) message2 = messages[1]
message1['Body'].should.equal(body_one)
message2['Body'].should.equal(body_two)
message1.shouldnt.have.key('MD5OfMessageAttributes')
message2.shouldnt.have.key('MD5OfMessageAttributes')
@mock_sqs
def test_send_receive_message_with_attributes():
sqs = boto3.resource('sqs', region_name='us-east-1')
conn = boto3.client("sqs", region_name='us-east-1')
conn.create_queue(QueueName="test-queue")
queue = sqs.Queue("test-queue")
body_one = 'this is a test message'
body_two = 'this is another test message'
queue.send_message(
MessageBody=body_one,
MessageAttributes={
'timestamp': {
'StringValue': '1493147359900',
'DataType': 'Number',
}
}
)
queue.send_message(
MessageBody=body_two,
MessageAttributes={
'timestamp': {
'StringValue': '1493147359901',
'DataType': 'Number',
}
}
)
messages = conn.receive_message(
QueueUrl=queue.url, MaxNumberOfMessages=2)['Messages']
message1 = messages[0]
message2 = messages[1]
message1.get('Body').should.equal(body_one)
message2.get('Body').should.equal(body_two)
message1.get('MD5OfMessageAttributes').should.equal('235c5c510d26fb653d073faed50ae77c')
message2.get('MD5OfMessageAttributes').should.equal('994258b45346a2cc3f9cbb611aa7af30')
@mock_sqs
def test_send_receive_message_timestamps():
sqs = boto3.resource('sqs', region_name='us-east-1')
conn = boto3.client("sqs", region_name='us-east-1')
conn.create_queue(QueueName="test-queue")
queue = sqs.Queue("test-queue")
queue.send_message(MessageBody="derp")
messages = conn.receive_message(
QueueUrl=queue.url, MaxNumberOfMessages=1)['Messages']
message = messages[0]
sent_timestamp = message.get('Attributes').get('SentTimestamp')
approximate_first_receive_timestamp = message.get('Attributes').get('ApproximateFirstReceiveTimestamp')
int.when.called_with(sent_timestamp).shouldnt.throw(ValueError)
int.when.called_with(approximate_first_receive_timestamp).shouldnt.throw(ValueError)
@mock_sqs @mock_sqs

View File

@ -47,6 +47,141 @@ def test_put_parameter():
response['Parameters'][0]['Type'].should.equal('String') response['Parameters'][0]['Type'].should.equal('String')
@mock_ssm
def test_describe_parameters():
client = boto3.client('ssm', region_name='us-east-1')
client.put_parameter(
Name='test',
Description='A test parameter',
Value='value',
Type='String')
response = client.describe_parameters()
len(response['Parameters']).should.equal(1)
response['Parameters'][0]['Name'].should.equal('test')
response['Parameters'][0]['Type'].should.equal('String')
@mock_ssm
def test_describe_parameters_paging():
client = boto3.client('ssm', region_name='us-east-1')
for i in range(50):
client.put_parameter(
Name="param-%d" % i,
Value="value-%d" % i,
Type="String"
)
response = client.describe_parameters()
len(response['Parameters']).should.equal(10)
response['NextToken'].should.equal('10')
response = client.describe_parameters(NextToken=response['NextToken'])
len(response['Parameters']).should.equal(10)
response['NextToken'].should.equal('20')
response = client.describe_parameters(NextToken=response['NextToken'])
len(response['Parameters']).should.equal(10)
response['NextToken'].should.equal('30')
response = client.describe_parameters(NextToken=response['NextToken'])
len(response['Parameters']).should.equal(10)
response['NextToken'].should.equal('40')
response = client.describe_parameters(NextToken=response['NextToken'])
len(response['Parameters']).should.equal(10)
response['NextToken'].should.equal('50')
response = client.describe_parameters(NextToken=response['NextToken'])
len(response['Parameters']).should.equal(0)
''.should.equal(response.get('NextToken', ''))
@mock_ssm
def test_describe_parameters_filter_names():
client = boto3.client('ssm', region_name='us-east-1')
for i in range(50):
p = {
'Name': "param-%d" % i,
'Value': "value-%d" % i,
'Type': "String"
}
if i % 5 == 0:
p['Type'] = 'SecureString'
p['KeyId'] = 'a key'
client.put_parameter(**p)
response = client.describe_parameters(Filters=[
{
'Key': 'Name',
'Values': ['param-22']
},
])
len(response['Parameters']).should.equal(1)
response['Parameters'][0]['Name'].should.equal('param-22')
response['Parameters'][0]['Type'].should.equal('String')
''.should.equal(response.get('NextToken', ''))
@mock_ssm
def test_describe_parameters_filter_type():
client = boto3.client('ssm', region_name='us-east-1')
for i in range(50):
p = {
'Name': "param-%d" % i,
'Value': "value-%d" % i,
'Type': "String"
}
if i % 5 == 0:
p['Type'] = 'SecureString'
p['KeyId'] = 'a key'
client.put_parameter(**p)
response = client.describe_parameters(Filters=[
{
'Key': 'Type',
'Values': ['SecureString']
},
])
len(response['Parameters']).should.equal(10)
response['Parameters'][0]['Type'].should.equal('SecureString')
'10'.should.equal(response.get('NextToken', ''))
@mock_ssm
def test_describe_parameters_filter_keyid():
client = boto3.client('ssm', region_name='us-east-1')
for i in range(50):
p = {
'Name': "param-%d" % i,
'Value': "value-%d" % i,
'Type': "String"
}
if i % 5 == 0:
p['Type'] = 'SecureString'
p['KeyId'] = "key:%d" % i
client.put_parameter(**p)
response = client.describe_parameters(Filters=[
{
'Key': 'KeyId',
'Values': ['key:10']
},
])
len(response['Parameters']).should.equal(1)
response['Parameters'][0]['Name'].should.equal('param-10')
response['Parameters'][0]['Type'].should.equal('SecureString')
''.should.equal(response.get('NextToken', ''))
@mock_ssm @mock_ssm
def test_put_parameter_secure_default_kms(): def test_put_parameter_secure_default_kms():
client = boto3.client('ssm', region_name='us-east-1') client = boto3.client('ssm', region_name='us-east-1')