extended the list/add/delete tags functions to support more resource types.

This commit is contained in:
Mike Fuller 2015-01-29 07:15:03 +11:00
parent 15fcec9c33
commit e42046aeda
3 changed files with 113 additions and 20 deletions

View File

@ -265,6 +265,7 @@ class SecurityGroup(object):
self.status = "authorized"
self.ip_ranges = []
self.ec2_security_groups = []
self.tags = []
def to_xml(self):
template = Template("""<DBSecurityGroup>
@ -323,6 +324,18 @@ class SecurityGroup(object):
security_group.authorize_security_group(subnet)
return security_group
def get_tags(self):
return self.tags
def add_tags(self, tags):
new_keys = [tag_set['Key'] for tag_set in tags]
self.tags = [tag_set for tag_set in self.tags if tag_set['Key'] not in new_keys]
self.tags.extend(tags)
return self.tags
def remove_tags(self, tag_keys):
self.tags = [tag_set for tag_set in self.tags if tag_set['Key'] not in tag_keys]
class SubnetGroup(object):
def __init__(self, subnet_name, description, subnets):
@ -330,7 +343,7 @@ class SubnetGroup(object):
self.description = description
self.subnets = subnets
self.status = "Complete"
self.tags = []
self.vpc_id = self.subnets[0].vpc_id
def to_xml(self):
@ -395,6 +408,18 @@ class SubnetGroup(object):
)
return subnet_group
def get_tags(self):
return self.tags
def add_tags(self, tags):
new_keys = [tag_set['Key'] for tag_set in tags]
self.tags = [tag_set for tag_set in self.tags if tag_set['Key'] not in new_keys]
self.tags.extend(tags)
return self.tags
def remove_tags(self, tag_keys):
self.tags = [tag_set for tag_set in self.tags if tag_set['Key'] not in tag_keys]
class RDS2Backend(BaseBackend):
@ -602,21 +627,58 @@ class RDS2Backend(BaseBackend):
def list_tags_for_resource(self, arn):
if self.arn_regex.match(arn):
arn_breakdown = arn.split(':')
db_instance_name = arn_breakdown[len(arn_breakdown)-1]
if db_instance_name in self.databases:
return self.databases[db_instance_name].get_tags()
else:
resource_type = arn_breakdown[len(arn_breakdown)-2]
resource_name = arn_breakdown[len(arn_breakdown)-1]
if resource_type == 'db': # Database
if resource_name in self.databases:
return self.databases[resource_name].get_tags()
elif resource_type == 'es': # Event Subscription
return []
elif resource_type == 'og': # Option Group
if resource_name in self.option_groups:
return self.option_groups[resource_name].get_tags()
elif resource_type == 'pg': # Parameter Group
return []
elif resource_type == 'ri': # Reserved DB instance
return []
elif resource_type == 'secgrp': # DB security group
if resource_type in self.security_groups:
return self.security_groups[resource_name].get_tags()
elif resource_type == 'snapshot': # DB Snapshot
return []
elif resource_type == 'subgrp': # DB subnet group
if resource_type in self.subnet_groups:
return self.subnet_groups[resource_name].get_tags()
else:
raise RDSClientError('InvalidParameterValue',
'Invalid resource name: {}'.format(arn))
return []
def remove_tags_from_resource(self, arn, tag_keys):
if self.arn_regex.match(arn):
arn_breakdown = arn.split(':')
db_instance_name = arn_breakdown[len(arn_breakdown)-1]
if db_instance_name in self.databases:
self.databases[db_instance_name].remove_tags(tag_keys)
resource_type = arn_breakdown[len(arn_breakdown)-2]
resource_name = arn_breakdown[len(arn_breakdown)-1]
if resource_type == 'db': # Database
if resource_name in self.databases:
self.databases[resource_name].remove_tags(tag_keys)
elif resource_type == 'es': # Event Subscription
return None
elif resource_type == 'og': # Option Group
if resource_name in self.option_groups:
return self.option_groups[resource_name].remove_tags(tag_keys)
elif resource_type == 'pg': # Parameter Group
return None
elif resource_type == 'ri': # Reserved DB instance
return None
elif resource_type == 'secgrp': # DB security group
if resource_type in self.security_groups:
return self.security_groups[resource_name].remove_tags(tag_keys)
elif resource_type == 'snapshot': # DB Snapshot
return None
elif resource_type == 'subgrp': # DB subnet group
if resource_type in self.subnet_groups:
return self.subnet_groups[resource_name].remove_tags(tag_keys)
else:
raise RDSClientError('InvalidParameterValue',
'Invalid resource name: {}'.format(arn))
@ -624,15 +686,33 @@ class RDS2Backend(BaseBackend):
def add_tags_to_resource(self, arn, tags):
if self.arn_regex.match(arn):
arn_breakdown = arn.split(':')
db_instance_name = arn_breakdown[len(arn_breakdown)-1]
if db_instance_name in self.databases:
return self.databases[db_instance_name].add_tags(tags)
else:
resource_type = arn_breakdown[len(arn_breakdown)-2]
resource_name = arn_breakdown[len(arn_breakdown)-1]
if resource_type == 'db': # Database
if resource_name in self.databases:
return self.databases[resource_name].add_tags(tags)
elif resource_type == 'es': # Event Subscription
return []
elif resource_type == 'og': # Option Group
if resource_name in self.option_groups:
return self.option_groups[resource_name].add_tags(tags)
elif resource_type == 'pg': # Parameter Group
return []
elif resource_type == 'ri': # Reserved DB instance
return []
elif resource_type == 'secgrp': # DB security group
if resource_type in self.security_groups:
return self.security_groups[resource_name].add_tags(tags)
elif resource_type == 'snapshot': # DB Snapshot
return []
elif resource_type == 'subgrp': # DB subnet group
if resource_type in self.subnet_groups:
return self.subnet_groups[resource_name].add_tags(tags)
else:
raise RDSClientError('InvalidParameterValue',
'Invalid resource name: {}'.format(arn))
class OptionGroup(object):
def __init__(self, name, engine_name, major_engine_version, description=None):
self.engine_name = engine_name
@ -642,6 +722,7 @@ class OptionGroup(object):
self.vpc_and_non_vpc_instance_memberships = False
self.options = {}
self.vpcId = 'null'
self.tags = []
def to_json(self):
template = Template("""{
@ -663,6 +744,18 @@ class OptionGroup(object):
# TODO: Validate option and add it to self.options. If invalid raise error
return
def get_tags(self):
return self.tags
def add_tags(self, tags):
new_keys = [tag_set['Key'] for tag_set in tags]
self.tags = [tag_set for tag_set in self.tags if tag_set['Key'] not in new_keys]
self.tags.extend(tags)
return self.tags
def remove_tags(self, tag_keys):
self.tags = [tag_set for tag_set in self.tags if tag_set['Key'] not in tag_keys]
class OptionGroupOption(object):
def __init__(self, engine_name, major_engine_version):
@ -687,6 +780,7 @@ class OptionGroupOption(object):
}""")
return template.render(option_group=self)
rds2_backends = {}
for region in boto.rds2.regions():
rds2_backends[region.name] = RDS2Backend()

View File

@ -155,7 +155,6 @@ class RDS2Response(BaseResponse):
template = self.response_template(ADD_TAGS_TO_RESOURCE_TEMPLATE)
return template.render(tags=tags)
def remove_tags_from_resource(self):
arn = self._get_param('ResourceName')
tag_keys = self.unpack_list_params('TagKeys.member')

View File

@ -273,7 +273,7 @@ def test_list_tags_invalid_arn():
@disable_on_py3()
@mock_rds2
def test_list_tags():
def test_list_tags_db():
conn = boto.rds2.connect_to_region("us-west-2")
result = conn.list_tags_for_resource('arn:aws:rds:us-west-2:1234567890:db:foo')
result['ListTagsForResourceResponse']['ListTagsForResourceResult']['TagList'].should.equal([])
@ -294,7 +294,7 @@ def test_list_tags():
@disable_on_py3()
@mock_rds2
def test_add_tags():
def test_add_tags_db():
conn = boto.rds2.connect_to_region("us-west-2")
conn.create_db_instance(db_instance_identifier='db-without-tags',
allocated_storage=10,
@ -314,7 +314,7 @@ def test_add_tags():
@disable_on_py3()
@mock_rds2
def test_remove_tags():
def test_remove_tags_db():
conn = boto.rds2.connect_to_region("us-west-2")
conn.create_db_instance(db_instance_identifier='db-with-tags',
allocated_storage=10,
@ -424,8 +424,8 @@ def test_remove_tags():
# subnet_group.name.should.equal('db_subnet')
# subnet_group.description.should.equal("my db subnet")
# list(subnet_group.subnet_ids).should.equal(subnet_ids)
#
#
@mock_ec2
@mock_rds2
def test_describe_database_subnet_group():
@ -451,8 +451,8 @@ def test_describe_database_subnet_group():
list(conn.describe_db_subnet_groups("db_subnet1")).should.have.length_of(1)
conn.describe_db_subnet_groups.when.called_with("not-a-subnet").should.throw(BotoServerError)
#
#
#@mock_ec2
#@mock_rds2
#def test_delete_database_subnet_group():