Add RDS modify_db_cluster() (#5550)

This commit is contained in:
Marshall Mamiya 2022-10-11 07:51:17 -07:00 committed by GitHub
parent 56ca48cfdd
commit 97011ba19b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 175 additions and 21 deletions

View File

@ -4806,7 +4806,7 @@
- [ ] modify_certificates - [ ] modify_certificates
- [ ] modify_current_db_cluster_capacity - [ ] modify_current_db_cluster_capacity
- [ ] modify_custom_db_engine_version - [ ] modify_custom_db_engine_version
- [ ] modify_db_cluster - [X] modify_db_cluster
- [ ] modify_db_cluster_endpoint - [ ] modify_db_cluster_endpoint
- [ ] modify_db_cluster_parameter_group - [ ] modify_db_cluster_parameter_group
- [ ] modify_db_cluster_snapshot_attribute - [ ] modify_db_cluster_snapshot_attribute

View File

@ -73,14 +73,7 @@ class Cluster:
"The parameter MasterUsername must be provided and must not be blank." "The parameter MasterUsername must be provided and must not be blank."
) )
self.master_user_password = kwargs.get("master_user_password") self.master_user_password = kwargs.get("master_user_password")
if not self.master_user_password:
raise InvalidParameterValue(
"The parameter MasterUserPassword must be provided and must not be blank."
)
if len(self.master_user_password) < 8:
raise InvalidParameterValue(
"The parameter MasterUserPassword is not a valid password because it is shorter than 8 characters."
)
self.availability_zones = kwargs.get("availability_zones") self.availability_zones = kwargs.get("availability_zones")
if not self.availability_zones: if not self.availability_zones:
self.availability_zones = [ self.availability_zones = [
@ -113,12 +106,40 @@ class Cluster:
self.enabled_cloudwatch_logs_exports = ( self.enabled_cloudwatch_logs_exports = (
kwargs.get("enable_cloudwatch_logs_exports") or [] kwargs.get("enable_cloudwatch_logs_exports") or []
) )
self.enable_http_endpoint = False self.enable_http_endpoint = kwargs.get("enable_http_endpoint")
@property
def db_cluster_arn(self):
return f"arn:aws:rds:{self.region_name}:{self.account_id}:cluster:{self.db_cluster_identifier}"
@property
def master_user_password(self):
return self._master_user_password
@master_user_password.setter
def master_user_password(self, val):
if not val:
raise InvalidParameterValue(
"The parameter MasterUserPassword must be provided and must not be blank."
)
if len(val) < 8:
raise InvalidParameterValue(
"The parameter MasterUserPassword is not a valid password because it is shorter than 8 characters."
)
self._master_user_password = val
@property
def enable_http_endpoint(self):
return self._enable_http_endpoint
@enable_http_endpoint.setter
def enable_http_endpoint(self, val):
# instead of raising an error on aws rds create-db-cluster commands with # instead of raising an error on aws rds create-db-cluster commands with
# incompatible configurations with enable_http_endpoint # incompatible configurations with enable_http_endpoint
# (e.g. engine_mode is not set to "serverless"), the API # (e.g. engine_mode is not set to "serverless"), the API
# automatically sets the enable_http_endpoint parameter to False # automatically sets the enable_http_endpoint parameter to False
if kwargs.get("enable_http_endpoint"): self._enable_http_endpoint = False
if val is not None:
if self.engine_mode == "serverless": if self.engine_mode == "serverless":
if self.engine == "aurora-mysql" and self.engine_version in [ if self.engine == "aurora-mysql" and self.engine_version in [
"5.6.10a", "5.6.10a",
@ -126,22 +147,20 @@ class Cluster:
"2.07.1", "2.07.1",
"5.7.2", "5.7.2",
]: ]:
self.enable_http_endpoint = kwargs.get( self._enable_http_endpoint = val
"enable_http_endpoint", False
)
elif self.engine == "aurora-postgresql" and self.engine_version in [ elif self.engine == "aurora-postgresql" and self.engine_version in [
"10.12", "10.12",
"10.14", "10.14",
"10.18", "10.18",
"11.13", "11.13",
]: ]:
self.enable_http_endpoint = kwargs.get( self._enable_http_endpoint = val
"enable_http_endpoint", False
)
@property def get_cfg(self):
def db_cluster_arn(self): cfg = self.__dict__
return f"arn:aws:rds:{self.region_name}:{self.account_id}:cluster:{self.db_cluster_identifier}" cfg["master_user_password"] = cfg.pop("_master_user_password")
cfg["enable_http_endpoint"] = cfg.pop("_enable_http_endpoint")
return cfg
def to_xml(self): def to_xml(self):
template = Template( template = Template(
@ -1794,6 +1813,24 @@ class RDSBackend(BaseBackend):
cluster.status = "available" # Already set the final status in the background cluster.status = "available" # Already set the final status in the background
return initial_state return initial_state
def modify_db_cluster(self, kwargs):
cluster_id = kwargs["db_cluster_identifier"]
cluster = self.clusters[cluster_id]
del self.clusters[cluster_id]
kwargs["db_cluster_identifier"] = kwargs.pop("new_db_cluster_identifier")
for k, v in kwargs.items():
if v is not None:
setattr(cluster, k, v)
cluster_id = kwargs.get("new_db_cluster_identifier", cluster_id)
self.clusters[cluster_id] = cluster
initial_state = copy.deepcopy(cluster) # Return status=creating
cluster.status = "available" # Already set the final status in the background
return initial_state
def create_db_cluster_snapshot( def create_db_cluster_snapshot(
self, db_cluster_identifier, db_snapshot_identifier, tags=None self, db_cluster_identifier, db_snapshot_identifier, tags=None
): ):
@ -1894,7 +1931,7 @@ class RDSBackend(BaseBackend):
db_cluster_identifier=None, db_snapshot_identifier=from_snapshot_id db_cluster_identifier=None, db_snapshot_identifier=from_snapshot_id
)[0] )[0]
original_cluster = snapshot.cluster original_cluster = snapshot.cluster
new_cluster_props = copy.deepcopy(original_cluster.__dict__) new_cluster_props = copy.deepcopy(original_cluster.get_cfg())
for key, value in overrides.items(): for key, value in overrides.items():
if value: if value:
new_cluster_props[key] = value new_cluster_props[key] = value

View File

@ -63,6 +63,56 @@ class RDSResponse(BaseResponse):
args["tags"] = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")) args["tags"] = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value"))
return args return args
def _get_modify_db_cluster_kwargs(self):
args = {
"auto_minor_version_upgrade": self._get_param("AutoMinorVersionUpgrade"),
"allocated_storage": self._get_int_param("AllocatedStorage"),
"availability_zone": self._get_param("AvailabilityZone"),
"backup_retention_period": self._get_param("BackupRetentionPeriod"),
"copy_tags_to_snapshot": self._get_param("CopyTagsToSnapshot"),
"db_instance_class": self._get_param("DBInstanceClass"),
"db_cluster_identifier": self._get_param("DBClusterIdentifier"),
"new_db_cluster_identifier": self._get_param("NewDBClusterIdentifier"),
"db_instance_identifier": self._get_param("DBInstanceIdentifier"),
"db_name": self._get_param("DBName"),
"db_parameter_group_name": self._get_param("DBParameterGroupName"),
"db_snapshot_identifier": self._get_param("DBSnapshotIdentifier"),
"db_subnet_group_name": self._get_param("DBSubnetGroupName"),
"engine": self._get_param("Engine"),
"engine_version": self._get_param("EngineVersion"),
"enable_cloudwatch_logs_exports": self._get_params().get(
"EnableCloudwatchLogsExports"
),
"enable_iam_database_authentication": self._get_bool_param(
"EnableIAMDatabaseAuthentication"
),
"license_model": self._get_param("LicenseModel"),
"iops": self._get_int_param("Iops"),
"kms_key_id": self._get_param("KmsKeyId"),
"master_user_password": self._get_param("MasterUserPassword"),
"master_username": self._get_param("MasterUsername"),
"multi_az": self._get_bool_param("MultiAZ"),
"option_group_name": self._get_param("OptionGroupName"),
"port": self._get_param("Port"),
# PreferredBackupWindow
# PreferredMaintenanceWindow
"publicly_accessible": self._get_param("PubliclyAccessible"),
"account_id": self.current_account,
"region": self.region,
"security_groups": self._get_multi_param(
"DBSecurityGroups.DBSecurityGroupName"
),
"storage_encrypted": self._get_param("StorageEncrypted"),
"storage_type": self._get_param("StorageType", None),
"vpc_security_group_ids": self._get_multi_param(
"VpcSecurityGroupIds.VpcSecurityGroupId"
),
"tags": list(),
"deletion_protection": self._get_bool_param("DeletionProtection"),
}
args["tags"] = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value"))
return args
def _get_db_replica_kwargs(self): def _get_db_replica_kwargs(self):
return { return {
"auto_minor_version_upgrade": self._get_param("AutoMinorVersionUpgrade"), "auto_minor_version_upgrade": self._get_param("AutoMinorVersionUpgrade"),
@ -515,6 +565,12 @@ class RDSResponse(BaseResponse):
template = self.response_template(CREATE_DB_CLUSTER_TEMPLATE) template = self.response_template(CREATE_DB_CLUSTER_TEMPLATE)
return template.render(cluster=cluster) return template.render(cluster=cluster)
def modify_db_cluster(self):
kwargs = self._get_modify_db_cluster_kwargs()
cluster = self.backend.modify_db_cluster(kwargs)
template = self.response_template(MODIFY_DB_CLUSTER_TEMPLATE)
return template.render(cluster=cluster)
def describe_db_clusters(self): def describe_db_clusters(self):
_id = self._get_param("DBClusterIdentifier") _id = self._get_param("DBClusterIdentifier")
clusters = self.backend.describe_db_clusters(cluster_identifier=_id) clusters = self.backend.describe_db_clusters(cluster_identifier=_id)
@ -988,6 +1044,15 @@ CREATE_DB_CLUSTER_TEMPLATE = """<CreateDBClusterResponse xmlns="http://rds.amazo
</ResponseMetadata> </ResponseMetadata>
</CreateDBClusterResponse>""" </CreateDBClusterResponse>"""
MODIFY_DB_CLUSTER_TEMPLATE = """<ModifyDBClusterResponse xmlns="http://rds.amazonaws.com/doc/2014-10-31/">
<ModifyDBClusterResult>
{{ cluster.to_xml() }}
</ModifyDBClusterResult>
<ResponseMetadata>
<RequestId>69673d54-e48e-4ba4-9333-c5a6c1e7526a</RequestId>
</ResponseMetadata>
</ModifyDBClusterResponse>"""
DESCRIBE_CLUSTERS_TEMPLATE = """<DescribeDBClustersResponse xmlns="http://rds.amazonaws.com/doc/2014-09-01/"> DESCRIBE_CLUSTERS_TEMPLATE = """<DescribeDBClustersResponse xmlns="http://rds.amazonaws.com/doc/2014-09-01/">
<DescribeDBClustersResult> <DescribeDBClustersResult>
<DBClusters> <DBClusters>

View File

@ -74,6 +74,58 @@ def test_create_db_cluster_needs_long_master_user_password():
) )
@mock_rds
def test_modify_db_cluster_needs_long_master_user_password():
client = boto3.client("rds", region_name="eu-north-1")
client.create_db_cluster(
DBClusterIdentifier="cluster-id",
Engine="aurora",
MasterUsername="root",
MasterUserPassword="hunter21",
)
with pytest.raises(ClientError) as ex:
client.modify_db_cluster(
DBClusterIdentifier="cluster-id",
MasterUserPassword="hunter2",
)
err = ex.value.response["Error"]
err["Code"].should.equal("InvalidParameterValue")
err["Message"].should.equal(
"The parameter MasterUserPassword is not a valid password because it is shorter than 8 characters."
)
@mock_rds
def test_modify_db_cluster_new_cluster_identifier():
client = boto3.client("rds", region_name="eu-north-1")
old_id = "cluster-id"
new_id = "new-cluster-id"
client.create_db_cluster(
DBClusterIdentifier=old_id,
Engine="aurora",
MasterUsername="root",
MasterUserPassword="hunter21",
)
resp = client.modify_db_cluster(
DBClusterIdentifier=old_id,
NewDBClusterIdentifier=new_id,
MasterUserPassword="hunter21",
)
resp["DBCluster"].should.have.key("DBClusterIdentifier").equal(new_id)
clusters = [
cluster["DBClusterIdentifier"]
for cluster in client.describe_db_clusters()["DBClusters"]
]
assert old_id not in clusters
@mock_rds @mock_rds
def test_create_db_cluster__verify_default_properties(): def test_create_db_cluster__verify_default_properties():
client = boto3.client("rds", region_name="eu-north-1") client = boto3.client("rds", region_name="eu-north-1")