From 97011ba19b185ecb6f688ad629ec997f25cc0327 Mon Sep 17 00:00:00 2001 From: Marshall Mamiya <44485531+marshall7m@users.noreply.github.com> Date: Tue, 11 Oct 2022 07:51:17 -0700 Subject: [PATCH] Add RDS modify_db_cluster() (#5550) --- IMPLEMENTATION_COVERAGE.md | 2 +- moto/rds/models.py | 77 +++++++++++++++++++++-------- moto/rds/responses.py | 65 ++++++++++++++++++++++++ tests/test_rds/test_rds_clusters.py | 52 +++++++++++++++++++ 4 files changed, 175 insertions(+), 21 deletions(-) diff --git a/IMPLEMENTATION_COVERAGE.md b/IMPLEMENTATION_COVERAGE.md index 9bb99f522..dad452d09 100644 --- a/IMPLEMENTATION_COVERAGE.md +++ b/IMPLEMENTATION_COVERAGE.md @@ -4806,7 +4806,7 @@ - [ ] modify_certificates - [ ] modify_current_db_cluster_capacity - [ ] modify_custom_db_engine_version -- [ ] modify_db_cluster +- [X] modify_db_cluster - [ ] modify_db_cluster_endpoint - [ ] modify_db_cluster_parameter_group - [ ] modify_db_cluster_snapshot_attribute diff --git a/moto/rds/models.py b/moto/rds/models.py index 871e0b926..86132e977 100644 --- a/moto/rds/models.py +++ b/moto/rds/models.py @@ -73,14 +73,7 @@ class Cluster: "The parameter MasterUsername must be provided and must not be blank." ) self.master_user_password = kwargs.get("master_user_password") - if not self.master_user_password: - raise InvalidParameterValue( - "The parameter MasterUserPassword must be provided and must not be blank." - ) - if len(self.master_user_password) < 8: - raise InvalidParameterValue( - "The parameter MasterUserPassword is not a valid password because it is shorter than 8 characters." - ) + self.availability_zones = kwargs.get("availability_zones") if not self.availability_zones: self.availability_zones = [ @@ -113,12 +106,40 @@ class Cluster: self.enabled_cloudwatch_logs_exports = ( kwargs.get("enable_cloudwatch_logs_exports") or [] ) - self.enable_http_endpoint = False + self.enable_http_endpoint = kwargs.get("enable_http_endpoint") + + @property + def db_cluster_arn(self): + return f"arn:aws:rds:{self.region_name}:{self.account_id}:cluster:{self.db_cluster_identifier}" + + @property + def master_user_password(self): + return self._master_user_password + + @master_user_password.setter + def master_user_password(self, val): + if not val: + raise InvalidParameterValue( + "The parameter MasterUserPassword must be provided and must not be blank." + ) + if len(val) < 8: + raise InvalidParameterValue( + "The parameter MasterUserPassword is not a valid password because it is shorter than 8 characters." + ) + self._master_user_password = val + + @property + def enable_http_endpoint(self): + return self._enable_http_endpoint + + @enable_http_endpoint.setter + def enable_http_endpoint(self, val): # instead of raising an error on aws rds create-db-cluster commands with # incompatible configurations with enable_http_endpoint # (e.g. engine_mode is not set to "serverless"), the API # automatically sets the enable_http_endpoint parameter to False - if kwargs.get("enable_http_endpoint"): + self._enable_http_endpoint = False + if val is not None: if self.engine_mode == "serverless": if self.engine == "aurora-mysql" and self.engine_version in [ "5.6.10a", @@ -126,22 +147,20 @@ class Cluster: "2.07.1", "5.7.2", ]: - self.enable_http_endpoint = kwargs.get( - "enable_http_endpoint", False - ) + self._enable_http_endpoint = val elif self.engine == "aurora-postgresql" and self.engine_version in [ "10.12", "10.14", "10.18", "11.13", ]: - self.enable_http_endpoint = kwargs.get( - "enable_http_endpoint", False - ) + self._enable_http_endpoint = val - @property - def db_cluster_arn(self): - return f"arn:aws:rds:{self.region_name}:{self.account_id}:cluster:{self.db_cluster_identifier}" + def get_cfg(self): + cfg = self.__dict__ + cfg["master_user_password"] = cfg.pop("_master_user_password") + cfg["enable_http_endpoint"] = cfg.pop("_enable_http_endpoint") + return cfg def to_xml(self): template = Template( @@ -1794,6 +1813,24 @@ class RDSBackend(BaseBackend): cluster.status = "available" # Already set the final status in the background return initial_state + def modify_db_cluster(self, kwargs): + cluster_id = kwargs["db_cluster_identifier"] + + cluster = self.clusters[cluster_id] + del self.clusters[cluster_id] + + kwargs["db_cluster_identifier"] = kwargs.pop("new_db_cluster_identifier") + for k, v in kwargs.items(): + if v is not None: + setattr(cluster, k, v) + + cluster_id = kwargs.get("new_db_cluster_identifier", cluster_id) + self.clusters[cluster_id] = cluster + + initial_state = copy.deepcopy(cluster) # Return status=creating + cluster.status = "available" # Already set the final status in the background + return initial_state + def create_db_cluster_snapshot( self, db_cluster_identifier, db_snapshot_identifier, tags=None ): @@ -1894,7 +1931,7 @@ class RDSBackend(BaseBackend): db_cluster_identifier=None, db_snapshot_identifier=from_snapshot_id )[0] original_cluster = snapshot.cluster - new_cluster_props = copy.deepcopy(original_cluster.__dict__) + new_cluster_props = copy.deepcopy(original_cluster.get_cfg()) for key, value in overrides.items(): if value: new_cluster_props[key] = value diff --git a/moto/rds/responses.py b/moto/rds/responses.py index 2c9423716..70487c369 100644 --- a/moto/rds/responses.py +++ b/moto/rds/responses.py @@ -63,6 +63,56 @@ class RDSResponse(BaseResponse): args["tags"] = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")) return args + def _get_modify_db_cluster_kwargs(self): + args = { + "auto_minor_version_upgrade": self._get_param("AutoMinorVersionUpgrade"), + "allocated_storage": self._get_int_param("AllocatedStorage"), + "availability_zone": self._get_param("AvailabilityZone"), + "backup_retention_period": self._get_param("BackupRetentionPeriod"), + "copy_tags_to_snapshot": self._get_param("CopyTagsToSnapshot"), + "db_instance_class": self._get_param("DBInstanceClass"), + "db_cluster_identifier": self._get_param("DBClusterIdentifier"), + "new_db_cluster_identifier": self._get_param("NewDBClusterIdentifier"), + "db_instance_identifier": self._get_param("DBInstanceIdentifier"), + "db_name": self._get_param("DBName"), + "db_parameter_group_name": self._get_param("DBParameterGroupName"), + "db_snapshot_identifier": self._get_param("DBSnapshotIdentifier"), + "db_subnet_group_name": self._get_param("DBSubnetGroupName"), + "engine": self._get_param("Engine"), + "engine_version": self._get_param("EngineVersion"), + "enable_cloudwatch_logs_exports": self._get_params().get( + "EnableCloudwatchLogsExports" + ), + "enable_iam_database_authentication": self._get_bool_param( + "EnableIAMDatabaseAuthentication" + ), + "license_model": self._get_param("LicenseModel"), + "iops": self._get_int_param("Iops"), + "kms_key_id": self._get_param("KmsKeyId"), + "master_user_password": self._get_param("MasterUserPassword"), + "master_username": self._get_param("MasterUsername"), + "multi_az": self._get_bool_param("MultiAZ"), + "option_group_name": self._get_param("OptionGroupName"), + "port": self._get_param("Port"), + # PreferredBackupWindow + # PreferredMaintenanceWindow + "publicly_accessible": self._get_param("PubliclyAccessible"), + "account_id": self.current_account, + "region": self.region, + "security_groups": self._get_multi_param( + "DBSecurityGroups.DBSecurityGroupName" + ), + "storage_encrypted": self._get_param("StorageEncrypted"), + "storage_type": self._get_param("StorageType", None), + "vpc_security_group_ids": self._get_multi_param( + "VpcSecurityGroupIds.VpcSecurityGroupId" + ), + "tags": list(), + "deletion_protection": self._get_bool_param("DeletionProtection"), + } + args["tags"] = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")) + return args + def _get_db_replica_kwargs(self): return { "auto_minor_version_upgrade": self._get_param("AutoMinorVersionUpgrade"), @@ -515,6 +565,12 @@ class RDSResponse(BaseResponse): template = self.response_template(CREATE_DB_CLUSTER_TEMPLATE) return template.render(cluster=cluster) + def modify_db_cluster(self): + kwargs = self._get_modify_db_cluster_kwargs() + cluster = self.backend.modify_db_cluster(kwargs) + template = self.response_template(MODIFY_DB_CLUSTER_TEMPLATE) + return template.render(cluster=cluster) + def describe_db_clusters(self): _id = self._get_param("DBClusterIdentifier") clusters = self.backend.describe_db_clusters(cluster_identifier=_id) @@ -988,6 +1044,15 @@ CREATE_DB_CLUSTER_TEMPLATE = """ + + {{ cluster.to_xml() }} + + + 69673d54-e48e-4ba4-9333-c5a6c1e7526a + +""" + DESCRIBE_CLUSTERS_TEMPLATE = """ diff --git a/tests/test_rds/test_rds_clusters.py b/tests/test_rds/test_rds_clusters.py index 2078688f4..200d0a152 100644 --- a/tests/test_rds/test_rds_clusters.py +++ b/tests/test_rds/test_rds_clusters.py @@ -74,6 +74,58 @@ def test_create_db_cluster_needs_long_master_user_password(): ) +@mock_rds +def test_modify_db_cluster_needs_long_master_user_password(): + client = boto3.client("rds", region_name="eu-north-1") + + client.create_db_cluster( + DBClusterIdentifier="cluster-id", + Engine="aurora", + MasterUsername="root", + MasterUserPassword="hunter21", + ) + + with pytest.raises(ClientError) as ex: + client.modify_db_cluster( + DBClusterIdentifier="cluster-id", + MasterUserPassword="hunter2", + ) + err = ex.value.response["Error"] + err["Code"].should.equal("InvalidParameterValue") + err["Message"].should.equal( + "The parameter MasterUserPassword is not a valid password because it is shorter than 8 characters." + ) + + +@mock_rds +def test_modify_db_cluster_new_cluster_identifier(): + client = boto3.client("rds", region_name="eu-north-1") + old_id = "cluster-id" + new_id = "new-cluster-id" + + client.create_db_cluster( + DBClusterIdentifier=old_id, + Engine="aurora", + MasterUsername="root", + MasterUserPassword="hunter21", + ) + + resp = client.modify_db_cluster( + DBClusterIdentifier=old_id, + NewDBClusterIdentifier=new_id, + MasterUserPassword="hunter21", + ) + + resp["DBCluster"].should.have.key("DBClusterIdentifier").equal(new_id) + + clusters = [ + cluster["DBClusterIdentifier"] + for cluster in client.describe_db_clusters()["DBClusters"] + ] + + assert old_id not in clusters + + @mock_rds def test_create_db_cluster__verify_default_properties(): client = boto3.client("rds", region_name="eu-north-1")