Techdebt: MyPy DAX (#5783)

This commit is contained in:
Bert Blommers 2022-12-18 15:08:32 -01:00 committed by GitHub
parent 7956812e66
commit fab91edd8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 62 additions and 56 deletions

View File

@ -1,13 +1,14 @@
from moto.core.exceptions import JsonRESTError
from typing import Optional
class InvalidParameterValueException(JsonRESTError):
def __init__(self, message):
def __init__(self, message: str):
super().__init__("InvalidParameterValueException", message)
class ClusterNotFoundFault(JsonRESTError):
def __init__(self, name=None):
def __init__(self, name: Optional[str] = None):
# DescribeClusters and DeleteCluster use a different message for the same error
msg = f"Cluster {name} not found." if name else "Cluster not found."
super().__init__("ClusterNotFoundFault", msg)

View File

@ -6,17 +6,18 @@ from moto.moto_api._internal import mock_random as random
from moto.moto_api._internal.managed_state_model import ManagedState
from moto.utilities.tagging_service import TaggingService
from moto.utilities.paginator import paginate
from typing import Any, Dict, List, Iterable
from .exceptions import ClusterNotFoundFault
from .utils import PAGINATION_MODEL
class DaxParameterGroup(BaseModel):
def __init__(self):
def __init__(self) -> None:
self.name = "default.dax1.0"
self.status = "in-sync"
def to_json(self):
def to_json(self) -> Dict[str, Any]:
return {
"ParameterGroupName": self.name,
"ParameterApplyStatus": self.status,
@ -25,7 +26,7 @@ class DaxParameterGroup(BaseModel):
class DaxNode:
def __init__(self, endpoint, name, index):
def __init__(self, endpoint: "DaxEndpoint", name: str, index: int):
self.node_id = f"{name}-{chr(ord('a')+index)}" # name-a, name-b, etc
self.node_endpoint = {
"Address": f"{self.node_id}.{endpoint.cluster_hex}.nodes.dax-clusters.{endpoint.region}.amazonaws.com",
@ -38,7 +39,7 @@ class DaxNode:
self.status = "available"
self.parameter_status = "in-sync"
def to_json(self):
def to_json(self) -> Dict[str, Any]:
return {
"NodeId": self.node_id,
"Endpoint": self.node_endpoint,
@ -50,14 +51,14 @@ class DaxNode:
class DaxEndpoint:
def __init__(self, name, cluster_hex, region):
def __init__(self, name: str, cluster_hex: str, region: str):
self.name = name
self.cluster_hex = cluster_hex
self.region = region
self.port = 8111
def to_json(self, full=False):
dct = {"Port": self.port}
def to_json(self, full: bool = False) -> Dict[str, Any]:
dct: Dict[str, Any] = {"Port": self.port}
if full:
dct[
"Address"
@ -69,15 +70,15 @@ class DaxEndpoint:
class DaxCluster(BaseModel, ManagedState):
def __init__(
self,
account_id,
region,
name,
description,
node_type,
replication_factor,
iam_role_arn,
sse_specification,
encryption_type,
account_id: str,
region: str,
name: str,
description: str,
node_type: str,
replication_factor: int,
iam_role_arn: str,
sse_specification: Dict[str, Any],
encryption_type: str,
):
# Configure ManagedState
super().__init__(
@ -108,28 +109,30 @@ class DaxCluster(BaseModel, ManagedState):
self.sse_specification = sse_specification
self.encryption_type = encryption_type
def _create_new_node(self, idx):
def _create_new_node(self, idx: int) -> DaxNode:
return DaxNode(endpoint=self.endpoint, name=self.name, index=idx)
def increase_replication_factor(self, new_replication_factor):
def increase_replication_factor(self, new_replication_factor: int) -> None:
for idx in range(self.replication_factor, new_replication_factor):
self.nodes.append(self._create_new_node(idx))
self.replication_factor = new_replication_factor
def decrease_replication_factor(self, new_replication_factor, node_ids_to_remove):
def decrease_replication_factor(
self, new_replication_factor: int, node_ids_to_remove: List[str]
) -> None:
if node_ids_to_remove:
self.nodes = [n for n in self.nodes if n.node_id not in node_ids_to_remove]
else:
self.nodes = self.nodes[0:new_replication_factor]
self.replication_factor = new_replication_factor
def delete(self):
def delete(self) -> None:
self.status = "deleting"
def is_deleted(self):
def is_deleted(self) -> bool:
return self.status == "deleted"
def to_json(self):
def to_json(self) -> Dict[str, Any]:
use_full_repr = self.status == "available"
dct = {
"ClusterName": self.name,
@ -158,9 +161,9 @@ class DaxCluster(BaseModel, ManagedState):
class DAXBackend(BaseBackend):
def __init__(self, region_name, account_id):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self._clusters = dict()
self._clusters: Dict[str, DaxCluster] = dict()
self._tagger = TaggingService()
state_manager.register_default_transition(
@ -168,7 +171,7 @@ class DAXBackend(BaseBackend):
)
@property
def clusters(self):
def clusters(self) -> Dict[str, DaxCluster]:
self._clusters = {
name: cluster
for name, cluster in self._clusters.items()
@ -178,15 +181,15 @@ class DAXBackend(BaseBackend):
def create_cluster(
self,
cluster_name,
node_type,
description,
replication_factor,
iam_role_arn,
tags,
sse_specification,
encryption_type,
):
cluster_name: str,
node_type: str,
description: str,
replication_factor: int,
iam_role_arn: str,
tags: List[Dict[str, str]],
sse_specification: Dict[str, Any],
encryption_type: str,
) -> DaxCluster:
"""
The following parameters are not yet processed:
AvailabilityZones, SubnetGroupNames, SecurityGroups, PreferredMaintenanceWindow, NotificationTopicArn, ParameterGroupName
@ -206,14 +209,14 @@ class DAXBackend(BaseBackend):
self._tagger.tag_resource(cluster.arn, tags)
return cluster
def delete_cluster(self, cluster_name):
def delete_cluster(self, cluster_name: str) -> DaxCluster:
if cluster_name not in self.clusters:
raise ClusterNotFoundFault()
self.clusters[cluster_name].delete()
return self.clusters[cluster_name]
@paginate(PAGINATION_MODEL)
def describe_clusters(self, cluster_names):
def describe_clusters(self, cluster_names: Iterable[str]) -> List[DaxCluster]: # type: ignore[misc]
clusters = self.clusters
if not cluster_names:
cluster_names = clusters.keys()
@ -229,7 +232,7 @@ class DAXBackend(BaseBackend):
raise ClusterNotFoundFault(name)
return [cluster for name, cluster in clusters.items() if name in cluster_names]
def list_tags(self, resource_name):
def list_tags(self, resource_name: str) -> Dict[str, List[Dict[str, str]]]:
"""
Pagination is not yet implemented
"""
@ -239,7 +242,9 @@ class DAXBackend(BaseBackend):
raise ClusterNotFoundFault()
return self._tagger.list_tags_for_resource(self.clusters[name].arn)
def increase_replication_factor(self, cluster_name, new_replication_factor):
def increase_replication_factor(
self, cluster_name: str, new_replication_factor: int
) -> DaxCluster:
"""
The AvailabilityZones-parameter is not yet implemented
"""
@ -250,10 +255,10 @@ class DAXBackend(BaseBackend):
def decrease_replication_factor(
self,
cluster_name,
new_replication_factor,
node_ids_to_remove,
):
cluster_name: str,
new_replication_factor: int,
node_ids_to_remove: List[str],
) -> DaxCluster:
"""
The AvailabilityZones-parameter is not yet implemented
"""

View File

@ -3,18 +3,18 @@ import re
from moto.core.responses import BaseResponse
from .exceptions import InvalidParameterValueException
from .models import dax_backends
from .models import dax_backends, DAXBackend
class DAXResponse(BaseResponse):
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="dax")
@property
def dax_backend(self):
def dax_backend(self) -> DAXBackend:
return dax_backends[self.current_account][self.region]
def create_cluster(self):
def create_cluster(self) -> str:
params = json.loads(self.body)
cluster_name = params.get("ClusterName")
node_type = params.get("NodeType")
@ -40,12 +40,12 @@ class DAXResponse(BaseResponse):
)
return json.dumps(dict(Cluster=cluster.to_json()))
def delete_cluster(self):
def delete_cluster(self) -> str:
cluster_name = json.loads(self.body).get("ClusterName")
cluster = self.dax_backend.delete_cluster(cluster_name)
return json.dumps(dict(Cluster=cluster.to_json()))
def describe_clusters(self):
def describe_clusters(self) -> str:
params = json.loads(self.body)
cluster_names = params.get("ClusterNames", [])
max_results = params.get("MaxResults")
@ -61,7 +61,7 @@ class DAXResponse(BaseResponse):
{"Clusters": [c.to_json() for c in clusters], "NextToken": next_token}
)
def _validate_arn(self, arn):
def _validate_arn(self, arn: str) -> None:
if not arn.startswith("arn:"):
raise InvalidParameterValueException(f"ARNs must start with 'arn:': {arn}")
sections = arn.split(":")
@ -80,20 +80,20 @@ class DAXResponse(BaseResponse):
f"Fifth colon (namespace/relative-id delimiter) not found: {arn}"
)
def _validate_name(self, name):
def _validate_name(self, name: str) -> None:
msg = "Cluster ID specified is not a valid identifier. Identifiers must begin with a letter; must contain only ASCII letters, digits, and hyphens; and must not end with a hyphen or contain two consecutive hyphens."
if not re.match("^[a-z][a-z0-9-]+[a-z0-9]$", name):
raise InvalidParameterValueException(msg)
if "--" in name:
raise InvalidParameterValueException(msg)
def list_tags(self):
def list_tags(self) -> str:
params = json.loads(self.body)
resource_name = params.get("ResourceName")
tags = self.dax_backend.list_tags(resource_name=resource_name)
return json.dumps(tags)
def increase_replication_factor(self):
def increase_replication_factor(self) -> str:
params = json.loads(self.body)
cluster_name = params.get("ClusterName")
new_replication_factor = params.get("NewReplicationFactor")
@ -102,7 +102,7 @@ class DAXResponse(BaseResponse):
)
return json.dumps({"Cluster": cluster.to_json()})
def decrease_replication_factor(self):
def decrease_replication_factor(self) -> str:
params = json.loads(self.body)
cluster_name = params.get("ClusterName")
new_replication_factor = params.get("NewReplicationFactor")

View File

@ -18,7 +18,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy]
files= moto/a*,moto/b*,moto/c*,moto/databrew,moto/datapipeline,moto/datasync,moto/moto_api
files= moto/a*,moto/b*,moto/c*,moto/databrew,moto/datapipeline,moto/datasync,moto/dax,moto/moto_api
show_column_numbers=True
show_error_codes = True
disable_error_code=abstract