Techdebt: MyPy DAX (#5783)

This commit is contained in:
Bert Blommers 2022-12-18 15:08:32 -01:00 committed by GitHub
parent 7956812e66
commit fab91edd8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 62 additions and 56 deletions

View File

@ -1,13 +1,14 @@
from moto.core.exceptions import JsonRESTError from moto.core.exceptions import JsonRESTError
from typing import Optional
class InvalidParameterValueException(JsonRESTError): class InvalidParameterValueException(JsonRESTError):
def __init__(self, message): def __init__(self, message: str):
super().__init__("InvalidParameterValueException", message) super().__init__("InvalidParameterValueException", message)
class ClusterNotFoundFault(JsonRESTError): class ClusterNotFoundFault(JsonRESTError):
def __init__(self, name=None): def __init__(self, name: Optional[str] = None):
# DescribeClusters and DeleteCluster use a different message for the same error # DescribeClusters and DeleteCluster use a different message for the same error
msg = f"Cluster {name} not found." if name else "Cluster not found." msg = f"Cluster {name} not found." if name else "Cluster not found."
super().__init__("ClusterNotFoundFault", msg) super().__init__("ClusterNotFoundFault", msg)

View File

@ -6,17 +6,18 @@ from moto.moto_api._internal import mock_random as random
from moto.moto_api._internal.managed_state_model import ManagedState from moto.moto_api._internal.managed_state_model import ManagedState
from moto.utilities.tagging_service import TaggingService from moto.utilities.tagging_service import TaggingService
from moto.utilities.paginator import paginate from moto.utilities.paginator import paginate
from typing import Any, Dict, List, Iterable
from .exceptions import ClusterNotFoundFault from .exceptions import ClusterNotFoundFault
from .utils import PAGINATION_MODEL from .utils import PAGINATION_MODEL
class DaxParameterGroup(BaseModel): class DaxParameterGroup(BaseModel):
def __init__(self): def __init__(self) -> None:
self.name = "default.dax1.0" self.name = "default.dax1.0"
self.status = "in-sync" self.status = "in-sync"
def to_json(self): def to_json(self) -> Dict[str, Any]:
return { return {
"ParameterGroupName": self.name, "ParameterGroupName": self.name,
"ParameterApplyStatus": self.status, "ParameterApplyStatus": self.status,
@ -25,7 +26,7 @@ class DaxParameterGroup(BaseModel):
class DaxNode: class DaxNode:
def __init__(self, endpoint, name, index): def __init__(self, endpoint: "DaxEndpoint", name: str, index: int):
self.node_id = f"{name}-{chr(ord('a')+index)}" # name-a, name-b, etc self.node_id = f"{name}-{chr(ord('a')+index)}" # name-a, name-b, etc
self.node_endpoint = { self.node_endpoint = {
"Address": f"{self.node_id}.{endpoint.cluster_hex}.nodes.dax-clusters.{endpoint.region}.amazonaws.com", "Address": f"{self.node_id}.{endpoint.cluster_hex}.nodes.dax-clusters.{endpoint.region}.amazonaws.com",
@ -38,7 +39,7 @@ class DaxNode:
self.status = "available" self.status = "available"
self.parameter_status = "in-sync" self.parameter_status = "in-sync"
def to_json(self): def to_json(self) -> Dict[str, Any]:
return { return {
"NodeId": self.node_id, "NodeId": self.node_id,
"Endpoint": self.node_endpoint, "Endpoint": self.node_endpoint,
@ -50,14 +51,14 @@ class DaxNode:
class DaxEndpoint: class DaxEndpoint:
def __init__(self, name, cluster_hex, region): def __init__(self, name: str, cluster_hex: str, region: str):
self.name = name self.name = name
self.cluster_hex = cluster_hex self.cluster_hex = cluster_hex
self.region = region self.region = region
self.port = 8111 self.port = 8111
def to_json(self, full=False): def to_json(self, full: bool = False) -> Dict[str, Any]:
dct = {"Port": self.port} dct: Dict[str, Any] = {"Port": self.port}
if full: if full:
dct[ dct[
"Address" "Address"
@ -69,15 +70,15 @@ class DaxEndpoint:
class DaxCluster(BaseModel, ManagedState): class DaxCluster(BaseModel, ManagedState):
def __init__( def __init__(
self, self,
account_id, account_id: str,
region, region: str,
name, name: str,
description, description: str,
node_type, node_type: str,
replication_factor, replication_factor: int,
iam_role_arn, iam_role_arn: str,
sse_specification, sse_specification: Dict[str, Any],
encryption_type, encryption_type: str,
): ):
# Configure ManagedState # Configure ManagedState
super().__init__( super().__init__(
@ -108,28 +109,30 @@ class DaxCluster(BaseModel, ManagedState):
self.sse_specification = sse_specification self.sse_specification = sse_specification
self.encryption_type = encryption_type self.encryption_type = encryption_type
def _create_new_node(self, idx): def _create_new_node(self, idx: int) -> DaxNode:
return DaxNode(endpoint=self.endpoint, name=self.name, index=idx) return DaxNode(endpoint=self.endpoint, name=self.name, index=idx)
def increase_replication_factor(self, new_replication_factor): def increase_replication_factor(self, new_replication_factor: int) -> None:
for idx in range(self.replication_factor, new_replication_factor): for idx in range(self.replication_factor, new_replication_factor):
self.nodes.append(self._create_new_node(idx)) self.nodes.append(self._create_new_node(idx))
self.replication_factor = new_replication_factor self.replication_factor = new_replication_factor
def decrease_replication_factor(self, new_replication_factor, node_ids_to_remove): def decrease_replication_factor(
self, new_replication_factor: int, node_ids_to_remove: List[str]
) -> None:
if node_ids_to_remove: if node_ids_to_remove:
self.nodes = [n for n in self.nodes if n.node_id not in node_ids_to_remove] self.nodes = [n for n in self.nodes if n.node_id not in node_ids_to_remove]
else: else:
self.nodes = self.nodes[0:new_replication_factor] self.nodes = self.nodes[0:new_replication_factor]
self.replication_factor = new_replication_factor self.replication_factor = new_replication_factor
def delete(self): def delete(self) -> None:
self.status = "deleting" self.status = "deleting"
def is_deleted(self): def is_deleted(self) -> bool:
return self.status == "deleted" return self.status == "deleted"
def to_json(self): def to_json(self) -> Dict[str, Any]:
use_full_repr = self.status == "available" use_full_repr = self.status == "available"
dct = { dct = {
"ClusterName": self.name, "ClusterName": self.name,
@ -158,9 +161,9 @@ class DaxCluster(BaseModel, ManagedState):
class DAXBackend(BaseBackend): class DAXBackend(BaseBackend):
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self._clusters = dict() self._clusters: Dict[str, DaxCluster] = dict()
self._tagger = TaggingService() self._tagger = TaggingService()
state_manager.register_default_transition( state_manager.register_default_transition(
@ -168,7 +171,7 @@ class DAXBackend(BaseBackend):
) )
@property @property
def clusters(self): def clusters(self) -> Dict[str, DaxCluster]:
self._clusters = { self._clusters = {
name: cluster name: cluster
for name, cluster in self._clusters.items() for name, cluster in self._clusters.items()
@ -178,15 +181,15 @@ class DAXBackend(BaseBackend):
def create_cluster( def create_cluster(
self, self,
cluster_name, cluster_name: str,
node_type, node_type: str,
description, description: str,
replication_factor, replication_factor: int,
iam_role_arn, iam_role_arn: str,
tags, tags: List[Dict[str, str]],
sse_specification, sse_specification: Dict[str, Any],
encryption_type, encryption_type: str,
): ) -> DaxCluster:
""" """
The following parameters are not yet processed: The following parameters are not yet processed:
AvailabilityZones, SubnetGroupNames, SecurityGroups, PreferredMaintenanceWindow, NotificationTopicArn, ParameterGroupName AvailabilityZones, SubnetGroupNames, SecurityGroups, PreferredMaintenanceWindow, NotificationTopicArn, ParameterGroupName
@ -206,14 +209,14 @@ class DAXBackend(BaseBackend):
self._tagger.tag_resource(cluster.arn, tags) self._tagger.tag_resource(cluster.arn, tags)
return cluster return cluster
def delete_cluster(self, cluster_name): def delete_cluster(self, cluster_name: str) -> DaxCluster:
if cluster_name not in self.clusters: if cluster_name not in self.clusters:
raise ClusterNotFoundFault() raise ClusterNotFoundFault()
self.clusters[cluster_name].delete() self.clusters[cluster_name].delete()
return self.clusters[cluster_name] return self.clusters[cluster_name]
@paginate(PAGINATION_MODEL) @paginate(PAGINATION_MODEL)
def describe_clusters(self, cluster_names): def describe_clusters(self, cluster_names: Iterable[str]) -> List[DaxCluster]: # type: ignore[misc]
clusters = self.clusters clusters = self.clusters
if not cluster_names: if not cluster_names:
cluster_names = clusters.keys() cluster_names = clusters.keys()
@ -229,7 +232,7 @@ class DAXBackend(BaseBackend):
raise ClusterNotFoundFault(name) raise ClusterNotFoundFault(name)
return [cluster for name, cluster in clusters.items() if name in cluster_names] return [cluster for name, cluster in clusters.items() if name in cluster_names]
def list_tags(self, resource_name): def list_tags(self, resource_name: str) -> Dict[str, List[Dict[str, str]]]:
""" """
Pagination is not yet implemented Pagination is not yet implemented
""" """
@ -239,7 +242,9 @@ class DAXBackend(BaseBackend):
raise ClusterNotFoundFault() raise ClusterNotFoundFault()
return self._tagger.list_tags_for_resource(self.clusters[name].arn) return self._tagger.list_tags_for_resource(self.clusters[name].arn)
def increase_replication_factor(self, cluster_name, new_replication_factor): def increase_replication_factor(
self, cluster_name: str, new_replication_factor: int
) -> DaxCluster:
""" """
The AvailabilityZones-parameter is not yet implemented The AvailabilityZones-parameter is not yet implemented
""" """
@ -250,10 +255,10 @@ class DAXBackend(BaseBackend):
def decrease_replication_factor( def decrease_replication_factor(
self, self,
cluster_name, cluster_name: str,
new_replication_factor, new_replication_factor: int,
node_ids_to_remove, node_ids_to_remove: List[str],
): ) -> DaxCluster:
""" """
The AvailabilityZones-parameter is not yet implemented The AvailabilityZones-parameter is not yet implemented
""" """

View File

@ -3,18 +3,18 @@ import re
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .exceptions import InvalidParameterValueException from .exceptions import InvalidParameterValueException
from .models import dax_backends from .models import dax_backends, DAXBackend
class DAXResponse(BaseResponse): class DAXResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="dax") super().__init__(service_name="dax")
@property @property
def dax_backend(self): def dax_backend(self) -> DAXBackend:
return dax_backends[self.current_account][self.region] return dax_backends[self.current_account][self.region]
def create_cluster(self): def create_cluster(self) -> str:
params = json.loads(self.body) params = json.loads(self.body)
cluster_name = params.get("ClusterName") cluster_name = params.get("ClusterName")
node_type = params.get("NodeType") node_type = params.get("NodeType")
@ -40,12 +40,12 @@ class DAXResponse(BaseResponse):
) )
return json.dumps(dict(Cluster=cluster.to_json())) return json.dumps(dict(Cluster=cluster.to_json()))
def delete_cluster(self): def delete_cluster(self) -> str:
cluster_name = json.loads(self.body).get("ClusterName") cluster_name = json.loads(self.body).get("ClusterName")
cluster = self.dax_backend.delete_cluster(cluster_name) cluster = self.dax_backend.delete_cluster(cluster_name)
return json.dumps(dict(Cluster=cluster.to_json())) return json.dumps(dict(Cluster=cluster.to_json()))
def describe_clusters(self): def describe_clusters(self) -> str:
params = json.loads(self.body) params = json.loads(self.body)
cluster_names = params.get("ClusterNames", []) cluster_names = params.get("ClusterNames", [])
max_results = params.get("MaxResults") max_results = params.get("MaxResults")
@ -61,7 +61,7 @@ class DAXResponse(BaseResponse):
{"Clusters": [c.to_json() for c in clusters], "NextToken": next_token} {"Clusters": [c.to_json() for c in clusters], "NextToken": next_token}
) )
def _validate_arn(self, arn): def _validate_arn(self, arn: str) -> None:
if not arn.startswith("arn:"): if not arn.startswith("arn:"):
raise InvalidParameterValueException(f"ARNs must start with 'arn:': {arn}") raise InvalidParameterValueException(f"ARNs must start with 'arn:': {arn}")
sections = arn.split(":") sections = arn.split(":")
@ -80,20 +80,20 @@ class DAXResponse(BaseResponse):
f"Fifth colon (namespace/relative-id delimiter) not found: {arn}" f"Fifth colon (namespace/relative-id delimiter) not found: {arn}"
) )
def _validate_name(self, name): def _validate_name(self, name: str) -> None:
msg = "Cluster ID specified is not a valid identifier. Identifiers must begin with a letter; must contain only ASCII letters, digits, and hyphens; and must not end with a hyphen or contain two consecutive hyphens." msg = "Cluster ID specified is not a valid identifier. Identifiers must begin with a letter; must contain only ASCII letters, digits, and hyphens; and must not end with a hyphen or contain two consecutive hyphens."
if not re.match("^[a-z][a-z0-9-]+[a-z0-9]$", name): if not re.match("^[a-z][a-z0-9-]+[a-z0-9]$", name):
raise InvalidParameterValueException(msg) raise InvalidParameterValueException(msg)
if "--" in name: if "--" in name:
raise InvalidParameterValueException(msg) raise InvalidParameterValueException(msg)
def list_tags(self): def list_tags(self) -> str:
params = json.loads(self.body) params = json.loads(self.body)
resource_name = params.get("ResourceName") resource_name = params.get("ResourceName")
tags = self.dax_backend.list_tags(resource_name=resource_name) tags = self.dax_backend.list_tags(resource_name=resource_name)
return json.dumps(tags) return json.dumps(tags)
def increase_replication_factor(self): def increase_replication_factor(self) -> str:
params = json.loads(self.body) params = json.loads(self.body)
cluster_name = params.get("ClusterName") cluster_name = params.get("ClusterName")
new_replication_factor = params.get("NewReplicationFactor") new_replication_factor = params.get("NewReplicationFactor")
@ -102,7 +102,7 @@ class DAXResponse(BaseResponse):
) )
return json.dumps({"Cluster": cluster.to_json()}) return json.dumps({"Cluster": cluster.to_json()})
def decrease_replication_factor(self): def decrease_replication_factor(self) -> str:
params = json.loads(self.body) params = json.loads(self.body)
cluster_name = params.get("ClusterName") cluster_name = params.get("ClusterName")
new_replication_factor = params.get("NewReplicationFactor") new_replication_factor = params.get("NewReplicationFactor")

View File

@ -18,7 +18,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy] [mypy]
files= moto/a*,moto/b*,moto/c*,moto/databrew,moto/datapipeline,moto/datasync,moto/moto_api files= moto/a*,moto/b*,moto/c*,moto/databrew,moto/datapipeline,moto/datasync,moto/dax,moto/moto_api
show_column_numbers=True show_column_numbers=True
show_error_codes = True show_error_codes = True
disable_error_code=abstract disable_error_code=abstract