diff --git a/AUTHORS.md b/AUTHORS.md index 71bc6319e..e5a5dcc79 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -37,3 +37,5 @@ Moto is written by Steve Pulec with contributions from: * [Mike Fuller](https://github.com/mfulleratlassian) * [Andy](https://github.com/aaltepet) * [Mike Grima](https://github.com/mikegrima) +* [Marco Rucci](https://github.com/mrucci) +* [Zack Kourouma](https://github.com/zkourouma) diff --git a/moto/autoscaling/models.py b/moto/autoscaling/models.py index edd580dea..2be241554 100644 --- a/moto/autoscaling/models.py +++ b/moto/autoscaling/models.py @@ -186,7 +186,7 @@ class FakeAutoScalingGroup(object): if self.desired_capacity > curr_instance_count: # Need more instances - count_needed = self.desired_capacity - curr_instance_count + count_needed = int(self.desired_capacity) - int(curr_instance_count) reservation = self.autoscaling_backend.ec2_backend.add_instances( self.launch_config.image_id, count_needed, diff --git a/moto/core/models.py b/moto/core/models.py index 495c9a382..be25321c7 100644 --- a/moto/core/models.py +++ b/moto/core/models.py @@ -29,10 +29,11 @@ class MockAWS(object): def __exit__(self, *args): self.stop() - def start(self): + def start(self, reset=True): self.__class__.nested_count += 1 - for backend in self.backends.values(): - backend.reset() + if reset: + for backend in self.backends.values(): + backend.reset() if not HTTPretty.is_enabled(): HTTPretty.enable() diff --git a/moto/dynamodb2/models.py b/moto/dynamodb2/models.py index f24d7398d..56a8fb4c0 100644 --- a/moto/dynamodb2/models.py +++ b/moto/dynamodb2/models.py @@ -122,7 +122,7 @@ class Item(object): class Table(object): - def __init__(self, table_name, schema=None, attr=None, throughput=None, indexes=None): + def __init__(self, table_name, schema=None, attr=None, throughput=None, indexes=None, global_indexes=None): self.name = table_name self.attr = attr self.schema = schema @@ -143,6 +143,7 @@ class Table(object): self.throughput = throughput self.throughput["NumberOfDecreasesToday"] = 0 self.indexes = indexes + self.global_indexes = global_indexes if global_indexes else [] self.created_at = datetime.datetime.now() self.items = defaultdict(dict) @@ -158,6 +159,7 @@ class Table(object): 'KeySchema': self.schema, 'ItemCount': len(self), 'CreationDateTime': unix_time(self.created_at), + 'GlobalSecondaryIndexes': [index for index in self.global_indexes], } } return results @@ -171,6 +173,24 @@ class Table(object): count += 1 return count + @property + def hash_key_names(self): + keys = [self.hash_key_attr] + for index in self.global_indexes: + for key in index['KeySchema']: + if key['KeyType'] == 'HASH': + keys.append(key['AttributeName']) + return keys + + @property + def range_key_names(self): + keys = [self.range_key_attr] + for index in self.global_indexes: + for key in index['KeySchema']: + if key['KeyType'] == 'RANGE': + keys.append(key['AttributeName']) + return keys + def put_item(self, item_attrs): hash_value = DynamoType(item_attrs.get(self.hash_key_attr)) if self.has_range_key: @@ -268,6 +288,16 @@ class Table(object): results.append(result) return results, scanned_count, last_page + def lookup(self, *args, **kwargs): + if not self.schema: + self.describe() + for x, arg in enumerate(args): + kwargs[self.schema[x].name] = arg + ret = self.get_item(**kwargs) + if not ret.keys(): + return None + return ret + class DynamoDBBackend(BaseBackend): @@ -293,12 +323,21 @@ class DynamoDBBackend(BaseBackend): return None return table.put_item(item_attrs) - def get_table_keys_name(self, table_name): + def get_table_keys_name(self, table_name, keys): + """ + Given a set of keys, extracts the key and range key + """ table = self.tables.get(table_name) if not table: return None, None else: - return table.hash_key_attr, table.range_key_attr + hash_key = range_key = None + for key in keys: + if key in table.hash_key_names: + hash_key = key + elif key in table.range_key_names: + range_key = key + return hash_key, range_key def get_keys_value(self, table, keys): if table.hash_key_attr not in keys or (table.has_range_key and table.range_key_attr not in keys): diff --git a/moto/dynamodb2/responses.py b/moto/dynamodb2/responses.py index 4cc064bf6..8cee08ebe 100644 --- a/moto/dynamodb2/responses.py +++ b/moto/dynamodb2/responses.py @@ -99,10 +99,12 @@ class DynamoHandler(BaseResponse): # getting attribute definition attr = body["AttributeDefinitions"] # getting the indexes + global_indexes = body.get("GlobalSecondaryIndexes", []) table = dynamodb_backend2.create_table(table_name, schema=key_schema, throughput=throughput, - attr=attr) + attr=attr, + global_indexes=global_indexes) return dynamo_json_dump(table.describe) def delete_table(self): @@ -216,13 +218,14 @@ class DynamoHandler(BaseResponse): def query(self): name = self.body['TableName'] - keys = self.body['KeyConditions'] - hash_key_name, range_key_name = dynamodb_backend2.get_table_keys_name(name) + key_conditions = self.body['KeyConditions'] + hash_key_name, range_key_name = dynamodb_backend2.get_table_keys_name(name, key_conditions.keys()) + # hash_key_name, range_key_name = dynamodb_backend2.get_table_keys_name(name) if hash_key_name is None: er = "'com.amazonaws.dynamodb.v20120810#ResourceNotFoundException" return self.error(er) - hash_key = keys[hash_key_name]['AttributeValueList'][0] - if len(keys) == 1: + hash_key = key_conditions[hash_key_name]['AttributeValueList'][0] + if len(key_conditions) == 1: range_comparison = None range_values = [] else: @@ -230,7 +233,7 @@ class DynamoHandler(BaseResponse): er = "com.amazon.coral.validate#ValidationException" return self.error(er) else: - range_condition = keys[range_key_name] + range_condition = key_conditions[range_key_name] if range_condition: range_comparison = range_condition['ComparisonOperator'] range_values = range_condition['AttributeValueList'] diff --git a/moto/ec2/models.py b/moto/ec2/models.py index c5d3c256c..cdb891487 100644 --- a/moto/ec2/models.py +++ b/moto/ec2/models.py @@ -1149,6 +1149,10 @@ class SecurityGroupBackend(object): def __init__(self): # the key in the dict group is the vpc_id or None (non-vpc) self.groups = defaultdict(dict) + + # Create the default security group + self.create_security_group("default", "default group") + super(SecurityGroupBackend, self).__init__() def create_security_group(self, name, description, vpc_id=None, force=False): @@ -1212,11 +1216,6 @@ class SecurityGroupBackend(object): if group.name == name: return group - if name == 'default': - # If the request is for the default group and it does not exist, create it - default_group = self.create_security_group("default", "The default security group", vpc_id=vpc_id, force=True) - return default_group - def get_security_group_by_name_or_id(self, group_name_or_id, vpc_id): # try searching by id, fallbacks to name search group = self.get_security_group_from_id(group_name_or_id) @@ -1309,7 +1308,7 @@ class SecurityGroupIngress(object): from_port = properties.get("FromPort") source_security_group_id = properties.get("SourceSecurityGroupId") source_security_group_name = properties.get("SourceSecurityGroupName") - source_security_owner_id = properties.get("SourceSecurityGroupOwnerId") # IGNORED AT THE MOMENT + # source_security_owner_id = properties.get("SourceSecurityGroupOwnerId") # IGNORED AT THE MOMENT to_port = properties.get("ToPort") assert group_id or group_name @@ -1329,7 +1328,6 @@ class SecurityGroupIngress(object): else: ip_ranges = [] - if group_id: security_group = ec2_backend.describe_security_groups(group_ids=[group_id])[0] else: @@ -1697,41 +1695,66 @@ class VPCPeeringConnectionBackend(object): class Subnet(TaggedEC2Resource): - def __init__(self, ec2_backend, subnet_id, vpc_id, cidr_block): + def __init__(self, ec2_backend, subnet_id, vpc_id, cidr_block, availability_zone): self.ec2_backend = ec2_backend self.id = subnet_id self.vpc_id = vpc_id self.cidr_block = cidr_block + self._availability_zone = availability_zone @classmethod def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): properties = cloudformation_json['Properties'] vpc_id = properties['VpcId'] + cidr_block = properties['CidrBlock'] + availability_zone = properties.get('AvailabilityZone') ec2_backend = ec2_backends[region_name] subnet = ec2_backend.create_subnet( vpc_id=vpc_id, - cidr_block=properties['CidrBlock'] + cidr_block=cidr_block, + availability_zone=availability_zone, ) return subnet @property def availability_zone(self): - # This could probably be smarter, but there doesn't appear to be a - # way to pull AZs for a region in boto - return self.ec2_backend.region_name + "a" + if self._availability_zone is None: + # This could probably be smarter, but there doesn't appear to be a + # way to pull AZs for a region in boto + return self.ec2_backend.region_name + "a" + else: + return self._availability_zone @property def physical_resource_id(self): return self.id def get_filter_value(self, filter_name): + """ + API Version 2014-10-01 defines the following filters for DescribeSubnets: + + * availabilityZone + * available-ip-address-count + * cidrBlock + * defaultForAz + * state + * subnet-id + * tag:key=value + * tag-key + * tag-value + * vpc-id + + Taken from: http://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_DescribeSubnets.html + """ if filter_name in ['cidr', 'cidrBlock', 'cidr-block']: return self.cidr_block elif filter_name == 'vpc-id': return self.vpc_id elif filter_name == 'subnet-id': return self.id + elif filter_name == 'availabilityZone': + return self.availability_zone filter_value = super(Subnet, self).get_filter_value(filter_name) @@ -1758,9 +1781,9 @@ class SubnetBackend(object): raise InvalidSubnetIdError(subnet_id) return subnet - def create_subnet(self, vpc_id, cidr_block): + def create_subnet(self, vpc_id, cidr_block, availability_zone=None): subnet_id = random_subnet_id() - subnet = Subnet(self, subnet_id, vpc_id, cidr_block) + subnet = Subnet(self, subnet_id, vpc_id, cidr_block, availability_zone) self.get_vpc(vpc_id) # Validate VPC exists # AWS associates a new subnet with the default Network ACL diff --git a/moto/ec2/responses/elastic_block_store.py b/moto/ec2/responses/elastic_block_store.py index ad0857ea1..96586a9bb 100644 --- a/moto/ec2/responses/elastic_block_store.py +++ b/moto/ec2/responses/elastic_block_store.py @@ -103,7 +103,7 @@ CREATE_VOLUME_RESPONSE = """ @@ -166,7 +166,7 @@ DETATCH_VOLUME_RESPONSE = """ @@ -174,7 +174,7 @@ CREATE_SNAPSHOT_RESPONSE = """ {{ nic.id }} - {{ nic.subnet.id }} - {{ nic.subnet.vpc_id }} + {% if nic.subnet %} + {{ nic.subnet.id }} + {{ nic.subnet.vpc_id }} + {% endif %} Primary network interface 111122223333 in-use diff --git a/moto/ec2/responses/subnets.py b/moto/ec2/responses/subnets.py index 9f0808648..a0798a615 100644 --- a/moto/ec2/responses/subnets.py +++ b/moto/ec2/responses/subnets.py @@ -7,7 +7,15 @@ class Subnets(BaseResponse): def create_subnet(self): vpc_id = self.querystring.get('VpcId')[0] cidr_block = self.querystring.get('CidrBlock')[0] - subnet = self.ec2_backend.create_subnet(vpc_id, cidr_block) + if 'AvailabilityZone' in self.querystring: + availability_zone = self.querystring['AvailabilityZone'][0] + else: + availability_zone = None + subnet = self.ec2_backend.create_subnet( + vpc_id, + cidr_block, + availability_zone, + ) template = self.response_template(CREATE_SUBNET_RESPONSE) return template.render(subnet=subnet) @@ -33,7 +41,7 @@ CREATE_SUBNET_RESPONSE = """ {{ subnet.vpc_id }} {{ subnet.cidr_block }} 251 - us-east-1a + {{ subnet.availability_zone }} {% for tag in subnet.get_tags() %} @@ -64,7 +72,7 @@ DESCRIBE_SUBNETS_RESPONSE = """ {{ subnet.vpc_id }} {{ subnet.cidr_block }} 251 - us-east-1a + {{ subnet.availability_zone }} {% for tag in subnet.get_tags() %} diff --git a/moto/elb/models.py b/moto/elb/models.py index 94914a8e7..e1487f5aa 100644 --- a/moto/elb/models.py +++ b/moto/elb/models.py @@ -1,6 +1,13 @@ from __future__ import unicode_literals import boto.ec2.elb +from boto.ec2.elb.attributes import ( + LbAttributes, + ConnectionSettingAttribute, + ConnectionDrainingAttribute, + AccessLogAttribute, + CrossZoneLoadBalancingAttribute, +) from moto.core import BaseBackend @@ -29,6 +36,7 @@ class FakeLoadBalancer(object): self.instance_ids = [] self.zones = zones self.listeners = [] + self.attributes = FakeLoadBalancer.get_default_attributes() for protocol, lb_port, instance_port, ssl_certificate_id in ports: listener = FakeListener( @@ -73,6 +81,28 @@ class FakeLoadBalancer(object): raise NotImplementedError('"Fn::GetAtt" : [ "{0}" , "SourceSecurityGroup.OwnerAlias" ]"') raise UnformattedGetAttTemplateException() + @classmethod + def get_default_attributes(cls): + attributes = LbAttributes() + + cross_zone_load_balancing = CrossZoneLoadBalancingAttribute() + cross_zone_load_balancing.enabled = False + attributes.cross_zone_load_balancing = cross_zone_load_balancing + + connection_draining = ConnectionDrainingAttribute() + connection_draining.enabled = False + attributes.connection_draining = connection_draining + + access_log = AccessLogAttribute() + access_log.enabled = False + attributes.access_log = access_log + + connection_settings = ConnectionSettingAttribute() + connection_settings.idle_timeout = 60 + attributes.connecting_settings = connection_settings + + return attributes + class ELBBackend(BaseBackend): @@ -151,6 +181,26 @@ class ELBBackend(BaseBackend): load_balancer.instance_ids = new_instance_ids return load_balancer + def set_cross_zone_load_balancing_attribute(self, load_balancer_name, attribute): + load_balancer = self.get_load_balancer(load_balancer_name) + load_balancer.attributes.cross_zone_load_balancing = attribute + return load_balancer + + def set_access_log_attribute(self, load_balancer_name, attribute): + load_balancer = self.get_load_balancer(load_balancer_name) + load_balancer.attributes.access_log = attribute + return load_balancer + + def set_connection_draining_attribute(self, load_balancer_name, attribute): + load_balancer = self.get_load_balancer(load_balancer_name) + load_balancer.attributes.connection_draining = attribute + return load_balancer + + def set_connection_settings_attribute(self, load_balancer_name, attribute): + load_balancer = self.get_load_balancer(load_balancer_name) + load_balancer.attributes.connecting_settings = attribute + return load_balancer + elb_backends = {} for region in boto.ec2.elb.regions(): diff --git a/moto/elb/responses.py b/moto/elb/responses.py index fb114fb22..d33a78fc8 100644 --- a/moto/elb/responses.py +++ b/moto/elb/responses.py @@ -1,4 +1,10 @@ from __future__ import unicode_literals +from boto.ec2.elb.attributes import ( + ConnectionSettingAttribute, + ConnectionDrainingAttribute, + AccessLogAttribute, + CrossZoneLoadBalancingAttribute, +) from moto.core.responses import BaseResponse from .models import elb_backends @@ -25,7 +31,7 @@ class ELBResponse(BaseResponse): break lb_port = self.querystring['Listeners.member.{0}.LoadBalancerPort'.format(port_index)][0] instance_port = self.querystring['Listeners.member.{0}.InstancePort'.format(port_index)][0] - ssl_certificate_id = self.querystring.get('Listeners.member.{0}.SSLCertificateId'.format(port_index)[0], None) + ssl_certificate_id = self.querystring.get('Listeners.member.{0}.SSLCertificateId'.format(port_index), [None])[0] ports.append([protocol, lb_port, instance_port, ssl_certificate_id]) port_index += 1 @@ -122,6 +128,64 @@ class ELBResponse(BaseResponse): load_balancer = self.elb_backend.deregister_instances(load_balancer_name, instance_ids) return template.render(load_balancer=load_balancer) + def describe_load_balancer_attributes(self): + load_balancer_name = self.querystring.get('LoadBalancerName')[0] + load_balancer = self.elb_backend.describe_load_balancers(load_balancer_name)[0] + template = self.response_template(DESCRIBE_ATTRIBUTES_TEMPLATE) + return template.render(attributes=load_balancer.attributes) + + def modify_load_balancer_attributes(self): + load_balancer_name = self.querystring.get('LoadBalancerName')[0] + load_balancer = self.elb_backend.describe_load_balancers(load_balancer_name)[0] + + def parse_attribute(attribute_name): + """ + Transform self.querystring parameters matching `LoadBalancerAttributes.attribute_name.attribute_key` + into a dictionary of (attribute_name, attribute_key)` pairs. + """ + attribute_prefix = "LoadBalancerAttributes." + attribute_name + return dict((key.lstrip(attribute_prefix), value[0]) for key, value in self.querystring.items() if key.startswith(attribute_prefix)) + + cross_zone = parse_attribute("CrossZoneLoadBalancing") + if cross_zone: + attribute = CrossZoneLoadBalancingAttribute() + attribute.enabled = cross_zone["Enabled"] == "true" + self.elb_backend.set_cross_zone_load_balancing_attribute(load_balancer_name, attribute) + + access_log = parse_attribute("AccessLog") + if access_log: + attribute = AccessLogAttribute() + attribute.enabled = access_log["Enabled"] == "true" + attribute.s3_bucket_name = access_log["S3BucketName"] + attribute.s3_bucket_prefix = access_log["S3BucketPrefix"] + attribute.emit_interval = access_log["EmitInterval"] + self.elb_backend.set_access_log_attribute(load_balancer_name, attribute) + + connection_draining = parse_attribute("ConnectionDraining") + if connection_draining: + attribute = ConnectionDrainingAttribute() + attribute.enabled = connection_draining["Enabled"] == "true" + attribute.timeout = connection_draining["Timeout"] + self.elb_backend.set_connection_draining_attribute(load_balancer_name, attribute) + + connection_settings = parse_attribute("ConnectionSettings") + if connection_settings: + attribute = ConnectionSettingAttribute() + attribute.idle_timeout = connection_settings["IdleTimeout"] + self.elb_backend.set_connection_settings_attribute(load_balancer_name, attribute) + + template = self.response_template(MODIFY_ATTRIBUTES_TEMPLATE) + return template.render(attributes=load_balancer.attributes) + + def describe_instance_health(self): + load_balancer_name = self.querystring.get('LoadBalancerName')[0] + instance_ids = [value[0] for key, value in self.querystring.items() if "Instances.member" in key] + if len(instance_ids) == 0: + instance_ids = self.elb_backend.describe_load_balancers(load_balancer_name)[0].instance_ids + template = self.response_template(DESCRIBE_INSTANCE_HEALTH_TEMPLATE) + return template.render(instance_ids=instance_ids) + + CREATE_LOAD_BALANCER_TEMPLATE = """ tests.us-east-1.elb.amazonaws.com """ @@ -253,3 +317,84 @@ DELETE_LOAD_BALANCER_LISTENERS = """ + + + + {{ attributes.access_log.enabled }} + {% if attributes.access_log.enabled %} + {{ attributes.access_log.s3_bucket_name }} + {{ attributes.access_log.s3_bucket_prefix }} + {{ attributes.access_log.emit_interval }} + {% endif %} + + + {{ attributes.connecting_settings.idle_timeout }} + + + {{ attributes.cross_zone_load_balancing.enabled }} + + + {{ attributes.connection_draining.enabled }} + {% if attributes.connection_draining.enabled %} + {{ attributes.connection_draining.timeout }} + {% endif %} + + + + + 83c88b9d-12b7-11e3-8b82-87b12EXAMPLE + + +""" + +MODIFY_ATTRIBUTES_TEMPLATE = """ + + my-loadbalancer + + + {{ attributes.access_log.enabled }} + {% if attributes.access_log.enabled %} + {{ attributes.access_log.s3_bucket_name }} + {{ attributes.access_log.s3_bucket_prefix }} + {{ attributes.access_log.emit_interval }} + {% endif %} + + + {{ attributes.connecting_settings.idle_timeout }} + + + {{ attributes.cross_zone_load_balancing.enabled }} + + + {{ attributes.connection_draining.enabled }} + {% if attributes.connection_draining.enabled %} + {{ attributes.connection_draining.timeout }} + {% endif %} + + + + + 83c88b9d-12b7-11e3-8b82-87b12EXAMPLE + + +""" + +DESCRIBE_INSTANCE_HEALTH_TEMPLATE = """ + + + {% for instance_id in instance_ids %} + + N/A + {{ instance_id }} + InService + N/A + + {% endfor %} + + + + 1549581b-12b7-11e3-895e-1334aEXAMPLE + +""" diff --git a/moto/iam/models.py b/moto/iam/models.py index 388984f51..2e9970785 100644 --- a/moto/iam/models.py +++ b/moto/iam/models.py @@ -6,6 +6,7 @@ from .utils import random_access_key, random_alphanumeric, random_resource_id from datetime import datetime import base64 + class Role(object): def __init__(self, role_id, name, assume_role_policy_document, path): @@ -212,16 +213,16 @@ class User(object): access_key_2_last_rotated = date_created.strftime(date_format) return '{0},{1},{2},{3},{4},{5},not_supported,false,{6},{7},{8},{9},false,N/A,false,N/A'.format(self.name, - self.arn, - date_created.strftime(date_format), - password_enabled, - password_last_used, - date_created.strftime(date_format), - access_key_1_active, - access_key_1_last_rotated, - access_key_2_active, - access_key_2_last_rotated - ) + self.arn, + date_created.strftime(date_format), + password_enabled, + password_last_used, + date_created.strftime(date_format), + access_key_1_active, + access_key_1_last_rotated, + access_key_2_active, + access_key_2_last_rotated + ) class IAMBackend(BaseBackend): @@ -337,6 +338,18 @@ class IAMBackend(BaseBackend): return group + def list_groups(self): + return self.groups.values() + + def get_groups_for_user(self, user_name): + user = self.get_user(user_name) + groups = [] + for group in self.list_groups(): + if user in group.users: + groups.append(group) + + return groups + def create_user(self, user_name, path='/'): if user_name in self.users: raise BotoServerError(409, 'Conflict') diff --git a/moto/iam/responses.py b/moto/iam/responses.py index 4ebfb74ec..be1601a83 100644 --- a/moto/iam/responses.py +++ b/moto/iam/responses.py @@ -131,6 +131,18 @@ class IamResponse(BaseResponse): template = self.response_template(GET_GROUP_TEMPLATE) return template.render(group=group) + def list_groups(self): + groups = iam_backend.list_groups() + template = self.response_template(LIST_GROUPS_TEMPLATE) + return template.render(groups=groups) + + def list_groups_for_user(self): + user_name = self._get_param('UserName') + + groups = iam_backend.get_groups_for_user(user_name) + template = self.response_template(LIST_GROUPS_FOR_USER_TEMPLATE) + return template.render(groups=groups) + def create_user(self): user_name = self._get_param('UserName') path = self._get_param('Path') @@ -502,6 +514,45 @@ GET_GROUP_TEMPLATE = """ """ +LIST_GROUPS_TEMPLATE = """ + + + {% for group in groups %} + + {{ group.path }} + {{ group.name }} + {{ group.id }} + arn:aws:iam::123456789012:group/{{ group.path }} + + {% endfor %} + + false + + + 7a62c49f-347e-4fc4-9331-6e8eEXAMPLE + +""" + +LIST_GROUPS_FOR_USER_TEMPLATE = """ + + + {% for group in groups %} + + {{ group.path }} + {{ group.name }} + {{ group.id }} + arn:aws:iam::123456789012:group/{{ group.path }} + + {% endfor %} + + false + + + 7a62c49f-347e-4fc4-9331-6e8eEXAMPLE + +""" + + USER_TEMPLATE = """<{{ action }}UserResponse> <{{ action }}UserResult> @@ -640,4 +691,4 @@ LIST_INSTANCE_PROFILES_FOR_ROLE_TEMPLATE = """ 6a8c3992-99f4-11e1-a4c3-27EXAMPLE804 -""" \ No newline at end of file +""" diff --git a/moto/route53/models.py b/moto/route53/models.py index 58c559f25..00e23c38e 100644 --- a/moto/route53/models.py +++ b/moto/route53/models.py @@ -106,9 +106,10 @@ class RecordSet(object): class FakeZone(object): - def __init__(self, name, id_): + def __init__(self, name, id_, comment=None): self.name = name self.id = id_ + self.comment = comment self.rrsets = [] def add_rrset(self, record_set): @@ -116,9 +117,12 @@ class FakeZone(object): self.rrsets.append(record_set) return record_set - def delete_rrset(self, name): + def delete_rrset_by_name(self, name): self.rrsets = [record_set for record_set in self.rrsets if record_set.name != name] + def delete_rrset_by_id(self, set_identifier): + self.rrsets = [record_set for record_set in self.rrsets if record_set.set_identifier != set_identifier] + def get_record_sets(self, type_filter, name_filter): record_sets = list(self.rrsets) # Copy the list if type_filter: @@ -170,9 +174,9 @@ class Route53Backend(BaseBackend): self.zones = {} self.health_checks = {} - def create_hosted_zone(self, name): + def create_hosted_zone(self, name, comment=None): new_id = get_random_hex() - new_zone = FakeZone(name, new_id) + new_zone = FakeZone(name, new_id, comment=comment) self.zones[new_id] = new_zone return new_zone diff --git a/moto/route53/responses.py b/moto/route53/responses.py index 5bbb8f451..3cd848607 100644 --- a/moto/route53/responses.py +++ b/moto/route53/responses.py @@ -9,7 +9,8 @@ def list_or_create_hostzone_response(request, full_url, headers): if request.method == "POST": elements = xmltodict.parse(request.body) - new_zone = route53_backend.create_hosted_zone(elements["CreateHostedZoneRequest"]["Name"]) + comment = elements["CreateHostedZoneRequest"]["HostedZoneConfig"]["Comment"] + new_zone = route53_backend.create_hosted_zone(elements["CreateHostedZoneRequest"]["Name"], comment=comment) template = Template(CREATE_HOSTED_ZONE_RESPONSE) return 201, headers, template.render(zone=new_zone) @@ -57,7 +58,10 @@ def rrset_response(request, full_url, headers): record_set['ResourceRecords'] = [x['Value'] for x in record_set['ResourceRecords'].values()] the_zone.add_rrset(record_set) elif action == "DELETE": - the_zone.delete_rrset(record_set["Name"]) + if 'SetIdentifier' in record_set: + the_zone.delete_rrset_by_id(record_set["SetIdentifier"]) + else: + the_zone.delete_rrset_by_name(record_set["Name"]) return 200, headers, CHANGE_RRSET_RESPONSE @@ -125,6 +129,9 @@ GET_HOSTED_ZONE_RESPONSE = """