diff --git a/moto/dax/exceptions.py b/moto/dax/exceptions.py index 96a18b52d..7c50aa039 100644 --- a/moto/dax/exceptions.py +++ b/moto/dax/exceptions.py @@ -1,13 +1,14 @@ from moto.core.exceptions import JsonRESTError +from typing import Optional class InvalidParameterValueException(JsonRESTError): - def __init__(self, message): + def __init__(self, message: str): super().__init__("InvalidParameterValueException", message) class ClusterNotFoundFault(JsonRESTError): - def __init__(self, name=None): + def __init__(self, name: Optional[str] = None): # DescribeClusters and DeleteCluster use a different message for the same error msg = f"Cluster {name} not found." if name else "Cluster not found." super().__init__("ClusterNotFoundFault", msg) diff --git a/moto/dax/models.py b/moto/dax/models.py index a2b47fc34..2a29194e5 100644 --- a/moto/dax/models.py +++ b/moto/dax/models.py @@ -6,17 +6,18 @@ from moto.moto_api._internal import mock_random as random from moto.moto_api._internal.managed_state_model import ManagedState from moto.utilities.tagging_service import TaggingService from moto.utilities.paginator import paginate +from typing import Any, Dict, List, Iterable from .exceptions import ClusterNotFoundFault from .utils import PAGINATION_MODEL class DaxParameterGroup(BaseModel): - def __init__(self): + def __init__(self) -> None: self.name = "default.dax1.0" self.status = "in-sync" - def to_json(self): + def to_json(self) -> Dict[str, Any]: return { "ParameterGroupName": self.name, "ParameterApplyStatus": self.status, @@ -25,7 +26,7 @@ class DaxParameterGroup(BaseModel): class DaxNode: - def __init__(self, endpoint, name, index): + def __init__(self, endpoint: "DaxEndpoint", name: str, index: int): self.node_id = f"{name}-{chr(ord('a')+index)}" # name-a, name-b, etc self.node_endpoint = { "Address": f"{self.node_id}.{endpoint.cluster_hex}.nodes.dax-clusters.{endpoint.region}.amazonaws.com", @@ -38,7 +39,7 @@ class DaxNode: self.status = "available" self.parameter_status = "in-sync" - def to_json(self): + def to_json(self) -> Dict[str, Any]: return { "NodeId": self.node_id, "Endpoint": self.node_endpoint, @@ -50,14 +51,14 @@ class DaxNode: class DaxEndpoint: - def __init__(self, name, cluster_hex, region): + def __init__(self, name: str, cluster_hex: str, region: str): self.name = name self.cluster_hex = cluster_hex self.region = region self.port = 8111 - def to_json(self, full=False): - dct = {"Port": self.port} + def to_json(self, full: bool = False) -> Dict[str, Any]: + dct: Dict[str, Any] = {"Port": self.port} if full: dct[ "Address" @@ -69,15 +70,15 @@ class DaxEndpoint: class DaxCluster(BaseModel, ManagedState): def __init__( self, - account_id, - region, - name, - description, - node_type, - replication_factor, - iam_role_arn, - sse_specification, - encryption_type, + account_id: str, + region: str, + name: str, + description: str, + node_type: str, + replication_factor: int, + iam_role_arn: str, + sse_specification: Dict[str, Any], + encryption_type: str, ): # Configure ManagedState super().__init__( @@ -108,28 +109,30 @@ class DaxCluster(BaseModel, ManagedState): self.sse_specification = sse_specification self.encryption_type = encryption_type - def _create_new_node(self, idx): + def _create_new_node(self, idx: int) -> DaxNode: return DaxNode(endpoint=self.endpoint, name=self.name, index=idx) - def increase_replication_factor(self, new_replication_factor): + def increase_replication_factor(self, new_replication_factor: int) -> None: for idx in range(self.replication_factor, new_replication_factor): self.nodes.append(self._create_new_node(idx)) self.replication_factor = new_replication_factor - def decrease_replication_factor(self, new_replication_factor, node_ids_to_remove): + def decrease_replication_factor( + self, new_replication_factor: int, node_ids_to_remove: List[str] + ) -> None: if node_ids_to_remove: self.nodes = [n for n in self.nodes if n.node_id not in node_ids_to_remove] else: self.nodes = self.nodes[0:new_replication_factor] self.replication_factor = new_replication_factor - def delete(self): + def delete(self) -> None: self.status = "deleting" - def is_deleted(self): + def is_deleted(self) -> bool: return self.status == "deleted" - def to_json(self): + def to_json(self) -> Dict[str, Any]: use_full_repr = self.status == "available" dct = { "ClusterName": self.name, @@ -158,9 +161,9 @@ class DaxCluster(BaseModel, ManagedState): class DAXBackend(BaseBackend): - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self._clusters = dict() + self._clusters: Dict[str, DaxCluster] = dict() self._tagger = TaggingService() state_manager.register_default_transition( @@ -168,7 +171,7 @@ class DAXBackend(BaseBackend): ) @property - def clusters(self): + def clusters(self) -> Dict[str, DaxCluster]: self._clusters = { name: cluster for name, cluster in self._clusters.items() @@ -178,15 +181,15 @@ class DAXBackend(BaseBackend): def create_cluster( self, - cluster_name, - node_type, - description, - replication_factor, - iam_role_arn, - tags, - sse_specification, - encryption_type, - ): + cluster_name: str, + node_type: str, + description: str, + replication_factor: int, + iam_role_arn: str, + tags: List[Dict[str, str]], + sse_specification: Dict[str, Any], + encryption_type: str, + ) -> DaxCluster: """ The following parameters are not yet processed: AvailabilityZones, SubnetGroupNames, SecurityGroups, PreferredMaintenanceWindow, NotificationTopicArn, ParameterGroupName @@ -206,14 +209,14 @@ class DAXBackend(BaseBackend): self._tagger.tag_resource(cluster.arn, tags) return cluster - def delete_cluster(self, cluster_name): + def delete_cluster(self, cluster_name: str) -> DaxCluster: if cluster_name not in self.clusters: raise ClusterNotFoundFault() self.clusters[cluster_name].delete() return self.clusters[cluster_name] @paginate(PAGINATION_MODEL) - def describe_clusters(self, cluster_names): + def describe_clusters(self, cluster_names: Iterable[str]) -> List[DaxCluster]: # type: ignore[misc] clusters = self.clusters if not cluster_names: cluster_names = clusters.keys() @@ -229,7 +232,7 @@ class DAXBackend(BaseBackend): raise ClusterNotFoundFault(name) return [cluster for name, cluster in clusters.items() if name in cluster_names] - def list_tags(self, resource_name): + def list_tags(self, resource_name: str) -> Dict[str, List[Dict[str, str]]]: """ Pagination is not yet implemented """ @@ -239,7 +242,9 @@ class DAXBackend(BaseBackend): raise ClusterNotFoundFault() return self._tagger.list_tags_for_resource(self.clusters[name].arn) - def increase_replication_factor(self, cluster_name, new_replication_factor): + def increase_replication_factor( + self, cluster_name: str, new_replication_factor: int + ) -> DaxCluster: """ The AvailabilityZones-parameter is not yet implemented """ @@ -250,10 +255,10 @@ class DAXBackend(BaseBackend): def decrease_replication_factor( self, - cluster_name, - new_replication_factor, - node_ids_to_remove, - ): + cluster_name: str, + new_replication_factor: int, + node_ids_to_remove: List[str], + ) -> DaxCluster: """ The AvailabilityZones-parameter is not yet implemented """ diff --git a/moto/dax/responses.py b/moto/dax/responses.py index b2016b273..86af7806d 100644 --- a/moto/dax/responses.py +++ b/moto/dax/responses.py @@ -3,18 +3,18 @@ import re from moto.core.responses import BaseResponse from .exceptions import InvalidParameterValueException -from .models import dax_backends +from .models import dax_backends, DAXBackend class DAXResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="dax") @property - def dax_backend(self): + def dax_backend(self) -> DAXBackend: return dax_backends[self.current_account][self.region] - def create_cluster(self): + def create_cluster(self) -> str: params = json.loads(self.body) cluster_name = params.get("ClusterName") node_type = params.get("NodeType") @@ -40,12 +40,12 @@ class DAXResponse(BaseResponse): ) return json.dumps(dict(Cluster=cluster.to_json())) - def delete_cluster(self): + def delete_cluster(self) -> str: cluster_name = json.loads(self.body).get("ClusterName") cluster = self.dax_backend.delete_cluster(cluster_name) return json.dumps(dict(Cluster=cluster.to_json())) - def describe_clusters(self): + def describe_clusters(self) -> str: params = json.loads(self.body) cluster_names = params.get("ClusterNames", []) max_results = params.get("MaxResults") @@ -61,7 +61,7 @@ class DAXResponse(BaseResponse): {"Clusters": [c.to_json() for c in clusters], "NextToken": next_token} ) - def _validate_arn(self, arn): + def _validate_arn(self, arn: str) -> None: if not arn.startswith("arn:"): raise InvalidParameterValueException(f"ARNs must start with 'arn:': {arn}") sections = arn.split(":") @@ -80,20 +80,20 @@ class DAXResponse(BaseResponse): f"Fifth colon (namespace/relative-id delimiter) not found: {arn}" ) - def _validate_name(self, name): + def _validate_name(self, name: str) -> None: msg = "Cluster ID specified is not a valid identifier. Identifiers must begin with a letter; must contain only ASCII letters, digits, and hyphens; and must not end with a hyphen or contain two consecutive hyphens." if not re.match("^[a-z][a-z0-9-]+[a-z0-9]$", name): raise InvalidParameterValueException(msg) if "--" in name: raise InvalidParameterValueException(msg) - def list_tags(self): + def list_tags(self) -> str: params = json.loads(self.body) resource_name = params.get("ResourceName") tags = self.dax_backend.list_tags(resource_name=resource_name) return json.dumps(tags) - def increase_replication_factor(self): + def increase_replication_factor(self) -> str: params = json.loads(self.body) cluster_name = params.get("ClusterName") new_replication_factor = params.get("NewReplicationFactor") @@ -102,7 +102,7 @@ class DAXResponse(BaseResponse): ) return json.dumps({"Cluster": cluster.to_json()}) - def decrease_replication_factor(self): + def decrease_replication_factor(self) -> str: params = json.loads(self.body) cluster_name = params.get("ClusterName") new_replication_factor = params.get("NewReplicationFactor") diff --git a/setup.cfg b/setup.cfg index a07c05467..8febb1afc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,7 +18,7 @@ disable = W,C,R,E enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import [mypy] -files= moto/a*,moto/b*,moto/c*,moto/databrew,moto/datapipeline,moto/datasync,moto/moto_api +files= moto/a*,moto/b*,moto/c*,moto/databrew,moto/datapipeline,moto/datasync,moto/dax,moto/moto_api show_column_numbers=True show_error_codes = True disable_error_code=abstract