From e42046aeda908acda36ee5e25179335b37bc9d20 Mon Sep 17 00:00:00 2001 From: Mike Fuller Date: Thu, 29 Jan 2015 07:15:03 +1100 Subject: [PATCH] extended the list/add/delete tags functions to support more resource types. --- moto/rds2/models.py | 118 +++++++++++++++++++++++++++++++---- moto/rds2/responses.py | 1 - tests/test_rds2/test_rds2.py | 14 ++--- 3 files changed, 113 insertions(+), 20 deletions(-) diff --git a/moto/rds2/models.py b/moto/rds2/models.py index bc73d16a5..c9c617f5b 100644 --- a/moto/rds2/models.py +++ b/moto/rds2/models.py @@ -265,6 +265,7 @@ class SecurityGroup(object): self.status = "authorized" self.ip_ranges = [] self.ec2_security_groups = [] + self.tags = [] def to_xml(self): template = Template(""" @@ -323,6 +324,18 @@ class SecurityGroup(object): security_group.authorize_security_group(subnet) return security_group + def get_tags(self): + return self.tags + + def add_tags(self, tags): + new_keys = [tag_set['Key'] for tag_set in tags] + self.tags = [tag_set for tag_set in self.tags if tag_set['Key'] not in new_keys] + self.tags.extend(tags) + return self.tags + + def remove_tags(self, tag_keys): + self.tags = [tag_set for tag_set in self.tags if tag_set['Key'] not in tag_keys] + class SubnetGroup(object): def __init__(self, subnet_name, description, subnets): @@ -330,7 +343,7 @@ class SubnetGroup(object): self.description = description self.subnets = subnets self.status = "Complete" - + self.tags = [] self.vpc_id = self.subnets[0].vpc_id def to_xml(self): @@ -395,6 +408,18 @@ class SubnetGroup(object): ) return subnet_group + def get_tags(self): + return self.tags + + def add_tags(self, tags): + new_keys = [tag_set['Key'] for tag_set in tags] + self.tags = [tag_set for tag_set in self.tags if tag_set['Key'] not in new_keys] + self.tags.extend(tags) + return self.tags + + def remove_tags(self, tag_keys): + self.tags = [tag_set for tag_set in self.tags if tag_set['Key'] not in tag_keys] + class RDS2Backend(BaseBackend): @@ -602,21 +627,58 @@ class RDS2Backend(BaseBackend): def list_tags_for_resource(self, arn): if self.arn_regex.match(arn): arn_breakdown = arn.split(':') - db_instance_name = arn_breakdown[len(arn_breakdown)-1] - if db_instance_name in self.databases: - return self.databases[db_instance_name].get_tags() - else: + resource_type = arn_breakdown[len(arn_breakdown)-2] + resource_name = arn_breakdown[len(arn_breakdown)-1] + if resource_type == 'db': # Database + if resource_name in self.databases: + return self.databases[resource_name].get_tags() + elif resource_type == 'es': # Event Subscription return [] + elif resource_type == 'og': # Option Group + if resource_name in self.option_groups: + return self.option_groups[resource_name].get_tags() + elif resource_type == 'pg': # Parameter Group + return [] + elif resource_type == 'ri': # Reserved DB instance + return [] + elif resource_type == 'secgrp': # DB security group + if resource_type in self.security_groups: + return self.security_groups[resource_name].get_tags() + elif resource_type == 'snapshot': # DB Snapshot + return [] + elif resource_type == 'subgrp': # DB subnet group + if resource_type in self.subnet_groups: + return self.subnet_groups[resource_name].get_tags() else: raise RDSClientError('InvalidParameterValue', 'Invalid resource name: {}'.format(arn)) + return [] def remove_tags_from_resource(self, arn, tag_keys): if self.arn_regex.match(arn): arn_breakdown = arn.split(':') - db_instance_name = arn_breakdown[len(arn_breakdown)-1] - if db_instance_name in self.databases: - self.databases[db_instance_name].remove_tags(tag_keys) + resource_type = arn_breakdown[len(arn_breakdown)-2] + resource_name = arn_breakdown[len(arn_breakdown)-1] + if resource_type == 'db': # Database + if resource_name in self.databases: + self.databases[resource_name].remove_tags(tag_keys) + elif resource_type == 'es': # Event Subscription + return None + elif resource_type == 'og': # Option Group + if resource_name in self.option_groups: + return self.option_groups[resource_name].remove_tags(tag_keys) + elif resource_type == 'pg': # Parameter Group + return None + elif resource_type == 'ri': # Reserved DB instance + return None + elif resource_type == 'secgrp': # DB security group + if resource_type in self.security_groups: + return self.security_groups[resource_name].remove_tags(tag_keys) + elif resource_type == 'snapshot': # DB Snapshot + return None + elif resource_type == 'subgrp': # DB subnet group + if resource_type in self.subnet_groups: + return self.subnet_groups[resource_name].remove_tags(tag_keys) else: raise RDSClientError('InvalidParameterValue', 'Invalid resource name: {}'.format(arn)) @@ -624,15 +686,33 @@ class RDS2Backend(BaseBackend): def add_tags_to_resource(self, arn, tags): if self.arn_regex.match(arn): arn_breakdown = arn.split(':') - db_instance_name = arn_breakdown[len(arn_breakdown)-1] - if db_instance_name in self.databases: - return self.databases[db_instance_name].add_tags(tags) - else: + resource_type = arn_breakdown[len(arn_breakdown)-2] + resource_name = arn_breakdown[len(arn_breakdown)-1] + if resource_type == 'db': # Database + if resource_name in self.databases: + return self.databases[resource_name].add_tags(tags) + elif resource_type == 'es': # Event Subscription return [] + elif resource_type == 'og': # Option Group + if resource_name in self.option_groups: + return self.option_groups[resource_name].add_tags(tags) + elif resource_type == 'pg': # Parameter Group + return [] + elif resource_type == 'ri': # Reserved DB instance + return [] + elif resource_type == 'secgrp': # DB security group + if resource_type in self.security_groups: + return self.security_groups[resource_name].add_tags(tags) + elif resource_type == 'snapshot': # DB Snapshot + return [] + elif resource_type == 'subgrp': # DB subnet group + if resource_type in self.subnet_groups: + return self.subnet_groups[resource_name].add_tags(tags) else: raise RDSClientError('InvalidParameterValue', 'Invalid resource name: {}'.format(arn)) + class OptionGroup(object): def __init__(self, name, engine_name, major_engine_version, description=None): self.engine_name = engine_name @@ -642,6 +722,7 @@ class OptionGroup(object): self.vpc_and_non_vpc_instance_memberships = False self.options = {} self.vpcId = 'null' + self.tags = [] def to_json(self): template = Template("""{ @@ -663,6 +744,18 @@ class OptionGroup(object): # TODO: Validate option and add it to self.options. If invalid raise error return + def get_tags(self): + return self.tags + + def add_tags(self, tags): + new_keys = [tag_set['Key'] for tag_set in tags] + self.tags = [tag_set for tag_set in self.tags if tag_set['Key'] not in new_keys] + self.tags.extend(tags) + return self.tags + + def remove_tags(self, tag_keys): + self.tags = [tag_set for tag_set in self.tags if tag_set['Key'] not in tag_keys] + class OptionGroupOption(object): def __init__(self, engine_name, major_engine_version): @@ -687,6 +780,7 @@ class OptionGroupOption(object): }""") return template.render(option_group=self) + rds2_backends = {} for region in boto.rds2.regions(): rds2_backends[region.name] = RDS2Backend() diff --git a/moto/rds2/responses.py b/moto/rds2/responses.py index 21bec15d3..1428765f4 100644 --- a/moto/rds2/responses.py +++ b/moto/rds2/responses.py @@ -155,7 +155,6 @@ class RDS2Response(BaseResponse): template = self.response_template(ADD_TAGS_TO_RESOURCE_TEMPLATE) return template.render(tags=tags) - def remove_tags_from_resource(self): arn = self._get_param('ResourceName') tag_keys = self.unpack_list_params('TagKeys.member') diff --git a/tests/test_rds2/test_rds2.py b/tests/test_rds2/test_rds2.py index 07ee8aa5d..a78657c09 100644 --- a/tests/test_rds2/test_rds2.py +++ b/tests/test_rds2/test_rds2.py @@ -273,7 +273,7 @@ def test_list_tags_invalid_arn(): @disable_on_py3() @mock_rds2 -def test_list_tags(): +def test_list_tags_db(): conn = boto.rds2.connect_to_region("us-west-2") result = conn.list_tags_for_resource('arn:aws:rds:us-west-2:1234567890:db:foo') result['ListTagsForResourceResponse']['ListTagsForResourceResult']['TagList'].should.equal([]) @@ -294,7 +294,7 @@ def test_list_tags(): @disable_on_py3() @mock_rds2 -def test_add_tags(): +def test_add_tags_db(): conn = boto.rds2.connect_to_region("us-west-2") conn.create_db_instance(db_instance_identifier='db-without-tags', allocated_storage=10, @@ -314,7 +314,7 @@ def test_add_tags(): @disable_on_py3() @mock_rds2 -def test_remove_tags(): +def test_remove_tags_db(): conn = boto.rds2.connect_to_region("us-west-2") conn.create_db_instance(db_instance_identifier='db-with-tags', allocated_storage=10, @@ -424,8 +424,8 @@ def test_remove_tags(): # subnet_group.name.should.equal('db_subnet') # subnet_group.description.should.equal("my db subnet") # list(subnet_group.subnet_ids).should.equal(subnet_ids) -# -# + + @mock_ec2 @mock_rds2 def test_describe_database_subnet_group(): @@ -451,8 +451,8 @@ def test_describe_database_subnet_group(): list(conn.describe_db_subnet_groups("db_subnet1")).should.have.length_of(1) conn.describe_db_subnet_groups.when.called_with("not-a-subnet").should.throw(BotoServerError) -# -# + + #@mock_ec2 #@mock_rds2 #def test_delete_database_subnet_group():