diff --git a/moto/redshift/responses.py b/moto/redshift/responses.py index eec440468..a9bf31f0a 100644 --- a/moto/redshift/responses.py +++ b/moto/redshift/responses.py @@ -161,6 +161,8 @@ class RedshiftResponse(BaseResponse): def restore_from_cluster_snapshot(self): enhanced_vpc_routing = self._get_bool_param("EnhancedVpcRouting") + node_type = self._get_param("NodeType") + number_of_nodes = self._get_int_param("NumberOfNodes") restore_kwargs = { "snapshot_identifier": self._get_param("SnapshotIdentifier"), "cluster_identifier": self._get_param("ClusterIdentifier"), @@ -185,6 +187,10 @@ class RedshiftResponse(BaseResponse): } if enhanced_vpc_routing is not None: restore_kwargs["enhanced_vpc_routing"] = enhanced_vpc_routing + if node_type is not None: + restore_kwargs["node_type"] = node_type + if number_of_nodes is not None: + restore_kwargs["number_of_nodes"] = number_of_nodes cluster = self.redshift_backend.restore_from_cluster_snapshot( **restore_kwargs ).to_json() diff --git a/tests/test_redshift/test_redshift.py b/tests/test_redshift/test_redshift.py index b952be122..465b686d5 100644 --- a/tests/test_redshift/test_redshift.py +++ b/tests/test_redshift/test_redshift.py @@ -1421,6 +1421,49 @@ def test_create_cluster_from_snapshot(): new_cluster["EnhancedVpcRouting"].should.equal(True) +@mock_redshift +def test_create_cluster_with_node_type_from_snapshot(): + client = boto3.client("redshift", region_name="us-east-1") + original_cluster_identifier = "original-cluster" + original_snapshot_identifier = "original-snapshot" + new_cluster_identifier = "new-cluster" + + client.create_cluster( + ClusterIdentifier=original_cluster_identifier, + ClusterType="multi-node", + NodeType="ds2.xlarge", + MasterUsername="username", + MasterUserPassword="password", + EnhancedVpcRouting=True, + NumberOfNodes=2, + ) + + client.create_cluster_snapshot( + SnapshotIdentifier=original_snapshot_identifier, + ClusterIdentifier=original_cluster_identifier, + ) + + client.restore_from_cluster_snapshot.when.called_with( + ClusterIdentifier=original_cluster_identifier, + SnapshotIdentifier=original_snapshot_identifier, + ).should.throw(ClientError, "ClusterAlreadyExists") + + response = client.restore_from_cluster_snapshot( + ClusterIdentifier=new_cluster_identifier, + SnapshotIdentifier=original_snapshot_identifier, + NodeType="ra3.xlplus", + NumberOfNodes=3, + ) + response["Cluster"]["ClusterStatus"].should.equal("creating") + + response = client.describe_clusters(ClusterIdentifier=new_cluster_identifier) + new_cluster = response["Clusters"][0] + new_cluster["NodeType"].should.equal("ra3.xlplus") + new_cluster["NumberOfNodes"].should.equal(3) + new_cluster["MasterUsername"].should.equal("username") + new_cluster["EnhancedVpcRouting"].should.equal(True) + + @mock_redshift def test_create_cluster_from_snapshot_with_waiter(): client = boto3.client("redshift", region_name="us-east-1")