RDS: Automated snapshots now have the appropriate SnapshotType (#6444)

This commit is contained in:
Bert Blommers 2023-06-25 17:36:42 +00:00 committed by GitHub
parent bec9130d4c
commit 5c0827547d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 157 additions and 77 deletions

View File

@ -5318,7 +5318,7 @@
## rds
<details>
<summary>38% implemented</summary>
<summary>39% implemented</summary>
- [ ] add_role_to_db_cluster
- [ ] add_role_to_db_instance
@ -5329,9 +5329,9 @@
- [ ] backtrack_db_cluster
- [X] cancel_export_task
- [ ] copy_db_cluster_parameter_group
- [ ] copy_db_cluster_snapshot
- [X] copy_db_cluster_snapshot
- [ ] copy_db_parameter_group
- [ ] copy_db_snapshot
- [X] copy_db_snapshot
- [ ] copy_option_group
- [ ] create_blue_green_deployment
- [ ] create_custom_db_engine_version

View File

@ -34,9 +34,9 @@ rds
- [ ] backtrack_db_cluster
- [X] cancel_export_task
- [ ] copy_db_cluster_parameter_group
- [ ] copy_db_cluster_snapshot
- [X] copy_db_cluster_snapshot
- [ ] copy_db_parameter_group
- [ ] copy_db_snapshot
- [X] copy_db_snapshot
- [ ] copy_option_group
- [ ] create_blue_green_deployment
- [ ] create_custom_db_engine_version

View File

@ -436,13 +436,20 @@ class ClusterSnapshot(BaseModel):
"db-cluster-snapshot-id": FilterDef(
["snapshot_id"], "DB Cluster Snapshot Identifiers"
),
"snapshot-type": FilterDef(None, "Snapshot Types"),
"snapshot-type": FilterDef(["snapshot_type"], "Snapshot Types"),
"engine": FilterDef(["cluster.engine"], "Engine Names"),
}
def __init__(self, cluster: Cluster, snapshot_id: str, tags: List[Dict[str, str]]):
def __init__(
self,
cluster: Cluster,
snapshot_id: str,
snapshot_type: str,
tags: List[Dict[str, str]],
):
self.cluster = cluster
self.snapshot_id = snapshot_id
self.snapshot_type = snapshot_type
self.tags = tags
self.status = "available"
self.created_at = iso_8601_datetime_with_milliseconds(
@ -467,7 +474,7 @@ class ClusterSnapshot(BaseModel):
<Port>{{ cluster.port }}</Port>
<Engine>{{ cluster.engine }}</Engine>
<Status>{{ snapshot.status }}</Status>
<SnapshotType>manual</SnapshotType>
<SnapshotType>{{ snapshot.snapshot_type }}</SnapshotType>
<DBClusterSnapshotArn>{{ snapshot.snapshot_arn }}</DBClusterSnapshotArn>
<SourceRegion>{{ cluster.region }}</SourceRegion>
{% if cluster.iops %}
@ -551,9 +558,7 @@ class Database(CloudFormationModel):
)
self.db_cluster_identifier: Optional[str] = kwargs.get("db_cluster_identifier")
self.db_instance_identifier = kwargs.get("db_instance_identifier")
self.source_db_identifier: Optional[str] = kwargs.get(
"source_db_ide.db_cluster_identifierntifier"
)
self.source_db_identifier: Optional[str] = kwargs.get("source_db_identifier")
self.db_instance_class = kwargs.get("db_instance_class")
self.port = kwargs.get("port")
if self.port is None:
@ -1056,15 +1061,20 @@ class DatabaseSnapshot(BaseModel):
),
"db-snapshot-id": FilterDef(["snapshot_id"], "DB Snapshot Identifiers"),
"dbi-resource-id": FilterDef(["database.dbi_resource_id"], "Dbi Resource Ids"),
"snapshot-type": FilterDef(None, "Snapshot Types"),
"snapshot-type": FilterDef(["snapshot_type"], "Snapshot Types"),
"engine": FilterDef(["database.engine"], "Engine Names"),
}
def __init__(
self, database: Database, snapshot_id: str, tags: List[Dict[str, str]]
self,
database: Database,
snapshot_id: str,
snapshot_type: str,
tags: List[Dict[str, str]],
):
self.database = database
self.snapshot_id = snapshot_id
self.snapshot_type = snapshot_type
self.tags = tags
self.status = "available"
self.created_at = iso_8601_datetime_with_milliseconds(
@ -1092,7 +1102,7 @@ class DatabaseSnapshot(BaseModel):
<MasterUsername>{{ database.master_username }}</MasterUsername>
<EngineVersion>{{ database.engine_version }}</EngineVersion>
<LicenseModel>{{ database.license_model }}</LicenseModel>
<SnapshotType>manual</SnapshotType>
<SnapshotType>{{ snapshot.snapshot_type }}</SnapshotType>
{% if database.iops %}
<Iops>{{ database.iops }}</Iops>
<StorageType>io1</StorageType>
@ -1566,10 +1576,20 @@ class RDSBackend(BaseBackend):
self.databases[database_id] = database
return database
def create_auto_snapshot(
self,
db_instance_identifier: str,
db_snapshot_identifier: str,
) -> DatabaseSnapshot:
return self.create_db_snapshot(
db_instance_identifier, db_snapshot_identifier, snapshot_type="automated"
)
def create_db_snapshot(
self,
db_instance_identifier: str,
db_snapshot_identifier: str,
snapshot_type: str = "manual",
tags: Optional[List[Dict[str, str]]] = None,
) -> DatabaseSnapshot:
database = self.databases.get(db_instance_identifier)
@ -1585,11 +1605,13 @@ class RDSBackend(BaseBackend):
tags = list()
if database.copy_tags_to_snapshot and not tags:
tags = database.get_tags()
snapshot = DatabaseSnapshot(database, db_snapshot_identifier, tags)
snapshot = DatabaseSnapshot(
database, db_snapshot_identifier, snapshot_type, tags
)
self.database_snapshots[db_snapshot_identifier] = snapshot
return snapshot
def copy_database_snapshot(
def copy_db_snapshot(
self,
source_snapshot_identifier: str,
target_snapshot_identifier: str,
@ -1597,24 +1619,17 @@ class RDSBackend(BaseBackend):
) -> DatabaseSnapshot:
if source_snapshot_identifier not in self.database_snapshots:
raise DBSnapshotNotFoundError(source_snapshot_identifier)
if target_snapshot_identifier in self.database_snapshots:
raise DBSnapshotAlreadyExistsError(target_snapshot_identifier)
if len(self.database_snapshots) >= int(
os.environ.get("MOTO_RDS_SNAPSHOT_LIMIT", "100")
):
raise SnapshotQuotaExceededError()
source_snapshot = self.database_snapshots[source_snapshot_identifier]
if tags is None:
tags = source_snapshot.tags
else:
tags = self._merge_tags(source_snapshot.tags, tags)
target_snapshot = DatabaseSnapshot(
source_snapshot.database, target_snapshot_identifier, tags
return self.create_db_snapshot(
db_instance_identifier=source_snapshot.database.db_instance_identifier, # type: ignore
db_snapshot_identifier=target_snapshot_identifier,
tags=tags,
)
self.database_snapshots[target_snapshot_identifier] = target_snapshot
return target_snapshot
def delete_db_snapshot(self, db_snapshot_identifier: str) -> DatabaseSnapshot:
if db_snapshot_identifier not in self.database_snapshots:
@ -1738,7 +1753,7 @@ class RDSBackend(BaseBackend):
if database.status != "available":
raise InvalidDBInstanceStateError(db_instance_identifier, "stop")
if db_snapshot_identifier:
self.create_db_snapshot(db_instance_identifier, db_snapshot_identifier)
self.create_auto_snapshot(db_instance_identifier, db_snapshot_identifier)
database.status = "stopped"
return database
@ -1771,7 +1786,7 @@ class RDSBackend(BaseBackend):
"Can't delete Instance with protection enabled"
)
if db_snapshot_name:
self.create_db_snapshot(db_instance_identifier, db_snapshot_name)
self.create_auto_snapshot(db_instance_identifier, db_snapshot_name)
database = self.databases.pop(db_instance_identifier)
if database.is_replica:
primary = self.find_db_from_id(database.source_db_identifier) # type: ignore
@ -2197,10 +2212,18 @@ class RDSBackend(BaseBackend):
cluster.replication_source_identifier = None
return cluster
def create_auto_cluster_snapshot(
self, db_cluster_identifier: str, db_snapshot_identifier: str
) -> ClusterSnapshot:
return self.create_db_cluster_snapshot(
db_cluster_identifier, db_snapshot_identifier, snapshot_type="automated"
)
def create_db_cluster_snapshot(
self,
db_cluster_identifier: str,
db_snapshot_identifier: str,
snapshot_type: str = "manual",
tags: Optional[List[Dict[str, str]]] = None,
) -> ClusterSnapshot:
cluster = self.clusters.get(db_cluster_identifier)
@ -2216,11 +2239,11 @@ class RDSBackend(BaseBackend):
tags = list()
if cluster.copy_tags_to_snapshot:
tags += cluster.get_tags()
snapshot = ClusterSnapshot(cluster, db_snapshot_identifier, tags)
snapshot = ClusterSnapshot(cluster, db_snapshot_identifier, snapshot_type, tags)
self.cluster_snapshots[db_snapshot_identifier] = snapshot
return snapshot
def copy_cluster_snapshot(
def copy_db_cluster_snapshot(
self,
source_snapshot_identifier: str,
target_snapshot_identifier: str,
@ -2228,22 +2251,17 @@ class RDSBackend(BaseBackend):
) -> ClusterSnapshot:
if source_snapshot_identifier not in self.cluster_snapshots:
raise DBClusterSnapshotNotFoundError(source_snapshot_identifier)
if target_snapshot_identifier in self.cluster_snapshots:
raise DBClusterSnapshotAlreadyExistsError(target_snapshot_identifier)
if len(self.cluster_snapshots) >= int(
os.environ.get("MOTO_RDS_SNAPSHOT_LIMIT", "100")
):
raise SnapshotQuotaExceededError()
source_snapshot = self.cluster_snapshots[source_snapshot_identifier]
if tags is None:
tags = source_snapshot.tags
else:
tags = self._merge_tags(source_snapshot.tags, tags) # type: ignore
target_snapshot = ClusterSnapshot(
source_snapshot.cluster, target_snapshot_identifier, tags
return self.create_db_cluster_snapshot(
db_cluster_identifier=source_snapshot.cluster.db_cluster_identifier, # type: ignore
db_snapshot_identifier=target_snapshot_identifier,
tags=tags,
)
self.cluster_snapshots[target_snapshot_identifier] = target_snapshot
return target_snapshot
def delete_db_cluster_snapshot(
self, db_snapshot_identifier: str
@ -2303,7 +2321,7 @@ class RDSBackend(BaseBackend):
self.remove_from_global_cluster(global_id, cluster_identifier)
if snapshot_name:
self.create_db_cluster_snapshot(cluster_identifier, snapshot_name)
self.create_auto_cluster_snapshot(cluster_identifier, snapshot_name)
return self.clusters.pop(cluster_identifier)
if cluster_identifier in self.neptune.clusters:
return self.neptune.delete_db_cluster(cluster_identifier) # type: ignore

View File

@ -312,7 +312,7 @@ class RDSResponse(BaseResponse):
db_snapshot_identifier = self._get_param("DBSnapshotIdentifier")
tags = self.unpack_list_params("Tags", "Tag")
snapshot = self.backend.create_db_snapshot(
db_instance_identifier, db_snapshot_identifier, tags
db_instance_identifier, db_snapshot_identifier, tags=tags
)
template = self.response_template(CREATE_SNAPSHOT_TEMPLATE)
return template.render(snapshot=snapshot)
@ -321,7 +321,7 @@ class RDSResponse(BaseResponse):
source_snapshot_identifier = self._get_param("SourceDBSnapshotIdentifier")
target_snapshot_identifier = self._get_param("TargetDBSnapshotIdentifier")
tags = self.unpack_list_params("Tags", "Tag")
snapshot = self.backend.copy_database_snapshot(
snapshot = self.backend.copy_db_snapshot(
source_snapshot_identifier, target_snapshot_identifier, tags
)
template = self.response_template(COPY_SNAPSHOT_TEMPLATE)
@ -642,7 +642,7 @@ class RDSResponse(BaseResponse):
db_snapshot_identifier = self._get_param("DBClusterSnapshotIdentifier")
tags = self.unpack_list_params("Tags", "Tag")
snapshot = self.backend.create_db_cluster_snapshot(
db_cluster_identifier, db_snapshot_identifier, tags
db_cluster_identifier, db_snapshot_identifier, tags=tags
)
template = self.response_template(CREATE_CLUSTER_SNAPSHOT_TEMPLATE)
return template.render(snapshot=snapshot)
@ -655,7 +655,7 @@ class RDSResponse(BaseResponse):
"TargetDBClusterSnapshotIdentifier"
)
tags = self.unpack_list_params("Tags", "Tag")
snapshot = self.backend.copy_cluster_snapshot(
snapshot = self.backend.copy_db_cluster_snapshot(
source_snapshot_identifier, target_snapshot_identifier, tags
)
template = self.response_template(COPY_CLUSTER_SNAPSHOT_TEMPLATE)

View File

@ -6,8 +6,7 @@ from botocore.exceptions import ClientError
from moto import mock_rds
class TestDBInstanceFilters(object):
class TestDBInstanceFilters:
mock = mock_rds()
@classmethod
@ -189,8 +188,7 @@ class TestDBInstanceFilters(object):
)
class TestDBSnapshotFilters(object):
class TestDBSnapshotFilters:
mock = mock_rds()
@classmethod
@ -289,6 +287,18 @@ class TestDBSnapshotFilters(object):
).get("DBSnapshots")
snapshots.should.have.length_of(0)
def test_snapshot_type_filter(self):
snapshots = self.client.describe_db_snapshots(
Filters=[{"Name": "snapshot-type", "Values": ["manual"]}]
)["DBSnapshots"]
for snapshot in snapshots:
assert snapshot["SnapshotType"] == "manual"
snapshots = self.client.describe_db_snapshots(
Filters=[{"Name": "snapshot-type", "Values": ["automated"]}]
)["DBSnapshots"]
assert len(snapshots) == 0
def test_multiple_filters(self):
snapshots = self.client.describe_db_snapshots(
Filters=[
@ -373,3 +383,66 @@ class TestDBSnapshotFilters(object):
ex.value.response["Error"]["Message"].should.equal(
"DBSnapshot db-instance-0-snapshot-0 not found."
)
class TestDBClusterSnapshotFilters:
mock = mock_rds()
@classmethod
def setup_class(cls):
cls.mock.start()
client = boto3.client("rds", region_name="us-west-2")
# We'll set up two instances (one postgres, one mysql)
# with two snapshots each.
for i in range(2):
_id = f"db-cluster-{i}"
client.create_db_cluster(
DBClusterIdentifier=_id,
Engine="postgres",
MasterUsername="root",
MasterUserPassword="hunter2000",
)
for j in range(2):
client.create_db_cluster_snapshot(
DBClusterIdentifier=_id,
DBClusterSnapshotIdentifier=f"snapshot-{i}-{j}",
)
cls.client = client
@classmethod
def teardown_class(cls):
try:
cls.mock.stop()
except RuntimeError:
pass
def test_invalid_filter_name_raises_error(self):
with pytest.raises(ClientError) as ex:
self.client.describe_db_cluster_snapshots(
Filters=[{"Name": "invalid-filter-name", "Values": []}]
)
ex.value.response["Error"]["Code"].should.equal("InvalidParameterValue")
ex.value.response["Error"]["Message"].should.equal(
"Unrecognized filter name: invalid-filter-name"
)
def test_empty_filter_values_raises_error(self):
with pytest.raises(ClientError) as ex:
self.client.describe_db_cluster_snapshots(
Filters=[{"Name": "snapshot-type", "Values": []}]
)
ex.value.response["Error"]["Code"].should.equal("InvalidParameterCombination")
ex.value.response["Error"]["Message"].should.contain("must not be empty")
def test_snapshot_type_filter(self):
snapshots = self.client.describe_db_cluster_snapshots(
Filters=[{"Name": "snapshot-type", "Values": ["manual"]}]
)["DBClusterSnapshots"]
for snapshot in snapshots:
assert snapshot["SnapshotType"] == "manual"
snapshots = self.client.describe_db_cluster_snapshots(
Filters=[{"Name": "snapshot-type", "Values": ["automated"]}]
)["DBClusterSnapshots"]
assert len(snapshots) == 0

View File

@ -914,7 +914,7 @@ def test_delete_database():
instances = conn.describe_db_instances()
list(instances["DBInstances"]).should.have.length_of(0)
conn.create_db_instance(
DBInstanceIdentifier="db-primary-1",
DBInstanceIdentifier="db-1",
AllocatedStorage=10,
Engine="postgres",
DBInstanceClass="db.m1.small",
@ -927,7 +927,7 @@ def test_delete_database():
list(instances["DBInstances"]).should.have.length_of(1)
conn.delete_db_instance(
DBInstanceIdentifier="db-primary-1",
DBInstanceIdentifier="db-1",
FinalDBSnapshotIdentifier="primary-1-snapshot",
)
@ -935,10 +935,9 @@ def test_delete_database():
list(instances["DBInstances"]).should.have.length_of(0)
# Saved the snapshot
snapshots = conn.describe_db_snapshots(DBInstanceIdentifier="db-primary-1").get(
"DBSnapshots"
)
snapshots[0].get("Engine").should.equal("postgres")
snapshot = conn.describe_db_snapshots(DBInstanceIdentifier="db-1")["DBSnapshots"][0]
assert snapshot["Engine"] == "postgres"
assert snapshot["SnapshotType"] == "automated"
@mock_rds
@ -1085,7 +1084,8 @@ def test_describe_db_snapshots():
DBInstanceIdentifier="db-primary-1", DBSnapshotIdentifier="snapshot-1"
).get("DBSnapshot")
created.get("Engine").should.equal("postgres")
assert created["Engine"] == "postgres"
assert created["SnapshotType"] == "manual"
by_database_id = conn.describe_db_snapshots(
DBInstanceIdentifier="db-primary-1"
@ -1096,8 +1096,7 @@ def test_describe_db_snapshots():
by_snapshot_id.should.equal(by_database_id)
snapshot = by_snapshot_id[0]
snapshot.should.equal(created)
snapshot.get("Engine").should.equal("postgres")
assert snapshot == created
conn.create_db_snapshot(
DBInstanceIdentifier="db-primary-1", DBSnapshotIdentifier="snapshot-2"
@ -2508,18 +2507,6 @@ def test_create_db_with_iam_authentication():
db_instance = database["DBInstance"]
db_instance["IAMDatabaseAuthenticationEnabled"].should.equal(True)
@mock_rds
def test_create_db_snapshot_with_iam_authentication():
conn = boto3.client("rds", region_name="us-west-2")
conn.create_db_instance(
DBInstanceIdentifier="rds",
DBInstanceClass="db.t1.micro",
Engine="postgres",
EnableIAMDatabaseAuthentication=True,
)
snapshot = conn.create_db_snapshot(
DBInstanceIdentifier="rds", DBSnapshotIdentifier="snapshot"
).get("DBSnapshot")

View File

@ -309,9 +309,10 @@ def test_delete_db_cluster_do_snapshot():
DBClusterIdentifier="cluster-id", FinalDBSnapshotIdentifier="final-snapshot"
)
client.describe_db_clusters()["DBClusters"].should.have.length_of(0)
snapshots = client.describe_db_cluster_snapshots()["DBClusterSnapshots"]
snapshots[0]["DBClusterIdentifier"].should.equal("cluster-id")
snapshots[0]["DBClusterSnapshotIdentifier"].should.equal("final-snapshot")
snapshot = client.describe_db_cluster_snapshots()["DBClusterSnapshots"][0]
assert snapshot["DBClusterIdentifier"] == "cluster-id"
assert snapshot["DBClusterSnapshotIdentifier"] == "final-snapshot"
assert snapshot["SnapshotType"] == "automated"
@mock_rds
@ -470,9 +471,10 @@ def test_create_db_cluster_snapshot():
DBClusterIdentifier="db-primary-1", DBClusterSnapshotIdentifier="g-1"
).get("DBClusterSnapshot")
snapshot.get("Engine").should.equal("postgres")
snapshot.get("DBClusterIdentifier").should.equal("db-primary-1")
snapshot.get("DBClusterSnapshotIdentifier").should.equal("g-1")
assert snapshot["Engine"] == "postgres"
assert snapshot["DBClusterIdentifier"] == "db-primary-1"
assert snapshot["DBClusterSnapshotIdentifier"] == "g-1"
assert snapshot["SnapshotType"] == "manual"
result = conn.list_tags_for_resource(ResourceName=snapshot["DBClusterSnapshotArn"])
result["TagList"].should.equal([])