diff --git a/IMPLEMENTATION_COVERAGE.md b/IMPLEMENTATION_COVERAGE.md
index 6fabaae1b..6476033a2 100644
--- a/IMPLEMENTATION_COVERAGE.md
+++ b/IMPLEMENTATION_COVERAGE.md
@@ -987,6 +987,33 @@
- [ ] update_task_execution
+## dax
+
+28% implemented
+
+- [X] create_cluster
+- [ ] create_parameter_group
+- [ ] create_subnet_group
+- [X] decrease_replication_factor
+- [X] delete_cluster
+- [ ] delete_parameter_group
+- [ ] delete_subnet_group
+- [X] describe_clusters
+- [ ] describe_default_parameters
+- [ ] describe_events
+- [ ] describe_parameter_groups
+- [ ] describe_parameters
+- [ ] describe_subnet_groups
+- [X] increase_replication_factor
+- [X] list_tags
+- [ ] reboot_node
+- [ ] tag_resource
+- [ ] untag_resource
+- [ ] update_cluster
+- [ ] update_parameter_group
+- [ ] update_subnet_group
+
+
## dms
9% implemented
@@ -5088,7 +5115,6 @@
- customer-profiles
- databrew
- dataexchange
-- dax
- detective
- devicefarm
- devops-guru
diff --git a/docs/docs/services/dax.rst b/docs/docs/services/dax.rst
new file mode 100644
index 000000000..5e1f37cd2
--- /dev/null
+++ b/docs/docs/services/dax.rst
@@ -0,0 +1,62 @@
+.. _implementedservice_dax:
+
+.. |start-h3| raw:: html
+
+
+
+.. |end-h3| raw:: html
+
+
+
+===
+dax
+===
+
+|start-h3| Example usage |end-h3|
+
+.. sourcecode:: python
+
+ @mock_dax
+ def test_dax_behaviour:
+ boto3.client("dax")
+ ...
+
+
+
+|start-h3| Implemented features for this service |end-h3|
+
+- [X] create_cluster
+
+ The following parameters are not yet processed:
+ AvailabilityZones, SubnetGroupNames, SecurityGroups, PreferredMaintenanceWindow, NotificationTopicArn, ParameterGroupName, ClusterEndpointEncryptionType
+
+
+- [ ] create_parameter_group
+- [ ] create_subnet_group
+- [X] decrease_replication_factor
+
+ The AvailabilityZones-parameter is not yet implemented
+
+
+- [X] delete_cluster
+- [ ] delete_parameter_group
+- [ ] delete_subnet_group
+- [X] describe_clusters
+- [ ] describe_default_parameters
+- [ ] describe_events
+- [ ] describe_parameter_groups
+- [ ] describe_parameters
+- [ ] describe_subnet_groups
+- [X] increase_replication_factor
+- [X] list_tags
+
+ Pagination is not yet implemented
+
+
+- [ ] reboot_node
+- [ ] tag_resource
+- [ ] untag_resource
+- [ ] update_cluster
+- [ ] update_parameter_group
+- [ ] update_subnet_group
+
diff --git a/moto/__init__.py b/moto/__init__.py
index 0404c663f..d09b8949c 100644
--- a/moto/__init__.py
+++ b/moto/__init__.py
@@ -63,6 +63,7 @@ mock_datapipeline_deprecated = lazy_load(
".datapipeline", "mock_datapipeline_deprecated"
)
mock_datasync = lazy_load(".datasync", "mock_datasync")
+mock_dax = lazy_load(".dax", "mock_dax")
mock_dms = lazy_load(".dms", "mock_dms")
mock_ds = lazy_load(".ds", "mock_ds", boto3_name="ds")
mock_dynamodb = lazy_load(".dynamodb", "mock_dynamodb", warn_repurpose=True)
diff --git a/moto/backend_index.py b/moto/backend_index.py
index 71b294174..a37aee7cd 100644
--- a/moto/backend_index.py
+++ b/moto/backend_index.py
@@ -26,6 +26,7 @@ backend_url_patterns = [
("config", re.compile("https?://config\\.(.+)\\.amazonaws\\.com")),
("datapipeline", re.compile("https?://datapipeline\\.(.+)\\.amazonaws\\.com")),
("datasync", re.compile("https?://(.*\\.)?(datasync)\\.(.+)\\.amazonaws.com")),
+ ("dax", re.compile("https?://dax\\.(.+)\\.amazonaws\\.com")),
("dms", re.compile("https?://dms\\.(.+)\\.amazonaws\\.com")),
("ds", re.compile("https?://ds\\.(.+)\\.amazonaws\\.com")),
("dynamodb", re.compile("https?://dynamodb\\.(.+)\\.amazonaws\\.com")),
diff --git a/moto/dax/__init__.py b/moto/dax/__init__.py
new file mode 100644
index 000000000..21ac86e6e
--- /dev/null
+++ b/moto/dax/__init__.py
@@ -0,0 +1,5 @@
+"""dax module initialization; sets value for base decorator."""
+from .models import dax_backends
+from ..core.models import base_decorator
+
+mock_dax = base_decorator(dax_backends)
diff --git a/moto/dax/exceptions.py b/moto/dax/exceptions.py
new file mode 100644
index 000000000..ee2decbda
--- /dev/null
+++ b/moto/dax/exceptions.py
@@ -0,0 +1,13 @@
+from moto.core.exceptions import JsonRESTError
+
+
+class InvalidParameterValueException(JsonRESTError):
+ def __init__(self, message):
+ super().__init__("InvalidParameterValueException", message)
+
+
+class ClusterNotFoundFault(JsonRESTError):
+ def __init__(self, name=None):
+ # DescribeClusters and DeleteCluster use a different message for the same error
+ msg = f"Cluster {name} not found." if name else f"Cluster not found."
+ super().__init__("ClusterNotFoundFault", msg)
diff --git a/moto/dax/models.py b/moto/dax/models.py
new file mode 100644
index 000000000..9487ccfac
--- /dev/null
+++ b/moto/dax/models.py
@@ -0,0 +1,277 @@
+"""DAXBackend class with methods for supported APIs."""
+from moto.core import ACCOUNT_ID, BaseBackend, BaseModel
+from moto.core.utils import BackendDict, get_random_hex, unix_time
+from moto.utilities.tagging_service import TaggingService
+from moto.utilities.paginator import paginate
+
+from .exceptions import ClusterNotFoundFault
+from .utils import PAGINATION_MODEL
+
+
+class DaxParameterGroup(BaseModel):
+ def __init__(self):
+ self.name = "default.dax1.0"
+ self.status = "in-sync"
+
+ def to_json(self):
+ return {
+ "ParameterGroupName": self.name,
+ "ParameterApplyStatus": self.status,
+ "NodeIdsToReboot": [],
+ }
+
+
+class DaxNode:
+ def __init__(self, endpoint, name, index):
+ self.node_id = f"{name}-{chr(ord('a')+index)}" # name-a, name-b, etc
+ self.node_endpoint = {
+ "Address": f"{self.node_id}.{endpoint.cluster_hex}.nodes.dax-clusters.{endpoint.region}.amazonaws.com",
+ "Port": endpoint.port,
+ }
+ self.create_time = unix_time()
+ # AWS spreads nodes across zones, i.e. three nodes will probably end up in us-east-1a, us-east-1b, us-east-1c
+ # For simplicity, we'll 'deploy' everything to us-east-1a
+ self.availability_zone = f"{endpoint.region}a"
+ self.status = "available"
+ self.parameter_status = "in-sync"
+
+ def to_json(self):
+ return {
+ "NodeId": self.node_id,
+ "Endpoint": self.node_endpoint,
+ "NodeCreateTime": self.create_time,
+ "AvailabilityZone": self.availability_zone,
+ "NodeStatus": self.status,
+ "ParameterGroupStatus": self.parameter_status,
+ }
+
+
+class DaxEndpoint:
+ def __init__(self, name, cluster_hex, region):
+ self.name = name
+ self.cluster_hex = cluster_hex
+ self.region = region
+ self.port = 8111
+
+ def to_json(self, full=False):
+ dct = {"Port": self.port}
+ if full:
+ dct[
+ "Address"
+ ] = f"{self.name}.{self.cluster_hex}.dax-clusters.{self.region}.amazonaws.com"
+ dct["URL"] = f"dax://{dct['Address']}"
+ return dct
+
+
+class DaxCluster(BaseModel):
+ def __init__(
+ self,
+ region,
+ name,
+ description,
+ node_type,
+ replication_factor,
+ iam_role_arn,
+ sse_specification,
+ ):
+ self.name = name
+ self.description = description
+ self.arn = f"arn:aws:dax:{region}:{ACCOUNT_ID}:cache/{self.name}"
+ self.node_type = node_type
+ self.replication_factor = replication_factor
+ self.status = "creating"
+ self.cluster_hex = get_random_hex(6)
+ self.endpoint = DaxEndpoint(
+ name=name, cluster_hex=self.cluster_hex, region=region
+ )
+ self.nodes = [self._create_new_node(i) for i in range(0, replication_factor)]
+ self.preferred_maintenance_window = "thu:23:30-fri:00:30"
+ self.subnet_group = "default"
+ self.iam_role_arn = iam_role_arn
+ self.parameter_group = DaxParameterGroup()
+ self.security_groups = [
+ {"SecurityGroupIdentifier": f"sg-{get_random_hex(10)}", "Status": "active"}
+ ]
+ self.sse_specification = sse_specification
+ print(sse_specification)
+
+ # Internal counter to keep track of when this cluster is available/deleted
+ # Used in conjunction with `advance()`
+ self._tick = 0
+
+ def _create_new_node(self, idx):
+ return DaxNode(endpoint=self.endpoint, name=self.name, index=idx)
+
+ def increase_replication_factor(self, new_replication_factor):
+ for idx in range(self.replication_factor, new_replication_factor):
+ self.nodes.append(self._create_new_node(idx))
+ self.replication_factor = new_replication_factor
+
+ def decrease_replication_factor(self, new_replication_factor, node_ids_to_remove):
+ if node_ids_to_remove:
+ self.nodes = [n for n in self.nodes if n.node_id not in node_ids_to_remove]
+ else:
+ self.nodes = self.nodes[0:new_replication_factor]
+ self.replication_factor = new_replication_factor
+
+ def delete(self):
+ self.status = "deleting"
+
+ def is_deleted(self):
+ return self.status == "deleted"
+
+ def advance(self):
+ if self.status == "creating":
+ if self._tick < 3:
+ self._tick += 1
+ else:
+ self.status = "available"
+ self._tick = 0
+ if self.status == "deleting":
+ if self._tick < 3:
+ self._tick += 1
+ else:
+ self.status = "deleted"
+
+ def to_json(self):
+ use_full_repr = self.status == "available"
+ dct = {
+ "ClusterName": self.name,
+ "Description": self.description,
+ "ClusterArn": self.arn,
+ "TotalNodes": self.replication_factor,
+ "ActiveNodes": 0,
+ "NodeType": self.node_type,
+ "Status": self.status,
+ "ClusterDiscoveryEndpoint": self.endpoint.to_json(use_full_repr),
+ "PreferredMaintenanceWindow": self.preferred_maintenance_window,
+ "SubnetGroup": self.subnet_group,
+ "IamRoleArn": self.iam_role_arn,
+ "ParameterGroup": self.parameter_group.to_json(),
+ "SSEDescription": {
+ "Status": "ENABLED"
+ if self.sse_specification.get("Enabled") is True
+ else "DISABLED"
+ },
+ "ClusterEndpointEncryptionType": "NONE",
+ "SecurityGroups": self.security_groups,
+ }
+ if use_full_repr:
+ dct["Nodes"] = [n.to_json() for n in self.nodes]
+ return dct
+
+
+class DAXBackend(BaseBackend):
+ def __init__(self, region_name):
+ self.region_name = region_name
+ self._clusters = dict()
+ self._tagger = TaggingService()
+
+ @property
+ def clusters(self):
+ self._clusters = {
+ name: cluster
+ for name, cluster in self._clusters.items()
+ if cluster.status != "deleted"
+ }
+ return self._clusters
+
+ def reset(self):
+ region_name = self.region_name
+ self.__dict__ = {}
+ self.__init__(region_name)
+
+ def create_cluster(
+ self,
+ cluster_name,
+ node_type,
+ description,
+ replication_factor,
+ availability_zones,
+ subnet_group_name,
+ security_group_ids,
+ preferred_maintenance_window,
+ notification_topic_arn,
+ iam_role_arn,
+ parameter_group_name,
+ tags,
+ sse_specification,
+ cluster_endpoint_encryption_type,
+ ):
+ """
+ The following parameters are not yet processed:
+ AvailabilityZones, SubnetGroupNames, SecurityGroups, PreferredMaintenanceWindow, NotificationTopicArn, ParameterGroupName, ClusterEndpointEncryptionType
+ """
+ cluster = DaxCluster(
+ region=self.region_name,
+ name=cluster_name,
+ description=description,
+ node_type=node_type,
+ replication_factor=replication_factor,
+ iam_role_arn=iam_role_arn,
+ sse_specification=sse_specification,
+ )
+ self.clusters[cluster_name] = cluster
+ self._tagger.tag_resource(cluster.arn, tags)
+ return cluster
+
+ def delete_cluster(self, cluster_name):
+ if cluster_name not in self.clusters:
+ raise ClusterNotFoundFault()
+ self.clusters[cluster_name].delete()
+ return self.clusters[cluster_name]
+
+ @paginate(PAGINATION_MODEL)
+ def describe_clusters(self, cluster_names):
+ clusters = self.clusters
+ if not cluster_names:
+ cluster_names = clusters.keys()
+
+ for name in cluster_names:
+ if name in self.clusters:
+ self.clusters[name].advance()
+
+ # Clusters may have been deleted while advancing the states
+ clusters = self.clusters
+ for name in cluster_names:
+ if name not in self.clusters:
+ raise ClusterNotFoundFault(name)
+ return [cluster for name, cluster in clusters.items() if name in cluster_names]
+
+ def list_tags(self, resource_name):
+ """
+ Pagination is not yet implemented
+ """
+ # resource_name can be the name, or the full ARN
+ name = resource_name.split("/")[-1]
+ if name not in self.clusters:
+ raise ClusterNotFoundFault()
+ return self._tagger.list_tags_for_resource(self.clusters[name].arn)
+
+ def increase_replication_factor(
+ self, cluster_name, new_replication_factor, availability_zones
+ ):
+ if cluster_name not in self.clusters:
+ raise ClusterNotFoundFault()
+ self.clusters[cluster_name].increase_replication_factor(new_replication_factor)
+ return self.clusters[cluster_name]
+
+ def decrease_replication_factor(
+ self,
+ cluster_name,
+ new_replication_factor,
+ availability_zones,
+ node_ids_to_remove,
+ ):
+ """
+ The AvailabilityZones-parameter is not yet implemented
+ """
+ if cluster_name not in self.clusters:
+ raise ClusterNotFoundFault()
+ self.clusters[cluster_name].decrease_replication_factor(
+ new_replication_factor, node_ids_to_remove
+ )
+ return self.clusters[cluster_name]
+
+
+dax_backends = BackendDict(DAXBackend, "dax")
diff --git a/moto/dax/responses.py b/moto/dax/responses.py
new file mode 100644
index 000000000..5004df67d
--- /dev/null
+++ b/moto/dax/responses.py
@@ -0,0 +1,129 @@
+import json
+import re
+
+from moto.core.responses import BaseResponse
+from .exceptions import InvalidParameterValueException
+from .models import dax_backends
+
+
+class DAXResponse(BaseResponse):
+ @property
+ def dax_backend(self):
+ return dax_backends[self.region]
+
+ def create_cluster(self):
+ params = json.loads(self.body)
+ cluster_name = params.get("ClusterName")
+ node_type = params.get("NodeType")
+ description = params.get("Description")
+ replication_factor = params.get("ReplicationFactor")
+ availability_zones = params.get("AvailabilityZones")
+ subnet_group_name = params.get("SubnetGroupName")
+ security_group_ids = params.get("SecurityGroupIds")
+ preferred_maintenance_window = params.get("PreferredMaintenanceWindow")
+ notification_topic_arn = params.get("NotificationTopicArn")
+ iam_role_arn = params.get("IamRoleArn")
+ parameter_group_name = params.get("ParameterGroupName")
+ tags = params.get("Tags", [])
+ sse_specification = params.get("SSESpecification", {})
+ cluster_endpoint_encryption_type = params.get("ClusterEndpointEncryptionType")
+
+ self._validate_arn(iam_role_arn)
+ self._validate_name(cluster_name)
+
+ cluster = self.dax_backend.create_cluster(
+ cluster_name=cluster_name,
+ node_type=node_type,
+ description=description,
+ replication_factor=replication_factor,
+ availability_zones=availability_zones,
+ subnet_group_name=subnet_group_name,
+ security_group_ids=security_group_ids,
+ preferred_maintenance_window=preferred_maintenance_window,
+ notification_topic_arn=notification_topic_arn,
+ iam_role_arn=iam_role_arn,
+ parameter_group_name=parameter_group_name,
+ tags=tags,
+ sse_specification=sse_specification,
+ cluster_endpoint_encryption_type=cluster_endpoint_encryption_type,
+ )
+ return json.dumps(dict(Cluster=cluster.to_json()))
+
+ def delete_cluster(self):
+ cluster_name = json.loads(self.body).get("ClusterName")
+ cluster = self.dax_backend.delete_cluster(cluster_name=cluster_name,)
+ return json.dumps(dict(Cluster=cluster.to_json()))
+
+ def describe_clusters(self):
+ params = json.loads(self.body)
+ cluster_names = params.get("ClusterNames", [])
+ max_results = params.get("MaxResults")
+ next_token = params.get("NextToken")
+
+ for name in cluster_names:
+ self._validate_name(name)
+
+ clusters, next_token = self.dax_backend.describe_clusters(
+ cluster_names=cluster_names, max_results=max_results, next_token=next_token
+ )
+ return json.dumps(
+ {"Clusters": [c.to_json() for c in clusters], "NextToken": next_token}
+ )
+
+ def _validate_arn(self, arn):
+ if not arn.startswith("arn:"):
+ raise InvalidParameterValueException(f"ARNs must start with 'arn:': {arn}")
+ sections = arn.split(":")
+ if len(sections) < 3:
+ raise InvalidParameterValueException(
+ f"Second colon partition not found: {arn}"
+ )
+ if len(sections) < 4:
+ raise InvalidParameterValueException(f"Third colon vendor not found: {arn}")
+ if len(sections) < 5:
+ raise InvalidParameterValueException(
+ f"Fourth colon (region/namespace delimiter) not found: {arn}"
+ )
+ if len(sections) < 6:
+ raise InvalidParameterValueException(
+ f"Fifth colon (namespace/relative-id delimiter) not found: {arn}"
+ )
+
+ def _validate_name(self, name):
+ msg = "Cluster ID specified is not a valid identifier. Identifiers must begin with a letter; must contain only ASCII letters, digits, and hyphens; and must not end with a hyphen or contain two consecutive hyphens."
+ if not re.match("^[a-z][a-z0-9-]+[a-z0-9]$", name):
+ raise InvalidParameterValueException(msg)
+ if "--" in name:
+ raise InvalidParameterValueException(msg)
+
+ def list_tags(self):
+ params = json.loads(self.body)
+ resource_name = params.get("ResourceName")
+ tags = self.dax_backend.list_tags(resource_name=resource_name)
+ return json.dumps(tags)
+
+ def increase_replication_factor(self):
+ params = json.loads(self.body)
+ cluster_name = params.get("ClusterName")
+ new_replication_factor = params.get("NewReplicationFactor")
+ availability_zones = params.get("AvailabilityZones")
+ cluster = self.dax_backend.increase_replication_factor(
+ cluster_name=cluster_name,
+ new_replication_factor=new_replication_factor,
+ availability_zones=availability_zones,
+ )
+ return json.dumps({"Cluster": cluster.to_json()})
+
+ def decrease_replication_factor(self):
+ params = json.loads(self.body)
+ cluster_name = params.get("ClusterName")
+ new_replication_factor = params.get("NewReplicationFactor")
+ availability_zones = params.get("AvailabilityZones")
+ node_ids_to_remove = params.get("NodeIdsToRemove")
+ cluster = self.dax_backend.decrease_replication_factor(
+ cluster_name=cluster_name,
+ new_replication_factor=new_replication_factor,
+ availability_zones=availability_zones,
+ node_ids_to_remove=node_ids_to_remove,
+ )
+ return json.dumps({"Cluster": cluster.to_json()})
diff --git a/moto/dax/urls.py b/moto/dax/urls.py
new file mode 100644
index 000000000..ed4fe11bc
--- /dev/null
+++ b/moto/dax/urls.py
@@ -0,0 +1,11 @@
+"""dax base URL and path."""
+from .responses import DAXResponse
+
+url_bases = [
+ r"https?://dax\.(.+)\.amazonaws\.com",
+]
+
+
+url_paths = {
+ "{0}/$": DAXResponse.dispatch,
+}
diff --git a/moto/dax/utils.py b/moto/dax/utils.py
new file mode 100644
index 000000000..f4a27d230
--- /dev/null
+++ b/moto/dax/utils.py
@@ -0,0 +1,8 @@
+PAGINATION_MODEL = {
+ "describe_clusters": {
+ "input_token": "next_token",
+ "limit_key": "max_results",
+ "limit_default": 100,
+ "unique_attribute": "arn",
+ },
+}
diff --git a/tests/terraform-tests.success.txt b/tests/terraform-tests.success.txt
index c9139b413..bf3b3b636 100644
--- a/tests/terraform-tests.success.txt
+++ b/tests/terraform-tests.success.txt
@@ -14,6 +14,7 @@ TestAccAWSCloudWatchEventRule
TestAccAWSCloudWatchEventTarget_ssmDocument
TestAccAWSCloudwatchLogGroupDataSource
TestAccAWSCloudWatchMetricAlarm
+TestAccAWSDAX
TestAccAWSDataSourceCloudwatch
TestAccAWSDataSourceElasticBeanstalkHostedZone
TestAccAWSDataSourceIAMGroup
diff --git a/tests/test_dax/__init__.py b/tests/test_dax/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/test_dax/test_dax.py b/tests/test_dax/test_dax.py
new file mode 100644
index 000000000..0203837cc
--- /dev/null
+++ b/tests/test_dax/test_dax.py
@@ -0,0 +1,537 @@
+"""Unit tests for dax-supported APIs."""
+import boto3
+import pytest
+import sure # noqa # pylint: disable=unused-import
+
+from botocore.exceptions import ClientError
+from moto import mock_dax
+from moto.core import ACCOUNT_ID
+
+# See our Development Tips on writing tests for hints on how to write good tests:
+# http://docs.getmoto.org/en/latest/docs/contributing/development_tips/tests.html
+
+
+@mock_dax
+def test_create_cluster_minimal():
+ client = boto3.client("dax", region_name="us-east-2")
+ iam_role_arn = f"arn:aws:iam::{ACCOUNT_ID}:role/aws-service-role/dax.amazonaws.com/AWSServiceRoleForDAX"
+ created_cluster = client.create_cluster(
+ ClusterName="daxcluster",
+ NodeType="dax.t3.small",
+ ReplicationFactor=3,
+ IamRoleArn=iam_role_arn,
+ )["Cluster"]
+
+ described_cluster = client.describe_clusters(ClusterNames=["daxcluster"])[
+ "Clusters"
+ ][0]
+
+ for cluster in [created_cluster, described_cluster]:
+ cluster["ClusterName"].should.equal("daxcluster")
+ cluster["ClusterArn"].should.equal(
+ f"arn:aws:dax:us-east-2:{ACCOUNT_ID}:cache/daxcluster"
+ )
+ cluster["TotalNodes"].should.equal(3)
+ cluster["ActiveNodes"].should.equal(0)
+ cluster["NodeType"].should.equal("dax.t3.small")
+ cluster["Status"].should.equal("creating")
+ cluster["ClusterDiscoveryEndpoint"].should.equal({"Port": 8111})
+ cluster["PreferredMaintenanceWindow"].should.equal("thu:23:30-fri:00:30")
+ cluster["SubnetGroup"].should.equal("default")
+ cluster["SecurityGroups"].should.have.length_of(1)
+ cluster["IamRoleArn"].should.equal(iam_role_arn)
+ cluster.should.have.key("ParameterGroup")
+ cluster["ParameterGroup"].should.have.key("ParameterGroupName").equals(
+ "default.dax1.0"
+ )
+ cluster["SSEDescription"].should.equal({"Status": "DISABLED"})
+ cluster.should.have.key("ClusterEndpointEncryptionType").equals("NONE")
+
+
+@mock_dax
+def test_create_cluster_description():
+ client = boto3.client("dax", region_name="us-east-2")
+ iam_role_arn = f"arn:aws:iam::{ACCOUNT_ID}:role/aws-service-role/dax.amazonaws.com/AWSServiceRoleForDAX"
+ created_cluster = client.create_cluster(
+ ClusterName="daxcluster",
+ Description="my cluster",
+ NodeType="dax.t3.small",
+ ReplicationFactor=3,
+ IamRoleArn=iam_role_arn,
+ )["Cluster"]
+
+ described_cluster = client.describe_clusters(ClusterNames=["daxcluster"])[
+ "Clusters"
+ ][0]
+
+ for cluster in [created_cluster, described_cluster]:
+ cluster["ClusterName"].should.equal("daxcluster")
+ cluster["Description"].should.equal("my cluster")
+
+
+@mock_dax
+def test_create_cluster_with_sse_enabled():
+ client = boto3.client("dax", region_name="us-east-2")
+ iam_role_arn = f"arn:aws:iam::{ACCOUNT_ID}:role/aws-service-role/dax.amazonaws.com/AWSServiceRoleForDAX"
+ created_cluster = client.create_cluster(
+ ClusterName="daxcluster",
+ NodeType="dax.t3.small",
+ ReplicationFactor=3,
+ IamRoleArn=iam_role_arn,
+ SSESpecification={"Enabled": True},
+ )["Cluster"]
+
+ described_cluster = client.describe_clusters(ClusterNames=["daxcluster"])[
+ "Clusters"
+ ][0]
+
+ for cluster in [created_cluster, described_cluster]:
+ cluster["ClusterName"].should.equal("daxcluster")
+ cluster["SSEDescription"].should.equal({"Status": "ENABLED"})
+
+
+@mock_dax
+def test_create_cluster_invalid_arn():
+ client = boto3.client("dax", region_name="eu-west-1")
+ with pytest.raises(ClientError) as exc:
+ client.create_cluster(
+ ClusterName="1invalid",
+ NodeType="dax.t3.small",
+ ReplicationFactor=3,
+ IamRoleArn="n/a",
+ )
+ err = exc.value.response["Error"]
+ err["Code"].should.equal("InvalidParameterValueException")
+ err["Message"].should.equal("ARNs must start with 'arn:': n/a")
+
+
+@mock_dax
+def test_create_cluster_invalid_arn_no_partition():
+ client = boto3.client("dax", region_name="eu-west-1")
+ with pytest.raises(ClientError) as exc:
+ client.create_cluster(
+ ClusterName="1invalid",
+ NodeType="dax.t3.small",
+ ReplicationFactor=3,
+ IamRoleArn="arn:sth",
+ )
+ err = exc.value.response["Error"]
+ err["Code"].should.equal("InvalidParameterValueException")
+ err["Message"].should.equal("Second colon partition not found: arn:sth")
+
+
+@mock_dax
+def test_create_cluster_invalid_arn_no_vendor():
+ client = boto3.client("dax", region_name="eu-west-1")
+ with pytest.raises(ClientError) as exc:
+ client.create_cluster(
+ ClusterName="1invalid",
+ NodeType="dax.t3.small",
+ ReplicationFactor=3,
+ IamRoleArn="arn:sth:aws",
+ )
+ err = exc.value.response["Error"]
+ err["Code"].should.equal("InvalidParameterValueException")
+ err["Message"].should.equal("Third colon vendor not found: arn:sth:aws")
+
+
+@mock_dax
+def test_create_cluster_invalid_arn_no_region():
+ client = boto3.client("dax", region_name="eu-west-1")
+ with pytest.raises(ClientError) as exc:
+ client.create_cluster(
+ ClusterName="1invalid",
+ NodeType="dax.t3.small",
+ ReplicationFactor=3,
+ IamRoleArn="arn:sth:aws:else",
+ )
+ err = exc.value.response["Error"]
+ err["Code"].should.equal("InvalidParameterValueException")
+ err["Message"].should.equal(
+ "Fourth colon (region/namespace delimiter) not found: arn:sth:aws:else"
+ )
+
+
+@mock_dax
+def test_create_cluster_invalid_arn_no_namespace():
+ client = boto3.client("dax", region_name="eu-west-1")
+ with pytest.raises(ClientError) as exc:
+ client.create_cluster(
+ ClusterName="1invalid",
+ NodeType="dax.t3.small",
+ ReplicationFactor=3,
+ IamRoleArn="arn:sth:aws:else:eu-west-1",
+ )
+ err = exc.value.response["Error"]
+ err["Code"].should.equal("InvalidParameterValueException")
+ err["Message"].should.equal(
+ "Fifth colon (namespace/relative-id delimiter) not found: arn:sth:aws:else:eu-west-1"
+ )
+
+
+@mock_dax
+@pytest.mark.parametrize(
+ "name", ["1invalid", "iИvalid", "in_valid", "invalid-", "in--valid"]
+)
+def test_create_cluster_invalid_name(name):
+ client = boto3.client("dax", region_name="eu-west-1")
+
+ with pytest.raises(ClientError) as exc:
+ client.create_cluster(
+ ClusterName=name,
+ NodeType="dax.t3.small",
+ ReplicationFactor=3,
+ IamRoleArn="arn:aws:iam::486285699788:role/apigatewayrole",
+ )
+ err = exc.value.response["Error"]
+ err["Code"].should.equal("InvalidParameterValueException")
+ err["Message"].should.equal(
+ "Cluster ID specified is not a valid identifier. Identifiers must begin with a letter; must contain only ASCII letters, digits, and hyphens; and must not end with a hyphen or contain two consecutive hyphens."
+ )
+
+
+@mock_dax
+@pytest.mark.parametrize(
+ "name", ["1invalid", "iИvalid", "in_valid", "invalid-", "in--valid"]
+)
+def test_describe_clusters_invalid_name(name):
+ client = boto3.client("dax", region_name="eu-west-1")
+
+ with pytest.raises(ClientError) as exc:
+ client.describe_clusters(ClusterNames=[name])
+ err = exc.value.response["Error"]
+ err["Code"].should.equal("InvalidParameterValueException")
+ err["Message"].should.equal(
+ "Cluster ID specified is not a valid identifier. Identifiers must begin with a letter; must contain only ASCII letters, digits, and hyphens; and must not end with a hyphen or contain two consecutive hyphens."
+ )
+
+
+@mock_dax
+def test_delete_cluster_unknown():
+ client = boto3.client("dax", region_name="eu-west-1")
+
+ with pytest.raises(ClientError) as exc:
+ client.delete_cluster(ClusterName="unknown")
+
+ err = exc.value.response["Error"]
+ err["Code"].should.equals("ClusterNotFoundFault")
+ err["Message"].should.equal("Cluster not found.")
+
+
+@mock_dax
+def test_delete_cluster():
+ client = boto3.client("dax", region_name="eu-west-1")
+
+ iam_role_arn = f"arn:aws:iam::{ACCOUNT_ID}:role/aws-service-role/dax.amazonaws.com/AWSServiceRoleForDAX"
+ client.create_cluster(
+ ClusterName="daxcluster",
+ NodeType="dax.t3.small",
+ ReplicationFactor=2,
+ IamRoleArn=iam_role_arn,
+ )
+
+ client.delete_cluster(ClusterName="daxcluster")
+
+ for _ in range(0, 3):
+ # Cluster takes a while to delete...
+ cluster = client.describe_clusters(ClusterNames=["daxcluster"])["Clusters"][0]
+ cluster["Status"].should.equal("deleting")
+ cluster["TotalNodes"].should.equal(2)
+ cluster["ActiveNodes"].should.equal(0)
+ cluster.shouldnt.have.key("Nodes")
+
+ with pytest.raises(ClientError) as exc:
+ client.describe_clusters(ClusterNames=["daxcluster"])
+ err = exc.value.response["Error"]
+ err["Code"].should.equal("ClusterNotFoundFault")
+
+
+@mock_dax
+def test_describe_cluster_unknown():
+ client = boto3.client("dax", region_name="eu-west-1")
+
+ with pytest.raises(ClientError) as exc:
+ client.describe_clusters(ClusterNames=["unknown"])
+ err = exc.value.response["Error"]
+ err["Code"].should.equal("ClusterNotFoundFault")
+ err["Message"].should.equal("Cluster unknown not found.")
+
+
+@mock_dax
+def test_describe_clusters_returns_all():
+ client = boto3.client("dax", region_name="us-east-1")
+ iam_role_arn = f"arn:aws:iam::{ACCOUNT_ID}:role/aws-service-role/dax.amazonaws.com/AWSServiceRoleForDAX"
+ for i in range(0, 50):
+ client.create_cluster(
+ ClusterName=f"daxcluster{i}",
+ NodeType="dax.t3.small",
+ ReplicationFactor=1,
+ IamRoleArn=iam_role_arn,
+ )
+
+ clusters = client.describe_clusters()["Clusters"]
+ clusters.should.have.length_of(50)
+
+
+@mock_dax
+def test_describe_clusters_paginates():
+ client = boto3.client("dax", region_name="us-east-1")
+ iam_role_arn = f"arn:aws:iam::{ACCOUNT_ID}:role/aws-service-role/dax.amazonaws.com/AWSServiceRoleForDAX"
+ for i in range(0, 50):
+ client.create_cluster(
+ ClusterName=f"daxcluster{i}",
+ NodeType="dax.t3.small",
+ ReplicationFactor=1,
+ IamRoleArn=iam_role_arn,
+ )
+
+ resp = client.describe_clusters(MaxResults=10)
+ resp["Clusters"].should.have.length_of(10)
+ resp.should.have.key("NextToken")
+
+ resp = client.describe_clusters(MaxResults=10, NextToken=resp["NextToken"])
+ resp["Clusters"].should.have.length_of(10)
+ resp.should.have.key("NextToken")
+
+ resp = client.describe_clusters(NextToken=resp["NextToken"])
+ resp["Clusters"].should.have.length_of(30)
+ resp.shouldnt.have.key("NextToken")
+
+
+@mock_dax
+def test_describe_clusters_returns_nodes_after_some_time():
+ client = boto3.client("dax", region_name="us-east-2")
+ client.create_cluster(
+ ClusterName="daxcluster",
+ NodeType="dax.t3.small",
+ ReplicationFactor=3,
+ IamRoleArn=f"arn:aws:iam::{ACCOUNT_ID}:role/aws-service-role/dax.amazonaws.com/AWSServiceRoleForDAX",
+ )["Cluster"]
+
+ for _ in range(0, 3):
+ # Cluster takes a while to load...
+ cluster = client.describe_clusters(ClusterNames=["daxcluster"])["Clusters"][0]
+ cluster["Status"].should.equal("creating")
+ cluster.shouldnt.have.key("Nodes")
+
+ # Finished loading by now
+ cluster = client.describe_clusters(ClusterNames=["daxcluster"])["Clusters"][0]
+
+ cluster["ClusterName"].should.equal("daxcluster")
+ cluster["ClusterArn"].should.equal(
+ f"arn:aws:dax:us-east-2:{ACCOUNT_ID}:cache/daxcluster"
+ )
+ cluster["TotalNodes"].should.equal(3)
+ cluster["ActiveNodes"].should.equal(0)
+ cluster["NodeType"].should.equal("dax.t3.small")
+ cluster["Status"].should.equal("available")
+
+ # Address Info is only available when the cluster is ready
+ cluster.should.have.key("ClusterDiscoveryEndpoint")
+ endpoint = cluster["ClusterDiscoveryEndpoint"]
+ endpoint.should.have.key("Address")
+ address = endpoint["Address"]
+ cluster_id = address.split(".")[1]
+ address.should.equal(
+ f"daxcluster.{cluster_id}.dax-clusters.us-east-2.amazonaws.com"
+ )
+ endpoint.should.have.key("Port").equal(8111)
+ endpoint.should.have.key("URL").equal(f"dax://{address}")
+
+ # Nodes are only shown when the cluster is ready
+ cluster.should.have.key("Nodes").length_of(3)
+ for idx, a in enumerate(["a", "b", "c"]):
+ node = cluster["Nodes"][idx]
+ node.should.have.key("NodeId").equals(f"daxcluster-{a}")
+ node.should.have.key("Endpoint")
+ node_address = (
+ f"daxcluster-{a}.{cluster_id}.nodes.dax-clusters.us-east-2.amazonaws.com"
+ )
+ node["Endpoint"].should.have.key("Address").equals(node_address)
+ node["Endpoint"].should.have.key("Port").equals(8111)
+ node.should.contain("AvailabilityZone")
+ node.should.have.key("NodeStatus").equals("available")
+ node.should.have.key("ParameterGroupStatus").equals("in-sync")
+
+ cluster["PreferredMaintenanceWindow"].should.equal("thu:23:30-fri:00:30")
+ cluster["SubnetGroup"].should.equal("default")
+ cluster["SecurityGroups"].should.have.length_of(1)
+ cluster.should.have.key("ParameterGroup")
+ cluster["ParameterGroup"].should.have.key("ParameterGroupName").equals(
+ "default.dax1.0"
+ )
+ cluster["SSEDescription"].should.equal({"Status": "DISABLED"})
+ cluster.should.have.key("ClusterEndpointEncryptionType").equals("NONE")
+
+
+@mock_dax
+def test_list_tags_unknown():
+ client = boto3.client("dax", region_name="ap-southeast-1")
+
+ with pytest.raises(ClientError) as exc:
+ client.list_tags(ResourceName="unknown")
+
+ err = exc.value.response["Error"]
+ err["Code"].should.equal("ClusterNotFoundFault")
+
+
+@mock_dax
+def test_list_tags():
+ client = boto3.client("dax", region_name="ap-southeast-1")
+
+ cluster = client.create_cluster(
+ ClusterName="daxcluster",
+ NodeType="dax.t3.small",
+ ReplicationFactor=3,
+ IamRoleArn=f"arn:aws:iam::{ACCOUNT_ID}:role/aws-service-role/dax.amazonaws.com/AWSServiceRoleForDAX",
+ Tags=[
+ {"Key": "tag1", "Value": "value1"},
+ {"Key": "tag2", "Value": "value2"},
+ {"Key": "tag3", "Value": "value3"},
+ ],
+ )["Cluster"]
+
+ for name in ["daxcluster", cluster["ClusterArn"]]:
+ resp = client.list_tags(ResourceName=name)
+
+ resp.shouldnt.have.key("NextToken")
+ resp.should.have.key("Tags").equals(
+ [
+ {"Key": "tag1", "Value": "value1"},
+ {"Key": "tag2", "Value": "value2"},
+ {"Key": "tag3", "Value": "value3"},
+ ]
+ )
+
+
+@mock_dax
+def test_increase_replication_factor_unknown():
+ client = boto3.client("dax", region_name="ap-southeast-1")
+
+ with pytest.raises(ClientError) as exc:
+ client.increase_replication_factor(
+ ClusterName="unknown", NewReplicationFactor=2
+ )
+
+ err = exc.value.response["Error"]
+ err["Code"].should.equal("ClusterNotFoundFault")
+
+
+@mock_dax
+def test_increase_replication_factor():
+ client = boto3.client("dax", region_name="ap-southeast-1")
+
+ name = "daxcluster"
+ cluster = client.create_cluster(
+ ClusterName=name,
+ NodeType="dax.t3.small",
+ ReplicationFactor=2,
+ IamRoleArn=f"arn:aws:iam::{ACCOUNT_ID}:role/aws-service-role/dax.amazonaws.com/AWSServiceRoleForDAX",
+ Tags=[
+ {"Key": "tag1", "Value": "value1"},
+ {"Key": "tag2", "Value": "value2"},
+ {"Key": "tag3", "Value": "value3"},
+ ],
+ )["Cluster"]
+ cluster["TotalNodes"].should.equal(2)
+
+ new_cluster = client.increase_replication_factor(
+ ClusterName=name, NewReplicationFactor=5
+ )["Cluster"]
+ new_cluster["TotalNodes"].should.equal(5)
+
+ new_cluster = client.describe_clusters(ClusterNames=[name])["Clusters"][0]
+ new_cluster["TotalNodes"].should.equal(5)
+
+ # Progress cluster until it's available
+ client.describe_clusters(ClusterNames=[name])["Clusters"][0]
+ client.describe_clusters(ClusterNames=[name])["Clusters"][0]
+
+ cluster = client.describe_clusters(ClusterNames=[name])["Clusters"][0]
+ node_ids = set([n["NodeId"] for n in cluster["Nodes"]])
+ node_ids.should.equal(
+ {f"{name}-a", f"{name}-b", f"{name}-c", f"{name}-d", f"{name}-e"}
+ )
+
+
+@mock_dax
+def test_decrease_replication_factor_unknown():
+ client = boto3.client("dax", region_name="ap-southeast-1")
+
+ with pytest.raises(ClientError) as exc:
+ client.decrease_replication_factor(
+ ClusterName="unknown", NewReplicationFactor=2
+ )
+
+ err = exc.value.response["Error"]
+ err["Code"].should.equal("ClusterNotFoundFault")
+
+
+@mock_dax
+def test_decrease_replication_factor():
+ client = boto3.client("dax", region_name="eu-west-1")
+
+ name = "daxcluster"
+ client.create_cluster(
+ ClusterName=name,
+ NodeType="dax.t3.small",
+ ReplicationFactor=5,
+ IamRoleArn=f"arn:aws:iam::{ACCOUNT_ID}:role/aws-service-role/dax.amazonaws.com/AWSServiceRoleForDAX",
+ Tags=[
+ {"Key": "tag1", "Value": "value1"},
+ {"Key": "tag2", "Value": "value2"},
+ {"Key": "tag3", "Value": "value3"},
+ ],
+ )
+
+ new_cluster = client.decrease_replication_factor(
+ ClusterName=name, NewReplicationFactor=3
+ )["Cluster"]
+ new_cluster["TotalNodes"].should.equal(3)
+
+ new_cluster = client.describe_clusters(ClusterNames=[name])["Clusters"][0]
+ new_cluster["TotalNodes"].should.equal(3)
+
+ # Progress cluster until it's available
+ client.describe_clusters(ClusterNames=[name])["Clusters"][0]
+ client.describe_clusters(ClusterNames=[name])["Clusters"][0]
+
+ cluster = client.describe_clusters(ClusterNames=[name])["Clusters"][0]
+ node_ids = set([n["NodeId"] for n in cluster["Nodes"]])
+ node_ids.should.equal({f"{name}-a", f"{name}-b", f"{name}-c"})
+
+
+@mock_dax
+def test_decrease_replication_factor_specific_nodeids():
+ client = boto3.client("dax", region_name="ap-southeast-1")
+
+ name = "daxcluster"
+ client.create_cluster(
+ ClusterName=name,
+ NodeType="dax.t3.small",
+ ReplicationFactor=5,
+ IamRoleArn=f"arn:aws:iam::{ACCOUNT_ID}:role/aws-service-role/dax.amazonaws.com/AWSServiceRoleForDAX",
+ Tags=[
+ {"Key": "tag1", "Value": "value1"},
+ {"Key": "tag2", "Value": "value2"},
+ {"Key": "tag3", "Value": "value3"},
+ ],
+ )
+
+ new_cluster = client.decrease_replication_factor(
+ ClusterName=name,
+ NewReplicationFactor=3,
+ NodeIdsToRemove=["daxcluster-b", "daxcluster-c"],
+ )["Cluster"]
+ new_cluster["TotalNodes"].should.equal(3)
+
+ new_cluster = client.describe_clusters(ClusterNames=[name])["Clusters"][0]
+ new_cluster["TotalNodes"].should.equal(3)
+
+ # Progress cluster until it's available
+ client.describe_clusters(ClusterNames=[name])["Clusters"][0]
+ client.describe_clusters(ClusterNames=[name])["Clusters"][0]
+
+ cluster = client.describe_clusters(ClusterNames=[name])["Clusters"][0]
+ node_ids = set([n["NodeId"] for n in cluster["Nodes"]])
+ node_ids.should.equal({f"{name}-a", f"{name}-d", f"{name}-e"})
diff --git a/tests/test_dax/test_server.py b/tests/test_dax/test_server.py
new file mode 100644
index 000000000..f0fc9aefa
--- /dev/null
+++ b/tests/test_dax/test_server.py
@@ -0,0 +1,15 @@
+import json
+import sure # noqa # pylint: disable=unused-import
+
+import moto.server as server
+
+
+def test_dax_list():
+ backend = server.create_backend_app("dax")
+ test_client = backend.test_client()
+
+ resp = test_client.post(
+ "/", headers={"X-Amz-Target": "AmazonDAXV3.DescribeClusters"}, data="{}"
+ )
+ resp.status_code.should.equal(200)
+ json.loads(resp.data).should.equal({"Clusters": [], "NextToken": None})