extended the list/add/delete tags functions to support more resource types.

This commit is contained in:
Mike Fuller 2015-01-29 07:15:03 +11:00
parent 15fcec9c33
commit e42046aeda
3 changed files with 113 additions and 20 deletions

View File

@ -265,6 +265,7 @@ class SecurityGroup(object):
self.status = "authorized" self.status = "authorized"
self.ip_ranges = [] self.ip_ranges = []
self.ec2_security_groups = [] self.ec2_security_groups = []
self.tags = []
def to_xml(self): def to_xml(self):
template = Template("""<DBSecurityGroup> template = Template("""<DBSecurityGroup>
@ -323,6 +324,18 @@ class SecurityGroup(object):
security_group.authorize_security_group(subnet) security_group.authorize_security_group(subnet)
return security_group return security_group
def get_tags(self):
return self.tags
def add_tags(self, tags):
new_keys = [tag_set['Key'] for tag_set in tags]
self.tags = [tag_set for tag_set in self.tags if tag_set['Key'] not in new_keys]
self.tags.extend(tags)
return self.tags
def remove_tags(self, tag_keys):
self.tags = [tag_set for tag_set in self.tags if tag_set['Key'] not in tag_keys]
class SubnetGroup(object): class SubnetGroup(object):
def __init__(self, subnet_name, description, subnets): def __init__(self, subnet_name, description, subnets):
@ -330,7 +343,7 @@ class SubnetGroup(object):
self.description = description self.description = description
self.subnets = subnets self.subnets = subnets
self.status = "Complete" self.status = "Complete"
self.tags = []
self.vpc_id = self.subnets[0].vpc_id self.vpc_id = self.subnets[0].vpc_id
def to_xml(self): def to_xml(self):
@ -395,6 +408,18 @@ class SubnetGroup(object):
) )
return subnet_group return subnet_group
def get_tags(self):
return self.tags
def add_tags(self, tags):
new_keys = [tag_set['Key'] for tag_set in tags]
self.tags = [tag_set for tag_set in self.tags if tag_set['Key'] not in new_keys]
self.tags.extend(tags)
return self.tags
def remove_tags(self, tag_keys):
self.tags = [tag_set for tag_set in self.tags if tag_set['Key'] not in tag_keys]
class RDS2Backend(BaseBackend): class RDS2Backend(BaseBackend):
@ -602,21 +627,58 @@ class RDS2Backend(BaseBackend):
def list_tags_for_resource(self, arn): def list_tags_for_resource(self, arn):
if self.arn_regex.match(arn): if self.arn_regex.match(arn):
arn_breakdown = arn.split(':') arn_breakdown = arn.split(':')
db_instance_name = arn_breakdown[len(arn_breakdown)-1] resource_type = arn_breakdown[len(arn_breakdown)-2]
if db_instance_name in self.databases: resource_name = arn_breakdown[len(arn_breakdown)-1]
return self.databases[db_instance_name].get_tags() if resource_type == 'db': # Database
else: if resource_name in self.databases:
return self.databases[resource_name].get_tags()
elif resource_type == 'es': # Event Subscription
return [] return []
elif resource_type == 'og': # Option Group
if resource_name in self.option_groups:
return self.option_groups[resource_name].get_tags()
elif resource_type == 'pg': # Parameter Group
return []
elif resource_type == 'ri': # Reserved DB instance
return []
elif resource_type == 'secgrp': # DB security group
if resource_type in self.security_groups:
return self.security_groups[resource_name].get_tags()
elif resource_type == 'snapshot': # DB Snapshot
return []
elif resource_type == 'subgrp': # DB subnet group
if resource_type in self.subnet_groups:
return self.subnet_groups[resource_name].get_tags()
else: else:
raise RDSClientError('InvalidParameterValue', raise RDSClientError('InvalidParameterValue',
'Invalid resource name: {}'.format(arn)) 'Invalid resource name: {}'.format(arn))
return []
def remove_tags_from_resource(self, arn, tag_keys): def remove_tags_from_resource(self, arn, tag_keys):
if self.arn_regex.match(arn): if self.arn_regex.match(arn):
arn_breakdown = arn.split(':') arn_breakdown = arn.split(':')
db_instance_name = arn_breakdown[len(arn_breakdown)-1] resource_type = arn_breakdown[len(arn_breakdown)-2]
if db_instance_name in self.databases: resource_name = arn_breakdown[len(arn_breakdown)-1]
self.databases[db_instance_name].remove_tags(tag_keys) if resource_type == 'db': # Database
if resource_name in self.databases:
self.databases[resource_name].remove_tags(tag_keys)
elif resource_type == 'es': # Event Subscription
return None
elif resource_type == 'og': # Option Group
if resource_name in self.option_groups:
return self.option_groups[resource_name].remove_tags(tag_keys)
elif resource_type == 'pg': # Parameter Group
return None
elif resource_type == 'ri': # Reserved DB instance
return None
elif resource_type == 'secgrp': # DB security group
if resource_type in self.security_groups:
return self.security_groups[resource_name].remove_tags(tag_keys)
elif resource_type == 'snapshot': # DB Snapshot
return None
elif resource_type == 'subgrp': # DB subnet group
if resource_type in self.subnet_groups:
return self.subnet_groups[resource_name].remove_tags(tag_keys)
else: else:
raise RDSClientError('InvalidParameterValue', raise RDSClientError('InvalidParameterValue',
'Invalid resource name: {}'.format(arn)) 'Invalid resource name: {}'.format(arn))
@ -624,15 +686,33 @@ class RDS2Backend(BaseBackend):
def add_tags_to_resource(self, arn, tags): def add_tags_to_resource(self, arn, tags):
if self.arn_regex.match(arn): if self.arn_regex.match(arn):
arn_breakdown = arn.split(':') arn_breakdown = arn.split(':')
db_instance_name = arn_breakdown[len(arn_breakdown)-1] resource_type = arn_breakdown[len(arn_breakdown)-2]
if db_instance_name in self.databases: resource_name = arn_breakdown[len(arn_breakdown)-1]
return self.databases[db_instance_name].add_tags(tags) if resource_type == 'db': # Database
else: if resource_name in self.databases:
return self.databases[resource_name].add_tags(tags)
elif resource_type == 'es': # Event Subscription
return [] return []
elif resource_type == 'og': # Option Group
if resource_name in self.option_groups:
return self.option_groups[resource_name].add_tags(tags)
elif resource_type == 'pg': # Parameter Group
return []
elif resource_type == 'ri': # Reserved DB instance
return []
elif resource_type == 'secgrp': # DB security group
if resource_type in self.security_groups:
return self.security_groups[resource_name].add_tags(tags)
elif resource_type == 'snapshot': # DB Snapshot
return []
elif resource_type == 'subgrp': # DB subnet group
if resource_type in self.subnet_groups:
return self.subnet_groups[resource_name].add_tags(tags)
else: else:
raise RDSClientError('InvalidParameterValue', raise RDSClientError('InvalidParameterValue',
'Invalid resource name: {}'.format(arn)) 'Invalid resource name: {}'.format(arn))
class OptionGroup(object): class OptionGroup(object):
def __init__(self, name, engine_name, major_engine_version, description=None): def __init__(self, name, engine_name, major_engine_version, description=None):
self.engine_name = engine_name self.engine_name = engine_name
@ -642,6 +722,7 @@ class OptionGroup(object):
self.vpc_and_non_vpc_instance_memberships = False self.vpc_and_non_vpc_instance_memberships = False
self.options = {} self.options = {}
self.vpcId = 'null' self.vpcId = 'null'
self.tags = []
def to_json(self): def to_json(self):
template = Template("""{ template = Template("""{
@ -663,6 +744,18 @@ class OptionGroup(object):
# TODO: Validate option and add it to self.options. If invalid raise error # TODO: Validate option and add it to self.options. If invalid raise error
return return
def get_tags(self):
return self.tags
def add_tags(self, tags):
new_keys = [tag_set['Key'] for tag_set in tags]
self.tags = [tag_set for tag_set in self.tags if tag_set['Key'] not in new_keys]
self.tags.extend(tags)
return self.tags
def remove_tags(self, tag_keys):
self.tags = [tag_set for tag_set in self.tags if tag_set['Key'] not in tag_keys]
class OptionGroupOption(object): class OptionGroupOption(object):
def __init__(self, engine_name, major_engine_version): def __init__(self, engine_name, major_engine_version):
@ -687,6 +780,7 @@ class OptionGroupOption(object):
}""") }""")
return template.render(option_group=self) return template.render(option_group=self)
rds2_backends = {} rds2_backends = {}
for region in boto.rds2.regions(): for region in boto.rds2.regions():
rds2_backends[region.name] = RDS2Backend() rds2_backends[region.name] = RDS2Backend()

View File

@ -155,7 +155,6 @@ class RDS2Response(BaseResponse):
template = self.response_template(ADD_TAGS_TO_RESOURCE_TEMPLATE) template = self.response_template(ADD_TAGS_TO_RESOURCE_TEMPLATE)
return template.render(tags=tags) return template.render(tags=tags)
def remove_tags_from_resource(self): def remove_tags_from_resource(self):
arn = self._get_param('ResourceName') arn = self._get_param('ResourceName')
tag_keys = self.unpack_list_params('TagKeys.member') tag_keys = self.unpack_list_params('TagKeys.member')

View File

@ -273,7 +273,7 @@ def test_list_tags_invalid_arn():
@disable_on_py3() @disable_on_py3()
@mock_rds2 @mock_rds2
def test_list_tags(): def test_list_tags_db():
conn = boto.rds2.connect_to_region("us-west-2") conn = boto.rds2.connect_to_region("us-west-2")
result = conn.list_tags_for_resource('arn:aws:rds:us-west-2:1234567890:db:foo') result = conn.list_tags_for_resource('arn:aws:rds:us-west-2:1234567890:db:foo')
result['ListTagsForResourceResponse']['ListTagsForResourceResult']['TagList'].should.equal([]) result['ListTagsForResourceResponse']['ListTagsForResourceResult']['TagList'].should.equal([])
@ -294,7 +294,7 @@ def test_list_tags():
@disable_on_py3() @disable_on_py3()
@mock_rds2 @mock_rds2
def test_add_tags(): def test_add_tags_db():
conn = boto.rds2.connect_to_region("us-west-2") conn = boto.rds2.connect_to_region("us-west-2")
conn.create_db_instance(db_instance_identifier='db-without-tags', conn.create_db_instance(db_instance_identifier='db-without-tags',
allocated_storage=10, allocated_storage=10,
@ -314,7 +314,7 @@ def test_add_tags():
@disable_on_py3() @disable_on_py3()
@mock_rds2 @mock_rds2
def test_remove_tags(): def test_remove_tags_db():
conn = boto.rds2.connect_to_region("us-west-2") conn = boto.rds2.connect_to_region("us-west-2")
conn.create_db_instance(db_instance_identifier='db-with-tags', conn.create_db_instance(db_instance_identifier='db-with-tags',
allocated_storage=10, allocated_storage=10,
@ -424,8 +424,8 @@ def test_remove_tags():
# subnet_group.name.should.equal('db_subnet') # subnet_group.name.should.equal('db_subnet')
# subnet_group.description.should.equal("my db subnet") # subnet_group.description.should.equal("my db subnet")
# list(subnet_group.subnet_ids).should.equal(subnet_ids) # list(subnet_group.subnet_ids).should.equal(subnet_ids)
#
#
@mock_ec2 @mock_ec2
@mock_rds2 @mock_rds2
def test_describe_database_subnet_group(): def test_describe_database_subnet_group():
@ -451,8 +451,8 @@ def test_describe_database_subnet_group():
list(conn.describe_db_subnet_groups("db_subnet1")).should.have.length_of(1) list(conn.describe_db_subnet_groups("db_subnet1")).should.have.length_of(1)
conn.describe_db_subnet_groups.when.called_with("not-a-subnet").should.throw(BotoServerError) conn.describe_db_subnet_groups.when.called_with("not-a-subnet").should.throw(BotoServerError)
#
#
#@mock_ec2 #@mock_ec2
#@mock_rds2 #@mock_rds2
#def test_delete_database_subnet_group(): #def test_delete_database_subnet_group():