moto/moto/dax/models.py

274 lines
9.8 KiB
Python
Raw Normal View History

2021-12-27 19:15:37 -01:00
"""DAXBackend class with methods for supported APIs."""
from moto.core import BaseBackend, BackendDict, BaseModel
from moto.core.utils import unix_time
from moto.moto_api import state_manager
from moto.moto_api._internal import mock_random as random
from moto.moto_api._internal.managed_state_model import ManagedState
2021-12-27 19:15:37 -01:00
from moto.utilities.tagging_service import TaggingService
from moto.utilities.paginator import paginate
2022-12-18 15:08:32 -01:00
from typing import Any, Dict, List, Iterable
2021-12-27 19:15:37 -01:00
from .exceptions import ClusterNotFoundFault
from .utils import PAGINATION_MODEL
class DaxParameterGroup(BaseModel):
2022-12-18 15:08:32 -01:00
def __init__(self) -> None:
2021-12-27 19:15:37 -01:00
self.name = "default.dax1.0"
self.status = "in-sync"
2022-12-18 15:08:32 -01:00
def to_json(self) -> Dict[str, Any]:
2021-12-27 19:15:37 -01:00
return {
"ParameterGroupName": self.name,
"ParameterApplyStatus": self.status,
"NodeIdsToReboot": [],
}
class DaxNode:
2022-12-18 15:08:32 -01:00
def __init__(self, endpoint: "DaxEndpoint", name: str, index: int):
2021-12-27 19:15:37 -01:00
self.node_id = f"{name}-{chr(ord('a')+index)}" # name-a, name-b, etc
self.node_endpoint = {
"Address": f"{self.node_id}.{endpoint.cluster_hex}.nodes.dax-clusters.{endpoint.region}.amazonaws.com",
"Port": endpoint.port,
}
self.create_time = unix_time()
# AWS spreads nodes across zones, i.e. three nodes will probably end up in us-east-1a, us-east-1b, us-east-1c
# For simplicity, we'll 'deploy' everything to us-east-1a
self.availability_zone = f"{endpoint.region}a"
self.status = "available"
self.parameter_status = "in-sync"
2022-12-18 15:08:32 -01:00
def to_json(self) -> Dict[str, Any]:
2021-12-27 19:15:37 -01:00
return {
"NodeId": self.node_id,
"Endpoint": self.node_endpoint,
"NodeCreateTime": self.create_time,
"AvailabilityZone": self.availability_zone,
"NodeStatus": self.status,
"ParameterGroupStatus": self.parameter_status,
}
class DaxEndpoint:
2022-12-18 15:08:32 -01:00
def __init__(self, name: str, cluster_hex: str, region: str):
2021-12-27 19:15:37 -01:00
self.name = name
self.cluster_hex = cluster_hex
self.region = region
self.port = 8111
2022-12-18 15:08:32 -01:00
def to_json(self, full: bool = False) -> Dict[str, Any]:
dct: Dict[str, Any] = {"Port": self.port}
2021-12-27 19:15:37 -01:00
if full:
dct[
"Address"
] = f"{self.name}.{self.cluster_hex}.dax-clusters.{self.region}.amazonaws.com"
dct["URL"] = f"dax://{dct['Address']}"
return dct
class DaxCluster(BaseModel, ManagedState):
2021-12-27 19:15:37 -01:00
def __init__(
self,
2022-12-18 15:08:32 -01:00
account_id: str,
region: str,
name: str,
description: str,
node_type: str,
replication_factor: int,
iam_role_arn: str,
sse_specification: Dict[str, Any],
encryption_type: str,
2021-12-27 19:15:37 -01:00
):
# Configure ManagedState
super().__init__(
model_name="dax::cluster",
transitions=[("creating", "available"), ("deleting", "deleted")],
)
# Set internal properties
2021-12-27 19:15:37 -01:00
self.name = name
self.description = description
2022-08-13 09:49:43 +00:00
self.arn = f"arn:aws:dax:{region}:{account_id}:cache/{self.name}"
2021-12-27 19:15:37 -01:00
self.node_type = node_type
self.replication_factor = replication_factor
self.cluster_hex = random.get_random_hex(6)
2021-12-27 19:15:37 -01:00
self.endpoint = DaxEndpoint(
name=name, cluster_hex=self.cluster_hex, region=region
)
self.nodes = [self._create_new_node(i) for i in range(0, replication_factor)]
self.preferred_maintenance_window = "thu:23:30-fri:00:30"
self.subnet_group = "default"
self.iam_role_arn = iam_role_arn
self.parameter_group = DaxParameterGroup()
self.security_groups = [
{
"SecurityGroupIdentifier": f"sg-{random.get_random_hex(10)}",
"Status": "active",
}
2021-12-27 19:15:37 -01:00
]
self.sse_specification = sse_specification
2022-05-19 11:08:02 +00:00
self.encryption_type = encryption_type
2021-12-27 19:15:37 -01:00
2022-12-18 15:08:32 -01:00
def _create_new_node(self, idx: int) -> DaxNode:
2021-12-27 19:15:37 -01:00
return DaxNode(endpoint=self.endpoint, name=self.name, index=idx)
2022-12-18 15:08:32 -01:00
def increase_replication_factor(self, new_replication_factor: int) -> None:
2021-12-27 19:15:37 -01:00
for idx in range(self.replication_factor, new_replication_factor):
self.nodes.append(self._create_new_node(idx))
self.replication_factor = new_replication_factor
2022-12-18 15:08:32 -01:00
def decrease_replication_factor(
self, new_replication_factor: int, node_ids_to_remove: List[str]
) -> None:
2021-12-27 19:15:37 -01:00
if node_ids_to_remove:
self.nodes = [n for n in self.nodes if n.node_id not in node_ids_to_remove]
else:
self.nodes = self.nodes[0:new_replication_factor]
self.replication_factor = new_replication_factor
2022-12-18 15:08:32 -01:00
def delete(self) -> None:
2021-12-27 19:15:37 -01:00
self.status = "deleting"
2022-12-18 15:08:32 -01:00
def is_deleted(self) -> bool:
2021-12-27 19:15:37 -01:00
return self.status == "deleted"
2022-12-18 15:08:32 -01:00
def to_json(self) -> Dict[str, Any]:
2021-12-27 19:15:37 -01:00
use_full_repr = self.status == "available"
dct = {
"ClusterName": self.name,
"Description": self.description,
"ClusterArn": self.arn,
"TotalNodes": self.replication_factor,
"ActiveNodes": 0,
"NodeType": self.node_type,
"Status": self.status,
"ClusterDiscoveryEndpoint": self.endpoint.to_json(use_full_repr),
"PreferredMaintenanceWindow": self.preferred_maintenance_window,
"SubnetGroup": self.subnet_group,
"IamRoleArn": self.iam_role_arn,
"ParameterGroup": self.parameter_group.to_json(),
"SSEDescription": {
"Status": "ENABLED"
if self.sse_specification.get("Enabled") is True
else "DISABLED"
},
2022-05-19 11:08:02 +00:00
"ClusterEndpointEncryptionType": self.encryption_type,
2021-12-27 19:15:37 -01:00
"SecurityGroups": self.security_groups,
}
if use_full_repr:
dct["Nodes"] = [n.to_json() for n in self.nodes]
return dct
class DAXBackend(BaseBackend):
2022-12-18 15:08:32 -01:00
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
2022-12-18 15:08:32 -01:00
self._clusters: Dict[str, DaxCluster] = dict()
2021-12-27 19:15:37 -01:00
self._tagger = TaggingService()
state_manager.register_default_transition(
model_name="dax::cluster", transition={"progression": "manual", "times": 4}
)
2021-12-27 19:15:37 -01:00
@property
2022-12-18 15:08:32 -01:00
def clusters(self) -> Dict[str, DaxCluster]:
2021-12-27 19:15:37 -01:00
self._clusters = {
name: cluster
for name, cluster in self._clusters.items()
if cluster.status != "deleted"
}
return self._clusters
def create_cluster(
self,
2022-12-18 15:08:32 -01:00
cluster_name: str,
node_type: str,
description: str,
replication_factor: int,
iam_role_arn: str,
tags: List[Dict[str, str]],
sse_specification: Dict[str, Any],
encryption_type: str,
) -> DaxCluster:
2021-12-27 19:15:37 -01:00
"""
The following parameters are not yet processed:
2022-05-19 11:08:02 +00:00
AvailabilityZones, SubnetGroupNames, SecurityGroups, PreferredMaintenanceWindow, NotificationTopicArn, ParameterGroupName
2021-12-27 19:15:37 -01:00
"""
cluster = DaxCluster(
2022-08-13 09:49:43 +00:00
account_id=self.account_id,
2021-12-27 19:15:37 -01:00
region=self.region_name,
name=cluster_name,
description=description,
node_type=node_type,
replication_factor=replication_factor,
iam_role_arn=iam_role_arn,
sse_specification=sse_specification,
2022-05-19 11:08:02 +00:00
encryption_type=encryption_type,
2021-12-27 19:15:37 -01:00
)
self.clusters[cluster_name] = cluster
self._tagger.tag_resource(cluster.arn, tags)
return cluster
2022-12-18 15:08:32 -01:00
def delete_cluster(self, cluster_name: str) -> DaxCluster:
2021-12-27 19:15:37 -01:00
if cluster_name not in self.clusters:
raise ClusterNotFoundFault()
self.clusters[cluster_name].delete()
return self.clusters[cluster_name]
@paginate(PAGINATION_MODEL)
2022-12-18 15:08:32 -01:00
def describe_clusters(self, cluster_names: Iterable[str]) -> List[DaxCluster]: # type: ignore[misc]
2021-12-27 19:15:37 -01:00
clusters = self.clusters
if not cluster_names:
cluster_names = clusters.keys()
for name in cluster_names:
if name in self.clusters:
self.clusters[name].advance()
# Clusters may have been deleted while advancing the states
clusters = self.clusters
for name in cluster_names:
if name not in self.clusters:
raise ClusterNotFoundFault(name)
return [cluster for name, cluster in clusters.items() if name in cluster_names]
2022-12-18 15:08:32 -01:00
def list_tags(self, resource_name: str) -> Dict[str, List[Dict[str, str]]]:
2021-12-27 19:15:37 -01:00
"""
Pagination is not yet implemented
"""
# resource_name can be the name, or the full ARN
name = resource_name.split("/")[-1]
if name not in self.clusters:
raise ClusterNotFoundFault()
return self._tagger.list_tags_for_resource(self.clusters[name].arn)
2022-12-18 15:08:32 -01:00
def increase_replication_factor(
self, cluster_name: str, new_replication_factor: int
) -> DaxCluster:
"""
The AvailabilityZones-parameter is not yet implemented
"""
2021-12-27 19:15:37 -01:00
if cluster_name not in self.clusters:
raise ClusterNotFoundFault()
self.clusters[cluster_name].increase_replication_factor(new_replication_factor)
return self.clusters[cluster_name]
def decrease_replication_factor(
self,
2022-12-18 15:08:32 -01:00
cluster_name: str,
new_replication_factor: int,
node_ids_to_remove: List[str],
) -> DaxCluster:
2021-12-27 19:15:37 -01:00
"""
The AvailabilityZones-parameter is not yet implemented
"""
if cluster_name not in self.clusters:
raise ClusterNotFoundFault()
self.clusters[cluster_name].decrease_replication_factor(
new_replication_factor, node_ids_to_remove
)
return self.clusters[cluster_name]
dax_backends = BackendDict(DAXBackend, "dax")