Techdebt: MyPy Redshift (#6217)

This commit is contained in:
Bert Blommers 2023-04-17 09:57:06 +00:00 committed by GitHub
parent a2abeb1039
commit f2b6384f28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 302 additions and 244 deletions

View File

@ -1,9 +1,10 @@
import json import json
from typing import List, Optional
from moto.core.exceptions import JsonRESTError from moto.core.exceptions import JsonRESTError
class RedshiftClientError(JsonRESTError): class RedshiftClientError(JsonRESTError):
def __init__(self, code, message): def __init__(self, code: str, message: str):
super().__init__(error_type=code, message=message) super().__init__(error_type=code, message=message)
self.description = json.dumps( self.description = json.dumps(
{ {
@ -14,19 +15,19 @@ class RedshiftClientError(JsonRESTError):
class ClusterNotFoundError(RedshiftClientError): class ClusterNotFoundError(RedshiftClientError):
def __init__(self, cluster_identifier): def __init__(self, cluster_identifier: str):
super().__init__("ClusterNotFound", f"Cluster {cluster_identifier} not found.") super().__init__("ClusterNotFound", f"Cluster {cluster_identifier} not found.")
class ClusterSubnetGroupNotFoundError(RedshiftClientError): class ClusterSubnetGroupNotFoundError(RedshiftClientError):
def __init__(self, subnet_identifier): def __init__(self, subnet_identifier: str):
super().__init__( super().__init__(
"ClusterSubnetGroupNotFound", f"Subnet group {subnet_identifier} not found." "ClusterSubnetGroupNotFound", f"Subnet group {subnet_identifier} not found."
) )
class ClusterSecurityGroupNotFoundError(RedshiftClientError): class ClusterSecurityGroupNotFoundError(RedshiftClientError):
def __init__(self, group_identifier): def __init__(self, group_identifier: str):
super().__init__( super().__init__(
"ClusterSecurityGroupNotFound", "ClusterSecurityGroupNotFound",
f"Security group {group_identifier} not found.", f"Security group {group_identifier} not found.",
@ -34,7 +35,7 @@ class ClusterSecurityGroupNotFoundError(RedshiftClientError):
class ClusterParameterGroupNotFoundError(RedshiftClientError): class ClusterParameterGroupNotFoundError(RedshiftClientError):
def __init__(self, group_identifier): def __init__(self, group_identifier: str):
super().__init__( super().__init__(
"ClusterParameterGroupNotFound", "ClusterParameterGroupNotFound",
f"Parameter group {group_identifier} not found.", f"Parameter group {group_identifier} not found.",
@ -42,12 +43,12 @@ class ClusterParameterGroupNotFoundError(RedshiftClientError):
class InvalidSubnetError(RedshiftClientError): class InvalidSubnetError(RedshiftClientError):
def __init__(self, subnet_identifier): def __init__(self, subnet_identifier: List[str]):
super().__init__("InvalidSubnet", f"Subnet {subnet_identifier} not found.") super().__init__("InvalidSubnet", f"Subnet {subnet_identifier} not found.")
class SnapshotCopyGrantAlreadyExistsFaultError(RedshiftClientError): class SnapshotCopyGrantAlreadyExistsFaultError(RedshiftClientError):
def __init__(self, snapshot_copy_grant_name): def __init__(self, snapshot_copy_grant_name: str):
super().__init__( super().__init__(
"SnapshotCopyGrantAlreadyExistsFault", "SnapshotCopyGrantAlreadyExistsFault",
"Cannot create the snapshot copy grant because a grant " "Cannot create the snapshot copy grant because a grant "
@ -56,7 +57,7 @@ class SnapshotCopyGrantAlreadyExistsFaultError(RedshiftClientError):
class SnapshotCopyGrantNotFoundFaultError(RedshiftClientError): class SnapshotCopyGrantNotFoundFaultError(RedshiftClientError):
def __init__(self, snapshot_copy_grant_name): def __init__(self, snapshot_copy_grant_name: str):
super().__init__( super().__init__(
"SnapshotCopyGrantNotFoundFault", "SnapshotCopyGrantNotFoundFault",
f"Snapshot copy grant not found: {snapshot_copy_grant_name}", f"Snapshot copy grant not found: {snapshot_copy_grant_name}",
@ -64,14 +65,14 @@ class SnapshotCopyGrantNotFoundFaultError(RedshiftClientError):
class ClusterSnapshotNotFoundError(RedshiftClientError): class ClusterSnapshotNotFoundError(RedshiftClientError):
def __init__(self, snapshot_identifier): def __init__(self, snapshot_identifier: str):
super().__init__( super().__init__(
"ClusterSnapshotNotFound", f"Snapshot {snapshot_identifier} not found." "ClusterSnapshotNotFound", f"Snapshot {snapshot_identifier} not found."
) )
class ClusterSnapshotAlreadyExistsError(RedshiftClientError): class ClusterSnapshotAlreadyExistsError(RedshiftClientError):
def __init__(self, snapshot_identifier): def __init__(self, snapshot_identifier: str):
super().__init__( super().__init__(
"ClusterSnapshotAlreadyExists", "ClusterSnapshotAlreadyExists",
"Cannot create the snapshot because a snapshot with the " "Cannot create the snapshot because a snapshot with the "
@ -80,7 +81,7 @@ class ClusterSnapshotAlreadyExistsError(RedshiftClientError):
class InvalidParameterValueError(RedshiftClientError): class InvalidParameterValueError(RedshiftClientError):
def __init__(self, message): def __init__(self, message: str):
super().__init__("InvalidParameterValue", message) super().__init__("InvalidParameterValue", message)
@ -88,7 +89,12 @@ class ResourceNotFoundFaultError(RedshiftClientError):
code = 404 code = 404
def __init__(self, resource_type=None, resource_name=None, message=None): def __init__(
self,
resource_type: Optional[str] = None,
resource_name: Optional[str] = None,
message: Optional[str] = None,
):
if resource_type and not resource_name: if resource_type and not resource_name:
msg = f"resource of type '{resource_type}' not found." msg = f"resource of type '{resource_type}' not found."
else: else:
@ -99,7 +105,7 @@ class ResourceNotFoundFaultError(RedshiftClientError):
class SnapshotCopyDisabledFaultError(RedshiftClientError): class SnapshotCopyDisabledFaultError(RedshiftClientError):
def __init__(self, cluster_identifier): def __init__(self, cluster_identifier: str):
super().__init__( super().__init__(
"SnapshotCopyDisabledFault", "SnapshotCopyDisabledFault",
f"Cannot modify retention period because snapshot copy is disabled on Cluster {cluster_identifier}.", f"Cannot modify retention period because snapshot copy is disabled on Cluster {cluster_identifier}.",
@ -107,7 +113,7 @@ class SnapshotCopyDisabledFaultError(RedshiftClientError):
class SnapshotCopyAlreadyDisabledFaultError(RedshiftClientError): class SnapshotCopyAlreadyDisabledFaultError(RedshiftClientError):
def __init__(self, cluster_identifier): def __init__(self, cluster_identifier: str):
super().__init__( super().__init__(
"SnapshotCopyAlreadyDisabledFault", "SnapshotCopyAlreadyDisabledFault",
f"Snapshot Copy is already disabled on Cluster {cluster_identifier}.", f"Snapshot Copy is already disabled on Cluster {cluster_identifier}.",
@ -115,7 +121,7 @@ class SnapshotCopyAlreadyDisabledFaultError(RedshiftClientError):
class SnapshotCopyAlreadyEnabledFaultError(RedshiftClientError): class SnapshotCopyAlreadyEnabledFaultError(RedshiftClientError):
def __init__(self, cluster_identifier): def __init__(self, cluster_identifier: str):
super().__init__( super().__init__(
"SnapshotCopyAlreadyEnabledFault", "SnapshotCopyAlreadyEnabledFault",
f"Snapshot Copy is already enabled on Cluster {cluster_identifier}.", f"Snapshot Copy is already enabled on Cluster {cluster_identifier}.",
@ -123,22 +129,22 @@ class SnapshotCopyAlreadyEnabledFaultError(RedshiftClientError):
class ClusterAlreadyExistsFaultError(RedshiftClientError): class ClusterAlreadyExistsFaultError(RedshiftClientError):
def __init__(self): def __init__(self) -> None:
super().__init__("ClusterAlreadyExists", "Cluster already exists") super().__init__("ClusterAlreadyExists", "Cluster already exists")
class InvalidParameterCombinationError(RedshiftClientError): class InvalidParameterCombinationError(RedshiftClientError):
def __init__(self, message): def __init__(self, message: str):
super().__init__("InvalidParameterCombination", message) super().__init__("InvalidParameterCombination", message)
class UnknownSnapshotCopyRegionFaultError(RedshiftClientError): class UnknownSnapshotCopyRegionFaultError(RedshiftClientError):
def __init__(self, message): def __init__(self, message: str):
super().__init__("UnknownSnapshotCopyRegionFault", message) super().__init__("UnknownSnapshotCopyRegionFault", message)
class ClusterSecurityGroupNotFoundFaultError(RedshiftClientError): class ClusterSecurityGroupNotFoundFaultError(RedshiftClientError):
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
"ClusterSecurityGroupNotFoundFault", "ClusterSecurityGroupNotFoundFault",
"The cluster security group name does not refer to an existing cluster security group.", "The cluster security group name does not refer to an existing cluster security group.",
@ -146,7 +152,7 @@ class ClusterSecurityGroupNotFoundFaultError(RedshiftClientError):
class InvalidClusterSnapshotStateFaultError(RedshiftClientError): class InvalidClusterSnapshotStateFaultError(RedshiftClientError):
def __init__(self, snapshot_identifier): def __init__(self, snapshot_identifier: str):
super().__init__( super().__init__(
"InvalidClusterSnapshotStateFault", "InvalidClusterSnapshotStateFault",
f"Cannot delete the snapshot {snapshot_identifier} because only manual snapshots may be deleted", f"Cannot delete the snapshot {snapshot_identifier} because only manual snapshots may be deleted",

View File

@ -2,6 +2,7 @@ import copy
import datetime import datetime
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict, Iterable, List, Optional
from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel
from moto.core.utils import iso_8601_datetime_with_milliseconds from moto.core.utils import iso_8601_datetime_with_milliseconds
from moto.ec2 import ec2_backends from moto.ec2 import ec2_backends
@ -29,30 +30,32 @@ from .exceptions import (
) )
class TaggableResourceMixin(object): class TaggableResourceMixin:
resource_type = None resource_type = ""
def __init__(self, account_id, region_name, tags): def __init__(
self, account_id: str, region_name: str, tags: Optional[List[Dict[str, Any]]]
):
self.account_id = account_id self.account_id = account_id
self.region = region_name self.region = region_name
self.tags = tags or [] self.tags = tags or []
@property @property
def resource_id(self): def resource_id(self) -> str:
return None return ""
@property @property
def arn(self): def arn(self) -> str:
return f"arn:aws:redshift:{self.region}:{self.account_id}:{self.resource_type}:{self.resource_id}" return f"arn:aws:redshift:{self.region}:{self.account_id}:{self.resource_type}:{self.resource_id}"
def create_tags(self, tags): def create_tags(self, tags: List[Dict[str, str]]) -> List[Dict[str, str]]:
new_keys = [tag_set["Key"] for tag_set in tags] new_keys = [tag_set["Key"] for tag_set in tags]
self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys] self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys]
self.tags.extend(tags) self.tags.extend(tags)
return self.tags return self.tags
def delete_tags(self, tag_keys): def delete_tags(self, tag_keys: List[str]) -> List[Dict[str, str]]:
self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys]
return self.tags return self.tags
@ -63,32 +66,32 @@ class Cluster(TaggableResourceMixin, CloudFormationModel):
def __init__( def __init__(
self, self,
redshift_backend, redshift_backend: "RedshiftBackend",
cluster_identifier, cluster_identifier: str,
node_type, node_type: str,
master_username, master_username: str,
master_user_password, master_user_password: str,
db_name, db_name: str,
cluster_type, cluster_type: str,
cluster_security_groups, cluster_security_groups: List[str],
vpc_security_group_ids, vpc_security_group_ids: List[str],
cluster_subnet_group_name, cluster_subnet_group_name: str,
availability_zone, availability_zone: str,
preferred_maintenance_window, preferred_maintenance_window: str,
cluster_parameter_group_name, cluster_parameter_group_name: str,
automated_snapshot_retention_period, automated_snapshot_retention_period: str,
port, port: str,
cluster_version, cluster_version: str,
allow_version_upgrade, allow_version_upgrade: str,
number_of_nodes, number_of_nodes: str,
publicly_accessible, publicly_accessible: str,
encrypted, encrypted: str,
region_name, region_name: str,
tags=None, tags: Optional[List[Dict[str, str]]] = None,
iam_roles_arn=None, iam_roles_arn: Optional[List[str]] = None,
enhanced_vpc_routing=None, enhanced_vpc_routing: Optional[str] = None,
restored_from_snapshot=False, restored_from_snapshot: bool = False,
kms_key_id=None, kms_key_id: Optional[str] = None,
): ):
super().__init__(redshift_backend.account_id, region_name, tags) super().__init__(redshift_backend.account_id, region_name, tags)
self.redshift_backend = redshift_backend self.redshift_backend = redshift_backend
@ -152,20 +155,26 @@ class Cluster(TaggableResourceMixin, CloudFormationModel):
self.iam_roles_arn = iam_roles_arn or [] self.iam_roles_arn = iam_roles_arn or []
self.restored_from_snapshot = restored_from_snapshot self.restored_from_snapshot = restored_from_snapshot
self.kms_key_id = kms_key_id self.kms_key_id = kms_key_id
self.cluster_snapshot_copy_status: Optional[Dict[str, Any]] = None
@staticmethod @staticmethod
def cloudformation_name_type(): def cloudformation_name_type() -> str:
return None return ""
@staticmethod @staticmethod
def cloudformation_type(): def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-redshift-cluster.html # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-redshift-cluster.html
return "AWS::Redshift::Cluster" return "AWS::Redshift::Cluster"
@classmethod @classmethod
def create_from_cloudformation_json( def create_from_cloudformation_json( # type: ignore[misc]
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs cls,
): resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
**kwargs: Any,
) -> "Cluster":
redshift_backend = redshift_backends[account_id][region_name] redshift_backend = redshift_backends[account_id][region_name]
properties = cloudformation_json["Properties"] properties = cloudformation_json["Properties"]
@ -205,10 +214,10 @@ class Cluster(TaggableResourceMixin, CloudFormationModel):
return cluster return cluster
@classmethod @classmethod
def has_cfn_attr(cls, attr): def has_cfn_attr(cls, attr: str) -> bool:
return attr in ["Endpoint.Address", "Endpoint.Port"] return attr in ["Endpoint.Address", "Endpoint.Port"]
def get_cfn_attribute(self, attribute_name): def get_cfn_attribute(self, attribute_name: str) -> Any:
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == "Endpoint.Address": if attribute_name == "Endpoint.Address":
@ -218,11 +227,11 @@ class Cluster(TaggableResourceMixin, CloudFormationModel):
raise UnformattedGetAttTemplateException() raise UnformattedGetAttTemplateException()
@property @property
def endpoint(self): def endpoint(self) -> str:
return f"{self.cluster_identifier}.cg034hpkmmjt.{self.region}.redshift.amazonaws.com" return f"{self.cluster_identifier}.cg034hpkmmjt.{self.region}.redshift.amazonaws.com"
@property @property
def security_groups(self): def security_groups(self) -> List["SecurityGroup"]:
return [ return [
security_group security_group
for security_group in self.redshift_backend.describe_cluster_security_groups() for security_group in self.redshift_backend.describe_cluster_security_groups()
@ -231,7 +240,7 @@ class Cluster(TaggableResourceMixin, CloudFormationModel):
] ]
@property @property
def vpc_security_groups(self): def vpc_security_groups(self) -> List["SecurityGroup"]:
return [ return [
security_group security_group
for security_group in self.redshift_backend.ec2_backend.describe_security_groups() for security_group in self.redshift_backend.ec2_backend.describe_security_groups()
@ -239,7 +248,7 @@ class Cluster(TaggableResourceMixin, CloudFormationModel):
] ]
@property @property
def parameter_groups(self): def parameter_groups(self) -> List["ParameterGroup"]:
return [ return [
parameter_group parameter_group
for parameter_group in self.redshift_backend.describe_cluster_parameter_groups() for parameter_group in self.redshift_backend.describe_cluster_parameter_groups()
@ -248,22 +257,22 @@ class Cluster(TaggableResourceMixin, CloudFormationModel):
] ]
@property @property
def resource_id(self): def resource_id(self) -> str:
return self.cluster_identifier return self.cluster_identifier
def pause(self): def pause(self) -> None:
self.status = "paused" self.status = "paused"
def resume(self): def resume(self) -> None:
self.status = "available" self.status = "available"
def to_json(self): def to_json(self) -> Dict[str, Any]:
json_response = { json_response = {
"MasterUsername": self.master_username, "MasterUsername": self.master_username,
"MasterUserPassword": "****", "MasterUserPassword": "****",
"ClusterVersion": self.cluster_version, "ClusterVersion": self.cluster_version,
"VpcSecurityGroups": [ "VpcSecurityGroups": [
{"Status": "active", "VpcSecurityGroupId": group.id} {"Status": "active", "VpcSecurityGroupId": group.id} # type: ignore
for group in self.vpc_security_groups for group in self.vpc_security_groups
], ],
"ClusterSubnetGroupName": self.cluster_subnet_group_name, "ClusterSubnetGroupName": self.cluster_subnet_group_name,
@ -313,12 +322,10 @@ class Cluster(TaggableResourceMixin, CloudFormationModel):
"ElapsedTimeInSeconds": 123, "ElapsedTimeInSeconds": 123,
"EstimatedTimeToCompletionInSeconds": 123, "EstimatedTimeToCompletionInSeconds": 123,
} }
try: if self.cluster_snapshot_copy_status is not None:
json_response[ json_response[
"ClusterSnapshotCopyStatus" "ClusterSnapshotCopyStatus"
] = self.cluster_snapshot_copy_status ] = self.cluster_snapshot_copy_status
except AttributeError:
pass
return json_response return json_response
@ -326,11 +333,11 @@ class SnapshotCopyGrant(TaggableResourceMixin, BaseModel):
resource_type = "snapshotcopygrant" resource_type = "snapshotcopygrant"
def __init__(self, snapshot_copy_grant_name, kms_key_id): def __init__(self, snapshot_copy_grant_name: str, kms_key_id: str):
self.snapshot_copy_grant_name = snapshot_copy_grant_name self.snapshot_copy_grant_name = snapshot_copy_grant_name
self.kms_key_id = kms_key_id self.kms_key_id = kms_key_id
def to_json(self): def to_json(self) -> Dict[str, Any]:
return { return {
"SnapshotCopyGrantName": self.snapshot_copy_grant_name, "SnapshotCopyGrantName": self.snapshot_copy_grant_name,
"KmsKeyId": self.kms_key_id, "KmsKeyId": self.kms_key_id,
@ -343,12 +350,12 @@ class SubnetGroup(TaggableResourceMixin, CloudFormationModel):
def __init__( def __init__(
self, self,
ec2_backend, ec2_backend: Any,
cluster_subnet_group_name, cluster_subnet_group_name: str,
description, description: str,
subnet_ids, subnet_ids: List[str],
region_name, region_name: str,
tags=None, tags: Optional[List[Dict[str, str]]] = None,
): ):
super().__init__(ec2_backend.account_id, region_name, tags) super().__init__(ec2_backend.account_id, region_name, tags)
self.ec2_backend = ec2_backend self.ec2_backend = ec2_backend
@ -359,18 +366,23 @@ class SubnetGroup(TaggableResourceMixin, CloudFormationModel):
raise InvalidSubnetError(subnet_ids) raise InvalidSubnetError(subnet_ids)
@staticmethod @staticmethod
def cloudformation_name_type(): def cloudformation_name_type() -> str:
return None return ""
@staticmethod @staticmethod
def cloudformation_type(): def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-redshift-clustersubnetgroup.html # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-redshift-clustersubnetgroup.html
return "AWS::Redshift::ClusterSubnetGroup" return "AWS::Redshift::ClusterSubnetGroup"
@classmethod @classmethod
def create_from_cloudformation_json( def create_from_cloudformation_json( # type: ignore[misc]
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs cls,
): resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
**kwargs: Any,
) -> "SubnetGroup":
redshift_backend = redshift_backends[account_id][region_name] redshift_backend = redshift_backends[account_id][region_name]
properties = cloudformation_json["Properties"] properties = cloudformation_json["Properties"]
@ -383,18 +395,18 @@ class SubnetGroup(TaggableResourceMixin, CloudFormationModel):
return subnet_group return subnet_group
@property @property
def subnets(self): def subnets(self) -> Any: # type: ignore[misc]
return self.ec2_backend.describe_subnets(filters={"subnet-id": self.subnet_ids}) return self.ec2_backend.describe_subnets(filters={"subnet-id": self.subnet_ids})
@property @property
def vpc_id(self): def vpc_id(self) -> str:
return self.subnets[0].vpc_id return self.subnets[0].vpc_id
@property @property
def resource_id(self): def resource_id(self) -> str:
return self.cluster_subnet_group_name return self.cluster_subnet_group_name
def to_json(self): def to_json(self) -> Dict[str, Any]:
return { return {
"VpcId": self.vpc_id, "VpcId": self.vpc_id,
"Description": self.description, "Description": self.description,
@ -418,22 +430,22 @@ class SecurityGroup(TaggableResourceMixin, BaseModel):
def __init__( def __init__(
self, self,
cluster_security_group_name, cluster_security_group_name: str,
description, description: str,
account_id, account_id: str,
region_name, region_name: str,
tags=None, tags: Optional[List[Dict[str, str]]] = None,
): ):
super().__init__(account_id, region_name, tags) super().__init__(account_id, region_name, tags)
self.cluster_security_group_name = cluster_security_group_name self.cluster_security_group_name = cluster_security_group_name
self.description = description self.description = description
self.ingress_rules = [] self.ingress_rules: List[str] = []
@property @property
def resource_id(self): def resource_id(self) -> str:
return self.cluster_security_group_name return self.cluster_security_group_name
def to_json(self): def to_json(self) -> Dict[str, Any]:
return { return {
"EC2SecurityGroups": [], "EC2SecurityGroups": [],
"IPRanges": [], "IPRanges": [],
@ -449,12 +461,12 @@ class ParameterGroup(TaggableResourceMixin, CloudFormationModel):
def __init__( def __init__(
self, self,
cluster_parameter_group_name, cluster_parameter_group_name: str,
group_family, group_family: str,
description, description: str,
account_id, account_id: str,
region_name, region_name: str,
tags=None, tags: Optional[List[Dict[str, str]]] = None,
): ):
super().__init__(account_id, region_name, tags) super().__init__(account_id, region_name, tags)
self.cluster_parameter_group_name = cluster_parameter_group_name self.cluster_parameter_group_name = cluster_parameter_group_name
@ -462,18 +474,23 @@ class ParameterGroup(TaggableResourceMixin, CloudFormationModel):
self.description = description self.description = description
@staticmethod @staticmethod
def cloudformation_name_type(): def cloudformation_name_type() -> str:
return None return ""
@staticmethod @staticmethod
def cloudformation_type(): def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-redshift-clusterparametergroup.html # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-redshift-clusterparametergroup.html
return "AWS::Redshift::ClusterParameterGroup" return "AWS::Redshift::ClusterParameterGroup"
@classmethod @classmethod
def create_from_cloudformation_json( def create_from_cloudformation_json( # type: ignore[misc]
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs cls,
): resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
**kwargs: Any,
) -> "ParameterGroup":
redshift_backend = redshift_backends[account_id][region_name] redshift_backend = redshift_backends[account_id][region_name]
properties = cloudformation_json["Properties"] properties = cloudformation_json["Properties"]
@ -486,10 +503,10 @@ class ParameterGroup(TaggableResourceMixin, CloudFormationModel):
return parameter_group return parameter_group
@property @property
def resource_id(self): def resource_id(self) -> str:
return self.cluster_parameter_group_name return self.cluster_parameter_group_name
def to_json(self): def to_json(self) -> Dict[str, Any]:
return { return {
"ParameterGroupFamily": self.group_family, "ParameterGroupFamily": self.group_family,
"Description": self.description, "Description": self.description,
@ -504,13 +521,13 @@ class Snapshot(TaggableResourceMixin, BaseModel):
def __init__( def __init__(
self, self,
cluster, cluster: Any,
snapshot_identifier, snapshot_identifier: str,
account_id, account_id: str,
region_name, region_name: str,
tags=None, tags: Optional[List[Dict[str, str]]] = None,
iam_roles_arn=None, iam_roles_arn: Optional[List[str]] = None,
snapshot_type="manual", snapshot_type: str = "manual",
): ):
super().__init__(account_id, region_name, tags) super().__init__(account_id, region_name, tags)
self.cluster = copy.copy(cluster) self.cluster = copy.copy(cluster)
@ -521,10 +538,10 @@ class Snapshot(TaggableResourceMixin, BaseModel):
self.iam_roles_arn = iam_roles_arn or [] self.iam_roles_arn = iam_roles_arn or []
@property @property
def resource_id(self): def resource_id(self) -> str:
return f"{self.cluster.cluster_identifier}/{self.snapshot_identifier}" return f"{self.cluster.cluster_identifier}/{self.snapshot_identifier}"
def to_json(self): def to_json(self) -> Dict[str, Any]:
return { return {
"SnapshotIdentifier": self.snapshot_identifier, "SnapshotIdentifier": self.snapshot_identifier,
"ClusterIdentifier": self.cluster.cluster_identifier, "ClusterIdentifier": self.cluster.cluster_identifier,
@ -548,16 +565,16 @@ class Snapshot(TaggableResourceMixin, BaseModel):
class RedshiftBackend(BaseBackend): class RedshiftBackend(BaseBackend):
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.clusters = {} self.clusters: Dict[str, Cluster] = {}
self.subnet_groups = {} self.subnet_groups: Dict[str, SubnetGroup] = {}
self.security_groups = { self.security_groups: Dict[str, SecurityGroup] = {
"Default": SecurityGroup( "Default": SecurityGroup(
"Default", "Default Redshift Security Group", account_id, region_name "Default", "Default Redshift Security Group", account_id, region_name
) )
} }
self.parameter_groups = { self.parameter_groups: Dict[str, ParameterGroup] = {
"default.redshift-1.0": ParameterGroup( "default.redshift-1.0": ParameterGroup(
"default.redshift-1.0", "default.redshift-1.0",
"redshift-1.0", "redshift-1.0",
@ -567,18 +584,20 @@ class RedshiftBackend(BaseBackend):
) )
} }
self.ec2_backend = ec2_backends[self.account_id][self.region_name] self.ec2_backend = ec2_backends[self.account_id][self.region_name]
self.snapshots = OrderedDict() self.snapshots: Dict[str, Snapshot] = OrderedDict()
self.RESOURCE_TYPE_MAP = { self.RESOURCE_TYPE_MAP: Dict[str, Dict[str, TaggableResourceMixin]] = {
"cluster": self.clusters, "cluster": self.clusters, # type: ignore
"parametergroup": self.parameter_groups, "parametergroup": self.parameter_groups, # type: ignore
"securitygroup": self.security_groups, "securitygroup": self.security_groups, # type: ignore
"snapshot": self.snapshots, "snapshot": self.snapshots, # type: ignore
"subnetgroup": self.subnet_groups, "subnetgroup": self.subnet_groups, # type: ignore
} }
self.snapshot_copy_grants = {} self.snapshot_copy_grants: Dict[str, SnapshotCopyGrant] = {}
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(
service_region: str, zones: List[str]
) -> List[Dict[str, str]]:
"""Default VPC endpoint service.""" """Default VPC endpoint service."""
return BaseBackend.default_vpc_endpoint_service_factory( return BaseBackend.default_vpc_endpoint_service_factory(
service_region, zones, "redshift" service_region, zones, "redshift"
@ -586,10 +605,10 @@ class RedshiftBackend(BaseBackend):
service_region, zones, "redshift-data", policy_supported=False service_region, zones, "redshift-data", policy_supported=False
) )
def enable_snapshot_copy(self, **kwargs): def enable_snapshot_copy(self, **kwargs: Any) -> Cluster:
cluster_identifier = kwargs["cluster_identifier"] cluster_identifier = kwargs["cluster_identifier"]
cluster = self.clusters[cluster_identifier] cluster = self.clusters[cluster_identifier]
if not hasattr(cluster, "cluster_snapshot_copy_status"): if cluster.cluster_snapshot_copy_status is None:
if ( if (
cluster.encrypted == "true" cluster.encrypted == "true"
and kwargs["snapshot_copy_grant_name"] is None and kwargs["snapshot_copy_grant_name"] is None
@ -611,26 +630,26 @@ class RedshiftBackend(BaseBackend):
else: else:
raise SnapshotCopyAlreadyEnabledFaultError(cluster_identifier) raise SnapshotCopyAlreadyEnabledFaultError(cluster_identifier)
def disable_snapshot_copy(self, **kwargs): def disable_snapshot_copy(self, **kwargs: Any) -> Cluster:
cluster_identifier = kwargs["cluster_identifier"] cluster_identifier = kwargs["cluster_identifier"]
cluster = self.clusters[cluster_identifier] cluster = self.clusters[cluster_identifier]
if hasattr(cluster, "cluster_snapshot_copy_status"): if cluster.cluster_snapshot_copy_status is not None:
del cluster.cluster_snapshot_copy_status cluster.cluster_snapshot_copy_status = None
return cluster return cluster
else: else:
raise SnapshotCopyAlreadyDisabledFaultError(cluster_identifier) raise SnapshotCopyAlreadyDisabledFaultError(cluster_identifier)
def modify_snapshot_copy_retention_period( def modify_snapshot_copy_retention_period(
self, cluster_identifier, retention_period self, cluster_identifier: str, retention_period: str
): ) -> Cluster:
cluster = self.clusters[cluster_identifier] cluster = self.clusters[cluster_identifier]
if hasattr(cluster, "cluster_snapshot_copy_status"): if cluster.cluster_snapshot_copy_status is not None:
cluster.cluster_snapshot_copy_status["RetentionPeriod"] = retention_period cluster.cluster_snapshot_copy_status["RetentionPeriod"] = retention_period
return cluster return cluster
else: else:
raise SnapshotCopyDisabledFaultError(cluster_identifier) raise SnapshotCopyDisabledFaultError(cluster_identifier)
def create_cluster(self, **cluster_kwargs): def create_cluster(self, **cluster_kwargs: Any) -> Cluster:
cluster_identifier = cluster_kwargs["cluster_identifier"] cluster_identifier = cluster_kwargs["cluster_identifier"]
if cluster_identifier in self.clusters: if cluster_identifier in self.clusters:
raise ClusterAlreadyExistsFaultError() raise ClusterAlreadyExistsFaultError()
@ -647,28 +666,29 @@ class RedshiftBackend(BaseBackend):
) )
return cluster return cluster
def pause_cluster(self, cluster_id): def pause_cluster(self, cluster_id: str) -> Cluster:
if cluster_id not in self.clusters: if cluster_id not in self.clusters:
raise ClusterNotFoundError(cluster_identifier=cluster_id) raise ClusterNotFoundError(cluster_identifier=cluster_id)
self.clusters[cluster_id].pause() self.clusters[cluster_id].pause()
return self.clusters[cluster_id] return self.clusters[cluster_id]
def resume_cluster(self, cluster_id): def resume_cluster(self, cluster_id: str) -> Cluster:
if cluster_id not in self.clusters: if cluster_id not in self.clusters:
raise ClusterNotFoundError(cluster_identifier=cluster_id) raise ClusterNotFoundError(cluster_identifier=cluster_id)
self.clusters[cluster_id].resume() self.clusters[cluster_id].resume()
return self.clusters[cluster_id] return self.clusters[cluster_id]
def describe_clusters(self, cluster_identifier=None): def describe_clusters(
clusters = self.clusters.values() self, cluster_identifier: Optional[str] = None
) -> List[Cluster]:
if cluster_identifier: if cluster_identifier:
if cluster_identifier in self.clusters: if cluster_identifier in self.clusters:
return [self.clusters[cluster_identifier]] return [self.clusters[cluster_identifier]]
else: else:
raise ClusterNotFoundError(cluster_identifier) raise ClusterNotFoundError(cluster_identifier)
return clusters return list(self.clusters.values())
def modify_cluster(self, **cluster_kwargs): def modify_cluster(self, **cluster_kwargs: Any) -> Cluster:
cluster_identifier = cluster_kwargs.pop("cluster_identifier") cluster_identifier = cluster_kwargs.pop("cluster_identifier")
new_cluster_identifier = cluster_kwargs.pop("new_cluster_identifier", None) new_cluster_identifier = cluster_kwargs.pop("new_cluster_identifier", None)
@ -703,7 +723,7 @@ class RedshiftBackend(BaseBackend):
return cluster return cluster
def delete_automated_snapshots(self, cluster_identifier): def delete_automated_snapshots(self, cluster_identifier: str) -> None:
snapshots = self.describe_cluster_snapshots( snapshots = self.describe_cluster_snapshots(
cluster_identifier=cluster_identifier cluster_identifier=cluster_identifier
) )
@ -711,7 +731,7 @@ class RedshiftBackend(BaseBackend):
if snapshot.snapshot_type == "automated": if snapshot.snapshot_type == "automated":
self.snapshots.pop(snapshot.snapshot_identifier) self.snapshots.pop(snapshot.snapshot_identifier)
def delete_cluster(self, **cluster_kwargs): def delete_cluster(self, **cluster_kwargs: Any) -> Cluster:
cluster_identifier = cluster_kwargs.pop("cluster_identifier") cluster_identifier = cluster_kwargs.pop("cluster_identifier")
cluster_skip_final_snapshot = cluster_kwargs.pop("skip_final_snapshot") cluster_skip_final_snapshot = cluster_kwargs.pop("skip_final_snapshot")
cluster_snapshot_identifer = cluster_kwargs.pop( cluster_snapshot_identifer = cluster_kwargs.pop(
@ -742,8 +762,13 @@ class RedshiftBackend(BaseBackend):
raise ClusterNotFoundError(cluster_identifier) raise ClusterNotFoundError(cluster_identifier)
def create_cluster_subnet_group( def create_cluster_subnet_group(
self, cluster_subnet_group_name, description, subnet_ids, region_name, tags=None self,
): cluster_subnet_group_name: str,
description: str,
subnet_ids: List[str],
region_name: str,
tags: Optional[List[Dict[str, str]]] = None,
) -> SubnetGroup:
subnet_group = SubnetGroup( subnet_group = SubnetGroup(
self.ec2_backend, self.ec2_backend,
cluster_subnet_group_name, cluster_subnet_group_name,
@ -755,23 +780,27 @@ class RedshiftBackend(BaseBackend):
self.subnet_groups[cluster_subnet_group_name] = subnet_group self.subnet_groups[cluster_subnet_group_name] = subnet_group
return subnet_group return subnet_group
def describe_cluster_subnet_groups(self, subnet_identifier=None): def describe_cluster_subnet_groups(
subnet_groups = self.subnet_groups.values() self, subnet_identifier: Optional[str] = None
) -> List[SubnetGroup]:
if subnet_identifier: if subnet_identifier:
if subnet_identifier in self.subnet_groups: if subnet_identifier in self.subnet_groups:
return [self.subnet_groups[subnet_identifier]] return [self.subnet_groups[subnet_identifier]]
else: else:
raise ClusterSubnetGroupNotFoundError(subnet_identifier) raise ClusterSubnetGroupNotFoundError(subnet_identifier)
return subnet_groups return list(self.subnet_groups.values())
def delete_cluster_subnet_group(self, subnet_identifier): def delete_cluster_subnet_group(self, subnet_identifier: str) -> SubnetGroup:
if subnet_identifier in self.subnet_groups: if subnet_identifier in self.subnet_groups:
return self.subnet_groups.pop(subnet_identifier) return self.subnet_groups.pop(subnet_identifier)
raise ClusterSubnetGroupNotFoundError(subnet_identifier) raise ClusterSubnetGroupNotFoundError(subnet_identifier)
def create_cluster_security_group( def create_cluster_security_group(
self, cluster_security_group_name, description, tags=None self,
): cluster_security_group_name: str,
description: str,
tags: Optional[List[Dict[str, str]]] = None,
) -> SecurityGroup:
security_group = SecurityGroup( security_group = SecurityGroup(
cluster_security_group_name, cluster_security_group_name,
description, description,
@ -782,21 +811,26 @@ class RedshiftBackend(BaseBackend):
self.security_groups[cluster_security_group_name] = security_group self.security_groups[cluster_security_group_name] = security_group
return security_group return security_group
def describe_cluster_security_groups(self, security_group_name=None): def describe_cluster_security_groups(
security_groups = self.security_groups.values() self, security_group_name: Optional[str] = None
) -> List[SecurityGroup]:
if security_group_name: if security_group_name:
if security_group_name in self.security_groups: if security_group_name in self.security_groups:
return [self.security_groups[security_group_name]] return [self.security_groups[security_group_name]]
else: else:
raise ClusterSecurityGroupNotFoundError(security_group_name) raise ClusterSecurityGroupNotFoundError(security_group_name)
return security_groups return list(self.security_groups.values())
def delete_cluster_security_group(self, security_group_identifier): def delete_cluster_security_group(
self, security_group_identifier: str
) -> SecurityGroup:
if security_group_identifier in self.security_groups: if security_group_identifier in self.security_groups:
return self.security_groups.pop(security_group_identifier) return self.security_groups.pop(security_group_identifier)
raise ClusterSecurityGroupNotFoundError(security_group_identifier) raise ClusterSecurityGroupNotFoundError(security_group_identifier)
def authorize_cluster_security_group_ingress(self, security_group_name, cidr_ip): def authorize_cluster_security_group_ingress(
self, security_group_name: str, cidr_ip: str
) -> SecurityGroup:
security_group = self.security_groups.get(security_group_name) security_group = self.security_groups.get(security_group_name)
if not security_group: if not security_group:
raise ClusterSecurityGroupNotFoundFaultError() raise ClusterSecurityGroupNotFoundFaultError()
@ -808,12 +842,12 @@ class RedshiftBackend(BaseBackend):
def create_cluster_parameter_group( def create_cluster_parameter_group(
self, self,
cluster_parameter_group_name, cluster_parameter_group_name: str,
group_family, group_family: str,
description, description: str,
region_name, region_name: str,
tags=None, tags: Optional[List[Dict[str, str]]] = None,
): ) -> ParameterGroup:
parameter_group = ParameterGroup( parameter_group = ParameterGroup(
cluster_parameter_group_name, cluster_parameter_group_name,
group_family, group_family,
@ -826,28 +860,31 @@ class RedshiftBackend(BaseBackend):
return parameter_group return parameter_group
def describe_cluster_parameter_groups(self, parameter_group_name=None): def describe_cluster_parameter_groups(
parameter_groups = self.parameter_groups.values() self, parameter_group_name: Optional[str] = None
) -> List[ParameterGroup]:
if parameter_group_name: if parameter_group_name:
if parameter_group_name in self.parameter_groups: if parameter_group_name in self.parameter_groups:
return [self.parameter_groups[parameter_group_name]] return [self.parameter_groups[parameter_group_name]]
else: else:
raise ClusterParameterGroupNotFoundError(parameter_group_name) raise ClusterParameterGroupNotFoundError(parameter_group_name)
return parameter_groups return list(self.parameter_groups.values())
def delete_cluster_parameter_group(self, parameter_group_name): def delete_cluster_parameter_group(
self, parameter_group_name: str
) -> ParameterGroup:
if parameter_group_name in self.parameter_groups: if parameter_group_name in self.parameter_groups:
return self.parameter_groups.pop(parameter_group_name) return self.parameter_groups.pop(parameter_group_name)
raise ClusterParameterGroupNotFoundError(parameter_group_name) raise ClusterParameterGroupNotFoundError(parameter_group_name)
def create_cluster_snapshot( def create_cluster_snapshot(
self, self,
cluster_identifier, cluster_identifier: str,
snapshot_identifier, snapshot_identifier: str,
region_name, region_name: str,
tags, tags: Optional[List[Dict[str, str]]],
snapshot_type="manual", snapshot_type: str = "manual",
): ) -> Snapshot:
cluster = self.clusters.get(cluster_identifier) cluster = self.clusters.get(cluster_identifier)
if not cluster: if not cluster:
raise ClusterNotFoundError(cluster_identifier) raise ClusterNotFoundError(cluster_identifier)
@ -865,8 +902,11 @@ class RedshiftBackend(BaseBackend):
return snapshot return snapshot
def describe_cluster_snapshots( def describe_cluster_snapshots(
self, cluster_identifier=None, snapshot_identifier=None, snapshot_type=None self,
): cluster_identifier: Optional[str] = None,
snapshot_identifier: Optional[str] = None,
snapshot_type: Optional[str] = None,
) -> List[Snapshot]:
snapshot_types = ( snapshot_types = (
["automated", "manual"] if snapshot_type is None else [snapshot_type] ["automated", "manual"] if snapshot_type is None else [snapshot_type]
) )
@ -885,9 +925,9 @@ class RedshiftBackend(BaseBackend):
return [self.snapshots[snapshot_identifier]] return [self.snapshots[snapshot_identifier]]
raise ClusterSnapshotNotFoundError(snapshot_identifier) raise ClusterSnapshotNotFoundError(snapshot_identifier)
return self.snapshots.values() return list(self.snapshots.values())
def delete_cluster_snapshot(self, snapshot_identifier): def delete_cluster_snapshot(self, snapshot_identifier: str) -> Snapshot:
if snapshot_identifier not in self.snapshots: if snapshot_identifier not in self.snapshots:
raise ClusterSnapshotNotFoundError(snapshot_identifier) raise ClusterSnapshotNotFoundError(snapshot_identifier)
@ -900,7 +940,7 @@ class RedshiftBackend(BaseBackend):
deleted_snapshot.status = "deleted" deleted_snapshot.status = "deleted"
return deleted_snapshot return deleted_snapshot
def restore_from_cluster_snapshot(self, **kwargs): def restore_from_cluster_snapshot(self, **kwargs: Any) -> Cluster:
snapshot_identifier = kwargs.pop("snapshot_identifier") snapshot_identifier = kwargs.pop("snapshot_identifier")
snapshot = self.describe_cluster_snapshots( snapshot = self.describe_cluster_snapshots(
snapshot_identifier=snapshot_identifier snapshot_identifier=snapshot_identifier
@ -925,7 +965,7 @@ class RedshiftBackend(BaseBackend):
create_kwargs.update(kwargs) create_kwargs.update(kwargs)
return self.create_cluster(**create_kwargs) return self.create_cluster(**create_kwargs)
def create_snapshot_copy_grant(self, **kwargs): def create_snapshot_copy_grant(self, **kwargs: Any) -> SnapshotCopyGrant:
snapshot_copy_grant_name = kwargs["snapshot_copy_grant_name"] snapshot_copy_grant_name = kwargs["snapshot_copy_grant_name"]
kms_key_id = kwargs["kms_key_id"] kms_key_id = kwargs["kms_key_id"]
if snapshot_copy_grant_name not in self.snapshot_copy_grants: if snapshot_copy_grant_name not in self.snapshot_copy_grants:
@ -936,14 +976,14 @@ class RedshiftBackend(BaseBackend):
return snapshot_copy_grant return snapshot_copy_grant
raise SnapshotCopyGrantAlreadyExistsFaultError(snapshot_copy_grant_name) raise SnapshotCopyGrantAlreadyExistsFaultError(snapshot_copy_grant_name)
def delete_snapshot_copy_grant(self, **kwargs): def delete_snapshot_copy_grant(self, **kwargs: Any) -> SnapshotCopyGrant:
snapshot_copy_grant_name = kwargs["snapshot_copy_grant_name"] snapshot_copy_grant_name = kwargs["snapshot_copy_grant_name"]
if snapshot_copy_grant_name in self.snapshot_copy_grants: if snapshot_copy_grant_name in self.snapshot_copy_grants:
return self.snapshot_copy_grants.pop(snapshot_copy_grant_name) return self.snapshot_copy_grants.pop(snapshot_copy_grant_name)
raise SnapshotCopyGrantNotFoundFaultError(snapshot_copy_grant_name) raise SnapshotCopyGrantNotFoundFaultError(snapshot_copy_grant_name)
def describe_snapshot_copy_grants(self, **kwargs): def describe_snapshot_copy_grants(self, **kwargs: Any) -> List[SnapshotCopyGrant]:
copy_grants = self.snapshot_copy_grants.values() copy_grants = list(self.snapshot_copy_grants.values())
snapshot_copy_grant_name = kwargs["snapshot_copy_grant_name"] snapshot_copy_grant_name = kwargs["snapshot_copy_grant_name"]
if snapshot_copy_grant_name: if snapshot_copy_grant_name:
if snapshot_copy_grant_name in self.snapshot_copy_grants: if snapshot_copy_grant_name in self.snapshot_copy_grants:
@ -952,7 +992,7 @@ class RedshiftBackend(BaseBackend):
raise SnapshotCopyGrantNotFoundFaultError(snapshot_copy_grant_name) raise SnapshotCopyGrantNotFoundFaultError(snapshot_copy_grant_name)
return copy_grants return copy_grants
def _get_resource_from_arn(self, arn): def _get_resource_from_arn(self, arn: str) -> TaggableResourceMixin:
try: try:
arn_breakdown = arn.split(":") arn_breakdown = arn.split(":")
resource_type = arn_breakdown[5] resource_type = arn_breakdown[5]
@ -977,7 +1017,7 @@ class RedshiftBackend(BaseBackend):
return resource return resource
@staticmethod @staticmethod
def _describe_tags_for_resources(resources): def _describe_tags_for_resources(resources: Iterable[Any]) -> List[Dict[str, Any]]: # type: ignore[misc]
tagged_resources = [] tagged_resources = []
for resource in resources: for resource in resources:
for tag in resource.tags: for tag in resource.tags:
@ -989,21 +1029,27 @@ class RedshiftBackend(BaseBackend):
tagged_resources.append(data) tagged_resources.append(data)
return tagged_resources return tagged_resources
def _describe_tags_for_resource_type(self, resource_type): def _describe_tags_for_resource_type(
self, resource_type: str
) -> List[Dict[str, Any]]:
resources = self.RESOURCE_TYPE_MAP.get(resource_type) resources = self.RESOURCE_TYPE_MAP.get(resource_type)
if not resources: if not resources:
raise ResourceNotFoundFaultError(resource_type=resource_type) raise ResourceNotFoundFaultError(resource_type=resource_type)
return self._describe_tags_for_resources(resources.values()) return self._describe_tags_for_resources(resources.values())
def _describe_tags_for_resource_name(self, resource_name): def _describe_tags_for_resource_name(
self, resource_name: str
) -> List[Dict[str, Any]]:
resource = self._get_resource_from_arn(resource_name) resource = self._get_resource_from_arn(resource_name)
return self._describe_tags_for_resources([resource]) return self._describe_tags_for_resources([resource])
def create_tags(self, resource_name, tags): def create_tags(self, resource_name: str, tags: List[Dict[str, str]]) -> None:
resource = self._get_resource_from_arn(resource_name) resource = self._get_resource_from_arn(resource_name)
resource.create_tags(tags) resource.create_tags(tags)
def describe_tags(self, resource_name, resource_type): def describe_tags(
self, resource_name: str, resource_type: str
) -> List[Dict[str, Any]]:
if resource_name and resource_type: if resource_name and resource_type:
raise InvalidParameterValueError( raise InvalidParameterValueError(
"You cannot filter a list of resources using an Amazon " "You cannot filter a list of resources using an Amazon "
@ -1025,13 +1071,17 @@ class RedshiftBackend(BaseBackend):
pass pass
return tagged_resources return tagged_resources
def delete_tags(self, resource_name, tag_keys): def delete_tags(self, resource_name: str, tag_keys: List[str]) -> None:
resource = self._get_resource_from_arn(resource_name) resource = self._get_resource_from_arn(resource_name)
resource.delete_tags(tag_keys) resource.delete_tags(tag_keys)
def get_cluster_credentials( def get_cluster_credentials(
self, cluster_identifier, db_user, auto_create, duration_seconds self,
): cluster_identifier: str,
db_user: str,
auto_create: bool,
duration_seconds: int,
) -> Dict[str, Any]:
if duration_seconds < 900 or duration_seconds > 3600: if duration_seconds < 900 or duration_seconds > 3600:
raise InvalidParameterValueError( raise InvalidParameterValueError(
"Token duration must be between 900 and 3600 seconds" "Token duration must be between 900 and 3600 seconds"

View File

@ -3,12 +3,14 @@ import json
import xmltodict import xmltodict
from jinja2 import Template from jinja2 import Template
from typing import Any, Dict, List
from moto.core.common_types import TYPE_RESPONSE
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import redshift_backends from .models import redshift_backends, RedshiftBackend
def convert_json_error_to_xml(json_error): def convert_json_error_to_xml(json_error: Any) -> str:
error = json.loads(json_error) error = json.loads(json_error)
code = error["Error"]["Code"] code = error["Error"]["Code"]
message = error["Error"]["Message"] message = error["Error"]["Message"]
@ -26,7 +28,7 @@ def convert_json_error_to_xml(json_error):
return template.render(code=code, message=message) return template.render(code=code, message=message)
def itemize(data): def itemize(data: Any) -> Dict[str, Any]:
""" """
The xmltodict.unparse requires we modify the shape of the input dictionary slightly. Instead of a dict of the form: The xmltodict.unparse requires we modify the shape of the input dictionary slightly. Instead of a dict of the form:
{'key': ['value1', 'value2']} {'key': ['value1', 'value2']}
@ -45,14 +47,14 @@ def itemize(data):
class RedshiftResponse(BaseResponse): class RedshiftResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="redshift") super().__init__(service_name="redshift")
@property @property
def redshift_backend(self): def redshift_backend(self) -> RedshiftBackend:
return redshift_backends[self.current_account][self.region] return redshift_backends[self.current_account][self.region]
def get_response(self, response): def get_response(self, response: Any) -> str:
if self.request_json: if self.request_json:
return json.dumps(response) return json.dumps(response)
else: else:
@ -61,17 +63,17 @@ class RedshiftResponse(BaseResponse):
xml = xml.decode("utf-8") xml = xml.decode("utf-8")
return xml return xml
def call_action(self): def call_action(self) -> TYPE_RESPONSE:
status, headers, body = super().call_action() status, headers, body = super().call_action()
if status >= 400 and not self.request_json: if status >= 400 and not self.request_json:
body = convert_json_error_to_xml(body) body = convert_json_error_to_xml(body)
return status, headers, body return status, headers, body
def unpack_list_params(self, label, child_label): def unpack_list_params(self, label: str, child_label: str) -> Any:
root = self._get_multi_param_dict(label) or {} root = self._get_multi_param_dict(label) or {}
return root.get(child_label, []) return root.get(child_label, [])
def _get_cluster_security_groups(self): def _get_cluster_security_groups(self) -> List[str]:
cluster_security_groups = self._get_multi_param("ClusterSecurityGroups.member") cluster_security_groups = self._get_multi_param("ClusterSecurityGroups.member")
if not cluster_security_groups: if not cluster_security_groups:
cluster_security_groups = self._get_multi_param( cluster_security_groups = self._get_multi_param(
@ -79,7 +81,7 @@ class RedshiftResponse(BaseResponse):
) )
return cluster_security_groups return cluster_security_groups
def _get_vpc_security_group_ids(self): def _get_vpc_security_group_ids(self) -> List[str]:
vpc_security_group_ids = self._get_multi_param("VpcSecurityGroupIds.member") vpc_security_group_ids = self._get_multi_param("VpcSecurityGroupIds.member")
if not vpc_security_group_ids: if not vpc_security_group_ids:
vpc_security_group_ids = self._get_multi_param( vpc_security_group_ids = self._get_multi_param(
@ -87,19 +89,19 @@ class RedshiftResponse(BaseResponse):
) )
return vpc_security_group_ids return vpc_security_group_ids
def _get_iam_roles(self): def _get_iam_roles(self) -> List[str]:
iam_roles = self._get_multi_param("IamRoles.member") iam_roles = self._get_multi_param("IamRoles.member")
if not iam_roles: if not iam_roles:
iam_roles = self._get_multi_param("IamRoles.IamRoleArn") iam_roles = self._get_multi_param("IamRoles.IamRoleArn")
return iam_roles return iam_roles
def _get_subnet_ids(self): def _get_subnet_ids(self) -> List[str]:
subnet_ids = self._get_multi_param("SubnetIds.member") subnet_ids = self._get_multi_param("SubnetIds.member")
if not subnet_ids: if not subnet_ids:
subnet_ids = self._get_multi_param("SubnetIds.SubnetIdentifier") subnet_ids = self._get_multi_param("SubnetIds.SubnetIdentifier")
return subnet_ids return subnet_ids
def create_cluster(self): def create_cluster(self) -> str:
cluster_kwargs = { cluster_kwargs = {
"cluster_identifier": self._get_param("ClusterIdentifier"), "cluster_identifier": self._get_param("ClusterIdentifier"),
"node_type": self._get_param("NodeType"), "node_type": self._get_param("NodeType"),
@ -145,7 +147,7 @@ class RedshiftResponse(BaseResponse):
} }
) )
def pause_cluster(self): def pause_cluster(self) -> str:
cluster_id = self._get_param("ClusterIdentifier") cluster_id = self._get_param("ClusterIdentifier")
cluster = self.redshift_backend.pause_cluster(cluster_id).to_json() cluster = self.redshift_backend.pause_cluster(cluster_id).to_json()
return self.get_response( return self.get_response(
@ -159,7 +161,7 @@ class RedshiftResponse(BaseResponse):
} }
) )
def resume_cluster(self): def resume_cluster(self) -> str:
cluster_id = self._get_param("ClusterIdentifier") cluster_id = self._get_param("ClusterIdentifier")
cluster = self.redshift_backend.resume_cluster(cluster_id).to_json() cluster = self.redshift_backend.resume_cluster(cluster_id).to_json()
return self.get_response( return self.get_response(
@ -173,7 +175,7 @@ class RedshiftResponse(BaseResponse):
} }
) )
def restore_from_cluster_snapshot(self): def restore_from_cluster_snapshot(self) -> str:
enhanced_vpc_routing = self._get_bool_param("EnhancedVpcRouting") enhanced_vpc_routing = self._get_bool_param("EnhancedVpcRouting")
node_type = self._get_param("NodeType") node_type = self._get_param("NodeType")
number_of_nodes = self._get_int_param("NumberOfNodes") number_of_nodes = self._get_int_param("NumberOfNodes")
@ -220,7 +222,7 @@ class RedshiftResponse(BaseResponse):
} }
) )
def describe_clusters(self): def describe_clusters(self) -> str:
cluster_identifier = self._get_param("ClusterIdentifier") cluster_identifier = self._get_param("ClusterIdentifier")
clusters = self.redshift_backend.describe_clusters(cluster_identifier) clusters = self.redshift_backend.describe_clusters(cluster_identifier)
@ -237,7 +239,7 @@ class RedshiftResponse(BaseResponse):
} }
) )
def modify_cluster(self): def modify_cluster(self) -> str:
request_kwargs = { request_kwargs = {
"cluster_identifier": self._get_param("ClusterIdentifier"), "cluster_identifier": self._get_param("ClusterIdentifier"),
"new_cluster_identifier": self._get_param("NewClusterIdentifier"), "new_cluster_identifier": self._get_param("NewClusterIdentifier"),
@ -284,7 +286,7 @@ class RedshiftResponse(BaseResponse):
} }
) )
def delete_cluster(self): def delete_cluster(self) -> str:
request_kwargs = { request_kwargs = {
"cluster_identifier": self._get_param("ClusterIdentifier"), "cluster_identifier": self._get_param("ClusterIdentifier"),
"final_cluster_snapshot_identifier": self._get_param( "final_cluster_snapshot_identifier": self._get_param(
@ -306,7 +308,7 @@ class RedshiftResponse(BaseResponse):
} }
) )
def create_cluster_subnet_group(self): def create_cluster_subnet_group(self) -> str:
cluster_subnet_group_name = self._get_param("ClusterSubnetGroupName") cluster_subnet_group_name = self._get_param("ClusterSubnetGroupName")
description = self._get_param("Description") description = self._get_param("Description")
subnet_ids = self._get_subnet_ids() subnet_ids = self._get_subnet_ids()
@ -333,7 +335,7 @@ class RedshiftResponse(BaseResponse):
} }
) )
def describe_cluster_subnet_groups(self): def describe_cluster_subnet_groups(self) -> str:
subnet_identifier = self._get_param("ClusterSubnetGroupName") subnet_identifier = self._get_param("ClusterSubnetGroupName")
subnet_groups = self.redshift_backend.describe_cluster_subnet_groups( subnet_groups = self.redshift_backend.describe_cluster_subnet_groups(
subnet_identifier subnet_identifier
@ -354,7 +356,7 @@ class RedshiftResponse(BaseResponse):
} }
) )
def delete_cluster_subnet_group(self): def delete_cluster_subnet_group(self) -> str:
subnet_identifier = self._get_param("ClusterSubnetGroupName") subnet_identifier = self._get_param("ClusterSubnetGroupName")
self.redshift_backend.delete_cluster_subnet_group(subnet_identifier) self.redshift_backend.delete_cluster_subnet_group(subnet_identifier)
@ -368,7 +370,7 @@ class RedshiftResponse(BaseResponse):
} }
) )
def create_cluster_security_group(self): def create_cluster_security_group(self) -> str:
cluster_security_group_name = self._get_param("ClusterSecurityGroupName") cluster_security_group_name = self._get_param("ClusterSecurityGroupName")
description = self._get_param("Description") description = self._get_param("Description")
tags = self.unpack_list_params("Tags", "Tag") tags = self.unpack_list_params("Tags", "Tag")
@ -392,7 +394,7 @@ class RedshiftResponse(BaseResponse):
} }
) )
def describe_cluster_security_groups(self): def describe_cluster_security_groups(self) -> str:
cluster_security_group_name = self._get_param("ClusterSecurityGroupName") cluster_security_group_name = self._get_param("ClusterSecurityGroupName")
security_groups = self.redshift_backend.describe_cluster_security_groups( security_groups = self.redshift_backend.describe_cluster_security_groups(
cluster_security_group_name cluster_security_group_name
@ -414,7 +416,7 @@ class RedshiftResponse(BaseResponse):
} }
) )
def delete_cluster_security_group(self): def delete_cluster_security_group(self) -> str:
security_group_identifier = self._get_param("ClusterSecurityGroupName") security_group_identifier = self._get_param("ClusterSecurityGroupName")
self.redshift_backend.delete_cluster_security_group(security_group_identifier) self.redshift_backend.delete_cluster_security_group(security_group_identifier)
@ -428,7 +430,7 @@ class RedshiftResponse(BaseResponse):
} }
) )
def authorize_cluster_security_group_ingress(self): def authorize_cluster_security_group_ingress(self) -> str:
cluster_security_group_name = self._get_param("ClusterSecurityGroupName") cluster_security_group_name = self._get_param("ClusterSecurityGroupName")
cidr_ip = self._get_param("CIDRIP") cidr_ip = self._get_param("CIDRIP")
@ -456,7 +458,7 @@ class RedshiftResponse(BaseResponse):
} }
) )
def create_cluster_parameter_group(self): def create_cluster_parameter_group(self) -> str:
cluster_parameter_group_name = self._get_param("ParameterGroupName") cluster_parameter_group_name = self._get_param("ParameterGroupName")
group_family = self._get_param("ParameterGroupFamily") group_family = self._get_param("ParameterGroupFamily")
description = self._get_param("Description") description = self._get_param("Description")
@ -479,7 +481,7 @@ class RedshiftResponse(BaseResponse):
} }
) )
def describe_cluster_parameter_groups(self): def describe_cluster_parameter_groups(self) -> str:
cluster_parameter_group_name = self._get_param("ParameterGroupName") cluster_parameter_group_name = self._get_param("ParameterGroupName")
parameter_groups = self.redshift_backend.describe_cluster_parameter_groups( parameter_groups = self.redshift_backend.describe_cluster_parameter_groups(
cluster_parameter_group_name cluster_parameter_group_name
@ -501,7 +503,7 @@ class RedshiftResponse(BaseResponse):
} }
) )
def delete_cluster_parameter_group(self): def delete_cluster_parameter_group(self) -> str:
cluster_parameter_group_name = self._get_param("ParameterGroupName") cluster_parameter_group_name = self._get_param("ParameterGroupName")
self.redshift_backend.delete_cluster_parameter_group( self.redshift_backend.delete_cluster_parameter_group(
cluster_parameter_group_name cluster_parameter_group_name
@ -517,7 +519,7 @@ class RedshiftResponse(BaseResponse):
} }
) )
def create_cluster_snapshot(self): def create_cluster_snapshot(self) -> str:
cluster_identifier = self._get_param("ClusterIdentifier") cluster_identifier = self._get_param("ClusterIdentifier")
snapshot_identifier = self._get_param("SnapshotIdentifier") snapshot_identifier = self._get_param("SnapshotIdentifier")
tags = self.unpack_list_params("Tags", "Tag") tags = self.unpack_list_params("Tags", "Tag")
@ -536,7 +538,7 @@ class RedshiftResponse(BaseResponse):
} }
) )
def describe_cluster_snapshots(self): def describe_cluster_snapshots(self) -> str:
cluster_identifier = self._get_param("ClusterIdentifier") cluster_identifier = self._get_param("ClusterIdentifier")
snapshot_identifier = self._get_param("SnapshotIdentifier") snapshot_identifier = self._get_param("SnapshotIdentifier")
snapshot_type = self._get_param("SnapshotType") snapshot_type = self._get_param("SnapshotType")
@ -556,7 +558,7 @@ class RedshiftResponse(BaseResponse):
} }
) )
def delete_cluster_snapshot(self): def delete_cluster_snapshot(self) -> str:
snapshot_identifier = self._get_param("SnapshotIdentifier") snapshot_identifier = self._get_param("SnapshotIdentifier")
snapshot = self.redshift_backend.delete_cluster_snapshot(snapshot_identifier) snapshot = self.redshift_backend.delete_cluster_snapshot(snapshot_identifier)
@ -571,7 +573,7 @@ class RedshiftResponse(BaseResponse):
} }
) )
def create_snapshot_copy_grant(self): def create_snapshot_copy_grant(self) -> str:
copy_grant_kwargs = { copy_grant_kwargs = {
"snapshot_copy_grant_name": self._get_param("SnapshotCopyGrantName"), "snapshot_copy_grant_name": self._get_param("SnapshotCopyGrantName"),
"kms_key_id": self._get_param("KmsKeyId"), "kms_key_id": self._get_param("KmsKeyId"),
@ -594,7 +596,7 @@ class RedshiftResponse(BaseResponse):
} }
) )
def delete_snapshot_copy_grant(self): def delete_snapshot_copy_grant(self) -> str:
copy_grant_kwargs = { copy_grant_kwargs = {
"snapshot_copy_grant_name": self._get_param("SnapshotCopyGrantName") "snapshot_copy_grant_name": self._get_param("SnapshotCopyGrantName")
} }
@ -609,7 +611,7 @@ class RedshiftResponse(BaseResponse):
} }
) )
def describe_snapshot_copy_grants(self): def describe_snapshot_copy_grants(self) -> str:
copy_grant_kwargs = { copy_grant_kwargs = {
"snapshot_copy_grant_name": self._get_param("SnapshotCopyGrantName") "snapshot_copy_grant_name": self._get_param("SnapshotCopyGrantName")
} }
@ -632,7 +634,7 @@ class RedshiftResponse(BaseResponse):
} }
) )
def create_tags(self): def create_tags(self) -> str:
resource_name = self._get_param("ResourceName") resource_name = self._get_param("ResourceName")
tags = self.unpack_list_params("Tags", "Tag") tags = self.unpack_list_params("Tags", "Tag")
@ -648,7 +650,7 @@ class RedshiftResponse(BaseResponse):
} }
) )
def describe_tags(self): def describe_tags(self) -> str:
resource_name = self._get_param("ResourceName") resource_name = self._get_param("ResourceName")
resource_type = self._get_param("ResourceType") resource_type = self._get_param("ResourceType")
@ -666,7 +668,7 @@ class RedshiftResponse(BaseResponse):
} }
) )
def delete_tags(self): def delete_tags(self) -> str:
resource_name = self._get_param("ResourceName") resource_name = self._get_param("ResourceName")
tag_keys = self.unpack_list_params("TagKeys", "TagKey") tag_keys = self.unpack_list_params("TagKeys", "TagKey")
@ -682,7 +684,7 @@ class RedshiftResponse(BaseResponse):
} }
) )
def enable_snapshot_copy(self): def enable_snapshot_copy(self) -> str:
snapshot_copy_kwargs = { snapshot_copy_kwargs = {
"cluster_identifier": self._get_param("ClusterIdentifier"), "cluster_identifier": self._get_param("ClusterIdentifier"),
"destination_region": self._get_param("DestinationRegion"), "destination_region": self._get_param("DestinationRegion"),
@ -702,7 +704,7 @@ class RedshiftResponse(BaseResponse):
} }
) )
def disable_snapshot_copy(self): def disable_snapshot_copy(self) -> str:
snapshot_copy_kwargs = { snapshot_copy_kwargs = {
"cluster_identifier": self._get_param("ClusterIdentifier") "cluster_identifier": self._get_param("ClusterIdentifier")
} }
@ -719,7 +721,7 @@ class RedshiftResponse(BaseResponse):
} }
) )
def modify_snapshot_copy_retention_period(self): def modify_snapshot_copy_retention_period(self) -> str:
snapshot_copy_kwargs = { snapshot_copy_kwargs = {
"cluster_identifier": self._get_param("ClusterIdentifier"), "cluster_identifier": self._get_param("ClusterIdentifier"),
"retention_period": self._get_param("RetentionPeriod"), "retention_period": self._get_param("RetentionPeriod"),
@ -741,7 +743,7 @@ class RedshiftResponse(BaseResponse):
} }
) )
def get_cluster_credentials(self): def get_cluster_credentials(self) -> str:
cluster_identifier = self._get_param("ClusterIdentifier") cluster_identifier = self._get_param("ClusterIdentifier")
db_user = self._get_param("DbUser") db_user = self._get_param("DbUser")
auto_create = self._get_bool_param("AutoCreate", False) auto_create = self._get_bool_param("AutoCreate", False)

View File

@ -239,7 +239,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy] [mypy]
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/ram,moto/rds,moto/rdsdata,moto/scheduler files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/ram,moto/rds,moto/rdsdata,moto/redshift,moto/scheduler
show_column_numbers=True show_column_numbers=True
show_error_codes = True show_error_codes = True
disable_error_code=abstract disable_error_code=abstract