From 7ecc0161af29e9f390c6a92803c3cf47f58b9350 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Wed, 12 Apr 2023 19:27:40 +0000 Subject: [PATCH] Techdebt: MyPy RDS (#6203) --- moto/rds/exceptions.py | 46 +-- moto/rds/models.py | 675 ++++++++++++++++++++++++----------------- moto/rds/responses.py | 175 +++++------ moto/rds/utils.py | 56 ++-- setup.cfg | 2 +- 5 files changed, 537 insertions(+), 417 deletions(-) diff --git a/moto/rds/exceptions.py b/moto/rds/exceptions.py index e56511d26..faf1cd842 100644 --- a/moto/rds/exceptions.py +++ b/moto/rds/exceptions.py @@ -3,7 +3,7 @@ from moto.core.exceptions import RESTError class RDSClientError(RESTError): - def __init__(self, code, message): + def __init__(self, code: str, message: str): super().__init__(error_type=code, message=message) template = Template( """ @@ -20,21 +20,21 @@ class RDSClientError(RESTError): class DBInstanceNotFoundError(RDSClientError): - def __init__(self, database_identifier): + def __init__(self, database_identifier: str): super().__init__( "DBInstanceNotFound", f"DBInstance {database_identifier} not found." ) class DBSnapshotNotFoundError(RDSClientError): - def __init__(self, snapshot_identifier): + def __init__(self, snapshot_identifier: str): super().__init__( "DBSnapshotNotFound", f"DBSnapshot {snapshot_identifier} not found." ) class DBSecurityGroupNotFoundError(RDSClientError): - def __init__(self, security_group_name): + def __init__(self, security_group_name: str): super().__init__( "DBSecurityGroupNotFound", f"Security Group {security_group_name} not found.", @@ -44,14 +44,14 @@ class DBSecurityGroupNotFoundError(RDSClientError): class DBSubnetGroupNotFoundError(RDSClientError): code = 404 - def __init__(self, subnet_group_name): + def __init__(self, subnet_group_name: str): super().__init__( "DBSubnetGroupNotFoundFault", f"Subnet Group {subnet_group_name} not found." ) class DBParameterGroupNotFoundError(RDSClientError): - def __init__(self, db_parameter_group_name): + def __init__(self, db_parameter_group_name: str): super().__init__( "DBParameterGroupNotFound", f"DB Parameter Group {db_parameter_group_name} not found.", @@ -59,7 +59,7 @@ class DBParameterGroupNotFoundError(RDSClientError): class DBClusterParameterGroupNotFoundError(RDSClientError): - def __init__(self, group_name): + def __init__(self, group_name: str): super().__init__( "DBParameterGroupNotFound", f"DBClusterParameterGroup not found: {group_name}", @@ -67,7 +67,7 @@ class DBClusterParameterGroupNotFoundError(RDSClientError): class OptionGroupNotFoundFaultError(RDSClientError): - def __init__(self, option_group_name): + def __init__(self, option_group_name: str): super().__init__( "OptionGroupNotFoundFault", f"Specified OptionGroupName: {option_group_name} not found.", @@ -75,7 +75,7 @@ class OptionGroupNotFoundFaultError(RDSClientError): class InvalidDBClusterStateFaultError(RDSClientError): - def __init__(self, database_identifier): + def __init__(self, database_identifier: str): super().__init__( "InvalidDBClusterStateFault", f"Invalid DB type, when trying to perform StopDBInstance on {database_identifier}e. See AWS RDS documentation on rds.stop_db_instance", @@ -83,7 +83,7 @@ class InvalidDBClusterStateFaultError(RDSClientError): class InvalidDBInstanceStateError(RDSClientError): - def __init__(self, database_identifier, istate): + def __init__(self, database_identifier: str, istate: str): estate = ( "in available state" if istate == "stop" @@ -95,7 +95,7 @@ class InvalidDBInstanceStateError(RDSClientError): class SnapshotQuotaExceededError(RDSClientError): - def __init__(self): + def __init__(self) -> None: super().__init__( "SnapshotQuotaExceeded", "The request cannot be processed because it would exceed the maximum number of snapshots.", @@ -103,7 +103,7 @@ class SnapshotQuotaExceededError(RDSClientError): class DBSnapshotAlreadyExistsError(RDSClientError): - def __init__(self, database_snapshot_identifier): + def __init__(self, database_snapshot_identifier: str): super().__init__( "DBSnapshotAlreadyExists", f"Cannot create the snapshot because a snapshot with the identifier {database_snapshot_identifier} already exists.", @@ -111,29 +111,29 @@ class DBSnapshotAlreadyExistsError(RDSClientError): class InvalidParameterValue(RDSClientError): - def __init__(self, message): + def __init__(self, message: str): super().__init__("InvalidParameterValue", message) class InvalidParameterCombination(RDSClientError): - def __init__(self, message): + def __init__(self, message: str): super().__init__("InvalidParameterCombination", message) class InvalidDBClusterStateFault(RDSClientError): - def __init__(self, message): + def __init__(self, message: str): super().__init__("InvalidDBClusterStateFault", message) class DBClusterNotFoundError(RDSClientError): - def __init__(self, cluster_identifier): + def __init__(self, cluster_identifier: str): super().__init__( "DBClusterNotFoundFault", f"DBCluster {cluster_identifier} not found." ) class DBClusterSnapshotNotFoundError(RDSClientError): - def __init__(self, snapshot_identifier): + def __init__(self, snapshot_identifier: str): super().__init__( "DBClusterSnapshotNotFoundFault", f"DBClusterSnapshot {snapshot_identifier} not found.", @@ -141,7 +141,7 @@ class DBClusterSnapshotNotFoundError(RDSClientError): class DBClusterSnapshotAlreadyExistsError(RDSClientError): - def __init__(self, database_snapshot_identifier): + def __init__(self, database_snapshot_identifier: str): super().__init__( "DBClusterSnapshotAlreadyExistsFault", f"Cannot create the snapshot because a snapshot with the identifier {database_snapshot_identifier} already exists.", @@ -149,7 +149,7 @@ class DBClusterSnapshotAlreadyExistsError(RDSClientError): class ExportTaskAlreadyExistsError(RDSClientError): - def __init__(self, export_task_identifier): + def __init__(self, export_task_identifier: str): super().__init__( "ExportTaskAlreadyExistsFault", f"Cannot start export task because a task with the identifier {export_task_identifier} already exists.", @@ -157,7 +157,7 @@ class ExportTaskAlreadyExistsError(RDSClientError): class ExportTaskNotFoundError(RDSClientError): - def __init__(self, export_task_identifier): + def __init__(self, export_task_identifier: str): super().__init__( "ExportTaskNotFoundFault", f"Cannot cancel export task because a task with the identifier {export_task_identifier} is not exist.", @@ -165,7 +165,7 @@ class ExportTaskNotFoundError(RDSClientError): class InvalidExportSourceStateError(RDSClientError): - def __init__(self, status): + def __init__(self, status: str): super().__init__( "InvalidExportSourceStateFault", f"Export source should be 'available' but current status is {status}.", @@ -173,7 +173,7 @@ class InvalidExportSourceStateError(RDSClientError): class SubscriptionAlreadyExistError(RDSClientError): - def __init__(self, subscription_name): + def __init__(self, subscription_name: str): super().__init__( "SubscriptionAlreadyExistFault", f"Subscription {subscription_name} already exists.", @@ -181,7 +181,7 @@ class SubscriptionAlreadyExistError(RDSClientError): class SubscriptionNotFoundError(RDSClientError): - def __init__(self, subscription_name): + def __init__(self, subscription_name: str): super().__init__( "SubscriptionNotFoundFault", f"Subscription {subscription_name} not found." ) diff --git a/moto/rds/models.py b/moto/rds/models.py index c080794cc..bdc823243 100644 --- a/moto/rds/models.py +++ b/moto/rds/models.py @@ -7,7 +7,7 @@ from collections import defaultdict from jinja2 import Template from re import compile as re_compile from collections import OrderedDict -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional, Iterable, Tuple, Union from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel from moto.core.utils import iso_8601_datetime_with_milliseconds from moto.ec2.models import ec2_backends @@ -49,7 +49,7 @@ from .utils import ( ) -def find_cluster(cluster_arn): +def find_cluster(cluster_arn: str) -> "Cluster": arn_parts = cluster_arn.split(":") region, account = arn_parts[3], arn_parts[4] return rds_backends[account][region].describe_db_clusters(cluster_arn)[0] @@ -61,9 +61,9 @@ class GlobalCluster(BaseModel): account_id: str, global_cluster_identifier: str, engine: str, - engine_version, - storage_encrypted, - deletion_protection, + engine_version: Optional[str], + storage_encrypted: Optional[str], + deletion_protection: Optional[str], ): self.global_cluster_identifier = global_cluster_identifier self.global_cluster_resource_id = "cluster-" + random.get_random_hex(8) @@ -78,7 +78,7 @@ class GlobalCluster(BaseModel): self.deletion_protection = ( deletion_protection and deletion_protection.lower() == "true" ) - self.members = [] + self.members: List[str] = [] def to_xml(self) -> str: template = Template( @@ -112,15 +112,13 @@ class Cluster: "engine": FilterDef(["engine"], "Engine Names"), } - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): self.db_name = kwargs.get("db_name") self.db_cluster_identifier = kwargs.get("db_cluster_identifier") self.db_cluster_instance_class = kwargs.get("db_cluster_instance_class") self.deletion_protection = kwargs.get("deletion_protection") self.engine = kwargs.get("engine") - self.engine_version = kwargs.get("engine_version") - if not self.engine_version: - self.engine_version = Cluster.default_engine_version(self.engine) + self.engine_version = kwargs.get("engine_version") or Cluster.default_engine_version(self.engine) # type: ignore self.engine_mode = kwargs.get("engine_mode") or "provisioned" self.iops = kwargs.get("iops") self.kms_key_id = kwargs.get("kms_key_id") @@ -140,14 +138,14 @@ class Cluster: self.allocated_storage = kwargs.get("allocated_storage") if self.allocated_storage is None: self.allocated_storage = Cluster.default_allocated_storage( - engine=self.engine, storage_type=self.storage_type + engine=self.engine, storage_type=self.storage_type # type: ignore ) self.master_username = kwargs.get("master_username") if not self.master_username: raise InvalidParameterValue( "The parameter MasterUsername must be provided and must not be blank." ) - self.master_user_password = kwargs.get("master_user_password") + self.master_user_password = kwargs.get("master_user_password") # type: ignore self.availability_zones = kwargs.get("availability_zones") if not self.availability_zones: @@ -164,13 +162,13 @@ class Cluster: ) self.endpoint = f"{self.db_cluster_identifier}.cluster-{self.url_identifier}.{self.region_name}.rds.amazonaws.com" self.reader_endpoint = f"{self.db_cluster_identifier}.cluster-ro-{self.url_identifier}.{self.region_name}.rds.amazonaws.com" - self.port = kwargs.get("port") + self.port: int = kwargs.get("port") # type: ignore if self.port is None: self.port = Cluster.default_port(self.engine) self.preferred_backup_window = "01:37-02:07" self.preferred_maintenance_window = "wed:02:40-wed:03:10" # This should default to the default security group - self.vpc_security_groups = [] + self.vpc_security_groups: List[str] = [] self.hosted_zone_id = "".join( random.choice(string.ascii_uppercase + string.digits) for _ in range(14) ) @@ -181,7 +179,7 @@ class Cluster: self.enabled_cloudwatch_logs_exports = ( kwargs.get("enable_cloudwatch_logs_exports") or [] ) - self.enable_http_endpoint = kwargs.get("enable_http_endpoint") + self.enable_http_endpoint = kwargs.get("enable_http_endpoint") # type: ignore self.earliest_restorable_time = iso_8601_datetime_with_milliseconds( datetime.datetime.utcnow() ) @@ -197,27 +195,27 @@ class Cluster: "seconds_before_timeout": 300, } self.global_cluster_identifier = kwargs.get("global_cluster_identifier") - self.cluster_members = list() + self.cluster_members: List[str] = list() self.replication_source_identifier = kwargs.get("replication_source_identifier") - self.read_replica_identifiers = list() + self.read_replica_identifiers: List[str] = list() @property - def is_multi_az(self): + def is_multi_az(self) -> bool: return ( len(self.read_replica_identifiers) > 0 or self.replication_source_identifier is not None ) @property - def db_cluster_arn(self): + def db_cluster_arn(self) -> str: return f"arn:aws:rds:{self.region_name}:{self.account_id}:cluster:{self.db_cluster_identifier}" @property - def master_user_password(self): + def master_user_password(self) -> str: return self._master_user_password @master_user_password.setter - def master_user_password(self, val): + def master_user_password(self, val: str) -> None: if not val: raise InvalidParameterValue( "The parameter MasterUserPassword must be provided and must not be blank." @@ -229,11 +227,11 @@ class Cluster: self._master_user_password = val @property - def enable_http_endpoint(self): + def enable_http_endpoint(self) -> bool: return self._enable_http_endpoint @enable_http_endpoint.setter - def enable_http_endpoint(self, val): + def enable_http_endpoint(self, val: Optional[bool]) -> None: # instead of raising an error on aws rds create-db-cluster commands with # incompatible configurations with enable_http_endpoint # (e.g. engine_mode is not set to "serverless"), the API @@ -260,13 +258,13 @@ class Cluster: ]: self._enable_http_endpoint = val - def get_cfg(self): + def get_cfg(self) -> Dict[str, Any]: cfg = self.__dict__ cfg["master_user_password"] = cfg.pop("_master_user_password") cfg["enable_http_endpoint"] = cfg.pop("_enable_http_endpoint") return cfg - def to_xml(self): + def to_xml(self) -> str: template = Template( """ {{ cluster.allocated_storage }} @@ -368,7 +366,7 @@ class Cluster: return template.render(cluster=self) @staticmethod - def default_engine_version(engine): + def default_engine_version(engine: str) -> str: return { "aurora": "5.6.mysql_aurora.1.22.5", "aurora-mysql": "5.7.mysql_aurora.2.07.2", @@ -378,7 +376,7 @@ class Cluster: }[engine] @staticmethod - def default_port(engine): + def default_port(engine: str) -> int: return { "aurora": 3306, "aurora-mysql": 3306, @@ -388,14 +386,14 @@ class Cluster: }[engine] @staticmethod - def default_storage_type(iops): + def default_storage_type(iops: Any) -> str: # type: ignore[misc] if iops is None: return "gp2" else: return "io1" @staticmethod - def default_allocated_storage(engine, storage_type): + def default_allocated_storage(engine: str, storage_type: str) -> int: return { "aurora": {"gp2": 0, "io1": 0, "standard": 0}, "aurora-mysql": {"gp2": 20, "io1": 100, "standard": 10}, @@ -404,16 +402,16 @@ class Cluster: "postgres": {"gp2": 20, "io1": 100, "standard": 5}, }[engine][storage_type] - def get_tags(self): + def get_tags(self) -> List[Dict[str, str]]: return self.tags - def add_tags(self, tags): + def add_tags(self, tags: List[Dict[str, str]]) -> List[Dict[str, str]]: new_keys = [tag_set["Key"] for tag_set in tags] self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys] self.tags.extend(tags) return self.tags - def remove_tags(self, tag_keys): + def remove_tags(self, tag_keys: List[str]) -> None: self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] @@ -431,7 +429,7 @@ class ClusterSnapshot(BaseModel): "engine": FilterDef(["cluster.engine"], "Engine Names"), } - def __init__(self, cluster, snapshot_id, tags): + def __init__(self, cluster: Cluster, snapshot_id: str, tags: List[Dict[str, str]]): self.cluster = cluster self.snapshot_id = snapshot_id self.tags = tags @@ -441,10 +439,10 @@ class ClusterSnapshot(BaseModel): ) @property - def snapshot_arn(self): + def snapshot_arn(self) -> str: return f"arn:aws:rds:{self.cluster.region_name}:{self.cluster.account_id}:cluster-snapshot:{self.snapshot_id}" - def to_xml(self): + def to_xml(self) -> str: template = Template( """ @@ -474,16 +472,16 @@ class ClusterSnapshot(BaseModel): ) return template.render(snapshot=self, cluster=self.cluster) - def get_tags(self): + def get_tags(self) -> List[Dict[str, str]]: return self.tags - def add_tags(self, tags): + def add_tags(self, tags: List[Dict[str, str]]) -> List[Dict[str, str]]: new_keys = [tag_set["Key"] for tag_set in tags] self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys] self.tags.extend(tags) return self.tags - def remove_tags(self, tag_keys): + def remove_tags(self, tag_keys: List[str]) -> None: self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] @@ -512,12 +510,12 @@ class Database(CloudFormationModel): "postgres": "9.3.3", } - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): self.status = "available" self.is_replica = False - self.replicas = [] - self.account_id = kwargs.get("account_id") - self.region_name = kwargs.get("region") + self.replicas: List[str] = [] + self.account_id: str = kwargs["account_id"] + self.region_name: str = kwargs["region"] self.engine = kwargs.get("engine") self.engine_version = kwargs.get("engine_version", None) if not self.engine_version and self.engine in self.default_engine_versions: @@ -539,17 +537,17 @@ class Database(CloudFormationModel): self.allocated_storage = kwargs.get("allocated_storage") if self.allocated_storage is None: self.allocated_storage = Database.default_allocated_storage( - engine=self.engine, storage_type=self.storage_type + engine=self.engine, storage_type=self.storage_type # type: ignore ) - self.db_cluster_identifier = kwargs.get("db_cluster_identifier") + self.db_cluster_identifier: Optional[str] = kwargs.get("db_cluster_identifier") self.db_instance_identifier = kwargs.get("db_instance_identifier") - self.source_db_identifier = kwargs.get( + self.source_db_identifier: Optional[str] = kwargs.get( "source_db_ide.db_cluster_identifierntifier" ) self.db_instance_class = kwargs.get("db_instance_class") self.port = kwargs.get("port") if self.port is None: - self.port = Database.default_port(self.engine) + self.port = Database.default_port(self.engine) # type: ignore self.db_instance_identifier = kwargs.get("db_instance_identifier") self.db_name = kwargs.get("db_name") self.instance_create_time = iso_8601_datetime_with_milliseconds( @@ -623,14 +621,14 @@ class Database(CloudFormationModel): ) @property - def db_instance_arn(self): + def db_instance_arn(self) -> str: return f"arn:aws:rds:{self.region_name}:{self.account_id}:db:{self.db_instance_identifier}" @property - def physical_resource_id(self): + def physical_resource_id(self) -> Optional[str]: return self.db_instance_identifier - def db_parameter_groups(self): + def db_parameter_groups(self) -> List["DBParameterGroup"]: if not self.db_parameter_group_name or self.is_default_parameter_group( self.db_parameter_group_name ): @@ -645,7 +643,7 @@ class Database(CloudFormationModel): name=db_parameter_group_name, family=db_family, description=description, - tags={}, + tags=[], region=self.region_name, ) ] @@ -656,19 +654,19 @@ class Database(CloudFormationModel): return [backend.db_parameter_groups[self.db_parameter_group_name]] - def is_default_parameter_group(self, param_group_name): - return param_group_name.startswith(f"default.{self.engine.lower()}") + def is_default_parameter_group(self, param_group_name: str) -> bool: + return param_group_name.startswith(f"default.{self.engine.lower()}") # type: ignore - def default_db_parameter_group_details(self): + def default_db_parameter_group_details(self) -> Tuple[Optional[str], Optional[str]]: if not self.engine_version: return (None, None) minor_engine_version = ".".join(str(self.engine_version).rsplit(".")[:-1]) - db_family = f"{self.engine.lower()}{minor_engine_version}" + db_family = f"{self.engine.lower()}{minor_engine_version}" # type: ignore return db_family, f"default.{db_family}" - def to_xml(self): + def to_xml(self) -> str: template = Template( """ {{ database.availability_zone }} @@ -795,33 +793,33 @@ class Database(CloudFormationModel): return template.render(database=self) @property - def address(self): + def address(self) -> str: return f"{self.db_instance_identifier}.aaaaaaaaaa.{self.region_name}.rds.amazonaws.com" - def add_replica(self, replica): + def add_replica(self, replica: "Database") -> None: if self.region_name != replica.region_name: # Cross Region replica self.replicas.append(replica.db_instance_arn) else: - self.replicas.append(replica.db_instance_identifier) + self.replicas.append(replica.db_instance_identifier) # type: ignore - def remove_replica(self, replica): - self.replicas.remove(replica.db_instance_identifier) + def remove_replica(self, replica: "Database") -> None: + self.replicas.remove(replica.db_instance_identifier) # type: ignore - def set_as_replica(self): + def set_as_replica(self) -> None: self.is_replica = True self.replicas = [] - def update(self, db_kwargs): + def update(self, db_kwargs: Dict[str, Any]) -> None: for key, value in db_kwargs.items(): if value is not None: setattr(self, key, value) @classmethod - def has_cfn_attr(cls, attr): + def has_cfn_attr(cls, attr: str) -> bool: return attr in ["Endpoint.Address", "Endpoint.Port"] - def get_cfn_attribute(self, attribute_name): + def get_cfn_attribute(self, attribute_name: str) -> Any: # Local import to avoid circular dependency with cloudformation.parsing from moto.cloudformation.exceptions import UnformattedGetAttTemplateException @@ -832,7 +830,7 @@ class Database(CloudFormationModel): raise UnformattedGetAttTemplateException() @staticmethod - def default_port(engine): + def default_port(engine: str) -> int: return { "aurora": 3306, "aurora-mysql": 3306, @@ -851,14 +849,14 @@ class Database(CloudFormationModel): }[engine] @staticmethod - def default_storage_type(iops): + def default_storage_type(iops: Any) -> str: # type: ignore[misc] if iops is None: return "gp2" else: return "io1" @staticmethod - def default_allocated_storage(engine, storage_type): + def default_allocated_storage(engine: str, storage_type: str) -> int: return { "aurora": {"gp2": 0, "io1": 0, "standard": 0}, "aurora-mysql": {"gp2": 20, "io1": 100, "standard": 10}, @@ -877,18 +875,23 @@ class Database(CloudFormationModel): }[engine][storage_type] @staticmethod - def cloudformation_name_type(): + def cloudformation_name_type() -> str: return "DBInstanceIdentifier" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-rds-dbinstance.html return "AWS::RDS::DBInstance" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "Database": properties = cloudformation_json["Properties"] db_security_groups = properties.get("DBSecurityGroups") @@ -942,7 +945,7 @@ class Database(CloudFormationModel): database = rds_backend.create_db_instance(db_kwargs) return database - def to_json(self): + def to_json(self) -> str: template = Template( """{ "AllocatedStorage": 10, @@ -1018,19 +1021,19 @@ class Database(CloudFormationModel): ) return template.render(database=self) - def get_tags(self): + def get_tags(self) -> List[Dict[str, str]]: return self.tags - def add_tags(self, tags): + def add_tags(self, tags: List[Dict[str, str]]) -> List[Dict[str, str]]: new_keys = [tag_set["Key"] for tag_set in tags] self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys] self.tags.extend(tags) return self.tags - def remove_tags(self, tag_keys): + def remove_tags(self, tag_keys: List[str]) -> None: self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] - def delete(self, account_id, region_name): + def delete(self, account_id: str, region_name: str) -> None: backend = rds_backends[account_id][region_name] backend.delete_db_instance(self.db_instance_identifier) @@ -1048,7 +1051,9 @@ class DatabaseSnapshot(BaseModel): "engine": FilterDef(["database.engine"], "Engine Names"), } - def __init__(self, database, snapshot_id, tags): + def __init__( + self, database: Database, snapshot_id: str, tags: List[Dict[str, str]] + ): self.database = database self.snapshot_id = snapshot_id self.tags = tags @@ -1056,10 +1061,10 @@ class DatabaseSnapshot(BaseModel): self.created_at = iso_8601_datetime_with_milliseconds(datetime.datetime.now()) @property - def snapshot_arn(self): + def snapshot_arn(self) -> str: return f"arn:aws:rds:{self.database.region_name}:{self.database.account_id}:snapshot:{self.snapshot_id}" - def to_xml(self): + def to_xml(self) -> str: template = Template( """ {{ snapshot.snapshot_id }} @@ -1102,21 +1107,23 @@ class DatabaseSnapshot(BaseModel): ) return template.render(snapshot=self, database=self.database) - def get_tags(self): + def get_tags(self) -> List[Dict[str, str]]: return self.tags - def add_tags(self, tags): + def add_tags(self, tags: List[Dict[str, str]]) -> List[Dict[str, str]]: new_keys = [tag_set["Key"] for tag_set in tags] self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys] self.tags.extend(tags) return self.tags - def remove_tags(self, tag_keys): + def remove_tags(self, tag_keys: List[str]) -> None: self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] class ExportTask(BaseModel): - def __init__(self, snapshot, kwargs): + def __init__( + self, snapshot: Union[DatabaseSnapshot, ClusterSnapshot], kwargs: Dict[str, Any] + ): self.snapshot = snapshot self.export_task_identifier = kwargs.get("export_task_identifier") @@ -1130,7 +1137,7 @@ class ExportTask(BaseModel): self.status = "complete" self.created_at = iso_8601_datetime_with_milliseconds(datetime.datetime.now()) - def to_xml(self): + def to_xml(self) -> str: template = Template( """ {{ task.export_task_identifier }} @@ -1160,7 +1167,7 @@ class ExportTask(BaseModel): class EventSubscription(BaseModel): - def __init__(self, kwargs): + def __init__(self, kwargs: Dict[str, Any]): self.subscription_name = kwargs.get("subscription_name") self.sns_topic_arn = kwargs.get("sns_topic_arn") self.source_type = kwargs.get("source_type") @@ -1175,10 +1182,10 @@ class EventSubscription(BaseModel): self.created_at = iso_8601_datetime_with_milliseconds(datetime.datetime.now()) @property - def es_arn(self): + def es_arn(self) -> str: return f"arn:aws:rds:{self.region_name}:{self.customer_aws_id}:es:{self.subscription_name}" - def to_xml(self): + def to_xml(self) -> str: template = Template( """ @@ -1210,31 +1217,37 @@ class EventSubscription(BaseModel): ) return template.render(subscription=self) - def get_tags(self): + def get_tags(self) -> List[Dict[str, str]]: return self.tags - def add_tags(self, tags): + def add_tags(self, tags: List[Dict[str, str]]) -> List[Dict[str, str]]: new_keys = [tag_set["Key"] for tag_set in tags] self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys] self.tags.extend(tags) return self.tags - def remove_tags(self, tag_keys): + def remove_tags(self, tag_keys: List[str]) -> None: self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] class SecurityGroup(CloudFormationModel): - def __init__(self, account_id, group_name, description, tags): + def __init__( + self, + account_id: str, + group_name: str, + description: str, + tags: List[Dict[str, str]], + ): self.group_name = group_name self.description = description self.status = "authorized" - self.ip_ranges = [] - self.ec2_security_groups = [] + self.ip_ranges: List[Any] = [] + self.ec2_security_groups: List[Any] = [] self.tags = tags self.owner_id = account_id self.vpc_id = None - def to_xml(self): + def to_xml(self) -> str: template = Template( """ @@ -1263,7 +1276,7 @@ class SecurityGroup(CloudFormationModel): ) return template.render(security_group=self) - def to_json(self): + def to_json(self) -> str: template = Template( """{ "DBSecurityGroupDescription": "{{ security_group.description }}", @@ -1280,25 +1293,30 @@ class SecurityGroup(CloudFormationModel): ) return template.render(security_group=self) - def authorize_cidr(self, cidr_ip): + def authorize_cidr(self, cidr_ip: str) -> None: self.ip_ranges.append(cidr_ip) - def authorize_security_group(self, security_group): + def authorize_security_group(self, security_group: str) -> None: self.ec2_security_groups.append(security_group) @staticmethod - def cloudformation_name_type(): - return None + def cloudformation_name_type() -> str: + return "" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-rds-dbsecuritygroup.html return "AWS::RDS::DBSecurityGroup" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "SecurityGroup": properties = cloudformation_json["Properties"] group_name = resource_name.lower() description = properties["GroupDescription"] @@ -1322,25 +1340,33 @@ class SecurityGroup(CloudFormationModel): security_group.authorize_security_group(subnet) return security_group - def get_tags(self): + def get_tags(self) -> List[Dict[str, str]]: return self.tags - def add_tags(self, tags): + def add_tags(self, tags: List[Dict[str, str]]) -> List[Dict[str, str]]: new_keys = [tag_set["Key"] for tag_set in tags] self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys] self.tags.extend(tags) return self.tags - def remove_tags(self, tag_keys): + def remove_tags(self, tag_keys: List[str]) -> None: self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] - def delete(self, account_id, region_name): + def delete(self, account_id: str, region_name: str) -> None: backend = rds_backends[account_id][region_name] backend.delete_security_group(self.group_name) class SubnetGroup(CloudFormationModel): - def __init__(self, subnet_name, description, subnets, tags, region, account_id): + def __init__( + self, + subnet_name: str, + description: str, + subnets: List[Any], + tags: List[Dict[str, str]], + region: str, + account_id: str, + ): self.subnet_name = subnet_name self.description = description self.subnets = subnets @@ -1351,10 +1377,10 @@ class SubnetGroup(CloudFormationModel): self.account_id = account_id @property - def sg_arn(self): + def sg_arn(self) -> str: return f"arn:aws:rds:{self.region}:{self.account_id}:subgrp:{self.subnet_name}" - def to_xml(self): + def to_xml(self) -> str: template = Template( """ {{ subnet_group.vpc_id }} @@ -1378,7 +1404,7 @@ class SubnetGroup(CloudFormationModel): ) return template.render(subnet_group=self) - def to_json(self): + def to_json(self) -> str: template = Template( """"DBSubnetGroup": { "VpcId": "{{ subnet_group.vpc_id }}", @@ -1402,18 +1428,23 @@ class SubnetGroup(CloudFormationModel): return template.render(subnet_group=self) @staticmethod - def cloudformation_name_type(): + def cloudformation_name_type() -> str: return "DBSubnetGroupName" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-rds-dbsubnetgroup.html return "AWS::RDS::DBSubnetGroup" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "SubnetGroup": properties = cloudformation_json["Properties"] description = properties["DBSubnetGroupDescription"] @@ -1431,44 +1462,44 @@ class SubnetGroup(CloudFormationModel): ) return subnet_group - def get_tags(self): + def get_tags(self) -> List[Dict[str, str]]: return self.tags - def add_tags(self, tags): + def add_tags(self, tags: List[Dict[str, str]]) -> List[Dict[str, str]]: new_keys = [tag_set["Key"] for tag_set in tags] self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys] self.tags.extend(tags) return self.tags - def remove_tags(self, tag_keys): + def remove_tags(self, tag_keys: List[str]) -> None: self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] - def delete(self, account_id, region_name): + def delete(self, account_id: str, region_name: str) -> None: backend = rds_backends[account_id][region_name] backend.delete_subnet_group(self.subnet_name) class RDSBackend(BaseBackend): - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) self.arn_regex = re_compile( r"^arn:aws:rds:.*:[0-9]*:(db|cluster|es|og|pg|ri|secgrp|snapshot|cluster-snapshot|subgrp):.*$" ) self.clusters: Dict[str, Cluster] = OrderedDict() - self.global_clusters = OrderedDict() - self.databases = OrderedDict() - self.database_snapshots = OrderedDict() - self.cluster_snapshots = OrderedDict() - self.export_tasks = OrderedDict() - self.event_subscriptions = OrderedDict() - self.db_parameter_groups = {} - self.db_cluster_parameter_groups = {} - self.option_groups = {} - self.security_groups = {} - self.subnet_groups = {} - self._db_cluster_options = None + self.global_clusters: Dict[str, GlobalCluster] = OrderedDict() + self.databases: Dict[str, Database] = OrderedDict() + self.database_snapshots: Dict[str, DatabaseSnapshot] = OrderedDict() + self.cluster_snapshots: Dict[str, ClusterSnapshot] = OrderedDict() + self.export_tasks: Dict[str, ExportTask] = OrderedDict() + self.event_subscriptions: Dict[str, EventSubscription] = OrderedDict() + self.db_parameter_groups: Dict[str, DBParameterGroup] = {} + self.db_cluster_parameter_groups: Dict[str, DBClusterParameterGroup] = {} + self.option_groups: Dict[str, OptionGroup] = {} + self.security_groups: Dict[str, SecurityGroup] = {} + self.subnet_groups: Dict[str, SubnetGroup] = {} + self._db_cluster_options: Optional[List[Dict[str, Any]]] = None - def reset(self): + def reset(self) -> None: self.neptune.reset() super().reset() @@ -1477,7 +1508,7 @@ class RDSBackend(BaseBackend): return neptune_backends[self.account_id][self.region_name] @property - def db_cluster_options(self): + def db_cluster_options(self) -> List[Dict[str, Any]]: # type: ignore if self._db_cluster_options is None: from moto.rds.utils import decode_orderable_db_instance @@ -1490,7 +1521,9 @@ class RDSBackend(BaseBackend): return self._db_cluster_options @staticmethod - def default_vpc_endpoint_service(service_region, zones): + def default_vpc_endpoint_service( + service_region: str, zones: List[str] + ) -> List[Dict[str, str]]: """Default VPC endpoint service.""" return BaseBackend.default_vpc_endpoint_service_factory( service_region, zones, "rds" @@ -1498,7 +1531,7 @@ class RDSBackend(BaseBackend): service_region, zones, "rds-data" ) - def create_db_instance(self, db_kwargs): + def create_db_instance(self, db_kwargs: Dict[str, Any]) -> Database: database_id = db_kwargs["db_instance_identifier"] database = Database(**db_kwargs) @@ -1515,8 +1548,11 @@ class RDSBackend(BaseBackend): return database def create_db_snapshot( - self, db_instance_identifier, db_snapshot_identifier, tags=None - ): + self, + db_instance_identifier: str, + db_snapshot_identifier: str, + tags: Optional[List[Dict[str, str]]] = None, + ) -> DatabaseSnapshot: database = self.databases.get(db_instance_identifier) if not database: raise DBInstanceNotFoundError(db_instance_identifier) @@ -1535,8 +1571,11 @@ class RDSBackend(BaseBackend): return snapshot def copy_database_snapshot( - self, source_snapshot_identifier, target_snapshot_identifier, tags=None - ): + self, + source_snapshot_identifier: str, + target_snapshot_identifier: str, + tags: Optional[List[Dict[str, str]]] = None, + ) -> DatabaseSnapshot: if source_snapshot_identifier not in self.database_snapshots: raise DBSnapshotNotFoundError(source_snapshot_identifier) if target_snapshot_identifier in self.database_snapshots: @@ -1558,22 +1597,22 @@ class RDSBackend(BaseBackend): return target_snapshot - def delete_db_snapshot(self, db_snapshot_identifier): + def delete_db_snapshot(self, db_snapshot_identifier: str) -> DatabaseSnapshot: if db_snapshot_identifier not in self.database_snapshots: raise DBSnapshotNotFoundError(db_snapshot_identifier) return self.database_snapshots.pop(db_snapshot_identifier) - def promote_read_replica(self, db_kwargs): + def promote_read_replica(self, db_kwargs: Dict[str, Any]) -> Database: database_id = db_kwargs["db_instance_identifier"] - database = self.databases.get(database_id) + database = self.databases[database_id] if database.is_replica: database.is_replica = False database.update(db_kwargs) return database - def create_db_instance_read_replica(self, db_kwargs): + def create_db_instance_read_replica(self, db_kwargs: Dict[str, Any]) -> Database: database_id = db_kwargs["db_instance_identifier"] source_database_id = db_kwargs["source_db_identifier"] primary = self.find_db_from_id(source_database_id) @@ -1589,7 +1628,9 @@ class RDSBackend(BaseBackend): primary.add_replica(replica) return replica - def describe_db_instances(self, db_instance_identifier=None, filters=None): + def describe_db_instances( + self, db_instance_identifier: Optional[str] = None, filters: Any = None + ) -> List[Database]: databases = self.databases if db_instance_identifier: filters = merge_filters( @@ -1602,8 +1643,11 @@ class RDSBackend(BaseBackend): return list(databases.values()) def describe_db_snapshots( - self, db_instance_identifier, db_snapshot_identifier, filters=None - ): + self, + db_instance_identifier: Optional[str], + db_snapshot_identifier: str, + filters: Optional[Dict[str, Any]] = None, + ) -> List[DatabaseSnapshot]: snapshots = self.database_snapshots if db_instance_identifier: filters = merge_filters( @@ -1619,7 +1663,9 @@ class RDSBackend(BaseBackend): raise DBSnapshotNotFoundError(db_snapshot_identifier) return list(snapshots.values()) - def modify_db_instance(self, db_instance_identifier, db_kwargs): + def modify_db_instance( + self, db_instance_identifier: str, db_kwargs: Dict[str, Any] + ) -> Database: database = self.describe_db_instances(db_instance_identifier)[0] if "new_db_instance_identifier" in db_kwargs: del self.databases[db_instance_identifier] @@ -1630,18 +1676,19 @@ class RDSBackend(BaseBackend): preferred_backup_window = db_kwargs.get("preferred_backup_window") preferred_maintenance_window = db_kwargs.get("preferred_maintenance_window") msg = valid_preferred_maintenance_window( - preferred_maintenance_window, preferred_backup_window + preferred_maintenance_window, preferred_backup_window # type: ignore ) if msg: raise RDSClientError("InvalidParameterValue", msg) database.update(db_kwargs) return database - def reboot_db_instance(self, db_instance_identifier): - database = self.describe_db_instances(db_instance_identifier)[0] - return database + def reboot_db_instance(self, db_instance_identifier: str) -> Database: + return self.describe_db_instances(db_instance_identifier)[0] - def restore_db_instance_from_db_snapshot(self, from_snapshot_id, overrides): + def restore_db_instance_from_db_snapshot( + self, from_snapshot_id: str, overrides: Dict[str, Any] + ) -> Database: snapshot = self.describe_db_snapshots( db_instance_identifier=None, db_snapshot_identifier=from_snapshot_id )[0] @@ -1658,12 +1705,14 @@ class RDSBackend(BaseBackend): return self.create_db_instance(new_instance_props) - def stop_db_instance(self, db_instance_identifier, db_snapshot_identifier=None): + def stop_db_instance( + self, db_instance_identifier: str, db_snapshot_identifier: Optional[str] = None + ) -> Database: database = self.describe_db_instances(db_instance_identifier)[0] # todo: certain rds types not allowed to be stopped at this time. # https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_StopInstance.html#USER_StopInstance.Limitations if database.is_replica or ( - database.multi_az and database.engine.lower().startswith("sqlserver") + database.multi_az and database.engine.lower().startswith("sqlserver") # type: ignore ): # todo: more db types not supported by stop/start instance api raise InvalidDBClusterStateFaultError(db_instance_identifier) @@ -1674,7 +1723,7 @@ class RDSBackend(BaseBackend): database.status = "stopped" return database - def start_db_instance(self, db_instance_identifier): + def start_db_instance(self, db_instance_identifier: str) -> Database: database = self.describe_db_instances(db_instance_identifier)[0] # todo: bunch of different error messages to be generated from this api call if database.status != "stopped": @@ -1682,7 +1731,7 @@ class RDSBackend(BaseBackend): database.status = "available" return database - def find_db_from_id(self, db_id): + def find_db_from_id(self, db_id: str) -> Database: if self.arn_regex.match(db_id): arn_breakdown = db_id.split(":") region = arn_breakdown[3] @@ -1694,7 +1743,9 @@ class RDSBackend(BaseBackend): return backend.describe_db_instances(db_name)[0] - def delete_db_instance(self, db_instance_identifier, db_snapshot_name=None): + def delete_db_instance( + self, db_instance_identifier: str, db_snapshot_name: Optional[str] = None + ) -> Database: if db_instance_identifier in self.databases: if self.databases[db_instance_identifier].deletion_protection: raise InvalidParameterValue( @@ -1704,7 +1755,7 @@ class RDSBackend(BaseBackend): self.create_db_snapshot(db_instance_identifier, db_snapshot_name) database = self.databases.pop(db_instance_identifier) if database.is_replica: - primary = self.find_db_from_id(database.source_db_identifier) + primary = self.find_db_from_id(database.source_db_identifier) # type: ignore primary.remove_replica(database) if database.db_cluster_identifier in self.clusters: self.clusters[database.db_cluster_identifier].cluster_members.remove( @@ -1715,50 +1766,58 @@ class RDSBackend(BaseBackend): else: raise DBInstanceNotFoundError(db_instance_identifier) - def create_db_security_group(self, group_name, description, tags): + def create_db_security_group( + self, group_name: str, description: str, tags: List[Dict[str, str]] + ) -> SecurityGroup: security_group = SecurityGroup(self.account_id, group_name, description, tags) self.security_groups[group_name] = security_group return security_group - def describe_security_groups(self, security_group_name): + def describe_security_groups(self, security_group_name: str) -> List[SecurityGroup]: if security_group_name: if security_group_name in self.security_groups: return [self.security_groups[security_group_name]] else: raise DBSecurityGroupNotFoundError(security_group_name) - return self.security_groups.values() + return list(self.security_groups.values()) - def delete_security_group(self, security_group_name): + def delete_security_group(self, security_group_name: str) -> SecurityGroup: if security_group_name in self.security_groups: return self.security_groups.pop(security_group_name) else: raise DBSecurityGroupNotFoundError(security_group_name) - def delete_db_parameter_group(self, db_parameter_group_name): + def delete_db_parameter_group( + self, db_parameter_group_name: str + ) -> "DBParameterGroup": if db_parameter_group_name in self.db_parameter_groups: return self.db_parameter_groups.pop(db_parameter_group_name) else: raise DBParameterGroupNotFoundError(db_parameter_group_name) - def authorize_security_group(self, security_group_name, cidr_ip): + def authorize_security_group( + self, security_group_name: str, cidr_ip: str + ) -> SecurityGroup: security_group = self.describe_security_groups(security_group_name)[0] security_group.authorize_cidr(cidr_ip) return security_group def create_subnet_group( self, - subnet_name, - description, - subnets, - tags, - ): + subnet_name: str, + description: str, + subnets: List[Any], + tags: List[Dict[str, str]], + ) -> SubnetGroup: subnet_group = SubnetGroup( subnet_name, description, subnets, tags, self.region_name, self.account_id ) self.subnet_groups[subnet_name] = subnet_group return subnet_group - def describe_db_subnet_groups(self, subnet_group_name): + def describe_db_subnet_groups( + self, subnet_group_name: str + ) -> Iterable[SubnetGroup]: if subnet_group_name: if subnet_group_name in self.subnet_groups: return [self.subnet_groups[subnet_group_name]] @@ -1766,7 +1825,9 @@ class RDSBackend(BaseBackend): raise DBSubnetGroupNotFoundError(subnet_group_name) return self.subnet_groups.values() - def modify_db_subnet_group(self, subnet_name, description, subnets): + def modify_db_subnet_group( + self, subnet_name: str, description: str, subnets: List[str] + ) -> SubnetGroup: subnet_group = self.subnet_groups.pop(subnet_name) if not subnet_group: raise DBSubnetGroupNotFoundError(subnet_name) @@ -1776,13 +1837,13 @@ class RDSBackend(BaseBackend): subnet_group.description = description return subnet_group - def delete_subnet_group(self, subnet_name): + def delete_subnet_group(self, subnet_name: str) -> SubnetGroup: if subnet_name in self.subnet_groups: return self.subnet_groups.pop(subnet_name) else: raise DBSubnetGroupNotFoundError(subnet_name) - def create_option_group(self, option_group_kwargs): + def create_option_group(self, option_group_kwargs: Dict[str, Any]) -> "OptionGroup": option_group_id = option_group_kwargs["name"] # This list was verified against the AWS Console on 14 Dec 2022 # Having an automated way (using the CLI) would be nice, but AFAICS that's not possible @@ -1826,7 +1887,7 @@ class RDSBackend(BaseBackend): "InvalidParameterValue", "Invalid DB engine: non-existent" ) if ( - option_group_kwargs["major_engine_version"] + option_group_kwargs["major_engine_version"] # type: ignore not in valid_option_group_engines[option_group_kwargs["engine_name"]] ): raise RDSClientError( @@ -1853,13 +1914,15 @@ class RDSBackend(BaseBackend): self.option_groups[option_group_id] = option_group return option_group - def delete_option_group(self, option_group_name): + def delete_option_group(self, option_group_name: str) -> "OptionGroup": if option_group_name in self.option_groups: return self.option_groups.pop(option_group_name) else: raise OptionGroupNotFoundFaultError(option_group_name) - def describe_option_groups(self, option_group_kwargs): + def describe_option_groups( + self, option_group_kwargs: Dict[str, Any] + ) -> List["OptionGroup"]: option_group_list = [] if option_group_kwargs["marker"]: @@ -1903,7 +1966,9 @@ class RDSBackend(BaseBackend): return option_group_list[marker : max_records + marker] @staticmethod - def describe_option_group_options(engine_name, major_engine_version=None): + def describe_option_group_options( + engine_name: str, major_engine_version: Optional[str] = None + ) -> str: default_option_group_options = { "mysql": { "5.6": '\n \n \n \n 5.611211TrueInnodb Memcached for MySQLMEMCACHED1-4294967295STATIC1TrueSpecifies how many memcached read operations (get) to perform before doing a COMMIT to start a new transactionDAEMON_MEMCACHED_R_BATCH_SIZE1-4294967295STATIC1TrueSpecifies how many memcached write operations, such as add, set, or incr, to perform before doing a COMMIT to start a new transactionDAEMON_MEMCACHED_W_BATCH_SIZE1-1073741824DYNAMIC5TrueSpecifies how often to auto-commit idle connections that use the InnoDB memcached interface.INNODB_API_BK_COMMIT_INTERVAL0,1STATIC0TrueDisables the use of row locks when using the InnoDB memcached interface.INNODB_API_DISABLE_ROWLOCK0,1STATIC0TrueLocks the table used by the InnoDB memcached plugin, so that it cannot be dropped or altered by DDL through the SQL interface.INNODB_API_ENABLE_MDL0-3STATIC0TrueLets you control the transaction isolation level on queries processed by the memcached interface.INNODB_API_TRX_LEVELauto,ascii,binarySTATICautoTrueThe binding protocol to use which can be either auto, ascii, or binary. The default is auto which means the server automatically negotiates the protocol with the client.BINDING_PROTOCOL1-2048STATIC1024TrueThe backlog queue configures how many network connections can be waiting to be processed by memcachedBACKLOG_QUEUE_LIMIT0,1STATIC0TrueDisable the use of compare and swap (CAS) which reduces the per-item size by 8 bytes.CAS_DISABLED1-48STATIC48TrueMinimum chunk size in bytes to allocate for the smallest item\'s key, value, and flags. The default is 48 and you can get a significant memory efficiency gain with a lower value.CHUNK_SIZE1-2STATIC1.25TrueChunk size growth factor that controls the size of each successive chunk with each chunk growing times this amount larger than the previous chunk.CHUNK_SIZE_GROWTH_FACTOR0,1STATIC0TrueIf enabled when there is no more memory to store items, memcached will return an error rather than evicting items.ERROR_ON_MEMORY_EXHAUSTED10-1024STATIC1024TrueMaximum number of concurrent connections. Setting this value to anything less than 10 prevents MySQL from starting.MAX_SIMULTANEOUS_CONNECTIONSv,vv,vvvSTATICvTrueVerbose level for memcached.VERBOSITYmysql\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', @@ -1945,8 +2010,11 @@ class RDSBackend(BaseBackend): return default_option_group_options[engine_name]["all"] def modify_option_group( - self, option_group_name, options_to_include=None, options_to_remove=None - ): + self, + option_group_name: str, + options_to_include: Optional[List[Dict[str, Any]]] = None, + options_to_remove: Optional[List[Dict[str, Any]]] = None, + ) -> "OptionGroup": if option_group_name not in self.option_groups: raise OptionGroupNotFoundFaultError(option_group_name) if not options_to_include and not options_to_remove: @@ -1960,7 +2028,9 @@ class RDSBackend(BaseBackend): self.option_groups[option_group_name].add_options(options_to_include) return self.option_groups[option_group_name] - def create_db_parameter_group(self, db_parameter_group_kwargs): + def create_db_parameter_group( + self, db_parameter_group_kwargs: Dict[str, Any] + ) -> "DBParameterGroup": db_parameter_group_id = db_parameter_group_kwargs["name"] if db_parameter_group_id in self.db_parameter_groups: raise RDSClientError( @@ -1983,7 +2053,9 @@ class RDSBackend(BaseBackend): self.db_parameter_groups[db_parameter_group_id] = db_parameter_group return db_parameter_group - def describe_db_parameter_groups(self, db_parameter_group_kwargs): + def describe_db_parameter_groups( + self, db_parameter_group_kwargs: Dict[str, Any] + ) -> List["DBParameterGroup"]: db_parameter_group_list = [] if db_parameter_group_kwargs.get("marker"): @@ -2014,8 +2086,10 @@ class RDSBackend(BaseBackend): return db_parameter_group_list[marker : max_records + marker] def modify_db_parameter_group( - self, db_parameter_group_name, db_parameter_group_parameters - ): + self, + db_parameter_group_name: str, + db_parameter_group_parameters: Iterable[Dict[str, Any]], + ) -> "DBParameterGroup": if db_parameter_group_name not in self.db_parameter_groups: raise DBParameterGroupNotFoundError(db_parameter_group_name) @@ -2024,16 +2098,19 @@ class RDSBackend(BaseBackend): return db_parameter_group - def describe_db_cluster_parameters(self): + def describe_db_cluster_parameters(self) -> List[Dict[str, Any]]: return [] - def create_db_cluster(self, kwargs): + def create_db_cluster(self, kwargs: Dict[str, Any]) -> Cluster: cluster_id = kwargs["db_cluster_identifier"] kwargs["account_id"] = self.account_id cluster = Cluster(**kwargs) self.clusters[cluster_id] = cluster - if (cluster.global_cluster_identifier or "") in self.global_clusters: + if ( + cluster.global_cluster_identifier + and cluster.global_cluster_identifier in self.global_clusters + ): global_cluster = self.global_clusters[cluster.global_cluster_identifier] global_cluster.members.append(cluster.db_cluster_arn) @@ -2046,11 +2123,11 @@ class RDSBackend(BaseBackend): cluster.status = "available" # Already set the final status in the background return initial_state - def modify_db_cluster(self, kwargs): + def modify_db_cluster(self, kwargs: Dict[str, Any]) -> Cluster: cluster_id = kwargs["db_cluster_identifier"] if cluster_id in self.neptune.clusters: - return self.neptune.modify_db_cluster(kwargs) + return self.neptune.modify_db_cluster(kwargs) # type: ignore cluster = self.clusters[cluster_id] del self.clusters[cluster_id] @@ -2076,14 +2153,17 @@ class RDSBackend(BaseBackend): def promote_read_replica_db_cluster(self, db_cluster_identifier: str) -> Cluster: cluster = self.clusters[db_cluster_identifier] - source_cluster = find_cluster(cluster.replication_source_identifier) + source_cluster = find_cluster(cluster.replication_source_identifier) # type: ignore source_cluster.read_replica_identifiers.remove(cluster.db_cluster_arn) cluster.replication_source_identifier = None return cluster def create_db_cluster_snapshot( - self, db_cluster_identifier, db_snapshot_identifier, tags=None - ): + self, + db_cluster_identifier: str, + db_snapshot_identifier: str, + tags: Optional[List[Dict[str, str]]] = None, + ) -> ClusterSnapshot: cluster = self.clusters.get(db_cluster_identifier) if cluster is None: raise DBClusterNotFoundError(db_cluster_identifier) @@ -2102,8 +2182,11 @@ class RDSBackend(BaseBackend): return snapshot def copy_cluster_snapshot( - self, source_snapshot_identifier, target_snapshot_identifier, tags=None - ): + self, + source_snapshot_identifier: str, + target_snapshot_identifier: str, + tags: Optional[List[Dict[str, str]]] = None, + ) -> ClusterSnapshot: if source_snapshot_identifier not in self.cluster_snapshots: raise DBClusterSnapshotNotFoundError(source_snapshot_identifier) if target_snapshot_identifier in self.cluster_snapshots: @@ -2116,21 +2199,23 @@ class RDSBackend(BaseBackend): if tags is None: tags = source_snapshot.tags else: - tags = self._merge_tags(source_snapshot.tags, tags) + tags = self._merge_tags(source_snapshot.tags, tags) # type: ignore target_snapshot = ClusterSnapshot( source_snapshot.cluster, target_snapshot_identifier, tags ) self.cluster_snapshots[target_snapshot_identifier] = target_snapshot return target_snapshot - def delete_db_cluster_snapshot(self, db_snapshot_identifier): + def delete_db_cluster_snapshot( + self, db_snapshot_identifier: str + ) -> ClusterSnapshot: if db_snapshot_identifier not in self.cluster_snapshots: raise DBClusterSnapshotNotFoundError(db_snapshot_identifier) return self.cluster_snapshots.pop(db_snapshot_identifier) def describe_db_clusters( - self, cluster_identifier=None, filters=None + self, cluster_identifier: Optional[str] = None, filters: Any = None ) -> List[Cluster]: clusters = self.clusters clusters_neptune = self.neptune.clusters @@ -2143,11 +2228,14 @@ class RDSBackend(BaseBackend): ) if cluster_identifier and not (clusters or clusters_neptune): raise DBClusterNotFoundError(cluster_identifier) - return list(clusters.values()) + list(clusters_neptune.values()) + return list(clusters.values()) + list(clusters_neptune.values()) # type: ignore def describe_db_cluster_snapshots( - self, db_cluster_identifier, db_snapshot_identifier, filters=None - ): + self, + db_cluster_identifier: Optional[str], + db_snapshot_identifier: str, + filters: Any = None, + ) -> List[ClusterSnapshot]: snapshots = self.cluster_snapshots if db_cluster_identifier: filters = merge_filters(filters, {"db-cluster-id": [db_cluster_identifier]}) @@ -2161,7 +2249,9 @@ class RDSBackend(BaseBackend): raise DBClusterSnapshotNotFoundError(db_snapshot_identifier) return list(snapshots.values()) - def delete_db_cluster(self, cluster_identifier, snapshot_name=None): + def delete_db_cluster( + self, cluster_identifier: str, snapshot_name: Optional[str] = None + ) -> Cluster: if cluster_identifier in self.clusters: cluster = self.clusters[cluster_identifier] if cluster.deletion_protection: @@ -2177,12 +2267,12 @@ class RDSBackend(BaseBackend): self.create_db_cluster_snapshot(cluster_identifier, snapshot_name) return self.clusters.pop(cluster_identifier) if cluster_identifier in self.neptune.clusters: - return self.neptune.delete_db_cluster(cluster_identifier) + return self.neptune.delete_db_cluster(cluster_identifier) # type: ignore raise DBClusterNotFoundError(cluster_identifier) - def start_db_cluster(self, cluster_identifier): + def start_db_cluster(self, cluster_identifier: str) -> Cluster: if cluster_identifier not in self.clusters: - return self.neptune.start_db_cluster(cluster_identifier) + return self.neptune.start_db_cluster(cluster_identifier) # type: ignore raise DBClusterNotFoundError(cluster_identifier) cluster = self.clusters[cluster_identifier] if cluster.status != "stopped": @@ -2194,7 +2284,9 @@ class RDSBackend(BaseBackend): cluster.status = "available" # This is the final status - already setting it in the background return temp_state - def restore_db_cluster_from_snapshot(self, from_snapshot_id, overrides): + def restore_db_cluster_from_snapshot( + self, from_snapshot_id: str, overrides: Dict[str, Any] + ) -> Cluster: snapshot = self.describe_db_cluster_snapshots( db_cluster_identifier=None, db_snapshot_identifier=from_snapshot_id )[0] @@ -2206,7 +2298,7 @@ class RDSBackend(BaseBackend): return self.create_db_cluster(new_cluster_props) - def stop_db_cluster(self, cluster_identifier): + def stop_db_cluster(self, cluster_identifier: str) -> "Cluster": if cluster_identifier not in self.clusters: raise DBClusterNotFoundError(cluster_identifier) cluster = self.clusters[cluster_identifier] @@ -2218,7 +2310,7 @@ class RDSBackend(BaseBackend): cluster.status = "stopped" return previous_state - def start_export_task(self, kwargs): + def start_export_task(self, kwargs: Dict[str, Any]) -> ExportTask: export_task_id = kwargs["export_task_identifier"] source_arn = kwargs["source_arn"] snapshot_id = source_arn.split(":")[-1] @@ -2235,7 +2327,9 @@ class RDSBackend(BaseBackend): raise DBClusterSnapshotNotFoundError(snapshot_id) if snapshot_type == "snapshot": - snapshot = self.database_snapshots[snapshot_id] + snapshot: Union[ + DatabaseSnapshot, ClusterSnapshot + ] = self.database_snapshots[snapshot_id] else: snapshot = self.cluster_snapshots[snapshot_id] @@ -2247,7 +2341,7 @@ class RDSBackend(BaseBackend): return export_task - def cancel_export_task(self, export_task_identifier): + def cancel_export_task(self, export_task_identifier: str) -> ExportTask: if export_task_identifier in self.export_tasks: export_task = self.export_tasks[export_task_identifier] export_task.status = "canceled" @@ -2255,7 +2349,9 @@ class RDSBackend(BaseBackend): return export_task raise ExportTaskNotFoundError(export_task_identifier) - def describe_export_tasks(self, export_task_identifier): + def describe_export_tasks( + self, export_task_identifier: str + ) -> Iterable[ExportTask]: if export_task_identifier: if export_task_identifier in self.export_tasks: return [self.export_tasks[export_task_identifier]] @@ -2263,7 +2359,7 @@ class RDSBackend(BaseBackend): raise ExportTaskNotFoundError(export_task_identifier) return self.export_tasks.values() - def create_event_subscription(self, kwargs): + def create_event_subscription(self, kwargs: Any) -> EventSubscription: subscription_name = kwargs["subscription_name"] if subscription_name in self.event_subscriptions: @@ -2275,12 +2371,14 @@ class RDSBackend(BaseBackend): return subscription - def delete_event_subscription(self, subscription_name): + def delete_event_subscription(self, subscription_name: str) -> EventSubscription: if subscription_name in self.event_subscriptions: return self.event_subscriptions.pop(subscription_name) raise SubscriptionNotFoundError(subscription_name) - def describe_event_subscriptions(self, subscription_name): + def describe_event_subscriptions( + self, subscription_name: str + ) -> Iterable[EventSubscription]: if subscription_name: if subscription_name in self.event_subscriptions: return [self.event_subscriptions[subscription_name]] @@ -2288,7 +2386,7 @@ class RDSBackend(BaseBackend): raise SubscriptionNotFoundError(subscription_name) return self.event_subscriptions.values() - def list_tags_for_resource(self, arn): + def list_tags_for_resource(self, arn: str) -> List[Dict[str, str]]: if self.arn_regex.match(arn): arn_breakdown = arn.split(":") resource_type = arn_breakdown[len(arn_breakdown) - 2] @@ -2332,48 +2430,48 @@ class RDSBackend(BaseBackend): ) return [] - def remove_tags_from_resource(self, arn, tag_keys): + def remove_tags_from_resource(self, arn: str, tag_keys: List[str]) -> None: if self.arn_regex.match(arn): arn_breakdown = arn.split(":") resource_type = arn_breakdown[len(arn_breakdown) - 2] resource_name = arn_breakdown[len(arn_breakdown) - 1] if resource_type == "db": # Database if resource_name in self.databases: - return self.databases[resource_name].remove_tags(tag_keys) + self.databases[resource_name].remove_tags(tag_keys) elif resource_type == "es": # Event Subscription if resource_name in self.event_subscriptions: - return self.event_subscriptions[resource_name].remove_tags(tag_keys) + self.event_subscriptions[resource_name].remove_tags(tag_keys) elif resource_type == "og": # Option Group if resource_name in self.option_groups: - return self.option_groups[resource_name].remove_tags(tag_keys) + self.option_groups[resource_name].remove_tags(tag_keys) elif resource_type == "pg": # Parameter Group if resource_name in self.db_parameter_groups: - return self.db_parameter_groups[resource_name].remove_tags(tag_keys) + self.db_parameter_groups[resource_name].remove_tags(tag_keys) elif resource_type == "ri": # Reserved DB instance return None elif resource_type == "secgrp": # DB security group if resource_name in self.security_groups: - return self.security_groups[resource_name].remove_tags(tag_keys) + self.security_groups[resource_name].remove_tags(tag_keys) elif resource_type == "snapshot": # DB Snapshot if resource_name in self.database_snapshots: - return self.database_snapshots[resource_name].remove_tags(tag_keys) + self.database_snapshots[resource_name].remove_tags(tag_keys) elif resource_type == "cluster": if resource_name in self.clusters: - return self.clusters[resource_name].remove_tags(tag_keys) + self.clusters[resource_name].remove_tags(tag_keys) if resource_name in self.neptune.clusters: - return self.neptune.clusters[resource_name].remove_tags(tag_keys) + self.neptune.clusters[resource_name].remove_tags(tag_keys) elif resource_type == "cluster-snapshot": # DB Cluster Snapshot if resource_name in self.cluster_snapshots: - return self.cluster_snapshots[resource_name].remove_tags(tag_keys) + self.cluster_snapshots[resource_name].remove_tags(tag_keys) elif resource_type == "subgrp": # DB subnet group if resource_name in self.subnet_groups: - return self.subnet_groups[resource_name].remove_tags(tag_keys) + self.subnet_groups[resource_name].remove_tags(tag_keys) else: raise RDSClientError( "InvalidParameterValue", f"Invalid resource name: {arn}" ) - def add_tags_to_resource(self, arn, tags): + def add_tags_to_resource(self, arn: str, tags: List[Dict[str, str]]) -> List[Dict[str, str]]: # type: ignore[return] if self.arn_regex.match(arn): arn_breakdown = arn.split(":") resource_type = arn_breakdown[-2] @@ -2415,7 +2513,7 @@ class RDSBackend(BaseBackend): ) @staticmethod - def _filter_resources(resources, filters, resource_class): + def _filter_resources(resources: Any, filters: Any, resource_class: Any) -> Any: # type: ignore[misc] try: filter_defs = resource_class.SUPPORTED_FILTERS validate_filters(filters, filter_defs) @@ -2427,13 +2525,15 @@ class RDSBackend(BaseBackend): raise InvalidParameterCombination(str(e)) @staticmethod - def _merge_tags(old_tags: list, new_tags: list): + def _merge_tags(old_tags: List[Dict[str, Any]], new_tags: List[Dict[str, Any]]) -> List[Dict[str, Any]]: # type: ignore[misc] tags_dict = dict() tags_dict.update({d["Key"]: d["Value"] for d in old_tags}) tags_dict.update({d["Key"]: d["Value"] for d in new_tags}) return [{"Key": k, "Value": v} for k, v in tags_dict.items()] - def describe_orderable_db_instance_options(self, engine, engine_version): + def describe_orderable_db_instance_options( + self, engine: str, engine_version: str + ) -> List[Dict[str, Any]]: """ Only the Aurora-Postgresql and Neptune-engine is currently implemented """ @@ -2451,10 +2551,10 @@ class RDSBackend(BaseBackend): def create_db_cluster_parameter_group( self, - group_name, - family, - description, - ): + group_name: str, + family: str, + description: str, + ) -> "DBClusterParameterGroup": group = DBClusterParameterGroup( account_id=self.account_id, region=self.region_name, @@ -2465,14 +2565,16 @@ class RDSBackend(BaseBackend): self.db_cluster_parameter_groups[group_name] = group return group - def describe_db_cluster_parameter_groups(self, group_name): + def describe_db_cluster_parameter_groups( + self, group_name: str + ) -> List["DBClusterParameterGroup"]: if group_name is not None: if group_name not in self.db_cluster_parameter_groups: raise DBClusterParameterGroupNotFoundError(group_name) return [self.db_cluster_parameter_groups[group_name]] return list(self.db_cluster_parameter_groups.values()) - def delete_db_cluster_parameter_group(self, group_name): + def delete_db_cluster_parameter_group(self, group_name: str) -> None: self.db_cluster_parameter_groups.pop(group_name) def create_global_cluster( @@ -2481,8 +2583,8 @@ class RDSBackend(BaseBackend): source_db_cluster_identifier: Optional[str], engine: Optional[str], engine_version: Optional[str], - storage_encrypted: Optional[bool], - deletion_protection: Optional[bool], + storage_encrypted: Optional[str], + deletion_protection: Optional[str], ) -> GlobalCluster: source_cluster = None if source_db_cluster_identifier is not None: @@ -2506,7 +2608,7 @@ class RDSBackend(BaseBackend): global_cluster = GlobalCluster( account_id=self.account_id, global_cluster_identifier=global_cluster_identifier, - engine=engine, + engine=engine, # type: ignore engine_version=engine_version, storage_encrypted=storage_encrypted, deletion_protection=deletion_protection, @@ -2517,15 +2619,15 @@ class RDSBackend(BaseBackend): global_cluster.members.append(source_cluster.db_cluster_arn) return global_cluster - def describe_global_clusters(self): + def describe_global_clusters(self) -> List[GlobalCluster]: return ( list(self.global_clusters.values()) - + self.neptune.describe_global_clusters() + + self.neptune.describe_global_clusters() # type: ignore ) def delete_global_cluster(self, global_cluster_identifier: str) -> GlobalCluster: try: - return self.neptune.delete_global_cluster(global_cluster_identifier) + return self.neptune.delete_global_cluster(global_cluster_identifier) # type: ignore except: # noqa: E722 Do not use bare except pass # It's not a Neptune Global Cluster - assume it's an RDS cluster instead global_cluster = self.global_clusters[global_cluster_identifier] @@ -2535,7 +2637,7 @@ class RDSBackend(BaseBackend): def remove_from_global_cluster( self, global_cluster_identifier: str, db_cluster_identifier: str - ) -> GlobalCluster: + ) -> Optional[GlobalCluster]: try: global_cluster = self.global_clusters[global_cluster_identifier] cluster = self.describe_db_clusters( @@ -2548,27 +2650,27 @@ class RDSBackend(BaseBackend): return None -class OptionGroup(object): +class OptionGroup: def __init__( self, - name, - engine_name, - major_engine_version, - region, - account_id, - description=None, + name: str, + engine_name: str, + major_engine_version: str, + region: str, + account_id: str, + description: Optional[str] = None, ): self.engine_name = engine_name self.major_engine_version = major_engine_version self.description = description self.name = name self.vpc_and_non_vpc_instance_memberships = False - self.options = {} + self.options: Dict[str, Any] = {} self.vpcId = "null" - self.tags = [] + self.tags: List[Dict[str, str]] = [] self.arn = f"arn:aws:rds:{region}:{account_id}:og:{name}" - def to_json(self): + def to_json(self) -> str: template = Template( """{ "VpcId": null, @@ -2583,7 +2685,7 @@ class OptionGroup(object): ) return template.render(option_group=self) - def to_xml(self): + def to_xml(self) -> str: template = Template( """ {{ option_group.name }} @@ -2597,39 +2699,51 @@ class OptionGroup(object): ) return template.render(option_group=self) - def remove_options(self, options_to_remove): # pylint: disable=unused-argument + def remove_options( + self, options_to_remove: Any # pylint: disable=unused-argument + ) -> None: # TODO: Check for option in self.options and remove if exists. Raise # error otherwise return - def add_options(self, options_to_add): # pylint: disable=unused-argument + def add_options( + self, options_to_add: Any # pylint: disable=unused-argument + ) -> None: # TODO: Validate option and add it to self.options. If invalid raise # error return - def get_tags(self): + def get_tags(self) -> List[Dict[str, str]]: return self.tags - def add_tags(self, tags): + def add_tags(self, tags: List[Dict[str, str]]) -> List[Dict[str, str]]: new_keys = [tag_set["Key"] for tag_set in tags] self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys] self.tags.extend(tags) return self.tags - def remove_tags(self, tag_keys): + def remove_tags(self, tag_keys: List[str]) -> None: self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] class DBParameterGroup(CloudFormationModel): - def __init__(self, account_id, name, description, family, tags, region): + def __init__( + self, + account_id: str, + name: Optional[str], + description: str, + family: Optional[str], + tags: List[Dict[str, str]], + region: str, + ): self.name = name self.description = description self.family = family self.tags = tags - self.parameters = defaultdict(dict) + self.parameters: Dict[str, Any] = defaultdict(dict) self.arn = f"arn:aws:rds:{region}:{account_id}:pg:{name}" - def to_xml(self): + def to_xml(self) -> str: template = Template( """ {{ param_group.name }} @@ -2640,40 +2754,45 @@ class DBParameterGroup(CloudFormationModel): ) return template.render(param_group=self) - def get_tags(self): + def get_tags(self) -> List[Dict[str, Any]]: return self.tags - def add_tags(self, tags): + def add_tags(self, tags: List[Dict[str, str]]) -> List[Dict[str, Any]]: new_keys = [tag_set["Key"] for tag_set in tags] self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys] self.tags.extend(tags) return self.tags - def remove_tags(self, tag_keys): + def remove_tags(self, tag_keys: List[str]) -> None: self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] - def update_parameters(self, new_parameters): + def update_parameters(self, new_parameters: Iterable[Dict[str, Any]]) -> None: for new_parameter in new_parameters: parameter = self.parameters[new_parameter["ParameterName"]] parameter.update(new_parameter) - def delete(self, account_id, region_name): + def delete(self, account_id: str, region_name: str) -> None: backend = rds_backends[account_id][region_name] backend.delete_db_parameter_group(self.name) @staticmethod - def cloudformation_name_type(): - return None + def cloudformation_name_type() -> str: + return "" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-rds-dbparametergroup.html return "AWS::RDS::DBParameterGroup" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "DBParameterGroup": properties = cloudformation_json["Properties"] db_parameter_group_kwargs = { @@ -2699,14 +2818,16 @@ class DBParameterGroup(CloudFormationModel): class DBClusterParameterGroup(CloudFormationModel): - def __init__(self, account_id, region, name, description, family): + def __init__( + self, account_id: str, region: str, name: str, description: str, family: str + ): self.name = name self.description = description self.family = family - self.parameters = defaultdict(dict) + self.parameters: Dict[str, Any] = defaultdict(dict) self.arn = f"arn:aws:rds:{region}:{account_id}:cpg:{name}" - def to_xml(self): + def to_xml(self) -> str: template = Template( """ {{ param_group.name }} diff --git a/moto/rds/responses.py b/moto/rds/responses.py index dc8fcee45..c22d63748 100644 --- a/moto/rds/responses.py +++ b/moto/rds/responses.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Any +from typing import Any, Dict, List, Iterable from moto.core.common_types import TYPE_RESPONSE from moto.core.responses import BaseResponse @@ -16,7 +16,7 @@ from .exceptions import DBParameterGroupNotFoundError class RDSResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="rds") # Neptune and RDS share a HTTP endpoint RDS is the lucky guy that catches all requests # So we have to determine whether we can handle an incoming request here, or whether it needs redirecting to Neptune @@ -31,13 +31,13 @@ class RDSResponse(BaseResponse): self.neptune.setup_class(request, full_url, headers) return super()._dispatch(request, full_url, headers) - def __getattribute__(self, name: str): + def __getattribute__(self, name: str) -> Any: if name in ["create_db_cluster", "create_global_cluster"]: if self._get_param("Engine") == "neptune": return object.__getattribute__(self.neptune, name) return object.__getattribute__(self, name) - def _get_db_kwargs(self): + def _get_db_kwargs(self) -> Dict[str, Any]: args = { "auto_minor_version_upgrade": self._get_param("AutoMinorVersionUpgrade"), "allocated_storage": self._get_int_param("AllocatedStorage"), @@ -90,7 +90,7 @@ class RDSResponse(BaseResponse): args["tags"] = self.unpack_list_params("Tags", "Tag") return args - def _get_modify_db_cluster_kwargs(self): + def _get_modify_db_cluster_kwargs(self) -> Dict[str, Any]: args = { "auto_minor_version_upgrade": self._get_param("AutoMinorVersionUpgrade"), "allocated_storage": self._get_int_param("AllocatedStorage"), @@ -146,7 +146,7 @@ class RDSResponse(BaseResponse): args["tags"] = self.unpack_list_params("Tags", "Tag") return args - def _get_db_replica_kwargs(self): + def _get_db_replica_kwargs(self) -> Dict[str, Any]: return { "auto_minor_version_upgrade": self._get_param("AutoMinorVersionUpgrade"), "availability_zone": self._get_param("AvailabilityZone"), @@ -161,7 +161,7 @@ class RDSResponse(BaseResponse): "storage_type": self._get_param("StorageType"), } - def _get_option_group_kwargs(self): + def _get_option_group_kwargs(self) -> Dict[str, Any]: return { "major_engine_version": self._get_param("MajorEngineVersion"), "description": self._get_param("OptionGroupDescription"), @@ -169,7 +169,7 @@ class RDSResponse(BaseResponse): "name": self._get_param("OptionGroupName"), } - def _get_db_parameter_group_kwargs(self): + def _get_db_parameter_group_kwargs(self) -> Dict[str, Any]: return { "description": self._get_param("Description"), "family": self._get_param("DBParameterGroupFamily"), @@ -177,7 +177,7 @@ class RDSResponse(BaseResponse): "tags": self.unpack_list_params("Tags", "Tag"), } - def _get_db_cluster_kwargs(self): + def _get_db_cluster_kwargs(self) -> Dict[str, Any]: return { "availability_zones": self._get_multi_param( "AvailabilityZones.AvailabilityZone" @@ -213,7 +213,7 @@ class RDSResponse(BaseResponse): ), } - def _get_export_task_kwargs(self): + def _get_export_task_kwargs(self) -> Dict[str, Any]: return { "export_task_identifier": self._get_param("ExportTaskIdentifier"), "source_arn": self._get_param("SourceArn"), @@ -224,7 +224,7 @@ class RDSResponse(BaseResponse): "export_only": self.unpack_list_params("ExportOnly", "member"), } - def _get_event_subscription_kwargs(self): + def _get_event_subscription_kwargs(self) -> Dict[str, Any]: return { "subscription_name": self._get_param("SubscriptionName"), "sns_topic_arn": self._get_param("SnsTopicArn"), @@ -237,29 +237,31 @@ class RDSResponse(BaseResponse): "tags": self.unpack_list_params("Tags", "Tag"), } - def unpack_list_params(self, label, child_label): + def unpack_list_params(self, label: str, child_label: str) -> List[Dict[str, Any]]: root = self._get_multi_param_dict(label) or {} return root.get(child_label, []) - def create_db_instance(self): + def create_db_instance(self) -> str: db_kwargs = self._get_db_kwargs() database = self.backend.create_db_instance(db_kwargs) template = self.response_template(CREATE_DATABASE_TEMPLATE) return template.render(database=database) - def create_db_instance_read_replica(self): + def create_db_instance_read_replica(self) -> str: db_kwargs = self._get_db_replica_kwargs() database = self.backend.create_db_instance_read_replica(db_kwargs) template = self.response_template(CREATE_DATABASE_REPLICA_TEMPLATE) return template.render(database=database) - def describe_db_instances(self): + def describe_db_instances(self) -> str: db_instance_identifier = self._get_param("DBInstanceIdentifier") filters = self._get_multi_param("Filters.Filter.") - filters = {f["Name"]: f["Values"] for f in filters} + filter_dict = {f["Name"]: f["Values"] for f in filters} all_instances = list( - self.backend.describe_db_instances(db_instance_identifier, filters=filters) + self.backend.describe_db_instances( + db_instance_identifier, filters=filter_dict + ) ) marker = self._get_param("Marker") all_ids = [instance.db_instance_identifier for instance in all_instances] @@ -278,7 +280,7 @@ class RDSResponse(BaseResponse): template = self.response_template(DESCRIBE_DATABASES_TEMPLATE) return template.render(databases=instances_resp, marker=next_marker) - def modify_db_instance(self): + def modify_db_instance(self) -> str: db_instance_identifier = self._get_param("DBInstanceIdentifier") db_kwargs = self._get_db_kwargs() # NOTE modify_db_instance does not support tags @@ -290,7 +292,7 @@ class RDSResponse(BaseResponse): template = self.response_template(MODIFY_DATABASE_TEMPLATE) return template.render(database=database) - def delete_db_instance(self): + def delete_db_instance(self) -> str: db_instance_identifier = self._get_param("DBInstanceIdentifier") db_snapshot_name = self._get_param("FinalDBSnapshotIdentifier") database = self.backend.delete_db_instance( @@ -299,13 +301,13 @@ class RDSResponse(BaseResponse): template = self.response_template(DELETE_DATABASE_TEMPLATE) return template.render(database=database) - def reboot_db_instance(self): + def reboot_db_instance(self) -> str: db_instance_identifier = self._get_param("DBInstanceIdentifier") database = self.backend.reboot_db_instance(db_instance_identifier) template = self.response_template(REBOOT_DATABASE_TEMPLATE) return template.render(database=database) - def create_db_snapshot(self): + def create_db_snapshot(self) -> str: db_instance_identifier = self._get_param("DBInstanceIdentifier") db_snapshot_identifier = self._get_param("DBSnapshotIdentifier") tags = self.unpack_list_params("Tags", "Tag") @@ -315,7 +317,7 @@ class RDSResponse(BaseResponse): template = self.response_template(CREATE_SNAPSHOT_TEMPLATE) return template.render(snapshot=snapshot) - def copy_db_snapshot(self): + def copy_db_snapshot(self) -> str: source_snapshot_identifier = self._get_param("SourceDBSnapshotIdentifier") target_snapshot_identifier = self._get_param("TargetDBSnapshotIdentifier") tags = self.unpack_list_params("Tags", "Tag") @@ -325,18 +327,18 @@ class RDSResponse(BaseResponse): template = self.response_template(COPY_SNAPSHOT_TEMPLATE) return template.render(snapshot=snapshot) - def describe_db_snapshots(self): + def describe_db_snapshots(self) -> str: db_instance_identifier = self._get_param("DBInstanceIdentifier") db_snapshot_identifier = self._get_param("DBSnapshotIdentifier") filters = self._get_multi_param("Filters.Filter.") - filters = {f["Name"]: f["Values"] for f in filters} + filter_dict = {f["Name"]: f["Values"] for f in filters} snapshots = self.backend.describe_db_snapshots( - db_instance_identifier, db_snapshot_identifier, filters + db_instance_identifier, db_snapshot_identifier, filter_dict ) template = self.response_template(DESCRIBE_SNAPSHOTS_TEMPLATE) return template.render(snapshots=snapshots) - def promote_read_replica(self): + def promote_read_replica(self) -> str: db_instance_identifier = self._get_param("DBInstanceIdentifier") db_kwargs = self._get_db_kwargs() database = self.backend.promote_read_replica(db_kwargs) @@ -344,13 +346,13 @@ class RDSResponse(BaseResponse): template = self.response_template(PROMOTE_REPLICA_TEMPLATE) return template.render(database=database) - def delete_db_snapshot(self): + def delete_db_snapshot(self) -> str: db_snapshot_identifier = self._get_param("DBSnapshotIdentifier") snapshot = self.backend.delete_db_snapshot(db_snapshot_identifier) template = self.response_template(DELETE_SNAPSHOT_TEMPLATE) return template.render(snapshot=snapshot) - def restore_db_instance_from_db_snapshot(self): + def restore_db_instance_from_db_snapshot(self) -> str: db_snapshot_identifier = self._get_param("DBSnapshotIdentifier") db_kwargs = self._get_db_kwargs() new_instance = self.backend.restore_db_instance_from_db_snapshot( @@ -359,27 +361,27 @@ class RDSResponse(BaseResponse): template = self.response_template(RESTORE_INSTANCE_FROM_SNAPSHOT_TEMPLATE) return template.render(database=new_instance) - def list_tags_for_resource(self): + def list_tags_for_resource(self) -> str: arn = self._get_param("ResourceName") template = self.response_template(LIST_TAGS_FOR_RESOURCE_TEMPLATE) tags = self.backend.list_tags_for_resource(arn) return template.render(tags=tags) - def add_tags_to_resource(self): + def add_tags_to_resource(self) -> str: arn = self._get_param("ResourceName") tags = self.unpack_list_params("Tags", "Tag") tags = self.backend.add_tags_to_resource(arn, tags) template = self.response_template(ADD_TAGS_TO_RESOURCE_TEMPLATE) return template.render(tags=tags) - def remove_tags_from_resource(self): + def remove_tags_from_resource(self) -> str: arn = self._get_param("ResourceName") tag_keys = self.unpack_list_params("TagKeys", "member") - self.backend.remove_tags_from_resource(arn, tag_keys) + self.backend.remove_tags_from_resource(arn, tag_keys) # type: ignore template = self.response_template(REMOVE_TAGS_FROM_RESOURCE_TEMPLATE) return template.render() - def stop_db_instance(self): + def stop_db_instance(self) -> str: db_instance_identifier = self._get_param("DBInstanceIdentifier") db_snapshot_identifier = self._get_param("DBSnapshotIdentifier") database = self.backend.stop_db_instance( @@ -388,13 +390,13 @@ class RDSResponse(BaseResponse): template = self.response_template(STOP_DATABASE_TEMPLATE) return template.render(database=database) - def start_db_instance(self): + def start_db_instance(self) -> str: db_instance_identifier = self._get_param("DBInstanceIdentifier") database = self.backend.start_db_instance(db_instance_identifier) template = self.response_template(START_DATABASE_TEMPLATE) return template.render(database=database) - def create_db_security_group(self): + def create_db_security_group(self) -> str: group_name = self._get_param("DBSecurityGroupName") description = self._get_param("DBSecurityGroupDescription") tags = self.unpack_list_params("Tags", "Tag") @@ -404,19 +406,19 @@ class RDSResponse(BaseResponse): template = self.response_template(CREATE_SECURITY_GROUP_TEMPLATE) return template.render(security_group=security_group) - def describe_db_security_groups(self): + def describe_db_security_groups(self) -> str: security_group_name = self._get_param("DBSecurityGroupName") security_groups = self.backend.describe_security_groups(security_group_name) template = self.response_template(DESCRIBE_SECURITY_GROUPS_TEMPLATE) return template.render(security_groups=security_groups) - def delete_db_security_group(self): + def delete_db_security_group(self) -> str: security_group_name = self._get_param("DBSecurityGroupName") security_group = self.backend.delete_security_group(security_group_name) template = self.response_template(DELETE_SECURITY_GROUP_TEMPLATE) return template.render(security_group=security_group) - def authorize_db_security_group_ingress(self): + def authorize_db_security_group_ingress(self) -> str: security_group_name = self._get_param("DBSecurityGroupName") cidr_ip = self._get_param("CIDRIP") security_group = self.backend.authorize_security_group( @@ -425,7 +427,7 @@ class RDSResponse(BaseResponse): template = self.response_template(AUTHORIZE_SECURITY_GROUP_TEMPLATE) return template.render(security_group=security_group) - def create_db_subnet_group(self): + def create_db_subnet_group(self) -> str: subnet_name = self._get_param("DBSubnetGroupName") description = self._get_param("DBSubnetGroupDescription") subnet_ids = self._get_multi_param("SubnetIds.SubnetIdentifier") @@ -440,13 +442,13 @@ class RDSResponse(BaseResponse): template = self.response_template(CREATE_SUBNET_GROUP_TEMPLATE) return template.render(subnet_group=subnet_group) - def describe_db_subnet_groups(self): + def describe_db_subnet_groups(self) -> str: subnet_name = self._get_param("DBSubnetGroupName") subnet_groups = self.backend.describe_db_subnet_groups(subnet_name) template = self.response_template(DESCRIBE_SUBNET_GROUPS_TEMPLATE) return template.render(subnet_groups=subnet_groups) - def modify_db_subnet_group(self): + def modify_db_subnet_group(self) -> str: subnet_name = self._get_param("DBSubnetGroupName") description = self._get_param("DBSubnetGroupDescription") subnet_ids = self._get_multi_param("SubnetIds.SubnetIdentifier") @@ -460,25 +462,25 @@ class RDSResponse(BaseResponse): template = self.response_template(MODIFY_SUBNET_GROUPS_TEMPLATE) return template.render(subnet_group=subnet_group) - def delete_db_subnet_group(self): + def delete_db_subnet_group(self) -> str: subnet_name = self._get_param("DBSubnetGroupName") subnet_group = self.backend.delete_subnet_group(subnet_name) template = self.response_template(DELETE_SUBNET_GROUP_TEMPLATE) return template.render(subnet_group=subnet_group) - def create_option_group(self): + def create_option_group(self) -> str: kwargs = self._get_option_group_kwargs() option_group = self.backend.create_option_group(kwargs) template = self.response_template(CREATE_OPTION_GROUP_TEMPLATE) return template.render(option_group=option_group) - def delete_option_group(self): + def delete_option_group(self) -> str: kwargs = self._get_option_group_kwargs() option_group = self.backend.delete_option_group(kwargs["name"]) template = self.response_template(DELETE_OPTION_GROUP_TEMPLATE) return template.render(option_group=option_group) - def describe_option_groups(self): + def describe_option_groups(self) -> str: kwargs = self._get_option_group_kwargs() kwargs["max_records"] = self._get_int_param("MaxRecords") kwargs["marker"] = self._get_param("Marker") @@ -486,15 +488,14 @@ class RDSResponse(BaseResponse): template = self.response_template(DESCRIBE_OPTION_GROUP_TEMPLATE) return template.render(option_groups=option_groups) - def describe_option_group_options(self): + def describe_option_group_options(self) -> str: engine_name = self._get_param("EngineName") major_engine_version = self._get_param("MajorEngineVersion") - option_group_options = self.backend.describe_option_group_options( + return self.backend.describe_option_group_options( engine_name, major_engine_version ) - return option_group_options - def modify_option_group(self): + def modify_option_group(self) -> str: option_group_name = self._get_param("OptionGroupName") count = 1 options_to_include = [] @@ -530,13 +531,13 @@ class RDSResponse(BaseResponse): template = self.response_template(MODIFY_OPTION_GROUP_TEMPLATE) return template.render(option_group=option_group) - def create_db_parameter_group(self): + def create_db_parameter_group(self) -> str: kwargs = self._get_db_parameter_group_kwargs() db_parameter_group = self.backend.create_db_parameter_group(kwargs) template = self.response_template(CREATE_DB_PARAMETER_GROUP_TEMPLATE) return template.render(db_parameter_group=db_parameter_group) - def describe_db_parameter_groups(self): + def describe_db_parameter_groups(self) -> str: kwargs = self._get_db_parameter_group_kwargs() kwargs["max_records"] = self._get_int_param("MaxRecords") kwargs["marker"] = self._get_param("Marker") @@ -544,7 +545,7 @@ class RDSResponse(BaseResponse): template = self.response_template(DESCRIBE_DB_PARAMETER_GROUPS_TEMPLATE) return template.render(db_parameter_groups=db_parameter_groups) - def modify_db_parameter_group(self): + def modify_db_parameter_group(self) -> str: db_parameter_group_name = self._get_param("DBParameterGroupName") db_parameter_group_parameters = self._get_db_parameter_group_parameters() db_parameter_group = self.backend.modify_db_parameter_group( @@ -553,8 +554,8 @@ class RDSResponse(BaseResponse): template = self.response_template(MODIFY_DB_PARAMETER_GROUP_TEMPLATE) return template.render(db_parameter_group=db_parameter_group) - def _get_db_parameter_group_parameters(self): - parameter_group_parameters = defaultdict(dict) + def _get_db_parameter_group_parameters(self) -> Iterable[Dict[str, Any]]: + parameter_group_parameters: Dict[str, Any] = defaultdict(dict) for param_name, value in self.querystring.items(): if not param_name.startswith("Parameters.Parameter"): continue @@ -567,7 +568,7 @@ class RDSResponse(BaseResponse): return parameter_group_parameters.values() - def describe_db_parameters(self): + def describe_db_parameters(self) -> str: db_parameter_group_name = self._get_param("DBParameterGroupName") db_parameter_groups = self.backend.describe_db_parameter_groups( {"name": db_parameter_group_name} @@ -578,13 +579,13 @@ class RDSResponse(BaseResponse): template = self.response_template(DESCRIBE_DB_PARAMETERS_TEMPLATE) return template.render(db_parameter_group=db_parameter_groups[0]) - def delete_db_parameter_group(self): + def delete_db_parameter_group(self) -> str: kwargs = self._get_db_parameter_group_kwargs() db_parameter_group = self.backend.delete_db_parameter_group(kwargs["name"]) template = self.response_template(DELETE_DB_PARAMETER_GROUP_TEMPLATE) return template.render(db_parameter_group=db_parameter_group) - def describe_db_cluster_parameters(self): + def describe_db_cluster_parameters(self) -> str: db_parameter_group_name = self._get_param("DBParameterGroupName") db_parameter_groups = self.backend.describe_db_cluster_parameters() if db_parameter_groups is None: @@ -593,29 +594,29 @@ class RDSResponse(BaseResponse): template = self.response_template(DESCRIBE_DB_CLUSTER_PARAMETERS_TEMPLATE) return template.render(db_parameter_group=db_parameter_groups) - def create_db_cluster(self): + def create_db_cluster(self) -> str: kwargs = self._get_db_cluster_kwargs() cluster = self.backend.create_db_cluster(kwargs) template = self.response_template(CREATE_DB_CLUSTER_TEMPLATE) return template.render(cluster=cluster) - def modify_db_cluster(self): + def modify_db_cluster(self) -> str: kwargs = self._get_modify_db_cluster_kwargs() cluster = self.backend.modify_db_cluster(kwargs) template = self.response_template(MODIFY_DB_CLUSTER_TEMPLATE) return template.render(cluster=cluster) - def describe_db_clusters(self): + def describe_db_clusters(self) -> str: _id = self._get_param("DBClusterIdentifier") filters = self._get_multi_param("Filters.Filter.") - filters = {f["Name"]: f["Values"] for f in filters} + filter_dict = {f["Name"]: f["Values"] for f in filters} clusters = self.backend.describe_db_clusters( - cluster_identifier=_id, filters=filters + cluster_identifier=_id, filters=filter_dict ) template = self.response_template(DESCRIBE_CLUSTERS_TEMPLATE) return template.render(clusters=clusters) - def delete_db_cluster(self): + def delete_db_cluster(self) -> str: _id = self._get_param("DBClusterIdentifier") snapshot_name = self._get_param("FinalDBSnapshotIdentifier") cluster = self.backend.delete_db_cluster( @@ -624,19 +625,19 @@ class RDSResponse(BaseResponse): template = self.response_template(DELETE_CLUSTER_TEMPLATE) return template.render(cluster=cluster) - def start_db_cluster(self): + def start_db_cluster(self) -> str: _id = self._get_param("DBClusterIdentifier") cluster = self.backend.start_db_cluster(cluster_identifier=_id) template = self.response_template(START_CLUSTER_TEMPLATE) return template.render(cluster=cluster) - def stop_db_cluster(self): + def stop_db_cluster(self) -> str: _id = self._get_param("DBClusterIdentifier") cluster = self.backend.stop_db_cluster(cluster_identifier=_id) template = self.response_template(STOP_CLUSTER_TEMPLATE) return template.render(cluster=cluster) - def create_db_cluster_snapshot(self): + def create_db_cluster_snapshot(self) -> str: db_cluster_identifier = self._get_param("DBClusterIdentifier") db_snapshot_identifier = self._get_param("DBClusterSnapshotIdentifier") tags = self.unpack_list_params("Tags", "Tag") @@ -646,7 +647,7 @@ class RDSResponse(BaseResponse): template = self.response_template(CREATE_CLUSTER_SNAPSHOT_TEMPLATE) return template.render(snapshot=snapshot) - def copy_db_cluster_snapshot(self): + def copy_db_cluster_snapshot(self) -> str: source_snapshot_identifier = self._get_param( "SourceDBClusterSnapshotIdentifier" ) @@ -660,24 +661,24 @@ class RDSResponse(BaseResponse): template = self.response_template(COPY_CLUSTER_SNAPSHOT_TEMPLATE) return template.render(snapshot=snapshot) - def describe_db_cluster_snapshots(self): + def describe_db_cluster_snapshots(self) -> str: db_cluster_identifier = self._get_param("DBClusterIdentifier") db_snapshot_identifier = self._get_param("DBClusterSnapshotIdentifier") filters = self._get_multi_param("Filters.Filter.") - filters = {f["Name"]: f["Values"] for f in filters} + filter_dict = {f["Name"]: f["Values"] for f in filters} snapshots = self.backend.describe_db_cluster_snapshots( - db_cluster_identifier, db_snapshot_identifier, filters + db_cluster_identifier, db_snapshot_identifier, filter_dict ) template = self.response_template(DESCRIBE_CLUSTER_SNAPSHOTS_TEMPLATE) return template.render(snapshots=snapshots) - def delete_db_cluster_snapshot(self): + def delete_db_cluster_snapshot(self) -> str: db_snapshot_identifier = self._get_param("DBClusterSnapshotIdentifier") snapshot = self.backend.delete_db_cluster_snapshot(db_snapshot_identifier) template = self.response_template(DELETE_CLUSTER_SNAPSHOT_TEMPLATE) return template.render(snapshot=snapshot) - def restore_db_cluster_from_snapshot(self): + def restore_db_cluster_from_snapshot(self) -> str: db_snapshot_identifier = self._get_param("SnapshotIdentifier") db_kwargs = self._get_db_cluster_kwargs() new_cluster = self.backend.restore_db_cluster_from_snapshot( @@ -686,43 +687,43 @@ class RDSResponse(BaseResponse): template = self.response_template(RESTORE_CLUSTER_FROM_SNAPSHOT_TEMPLATE) return template.render(cluster=new_cluster) - def start_export_task(self): + def start_export_task(self) -> str: kwargs = self._get_export_task_kwargs() export_task = self.backend.start_export_task(kwargs) template = self.response_template(START_EXPORT_TASK_TEMPLATE) return template.render(task=export_task) - def cancel_export_task(self): + def cancel_export_task(self) -> str: export_task_identifier = self._get_param("ExportTaskIdentifier") export_task = self.backend.cancel_export_task(export_task_identifier) template = self.response_template(CANCEL_EXPORT_TASK_TEMPLATE) return template.render(task=export_task) - def describe_export_tasks(self): + def describe_export_tasks(self) -> str: export_task_identifier = self._get_param("ExportTaskIdentifier") tasks = self.backend.describe_export_tasks(export_task_identifier) template = self.response_template(DESCRIBE_EXPORT_TASKS_TEMPLATE) return template.render(tasks=tasks) - def create_event_subscription(self): + def create_event_subscription(self) -> str: kwargs = self._get_event_subscription_kwargs() subscription = self.backend.create_event_subscription(kwargs) template = self.response_template(CREATE_EVENT_SUBSCRIPTION_TEMPLATE) return template.render(subscription=subscription) - def delete_event_subscription(self): + def delete_event_subscription(self) -> str: subscription_name = self._get_param("SubscriptionName") subscription = self.backend.delete_event_subscription(subscription_name) template = self.response_template(DELETE_EVENT_SUBSCRIPTION_TEMPLATE) return template.render(subscription=subscription) - def describe_event_subscriptions(self): + def describe_event_subscriptions(self) -> str: subscription_name = self._get_param("SubscriptionName") subscriptions = self.backend.describe_event_subscriptions(subscription_name) template = self.response_template(DESCRIBE_EVENT_SUBSCRIPTIONS_TEMPLATE) return template.render(subscriptions=subscriptions) - def describe_orderable_db_instance_options(self): + def describe_orderable_db_instance_options(self) -> str: engine = self._get_param("Engine") engine_version = self._get_param("EngineVersion") options = self.backend.describe_orderable_db_instance_options( @@ -731,12 +732,12 @@ class RDSResponse(BaseResponse): template = self.response_template(DESCRIBE_ORDERABLE_CLUSTER_OPTIONS) return template.render(options=options, marker=None) - def describe_global_clusters(self): + def describe_global_clusters(self) -> str: clusters = self.backend.describe_global_clusters() template = self.response_template(DESCRIBE_GLOBAL_CLUSTERS_TEMPLATE) return template.render(clusters=clusters) - def create_global_cluster(self): + def create_global_cluster(self) -> str: params = self._get_params() cluster = self.backend.create_global_cluster( global_cluster_identifier=params["GlobalClusterIdentifier"], @@ -749,7 +750,7 @@ class RDSResponse(BaseResponse): template = self.response_template(CREATE_GLOBAL_CLUSTER_TEMPLATE) return template.render(cluster=cluster) - def delete_global_cluster(self): + def delete_global_cluster(self) -> str: params = self._get_params() cluster = self.backend.delete_global_cluster( global_cluster_identifier=params["GlobalClusterIdentifier"], @@ -757,7 +758,7 @@ class RDSResponse(BaseResponse): template = self.response_template(DELETE_GLOBAL_CLUSTER_TEMPLATE) return template.render(cluster=cluster) - def remove_from_global_cluster(self): + def remove_from_global_cluster(self) -> str: params = self._get_params() global_cluster = self.backend.remove_from_global_cluster( global_cluster_identifier=params["GlobalClusterIdentifier"], @@ -766,7 +767,7 @@ class RDSResponse(BaseResponse): template = self.response_template(REMOVE_FROM_GLOBAL_CLUSTER_TEMPLATE) return template.render(cluster=global_cluster) - def create_db_cluster_parameter_group(self): + def create_db_cluster_parameter_group(self) -> str: group_name = self._get_param("DBClusterParameterGroupName") family = self._get_param("DBParameterGroupFamily") desc = self._get_param("Description") @@ -778,7 +779,7 @@ class RDSResponse(BaseResponse): template = self.response_template(CREATE_DB_CLUSTER_PARAMETER_GROUP_TEMPLATE) return template.render(db_cluster_parameter_group=db_cluster_parameter_group) - def describe_db_cluster_parameter_groups(self): + def describe_db_cluster_parameter_groups(self) -> str: group_name = self._get_param("DBClusterParameterGroupName") db_parameter_groups = self.backend.describe_db_cluster_parameter_groups( group_name=group_name, @@ -786,7 +787,7 @@ class RDSResponse(BaseResponse): template = self.response_template(DESCRIBE_DB_CLUSTER_PARAMETER_GROUPS_TEMPLATE) return template.render(db_parameter_groups=db_parameter_groups) - def delete_db_cluster_parameter_group(self): + def delete_db_cluster_parameter_group(self) -> str: group_name = self._get_param("DBClusterParameterGroupName") self.backend.delete_db_cluster_parameter_group( group_name=group_name, @@ -794,7 +795,7 @@ class RDSResponse(BaseResponse): template = self.response_template(DELETE_DB_CLUSTER_PARAMETER_GROUP_TEMPLATE) return template.render() - def promote_read_replica_db_cluster(self): + def promote_read_replica_db_cluster(self) -> str: db_cluster_identifier = self._get_param("DBClusterIdentifier") cluster = self.backend.promote_read_replica_db_cluster(db_cluster_identifier) template = self.response_template(PROMOTE_READ_REPLICA_DB_CLUSTER_TEMPLATE) diff --git a/moto/rds/utils.py b/moto/rds/utils.py index b62a6dd83..75038d2e5 100644 --- a/moto/rds/utils.py +++ b/moto/rds/utils.py @@ -1,6 +1,6 @@ import copy from collections import namedtuple -from typing import Any, Dict +from typing import Any, Dict, Tuple, Optional from botocore.utils import merge_dicts @@ -22,7 +22,7 @@ FilterDef = namedtuple( ) -def get_object_value(obj, attr): +def get_object_value(obj: Any, attr: str) -> Any: """Retrieves an arbitrary attribute value from an object. Nested attributes can be specified using dot notation, @@ -47,7 +47,9 @@ def get_object_value(obj, attr): return val -def merge_filters(filters_to_update, filters_to_merge): +def merge_filters( + filters_to_update: Optional[Dict[str, Any]], filters_to_merge: Dict[str, Any] +) -> Dict[str, Any]: """Given two groups of filters, merge the second into the first. List values are appended instead of overwritten: @@ -76,7 +78,9 @@ def merge_filters(filters_to_update, filters_to_merge): return filters_to_update -def validate_filters(filters, filter_defs): +def validate_filters( + filters: Dict[str, Any], filter_defs: Dict[str, FilterDef] +) -> None: """Validates filters against a set of filter definitions. Raises standard Python exceptions which should be caught @@ -108,7 +112,7 @@ def validate_filters(filters, filter_defs): ) -def apply_filter(resources, filters, filter_defs): +def apply_filter(resources: Any, filters: Any, filter_defs: Any) -> Any: """Apply an arbitrary filter to a group of resources. :param dict[str, object] resources: @@ -140,7 +144,9 @@ def apply_filter(resources, filters, filter_defs): return resources_filtered -def get_start_date_end_date(base_date, window): +def get_start_date_end_date( + base_date: str, window: str +) -> Tuple[datetime.datetime, datetime.datetime]: """Gets the start date and end date given DDD:HH24:MM-DDD:HH24:MM. :param base_date: @@ -162,11 +168,11 @@ def get_start_date_end_date(base_date, window): return start, end -def get_start_date_end_date_from_time(base_date, window): +def get_start_date_end_date_from_time( + base_date: str, window: str +) -> Tuple[datetime.datetime, datetime.datetime, bool]: """Gets the start date and end date given HH24:MM-HH24:MM. - :param base_date: - type datetime :param window: HH24:MM-HH24:MM :returns: @@ -187,31 +193,23 @@ def get_start_date_end_date_from_time(base_date, window): def get_overlap_between_two_date_ranges( - start_time_1, end_time_1, start_time_2, end_time_2 -): - """Determines overlap between 2 date ranges. - - :param start_time_1: - type datetime - :param start_time_2: - type datetime - :param end_time_1: - type datetime - :param end_time_2: - type datetime - :returns: - overlap in seconds - :rtype: - int + start_time_1: datetime.datetime, + end_time_1: datetime.datetime, + start_time_2: datetime.datetime, + end_time_2: datetime.datetime, +) -> int: + """ + Determines overlap between 2 date ranges. Returns the overlap in seconds. """ latest_start = max(start_time_1, start_time_2) earliest_end = min(end_time_1, end_time_2) delta = earliest_end - latest_start - overlap = (delta.days * SECONDS_IN_ONE_DAY) + delta.seconds - return overlap + return (delta.days * SECONDS_IN_ONE_DAY) + delta.seconds -def valid_preferred_maintenance_window(maintenance_window, backup_window): +def valid_preferred_maintenance_window( + maintenance_window: Any, backup_window: Any +) -> Optional[str]: """Determines validity of preferred_maintenance_window :param maintenance_windown: @@ -283,7 +281,7 @@ def valid_preferred_maintenance_window(maintenance_window, backup_window): delta = maintenance_window_end - maintenance_window_start delta_seconds = delta.seconds + (delta.days * SECONDS_IN_ONE_DAY) if delta_seconds >= MINUTES_30 and delta_seconds <= HOURS_24: - return + return None elif delta_seconds >= 0 and delta_seconds <= MINUTES_30: return "The maintenance window must be at least 30 minutes." else: diff --git a/setup.cfg b/setup.cfg index 45de20cf8..0cb467bd5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -235,7 +235,7 @@ disable = W,C,R,E enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import [mypy] -files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/ram,moto/rdsdata,moto/scheduler +files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/ram,moto/rds,moto/rdsdata,moto/scheduler show_column_numbers=True show_error_codes = True disable_error_code=abstract