RDS - support for EnableCloudWatchLogExports-parameter (#5107)

This commit is contained in:
Bert Blommers 2022-05-09 08:40:16 +00:00 committed by GitHub
parent e911341e6a
commit 515243eab0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 37 additions and 11 deletions

View File

@ -4294,7 +4294,7 @@
- [X] describe_db_clusters - [X] describe_db_clusters
- [ ] describe_db_engine_versions - [ ] describe_db_engine_versions
- [ ] describe_db_instance_automated_backups - [ ] describe_db_instance_automated_backups
- [ ] describe_db_instances - [X] describe_db_instances
- [ ] describe_db_log_files - [ ] describe_db_log_files
- [X] describe_db_parameter_groups - [X] describe_db_parameter_groups
- [ ] describe_db_parameters - [ ] describe_db_parameters

View File

@ -82,7 +82,7 @@ rds
- [X] describe_db_clusters - [X] describe_db_clusters
- [ ] describe_db_engine_versions - [ ] describe_db_engine_versions
- [ ] describe_db_instance_automated_backups - [ ] describe_db_instance_automated_backups
- [ ] describe_db_instances - [X] describe_db_instances
- [ ] describe_db_log_files - [ ] describe_db_log_files
- [X] describe_db_parameter_groups - [X] describe_db_parameter_groups
- [ ] describe_db_parameters - [ ] describe_db_parameters

View File

@ -7,14 +7,14 @@ class RDSClientError(BadRequest):
super().__init__() super().__init__()
template = Template( template = Template(
""" """
<RDSClientError> <ErrorResponse>
<Error> <Error>
<Code>{{ code }}</Code> <Code>{{ code }}</Code>
<Message>{{ message }}</Message> <Message>{{ message }}</Message>
<Type>Sender</Type> <Type>Sender</Type>
</Error> </Error>
<RequestId>6876f774-7273-11e4-85dc-39e55ca848d1</RequestId> <RequestId>6876f774-7273-11e4-85dc-39e55ca848d1</RequestId>
</RDSClientError>""" </ErrorResponse>"""
) )
self.description = template.render(code=code, message=message) self.description = template.render(code=code, message=message)

View File

@ -110,6 +110,9 @@ class Cluster:
random.choice(string.ascii_uppercase + string.digits) for _ in range(26) random.choice(string.ascii_uppercase + string.digits) for _ in range(26)
) )
self.tags = kwargs.get("tags", []) self.tags = kwargs.get("tags", [])
self.enabled_cloudwatch_logs_exports = (
kwargs.get("enable_cloudwatch_logs_exports") or []
)
@property @property
def db_cluster_arn(self): def db_cluster_arn(self):
@ -172,6 +175,11 @@ class Cluster:
<CopyTagsToSnapshot>{{ cluster.copy_tags_to_snapshot }}</CopyTagsToSnapshot> <CopyTagsToSnapshot>{{ cluster.copy_tags_to_snapshot }}</CopyTagsToSnapshot>
<CrossAccountClone>false</CrossAccountClone> <CrossAccountClone>false</CrossAccountClone>
<DomainMemberships></DomainMemberships> <DomainMemberships></DomainMemberships>
<EnabledCloudwatchLogsExports>
{% for export in cluster.enabled_cloudwatch_logs_exports %}
<member>{{ export }}</member>
{% endfor %}
</EnabledCloudwatchLogsExports>
<TagList> <TagList>
{%- for tag in cluster.tags -%} {%- for tag in cluster.tags -%}
<Tag> <Tag>
@ -426,6 +434,9 @@ class Database(CloudFormationModel):
self.dbi_resource_id = "db-M5ENSHXFPU6XHZ4G4ZEI5QIO2U" self.dbi_resource_id = "db-M5ENSHXFPU6XHZ4G4ZEI5QIO2U"
self.tags = kwargs.get("tags", []) self.tags = kwargs.get("tags", [])
self.deletion_protection = kwargs.get("deletion_protection", False) self.deletion_protection = kwargs.get("deletion_protection", False)
self.enabled_cloudwatch_logs_exports = (
kwargs.get("enable_cloudwatch_logs_exports") or []
)
@property @property
def db_instance_arn(self): def db_instance_arn(self):
@ -516,6 +527,11 @@ class Database(CloudFormationModel):
</DBInstanceStatusInfo> </DBInstanceStatusInfo>
{% endif %} {% endif %}
</StatusInfos> </StatusInfos>
<EnabledCloudwatchLogsExports>
{% for export in database.enabled_cloudwatch_logs_exports %}
<member>{{ export }}</member>
{% endfor %}
</EnabledCloudwatchLogsExports>
{% if database.is_replica %} {% if database.is_replica %}
<ReadReplicaSourceDBInstanceIdentifier>{{ database.source_db_identifier }}</ReadReplicaSourceDBInstanceIdentifier> <ReadReplicaSourceDBInstanceIdentifier>{{ database.source_db_identifier }}</ReadReplicaSourceDBInstanceIdentifier>
{% endif %} {% endif %}
@ -1331,7 +1347,7 @@ class RDSBackend(BaseBackend):
primary.add_replica(replica) primary.add_replica(replica)
return replica return replica
def describe_databases(self, db_instance_identifier=None, filters=None): def describe_db_instances(self, db_instance_identifier=None, filters=None):
databases = self.databases databases = self.databases
if db_instance_identifier: if db_instance_identifier:
filters = merge_filters( filters = merge_filters(
@ -1362,7 +1378,7 @@ class RDSBackend(BaseBackend):
return list(snapshots.values()) return list(snapshots.values())
def modify_db_instance(self, db_instance_identifier, db_kwargs): def modify_db_instance(self, db_instance_identifier, db_kwargs):
database = self.describe_databases(db_instance_identifier)[0] database = self.describe_db_instances(db_instance_identifier)[0]
if "new_db_instance_identifier" in db_kwargs: if "new_db_instance_identifier" in db_kwargs:
del self.databases[db_instance_identifier] del self.databases[db_instance_identifier]
db_instance_identifier = db_kwargs[ db_instance_identifier = db_kwargs[
@ -1373,7 +1389,7 @@ class RDSBackend(BaseBackend):
return database return database
def reboot_db_instance(self, db_instance_identifier): def reboot_db_instance(self, db_instance_identifier):
database = self.describe_databases(db_instance_identifier)[0] database = self.describe_db_instances(db_instance_identifier)[0]
return database return database
def restore_db_instance_from_db_snapshot(self, from_snapshot_id, overrides): def restore_db_instance_from_db_snapshot(self, from_snapshot_id, overrides):
@ -1394,7 +1410,7 @@ class RDSBackend(BaseBackend):
return self.create_db_instance(new_instance_props) return self.create_db_instance(new_instance_props)
def stop_db_instance(self, db_instance_identifier, db_snapshot_identifier=None): def stop_db_instance(self, db_instance_identifier, db_snapshot_identifier=None):
database = self.describe_databases(db_instance_identifier)[0] database = self.describe_db_instances(db_instance_identifier)[0]
# todo: certain rds types not allowed to be stopped at this time. # todo: certain rds types not allowed to be stopped at this time.
# https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_StopInstance.html#USER_StopInstance.Limitations # https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_StopInstance.html#USER_StopInstance.Limitations
if database.is_replica or ( if database.is_replica or (
@ -1410,7 +1426,7 @@ class RDSBackend(BaseBackend):
return database return database
def start_db_instance(self, db_instance_identifier): def start_db_instance(self, db_instance_identifier):
database = self.describe_databases(db_instance_identifier)[0] database = self.describe_db_instances(db_instance_identifier)[0]
# todo: bunch of different error messages to be generated from this api call # todo: bunch of different error messages to be generated from this api call
if database.status != "stopped": if database.status != "stopped":
raise InvalidDBInstanceStateError(db_instance_identifier, "start") raise InvalidDBInstanceStateError(db_instance_identifier, "start")
@ -1427,7 +1443,7 @@ class RDSBackend(BaseBackend):
backend = self backend = self
db_name = db_id db_name = db_id
return backend.describe_databases(db_name)[0] return backend.describe_db_instances(db_name)[0]
def delete_db_instance(self, db_instance_identifier, db_snapshot_name=None): def delete_db_instance(self, db_instance_identifier, db_snapshot_name=None):
if db_instance_identifier in self.databases: if db_instance_identifier in self.databases:

View File

@ -27,6 +27,9 @@ class RDSResponse(BaseResponse):
"db_subnet_group_name": self._get_param("DBSubnetGroupName"), "db_subnet_group_name": self._get_param("DBSubnetGroupName"),
"engine": self._get_param("Engine"), "engine": self._get_param("Engine"),
"engine_version": self._get_param("EngineVersion"), "engine_version": self._get_param("EngineVersion"),
"enable_cloudwatch_logs_exports": self._get_params().get(
"EnableCloudwatchLogsExports"
),
"enable_iam_database_authentication": self._get_bool_param( "enable_iam_database_authentication": self._get_bool_param(
"EnableIAMDatabaseAuthentication" "EnableIAMDatabaseAuthentication"
), ),
@ -92,6 +95,9 @@ class RDSResponse(BaseResponse):
"availability_zones": self._get_multi_param( "availability_zones": self._get_multi_param(
"AvailabilityZones.AvailabilityZone" "AvailabilityZones.AvailabilityZone"
), ),
"enable_cloudwatch_logs_exports": self._get_params().get(
"EnableCloudwatchLogsExports"
),
"db_name": self._get_param("DatabaseName"), "db_name": self._get_param("DatabaseName"),
"db_cluster_identifier": self._get_param("DBClusterIdentifier"), "db_cluster_identifier": self._get_param("DBClusterIdentifier"),
"deletion_protection": self._get_bool_param("DeletionProtection"), "deletion_protection": self._get_bool_param("DeletionProtection"),
@ -174,7 +180,7 @@ class RDSResponse(BaseResponse):
filters = self._get_multi_param("Filters.Filter.") filters = self._get_multi_param("Filters.Filter.")
filters = {f["Name"]: f["Values"] for f in filters} filters = {f["Name"]: f["Values"] for f in filters}
all_instances = list( all_instances = list(
self.backend.describe_databases(db_instance_identifier, filters=filters) self.backend.describe_db_instances(db_instance_identifier, filters=filters)
) )
marker = self._get_param("Marker") marker = self._get_param("Marker")
all_ids = [instance.db_instance_identifier for instance in all_instances] all_ids = [instance.db_instance_identifier for instance in all_instances]

View File

@ -21,6 +21,7 @@ def test_create_database():
Port=1234, Port=1234,
DBSecurityGroups=["my_sg"], DBSecurityGroups=["my_sg"],
VpcSecurityGroupIds=["sg-123456"], VpcSecurityGroupIds=["sg-123456"],
EnableCloudwatchLogsExports=["audit", "error"],
) )
db_instance = database["DBInstance"] db_instance = database["DBInstance"]
db_instance["AllocatedStorage"].should.equal(10) db_instance["AllocatedStorage"].should.equal(10)
@ -40,6 +41,7 @@ def test_create_database():
db_instance["InstanceCreateTime"].should.be.a("datetime.datetime") db_instance["InstanceCreateTime"].should.be.a("datetime.datetime")
db_instance["VpcSecurityGroups"][0]["VpcSecurityGroupId"].should.equal("sg-123456") db_instance["VpcSecurityGroups"][0]["VpcSecurityGroupId"].should.equal("sg-123456")
db_instance["DeletionProtection"].should.equal(False) db_instance["DeletionProtection"].should.equal(False)
db_instance["EnabledCloudwatchLogsExports"].should.equal(["audit", "error"])
@mock_rds @mock_rds

View File

@ -157,6 +157,7 @@ def test_create_db_cluster_additional_parameters():
MasterUserPassword="hunter2_", MasterUserPassword="hunter2_",
Port=1234, Port=1234,
DeletionProtection=True, DeletionProtection=True,
EnableCloudwatchLogsExports=["audit"],
) )
cluster = resp["DBCluster"] cluster = resp["DBCluster"]
@ -167,6 +168,7 @@ def test_create_db_cluster_additional_parameters():
cluster.should.have.key("EngineMode").equal("serverless") cluster.should.have.key("EngineMode").equal("serverless")
cluster.should.have.key("Port").equal(1234) cluster.should.have.key("Port").equal(1234)
cluster.should.have.key("DeletionProtection").equal(True) cluster.should.have.key("DeletionProtection").equal(True)
cluster.should.have.key("EnabledCloudwatchLogsExports").equals(["audit"])
@mock_rds @mock_rds