From 5c0827547dccefd75f3bcff02ca64897819f845c Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Sun, 25 Jun 2023 17:36:42 +0000 Subject: [PATCH] RDS: Automated snapshots now have the appropriate SnapshotType (#6444) --- IMPLEMENTATION_COVERAGE.md | 6 +- docs/docs/services/rds.rst | 4 +- moto/rds/models.py | 92 +++++++++++++++++------------ moto/rds/responses.py | 8 +-- tests/test_rds/test_filters.py | 81 +++++++++++++++++++++++-- tests/test_rds/test_rds.py | 29 +++------ tests/test_rds/test_rds_clusters.py | 14 +++-- 7 files changed, 157 insertions(+), 77 deletions(-) diff --git a/IMPLEMENTATION_COVERAGE.md b/IMPLEMENTATION_COVERAGE.md index f2baabe83..faf335e50 100644 --- a/IMPLEMENTATION_COVERAGE.md +++ b/IMPLEMENTATION_COVERAGE.md @@ -5318,7 +5318,7 @@ ## rds
-38% implemented +39% implemented - [ ] add_role_to_db_cluster - [ ] add_role_to_db_instance @@ -5329,9 +5329,9 @@ - [ ] backtrack_db_cluster - [X] cancel_export_task - [ ] copy_db_cluster_parameter_group -- [ ] copy_db_cluster_snapshot +- [X] copy_db_cluster_snapshot - [ ] copy_db_parameter_group -- [ ] copy_db_snapshot +- [X] copy_db_snapshot - [ ] copy_option_group - [ ] create_blue_green_deployment - [ ] create_custom_db_engine_version diff --git a/docs/docs/services/rds.rst b/docs/docs/services/rds.rst index 8b258f899..8741432ee 100644 --- a/docs/docs/services/rds.rst +++ b/docs/docs/services/rds.rst @@ -34,9 +34,9 @@ rds - [ ] backtrack_db_cluster - [X] cancel_export_task - [ ] copy_db_cluster_parameter_group -- [ ] copy_db_cluster_snapshot +- [X] copy_db_cluster_snapshot - [ ] copy_db_parameter_group -- [ ] copy_db_snapshot +- [X] copy_db_snapshot - [ ] copy_option_group - [ ] create_blue_green_deployment - [ ] create_custom_db_engine_version diff --git a/moto/rds/models.py b/moto/rds/models.py index a2811bd32..9fb2af047 100644 --- a/moto/rds/models.py +++ b/moto/rds/models.py @@ -436,13 +436,20 @@ class ClusterSnapshot(BaseModel): "db-cluster-snapshot-id": FilterDef( ["snapshot_id"], "DB Cluster Snapshot Identifiers" ), - "snapshot-type": FilterDef(None, "Snapshot Types"), + "snapshot-type": FilterDef(["snapshot_type"], "Snapshot Types"), "engine": FilterDef(["cluster.engine"], "Engine Names"), } - def __init__(self, cluster: Cluster, snapshot_id: str, tags: List[Dict[str, str]]): + def __init__( + self, + cluster: Cluster, + snapshot_id: str, + snapshot_type: str, + tags: List[Dict[str, str]], + ): self.cluster = cluster self.snapshot_id = snapshot_id + self.snapshot_type = snapshot_type self.tags = tags self.status = "available" self.created_at = iso_8601_datetime_with_milliseconds( @@ -467,7 +474,7 @@ class ClusterSnapshot(BaseModel): {{ cluster.port }} {{ cluster.engine }} {{ snapshot.status }} - manual + {{ snapshot.snapshot_type }} {{ snapshot.snapshot_arn }} {{ cluster.region }} {% if cluster.iops %} @@ -551,9 +558,7 @@ class Database(CloudFormationModel): ) self.db_cluster_identifier: Optional[str] = kwargs.get("db_cluster_identifier") self.db_instance_identifier = kwargs.get("db_instance_identifier") - self.source_db_identifier: Optional[str] = kwargs.get( - "source_db_ide.db_cluster_identifierntifier" - ) + self.source_db_identifier: Optional[str] = kwargs.get("source_db_identifier") self.db_instance_class = kwargs.get("db_instance_class") self.port = kwargs.get("port") if self.port is None: @@ -1056,15 +1061,20 @@ class DatabaseSnapshot(BaseModel): ), "db-snapshot-id": FilterDef(["snapshot_id"], "DB Snapshot Identifiers"), "dbi-resource-id": FilterDef(["database.dbi_resource_id"], "Dbi Resource Ids"), - "snapshot-type": FilterDef(None, "Snapshot Types"), + "snapshot-type": FilterDef(["snapshot_type"], "Snapshot Types"), "engine": FilterDef(["database.engine"], "Engine Names"), } def __init__( - self, database: Database, snapshot_id: str, tags: List[Dict[str, str]] + self, + database: Database, + snapshot_id: str, + snapshot_type: str, + tags: List[Dict[str, str]], ): self.database = database self.snapshot_id = snapshot_id + self.snapshot_type = snapshot_type self.tags = tags self.status = "available" self.created_at = iso_8601_datetime_with_milliseconds( @@ -1092,7 +1102,7 @@ class DatabaseSnapshot(BaseModel): {{ database.master_username }} {{ database.engine_version }} {{ database.license_model }} - manual + {{ snapshot.snapshot_type }} {% if database.iops %} {{ database.iops }} io1 @@ -1566,10 +1576,20 @@ class RDSBackend(BaseBackend): self.databases[database_id] = database return database + def create_auto_snapshot( + self, + db_instance_identifier: str, + db_snapshot_identifier: str, + ) -> DatabaseSnapshot: + return self.create_db_snapshot( + db_instance_identifier, db_snapshot_identifier, snapshot_type="automated" + ) + def create_db_snapshot( self, db_instance_identifier: str, db_snapshot_identifier: str, + snapshot_type: str = "manual", tags: Optional[List[Dict[str, str]]] = None, ) -> DatabaseSnapshot: database = self.databases.get(db_instance_identifier) @@ -1585,11 +1605,13 @@ class RDSBackend(BaseBackend): tags = list() if database.copy_tags_to_snapshot and not tags: tags = database.get_tags() - snapshot = DatabaseSnapshot(database, db_snapshot_identifier, tags) + snapshot = DatabaseSnapshot( + database, db_snapshot_identifier, snapshot_type, tags + ) self.database_snapshots[db_snapshot_identifier] = snapshot return snapshot - def copy_database_snapshot( + def copy_db_snapshot( self, source_snapshot_identifier: str, target_snapshot_identifier: str, @@ -1597,24 +1619,17 @@ class RDSBackend(BaseBackend): ) -> DatabaseSnapshot: if source_snapshot_identifier not in self.database_snapshots: raise DBSnapshotNotFoundError(source_snapshot_identifier) - if target_snapshot_identifier in self.database_snapshots: - raise DBSnapshotAlreadyExistsError(target_snapshot_identifier) - if len(self.database_snapshots) >= int( - os.environ.get("MOTO_RDS_SNAPSHOT_LIMIT", "100") - ): - raise SnapshotQuotaExceededError() source_snapshot = self.database_snapshots[source_snapshot_identifier] if tags is None: tags = source_snapshot.tags else: tags = self._merge_tags(source_snapshot.tags, tags) - target_snapshot = DatabaseSnapshot( - source_snapshot.database, target_snapshot_identifier, tags + return self.create_db_snapshot( + db_instance_identifier=source_snapshot.database.db_instance_identifier, # type: ignore + db_snapshot_identifier=target_snapshot_identifier, + tags=tags, ) - self.database_snapshots[target_snapshot_identifier] = target_snapshot - - return target_snapshot def delete_db_snapshot(self, db_snapshot_identifier: str) -> DatabaseSnapshot: if db_snapshot_identifier not in self.database_snapshots: @@ -1738,7 +1753,7 @@ class RDSBackend(BaseBackend): if database.status != "available": raise InvalidDBInstanceStateError(db_instance_identifier, "stop") if db_snapshot_identifier: - self.create_db_snapshot(db_instance_identifier, db_snapshot_identifier) + self.create_auto_snapshot(db_instance_identifier, db_snapshot_identifier) database.status = "stopped" return database @@ -1771,7 +1786,7 @@ class RDSBackend(BaseBackend): "Can't delete Instance with protection enabled" ) if db_snapshot_name: - self.create_db_snapshot(db_instance_identifier, db_snapshot_name) + self.create_auto_snapshot(db_instance_identifier, db_snapshot_name) database = self.databases.pop(db_instance_identifier) if database.is_replica: primary = self.find_db_from_id(database.source_db_identifier) # type: ignore @@ -2197,10 +2212,18 @@ class RDSBackend(BaseBackend): cluster.replication_source_identifier = None return cluster + def create_auto_cluster_snapshot( + self, db_cluster_identifier: str, db_snapshot_identifier: str + ) -> ClusterSnapshot: + return self.create_db_cluster_snapshot( + db_cluster_identifier, db_snapshot_identifier, snapshot_type="automated" + ) + def create_db_cluster_snapshot( self, db_cluster_identifier: str, db_snapshot_identifier: str, + snapshot_type: str = "manual", tags: Optional[List[Dict[str, str]]] = None, ) -> ClusterSnapshot: cluster = self.clusters.get(db_cluster_identifier) @@ -2216,11 +2239,11 @@ class RDSBackend(BaseBackend): tags = list() if cluster.copy_tags_to_snapshot: tags += cluster.get_tags() - snapshot = ClusterSnapshot(cluster, db_snapshot_identifier, tags) + snapshot = ClusterSnapshot(cluster, db_snapshot_identifier, snapshot_type, tags) self.cluster_snapshots[db_snapshot_identifier] = snapshot return snapshot - def copy_cluster_snapshot( + def copy_db_cluster_snapshot( self, source_snapshot_identifier: str, target_snapshot_identifier: str, @@ -2228,22 +2251,17 @@ class RDSBackend(BaseBackend): ) -> ClusterSnapshot: if source_snapshot_identifier not in self.cluster_snapshots: raise DBClusterSnapshotNotFoundError(source_snapshot_identifier) - if target_snapshot_identifier in self.cluster_snapshots: - raise DBClusterSnapshotAlreadyExistsError(target_snapshot_identifier) - if len(self.cluster_snapshots) >= int( - os.environ.get("MOTO_RDS_SNAPSHOT_LIMIT", "100") - ): - raise SnapshotQuotaExceededError() + source_snapshot = self.cluster_snapshots[source_snapshot_identifier] if tags is None: tags = source_snapshot.tags else: tags = self._merge_tags(source_snapshot.tags, tags) # type: ignore - target_snapshot = ClusterSnapshot( - source_snapshot.cluster, target_snapshot_identifier, tags + return self.create_db_cluster_snapshot( + db_cluster_identifier=source_snapshot.cluster.db_cluster_identifier, # type: ignore + db_snapshot_identifier=target_snapshot_identifier, + tags=tags, ) - self.cluster_snapshots[target_snapshot_identifier] = target_snapshot - return target_snapshot def delete_db_cluster_snapshot( self, db_snapshot_identifier: str @@ -2303,7 +2321,7 @@ class RDSBackend(BaseBackend): self.remove_from_global_cluster(global_id, cluster_identifier) if snapshot_name: - self.create_db_cluster_snapshot(cluster_identifier, snapshot_name) + self.create_auto_cluster_snapshot(cluster_identifier, snapshot_name) return self.clusters.pop(cluster_identifier) if cluster_identifier in self.neptune.clusters: return self.neptune.delete_db_cluster(cluster_identifier) # type: ignore diff --git a/moto/rds/responses.py b/moto/rds/responses.py index c22d63748..cd80fef6a 100644 --- a/moto/rds/responses.py +++ b/moto/rds/responses.py @@ -312,7 +312,7 @@ class RDSResponse(BaseResponse): db_snapshot_identifier = self._get_param("DBSnapshotIdentifier") tags = self.unpack_list_params("Tags", "Tag") snapshot = self.backend.create_db_snapshot( - db_instance_identifier, db_snapshot_identifier, tags + db_instance_identifier, db_snapshot_identifier, tags=tags ) template = self.response_template(CREATE_SNAPSHOT_TEMPLATE) return template.render(snapshot=snapshot) @@ -321,7 +321,7 @@ class RDSResponse(BaseResponse): source_snapshot_identifier = self._get_param("SourceDBSnapshotIdentifier") target_snapshot_identifier = self._get_param("TargetDBSnapshotIdentifier") tags = self.unpack_list_params("Tags", "Tag") - snapshot = self.backend.copy_database_snapshot( + snapshot = self.backend.copy_db_snapshot( source_snapshot_identifier, target_snapshot_identifier, tags ) template = self.response_template(COPY_SNAPSHOT_TEMPLATE) @@ -642,7 +642,7 @@ class RDSResponse(BaseResponse): db_snapshot_identifier = self._get_param("DBClusterSnapshotIdentifier") tags = self.unpack_list_params("Tags", "Tag") snapshot = self.backend.create_db_cluster_snapshot( - db_cluster_identifier, db_snapshot_identifier, tags + db_cluster_identifier, db_snapshot_identifier, tags=tags ) template = self.response_template(CREATE_CLUSTER_SNAPSHOT_TEMPLATE) return template.render(snapshot=snapshot) @@ -655,7 +655,7 @@ class RDSResponse(BaseResponse): "TargetDBClusterSnapshotIdentifier" ) tags = self.unpack_list_params("Tags", "Tag") - snapshot = self.backend.copy_cluster_snapshot( + snapshot = self.backend.copy_db_cluster_snapshot( source_snapshot_identifier, target_snapshot_identifier, tags ) template = self.response_template(COPY_CLUSTER_SNAPSHOT_TEMPLATE) diff --git a/tests/test_rds/test_filters.py b/tests/test_rds/test_filters.py index 488c87c9e..529b32516 100644 --- a/tests/test_rds/test_filters.py +++ b/tests/test_rds/test_filters.py @@ -6,8 +6,7 @@ from botocore.exceptions import ClientError from moto import mock_rds -class TestDBInstanceFilters(object): - +class TestDBInstanceFilters: mock = mock_rds() @classmethod @@ -189,8 +188,7 @@ class TestDBInstanceFilters(object): ) -class TestDBSnapshotFilters(object): - +class TestDBSnapshotFilters: mock = mock_rds() @classmethod @@ -289,6 +287,18 @@ class TestDBSnapshotFilters(object): ).get("DBSnapshots") snapshots.should.have.length_of(0) + def test_snapshot_type_filter(self): + snapshots = self.client.describe_db_snapshots( + Filters=[{"Name": "snapshot-type", "Values": ["manual"]}] + )["DBSnapshots"] + for snapshot in snapshots: + assert snapshot["SnapshotType"] == "manual" + + snapshots = self.client.describe_db_snapshots( + Filters=[{"Name": "snapshot-type", "Values": ["automated"]}] + )["DBSnapshots"] + assert len(snapshots) == 0 + def test_multiple_filters(self): snapshots = self.client.describe_db_snapshots( Filters=[ @@ -373,3 +383,66 @@ class TestDBSnapshotFilters(object): ex.value.response["Error"]["Message"].should.equal( "DBSnapshot db-instance-0-snapshot-0 not found." ) + + +class TestDBClusterSnapshotFilters: + mock = mock_rds() + + @classmethod + def setup_class(cls): + cls.mock.start() + client = boto3.client("rds", region_name="us-west-2") + # We'll set up two instances (one postgres, one mysql) + # with two snapshots each. + for i in range(2): + _id = f"db-cluster-{i}" + client.create_db_cluster( + DBClusterIdentifier=_id, + Engine="postgres", + MasterUsername="root", + MasterUserPassword="hunter2000", + ) + + for j in range(2): + client.create_db_cluster_snapshot( + DBClusterIdentifier=_id, + DBClusterSnapshotIdentifier=f"snapshot-{i}-{j}", + ) + cls.client = client + + @classmethod + def teardown_class(cls): + try: + cls.mock.stop() + except RuntimeError: + pass + + def test_invalid_filter_name_raises_error(self): + with pytest.raises(ClientError) as ex: + self.client.describe_db_cluster_snapshots( + Filters=[{"Name": "invalid-filter-name", "Values": []}] + ) + ex.value.response["Error"]["Code"].should.equal("InvalidParameterValue") + ex.value.response["Error"]["Message"].should.equal( + "Unrecognized filter name: invalid-filter-name" + ) + + def test_empty_filter_values_raises_error(self): + with pytest.raises(ClientError) as ex: + self.client.describe_db_cluster_snapshots( + Filters=[{"Name": "snapshot-type", "Values": []}] + ) + ex.value.response["Error"]["Code"].should.equal("InvalidParameterCombination") + ex.value.response["Error"]["Message"].should.contain("must not be empty") + + def test_snapshot_type_filter(self): + snapshots = self.client.describe_db_cluster_snapshots( + Filters=[{"Name": "snapshot-type", "Values": ["manual"]}] + )["DBClusterSnapshots"] + for snapshot in snapshots: + assert snapshot["SnapshotType"] == "manual" + + snapshots = self.client.describe_db_cluster_snapshots( + Filters=[{"Name": "snapshot-type", "Values": ["automated"]}] + )["DBClusterSnapshots"] + assert len(snapshots) == 0 diff --git a/tests/test_rds/test_rds.py b/tests/test_rds/test_rds.py index f9b3b5d1e..8b66bfea7 100644 --- a/tests/test_rds/test_rds.py +++ b/tests/test_rds/test_rds.py @@ -914,7 +914,7 @@ def test_delete_database(): instances = conn.describe_db_instances() list(instances["DBInstances"]).should.have.length_of(0) conn.create_db_instance( - DBInstanceIdentifier="db-primary-1", + DBInstanceIdentifier="db-1", AllocatedStorage=10, Engine="postgres", DBInstanceClass="db.m1.small", @@ -927,7 +927,7 @@ def test_delete_database(): list(instances["DBInstances"]).should.have.length_of(1) conn.delete_db_instance( - DBInstanceIdentifier="db-primary-1", + DBInstanceIdentifier="db-1", FinalDBSnapshotIdentifier="primary-1-snapshot", ) @@ -935,10 +935,9 @@ def test_delete_database(): list(instances["DBInstances"]).should.have.length_of(0) # Saved the snapshot - snapshots = conn.describe_db_snapshots(DBInstanceIdentifier="db-primary-1").get( - "DBSnapshots" - ) - snapshots[0].get("Engine").should.equal("postgres") + snapshot = conn.describe_db_snapshots(DBInstanceIdentifier="db-1")["DBSnapshots"][0] + assert snapshot["Engine"] == "postgres" + assert snapshot["SnapshotType"] == "automated" @mock_rds @@ -1085,7 +1084,8 @@ def test_describe_db_snapshots(): DBInstanceIdentifier="db-primary-1", DBSnapshotIdentifier="snapshot-1" ).get("DBSnapshot") - created.get("Engine").should.equal("postgres") + assert created["Engine"] == "postgres" + assert created["SnapshotType"] == "manual" by_database_id = conn.describe_db_snapshots( DBInstanceIdentifier="db-primary-1" @@ -1096,8 +1096,7 @@ def test_describe_db_snapshots(): by_snapshot_id.should.equal(by_database_id) snapshot = by_snapshot_id[0] - snapshot.should.equal(created) - snapshot.get("Engine").should.equal("postgres") + assert snapshot == created conn.create_db_snapshot( DBInstanceIdentifier="db-primary-1", DBSnapshotIdentifier="snapshot-2" @@ -2508,18 +2507,6 @@ def test_create_db_with_iam_authentication(): db_instance = database["DBInstance"] db_instance["IAMDatabaseAuthenticationEnabled"].should.equal(True) - -@mock_rds -def test_create_db_snapshot_with_iam_authentication(): - conn = boto3.client("rds", region_name="us-west-2") - - conn.create_db_instance( - DBInstanceIdentifier="rds", - DBInstanceClass="db.t1.micro", - Engine="postgres", - EnableIAMDatabaseAuthentication=True, - ) - snapshot = conn.create_db_snapshot( DBInstanceIdentifier="rds", DBSnapshotIdentifier="snapshot" ).get("DBSnapshot") diff --git a/tests/test_rds/test_rds_clusters.py b/tests/test_rds/test_rds_clusters.py index 4ead4e19d..60aa83dc5 100644 --- a/tests/test_rds/test_rds_clusters.py +++ b/tests/test_rds/test_rds_clusters.py @@ -309,9 +309,10 @@ def test_delete_db_cluster_do_snapshot(): DBClusterIdentifier="cluster-id", FinalDBSnapshotIdentifier="final-snapshot" ) client.describe_db_clusters()["DBClusters"].should.have.length_of(0) - snapshots = client.describe_db_cluster_snapshots()["DBClusterSnapshots"] - snapshots[0]["DBClusterIdentifier"].should.equal("cluster-id") - snapshots[0]["DBClusterSnapshotIdentifier"].should.equal("final-snapshot") + snapshot = client.describe_db_cluster_snapshots()["DBClusterSnapshots"][0] + assert snapshot["DBClusterIdentifier"] == "cluster-id" + assert snapshot["DBClusterSnapshotIdentifier"] == "final-snapshot" + assert snapshot["SnapshotType"] == "automated" @mock_rds @@ -470,9 +471,10 @@ def test_create_db_cluster_snapshot(): DBClusterIdentifier="db-primary-1", DBClusterSnapshotIdentifier="g-1" ).get("DBClusterSnapshot") - snapshot.get("Engine").should.equal("postgres") - snapshot.get("DBClusterIdentifier").should.equal("db-primary-1") - snapshot.get("DBClusterSnapshotIdentifier").should.equal("g-1") + assert snapshot["Engine"] == "postgres" + assert snapshot["DBClusterIdentifier"] == "db-primary-1" + assert snapshot["DBClusterSnapshotIdentifier"] == "g-1" + assert snapshot["SnapshotType"] == "manual" result = conn.list_tags_for_resource(ResourceName=snapshot["DBClusterSnapshotArn"]) result["TagList"].should.equal([])