RDS - improve tagging support (#4994)

This commit is contained in:
Bert Blommers 2022-03-31 13:12:49 +00:00 committed by GitHub
parent e533b1a3ff
commit 15b49396ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 165 additions and 132 deletions

View File

@ -4196,7 +4196,7 @@
## rds
<details>
<summary>21% implemented</summary>
<summary>26% implemented</summary>
- [ ] add_role_to_db_cluster
- [ ] add_role_to_db_instance
@ -4217,13 +4217,13 @@
- [ ] create_db_cluster_endpoint
- [ ] create_db_cluster_parameter_group
- [X] create_db_cluster_snapshot
- [ ] create_db_instance
- [X] create_db_instance
- [ ] create_db_instance_read_replica
- [X] create_db_parameter_group
- [ ] create_db_proxy
- [ ] create_db_proxy_endpoint
- [ ] create_db_security_group
- [ ] create_db_snapshot
- [X] create_db_security_group
- [X] create_db_snapshot
- [ ] create_db_subnet_group
- [X] create_event_subscription
- [ ] create_global_cluster
@ -4234,13 +4234,13 @@
- [ ] delete_db_cluster_endpoint
- [ ] delete_db_cluster_parameter_group
- [X] delete_db_cluster_snapshot
- [ ] delete_db_instance
- [X] delete_db_instance
- [ ] delete_db_instance_automated_backup
- [X] delete_db_parameter_group
- [ ] delete_db_proxy
- [ ] delete_db_proxy_endpoint
- [ ] delete_db_security_group
- [ ] delete_db_snapshot
- [X] delete_db_snapshot
- [ ] delete_db_subnet_group
- [X] delete_event_subscription
- [ ] delete_global_cluster
@ -4299,7 +4299,7 @@
- [ ] modify_db_cluster_endpoint
- [ ] modify_db_cluster_parameter_group
- [ ] modify_db_cluster_snapshot_attribute
- [ ] modify_db_instance
- [X] modify_db_instance
- [X] modify_db_parameter_group
- [ ] modify_db_proxy
- [ ] modify_db_proxy_endpoint
@ -4332,12 +4332,12 @@
- [ ] revoke_db_security_group_ingress
- [ ] start_activity_stream
- [X] start_db_cluster
- [ ] start_db_instance
- [X] start_db_instance
- [ ] start_db_instance_automated_backups_replication
- [X] start_export_task
- [ ] stop_activity_stream
- [X] stop_db_cluster
- [ ] stop_db_instance
- [X] stop_db_instance
- [ ] stop_db_instance_automated_backups_replication
</details>

View File

@ -44,13 +44,13 @@ rds
- [ ] create_db_cluster_endpoint
- [ ] create_db_cluster_parameter_group
- [X] create_db_cluster_snapshot
- [ ] create_db_instance
- [X] create_db_instance
- [ ] create_db_instance_read_replica
- [X] create_db_parameter_group
- [ ] create_db_proxy
- [ ] create_db_proxy_endpoint
- [ ] create_db_security_group
- [ ] create_db_snapshot
- [X] create_db_security_group
- [X] create_db_snapshot
- [ ] create_db_subnet_group
- [X] create_event_subscription
- [ ] create_global_cluster
@ -61,13 +61,13 @@ rds
- [ ] delete_db_cluster_endpoint
- [ ] delete_db_cluster_parameter_group
- [X] delete_db_cluster_snapshot
- [ ] delete_db_instance
- [X] delete_db_instance
- [ ] delete_db_instance_automated_backup
- [X] delete_db_parameter_group
- [ ] delete_db_proxy
- [ ] delete_db_proxy_endpoint
- [ ] delete_db_security_group
- [ ] delete_db_snapshot
- [X] delete_db_snapshot
- [ ] delete_db_subnet_group
- [X] delete_event_subscription
- [ ] delete_global_cluster
@ -126,7 +126,7 @@ rds
- [ ] modify_db_cluster_endpoint
- [ ] modify_db_cluster_parameter_group
- [ ] modify_db_cluster_snapshot_attribute
- [ ] modify_db_instance
- [X] modify_db_instance
- [X] modify_db_parameter_group
- [ ] modify_db_proxy
- [ ] modify_db_proxy_endpoint
@ -159,11 +159,11 @@ rds
- [ ] revoke_db_security_group_ingress
- [ ] start_activity_stream
- [X] start_db_cluster
- [ ] start_db_instance
- [X] start_db_instance
- [ ] start_db_instance_automated_backups_replication
- [X] start_export_task
- [ ] stop_activity_stream
- [X] stop_db_cluster
- [ ] stop_db_instance
- [X] stop_db_instance
- [ ] stop_db_instance_automated_backups_replication

View File

@ -728,7 +728,7 @@ class Database(CloudFormationModel):
db_kwargs["source_db_identifier"] = source_db_identifier
database = rds_backend.create_database_replica(db_kwargs)
else:
database = rds_backend.create_database(db_kwargs)
database = rds_backend.create_db_instance(db_kwargs)
return database
def to_json(self):
@ -820,7 +820,7 @@ class Database(CloudFormationModel):
def delete(self, region_name):
backend = rds_backends[region_name]
backend.delete_database(self.db_instance_identifier)
backend.delete_db_instance(self.db_instance_identifier)
class DatabaseSnapshot(BaseModel):
@ -1094,7 +1094,7 @@ class SecurityGroup(CloudFormationModel):
ec2_backend = ec2_backends[region_name]
rds_backend = rds_backends[region_name]
security_group = rds_backend.create_security_group(
security_group = rds_backend.create_db_security_group(
group_name, description, tags
)
for security_group_ingress in security_group_ingress_rules:
@ -1257,13 +1257,13 @@ class RDSBackend(BaseBackend):
service_region, zones, "rds-data"
)
def create_database(self, db_kwargs):
def create_db_instance(self, db_kwargs):
database_id = db_kwargs["db_instance_identifier"]
database = Database(**db_kwargs)
self.databases[database_id] = database
return database
def create_database_snapshot(
def create_db_snapshot(
self, db_instance_identifier, db_snapshot_identifier, tags=None
):
database = self.databases.get(db_instance_identifier)
@ -1307,7 +1307,7 @@ class RDSBackend(BaseBackend):
return target_snapshot
def delete_database_snapshot(self, db_snapshot_identifier):
def delete_db_snapshot(self, db_snapshot_identifier):
if db_snapshot_identifier not in self.database_snapshots:
raise DBSnapshotNotFoundError(db_snapshot_identifier)
@ -1358,7 +1358,7 @@ class RDSBackend(BaseBackend):
raise DBSnapshotNotFoundError(db_snapshot_identifier)
return list(snapshots.values())
def modify_database(self, db_instance_identifier, db_kwargs):
def modify_db_instance(self, db_instance_identifier, db_kwargs):
database = self.describe_databases(db_instance_identifier)[0]
if "new_db_instance_identifier" in db_kwargs:
del self.databases[db_instance_identifier]
@ -1388,9 +1388,9 @@ class RDSBackend(BaseBackend):
if value:
new_instance_props[key] = value
return self.create_database(new_instance_props)
return self.create_db_instance(new_instance_props)
def stop_database(self, db_instance_identifier, db_snapshot_identifier=None):
def stop_db_instance(self, db_instance_identifier, db_snapshot_identifier=None):
database = self.describe_databases(db_instance_identifier)[0]
# todo: certain rds types not allowed to be stopped at this time.
# https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_StopInstance.html#USER_StopInstance.Limitations
@ -1402,13 +1402,11 @@ class RDSBackend(BaseBackend):
if database.status != "available":
raise InvalidDBInstanceStateError(db_instance_identifier, "stop")
if db_snapshot_identifier:
self.create_database_snapshot(
db_instance_identifier, db_snapshot_identifier
)
self.create_db_snapshot(db_instance_identifier, db_snapshot_identifier)
database.status = "stopped"
return database
def start_database(self, db_instance_identifier):
def start_db_instance(self, db_instance_identifier):
database = self.describe_databases(db_instance_identifier)[0]
# todo: bunch of different error messages to be generated from this api call
if database.status != "stopped":
@ -1428,14 +1426,14 @@ class RDSBackend(BaseBackend):
return backend.describe_databases(db_name)[0]
def delete_database(self, db_instance_identifier, db_snapshot_name=None):
def delete_db_instance(self, db_instance_identifier, db_snapshot_name=None):
if db_instance_identifier in self.databases:
if self.databases[db_instance_identifier].deletion_protection:
raise InvalidParameterValue(
"Can't delete Instance with protection enabled"
)
if db_snapshot_name:
self.create_database_snapshot(db_instance_identifier, db_snapshot_name)
self.create_db_snapshot(db_instance_identifier, db_snapshot_name)
database = self.databases.pop(db_instance_identifier)
if database.is_replica:
primary = self.find_db_from_id(database.source_db_identifier)
@ -1445,7 +1443,7 @@ class RDSBackend(BaseBackend):
else:
raise DBInstanceNotFoundError(db_instance_identifier)
def create_security_group(self, group_name, description, tags):
def create_db_security_group(self, group_name, description, tags):
security_group = SecurityGroup(group_name, description, tags)
self.security_groups[group_name] = security_group
return security_group
@ -1985,6 +1983,9 @@ class RDSBackend(BaseBackend):
elif resource_type == "snapshot": # DB Snapshot
if resource_name in self.database_snapshots:
return self.database_snapshots[resource_name].remove_tags(tag_keys)
elif resource_type == "cluster":
if resource_name in self.clusters:
return self.clusters[resource_name].remove_tags(tag_keys)
elif resource_type == "cluster-snapshot": # DB Cluster Snapshot
if resource_name in self.cluster_snapshots:
return self.cluster_snapshots[resource_name].remove_tags(tag_keys)
@ -1999,8 +2000,8 @@ class RDSBackend(BaseBackend):
def add_tags_to_resource(self, arn, tags):
if self.arn_regex.match(arn):
arn_breakdown = arn.split(":")
resource_type = arn_breakdown[len(arn_breakdown) - 2]
resource_name = arn_breakdown[len(arn_breakdown) - 1]
resource_type = arn_breakdown[-2]
resource_name = arn_breakdown[-1]
if resource_type == "db": # Database
if resource_name in self.databases:
return self.databases[resource_name].add_tags(tags)
@ -2021,6 +2022,9 @@ class RDSBackend(BaseBackend):
elif resource_type == "snapshot": # DB Snapshot
if resource_name in self.database_snapshots:
return self.database_snapshots[resource_name].add_tags(tags)
elif resource_type == "cluster":
if resource_name in self.clusters:
return self.clusters[resource_name].add_tags(tags)
elif resource_type == "cluster-snapshot": # DB Cluster Snapshot
if resource_name in self.cluster_snapshots:
return self.cluster_snapshots[resource_name].add_tags(tags)
@ -2113,95 +2117,6 @@ class OptionGroup(object):
self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys]
class OptionGroupOption(object):
def __init__(self, **kwargs):
self.default_port = kwargs.get("default_port")
self.description = kwargs.get("description")
self.engine_name = kwargs.get("engine_name")
self.major_engine_version = kwargs.get("major_engine_version")
self.name = kwargs.get("name")
self.option_group_option_settings = self._make_option_group_option_settings(
kwargs.get("option_group_option_settings", [])
)
self.options_depended_on = kwargs.get("options_depended_on", [])
self.permanent = kwargs.get("permanent")
self.persistent = kwargs.get("persistent")
self.port_required = kwargs.get("port_required")
def _make_option_group_option_settings(self, option_group_option_settings_kwargs):
return [
OptionGroupOptionSetting(**setting_kwargs)
for setting_kwargs in option_group_option_settings_kwargs
]
def to_json(self):
template = Template(
"""{ "MinimumRequiredMinorEngineVersion":
"2789.0.v1",
"OptionsDependedOn": [],
"MajorEngineVersion": "10.50",
"Persistent": false,
"DefaultPort": null,
"Permanent": false,
"OptionGroupOptionSettings": [],
"EngineName": "sqlserver-se",
"Name": "Mirroring",
"PortRequired": false,
"Description": "SQLServer Database Mirroring"
}"""
)
return template.render(option_group=self)
def to_xml(self):
template = Template(
"""<OptionGroupOption>
<MajorEngineVersion>{{ option_group.major_engine_version }}</MajorEngineVersion>
<DefaultPort>{{ option_group.default_port }}</DefaultPort>
<PortRequired>{{ option_group.port_required }}</PortRequired>
<Persistent>{{ option_group.persistent }}</Persistent>
<OptionsDependedOn>
{%- for option_name in option_group.options_depended_on -%}
<OptionName>{{ option_name }}</OptionName>
{%- endfor -%}
</OptionsDependedOn>
<Permanent>{{ option_group.permanent }}</Permanent>
<Description>{{ option_group.description }}</Description>
<Name>{{ option_group.name }}</Name>
<OptionGroupOptionSettings>
{%- for setting in option_group.option_group_option_settings -%}
{{ setting.to_xml() }}
{%- endfor -%}
</OptionGroupOptionSettings>
<EngineName>{{ option_group.engine_name }}</EngineName>
<MinimumRequiredMinorEngineVersion>{{ option_group.minimum_required_minor_engine_version }}</MinimumRequiredMinorEngineVersion>
</OptionGroupOption>"""
)
return template.render(option_group=self)
class OptionGroupOptionSetting(object):
def __init__(self, *kwargs):
self.allowed_values = kwargs.get("allowed_values")
self.apply_type = kwargs.get("apply_type")
self.default_value = kwargs.get("default_value")
self.is_modifiable = kwargs.get("is_modifiable")
self.setting_description = kwargs.get("setting_description")
self.setting_name = kwargs.get("setting_name")
def to_xml(self):
template = Template(
"""<OptionGroupOptionSetting>
<AllowedValues>{{ option_group_option_setting.allowed_values }}</AllowedValues>
<ApplyType>{{ option_group_option_setting.apply_type }}</ApplyType>
<DefaultValue>{{ option_group_option_setting.default_value }}</DefaultValue>
<IsModifiable>{{ option_group_option_setting.is_modifiable }}</IsModifiable>
<SettingDescription>{{ option_group_option_setting.setting_description }}</SettingDescription>
<SettingName>{{ option_group_option_setting.setting_name }}</SettingName>
</OptionGroupOptionSetting>"""
)
return template.render(option_group_option_setting=self)
def make_rds_arn(region, name):
return "arn:aws:rds:{0}:{1}:pg:{2}".format(region, ACCOUNT_ID, name)

View File

@ -157,7 +157,7 @@ class RDSResponse(BaseResponse):
def create_db_instance(self):
db_kwargs = self._get_db_kwargs()
database = self.backend.create_database(db_kwargs)
database = self.backend.create_db_instance(db_kwargs)
template = self.response_template(CREATE_DATABASE_TEMPLATE)
return template.render(database=database)
@ -198,14 +198,14 @@ class RDSResponse(BaseResponse):
new_db_instance_identifier = self._get_param("NewDBInstanceIdentifier")
if new_db_instance_identifier:
db_kwargs["new_db_instance_identifier"] = new_db_instance_identifier
database = self.backend.modify_database(db_instance_identifier, db_kwargs)
database = self.backend.modify_db_instance(db_instance_identifier, db_kwargs)
template = self.response_template(MODIFY_DATABASE_TEMPLATE)
return template.render(database=database)
def delete_db_instance(self):
db_instance_identifier = self._get_param("DBInstanceIdentifier")
db_snapshot_name = self._get_param("FinalDBSnapshotIdentifier")
database = self.backend.delete_database(
database = self.backend.delete_db_instance(
db_instance_identifier, db_snapshot_name
)
template = self.response_template(DELETE_DATABASE_TEMPLATE)
@ -221,7 +221,7 @@ class RDSResponse(BaseResponse):
db_instance_identifier = self._get_param("DBInstanceIdentifier")
db_snapshot_identifier = self._get_param("DBSnapshotIdentifier")
tags = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value"))
snapshot = self.backend.create_database_snapshot(
snapshot = self.backend.create_db_snapshot(
db_instance_identifier, db_snapshot_identifier, tags
)
template = self.response_template(CREATE_SNAPSHOT_TEMPLATE)
@ -250,7 +250,7 @@ class RDSResponse(BaseResponse):
def delete_db_snapshot(self):
db_snapshot_identifier = self._get_param("DBSnapshotIdentifier")
snapshot = self.backend.delete_database_snapshot(db_snapshot_identifier)
snapshot = self.backend.delete_db_snapshot(db_snapshot_identifier)
template = self.response_template(DELETE_SNAPSHOT_TEMPLATE)
return template.render(snapshot=snapshot)
@ -286,7 +286,7 @@ class RDSResponse(BaseResponse):
def stop_db_instance(self):
db_instance_identifier = self._get_param("DBInstanceIdentifier")
db_snapshot_identifier = self._get_param("DBSnapshotIdentifier")
database = self.backend.stop_database(
database = self.backend.stop_db_instance(
db_instance_identifier, db_snapshot_identifier
)
template = self.response_template(STOP_DATABASE_TEMPLATE)
@ -294,7 +294,7 @@ class RDSResponse(BaseResponse):
def start_db_instance(self):
db_instance_identifier = self._get_param("DBInstanceIdentifier")
database = self.backend.start_database(db_instance_identifier)
database = self.backend.start_db_instance(db_instance_identifier)
template = self.response_template(START_DATABASE_TEMPLATE)
return template.render(database=database)
@ -302,7 +302,7 @@ class RDSResponse(BaseResponse):
group_name = self._get_param("DBSecurityGroupName")
description = self._get_param("DBSecurityGroupDescription")
tags = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value"))
security_group = self.backend.create_security_group(
security_group = self.backend.create_db_security_group(
group_name, description, tags
)
template = self.response_template(CREATE_SECURITY_GROUP_TEMPLATE)

View File

@ -68,6 +68,10 @@ def test_create_dbinstance_via_cf():
},
}
},
"Outputs": {
"db_address": {"Value": {"Fn::GetAtt": ["db", "Endpoint.Address"]}},
"db_port": {"Value": {"Fn::GetAtt": ["db", "Endpoint.Port"]}},
},
}
template_json = json.dumps(template)
cf.create_stack(StackName="test_stack", TemplateBody=template_json)
@ -85,6 +89,13 @@ def test_create_dbinstance_via_cf():
created["Engine"].should.equal("mysql")
created["DBInstanceStatus"].should.equal("available")
# Verify the stack outputs are correct
o = _get_stack_outputs(cf, stack_name="test_stack")
o.should.have.key("db_address").equals(
f"{db_instance_identifier}.aaaaaaaaaa.us-west-2.rds.amazonaws.com"
)
o.should.have.key("db_port").equals("3307")
@mock_ec2
@mock_rds
@ -248,3 +259,49 @@ def test_rds_mysql_with_read_replica_in_vpc():
"DBSubnetGroups"
][0]
subnet_group.should.have.key("DBSubnetGroupDescription").equal("my db subnet group")
@mock_ec2
@mock_rds
@mock_cloudformation
def test_delete_dbinstance_via_cf():
vpc_conn = boto3.client("ec2", "us-west-2")
vpc = vpc_conn.create_vpc(CidrBlock="10.0.0.0/16")["Vpc"]
vpc_conn.create_subnet(VpcId=vpc["VpcId"], CidrBlock="10.0.1.0/24")
rds = boto3.client("rds", region_name="us-west-2")
cf = boto3.client("cloudformation", region_name="us-west-2")
template = {
"AWSTemplateFormatVersion": "2010-09-09",
"Resources": {
"db": {
"Type": "AWS::RDS::DBInstance",
"Properties": {
"Port": 3307,
"Engine": "mysql",
# Required - throws exception when describing an instance without tags
"Tags": [],
},
}
},
}
template_json = json.dumps(template)
cf.create_stack(StackName="test_stack", TemplateBody=template_json)
resp = rds.describe_db_instances()["DBInstances"]
resp.should.have.length_of(1)
cf.delete_stack(StackName="test_stack")
resp = rds.describe_db_instances()["DBInstances"]
resp.should.have.length_of(0)
def _get_stack_outputs(cf_client, stack_name):
"""Returns the outputs for the first entry in describe_stacks."""
stack_description = cf_client.describe_stacks(StackName=stack_name)["Stacks"][0]
return {
output["OutputKey"]: output["OutputValue"]
for output in stack_description["Outputs"]
}

View File

@ -620,3 +620,64 @@ def test_restore_db_cluster_from_snapshot_and_override_params():
new_cluster["DBClusterParameterGroup"].should.equal("default.aurora8.0")
new_cluster["DBClusterInstanceClass"].should.equal("db.r6g.xlarge")
new_cluster["Port"].should.equal(10000)
@mock_rds
def test_add_tags_to_cluster():
conn = boto3.client("rds", region_name="us-west-2")
resp = conn.create_db_cluster(
DBClusterIdentifier="db-primary-1",
AllocatedStorage=10,
Engine="postgres",
DatabaseName="staging-postgres",
DBClusterInstanceClass="db.m1.small",
MasterUsername="root",
MasterUserPassword="hunter2000",
Port=1234,
Tags=[{"Key": "k1", "Value": "v1"}],
)
cluster_arn = resp["DBCluster"]["DBClusterArn"]
conn.add_tags_to_resource(
ResourceName=cluster_arn, Tags=[{"Key": "k2", "Value": "v2"}]
)
tags = conn.list_tags_for_resource(ResourceName=cluster_arn)["TagList"]
tags.should.equal([{"Key": "k1", "Value": "v1"}, {"Key": "k2", "Value": "v2"}])
conn.remove_tags_from_resource(ResourceName=cluster_arn, TagKeys=["k1"])
tags = conn.list_tags_for_resource(ResourceName=cluster_arn)["TagList"]
tags.should.equal([{"Key": "k2", "Value": "v2"}])
@mock_rds
def test_add_tags_to_cluster_snapshot():
conn = boto3.client("rds", region_name="us-west-2")
conn.create_db_cluster(
DBClusterIdentifier="db-primary-1",
AllocatedStorage=10,
Engine="postgres",
DatabaseName="staging-postgres",
DBClusterInstanceClass="db.m1.small",
MasterUsername="root",
MasterUserPassword="hunter2000",
Port=1234,
Tags=[{"Key": "k1", "Value": "v1"}],
)
resp = conn.create_db_cluster_snapshot(
DBClusterIdentifier="db-primary-1", DBClusterSnapshotIdentifier="snapshot-1"
)
snapshot_arn = resp["DBClusterSnapshot"]["DBClusterSnapshotArn"]
conn.add_tags_to_resource(
ResourceName=snapshot_arn, Tags=[{"Key": "k2", "Value": "v2"}]
)
tags = conn.list_tags_for_resource(ResourceName=snapshot_arn)["TagList"]
tags.should.equal([{"Key": "k1", "Value": "v1"}, {"Key": "k2", "Value": "v2"}])
conn.remove_tags_from_resource(ResourceName=snapshot_arn, TagKeys=["k1"])
tags = conn.list_tags_for_resource(ResourceName=snapshot_arn)["TagList"]
tags.should.equal([{"Key": "k2", "Value": "v2"}])