diff --git a/moto/redshift/responses.py b/moto/redshift/responses.py index 0dbf35cb2..52ca908e8 100644 --- a/moto/redshift/responses.py +++ b/moto/redshift/responses.py @@ -66,6 +66,24 @@ class RedshiftResponse(BaseResponse): count += 1 return unpacked_list + def _get_cluster_security_groups(self): + cluster_security_groups = self._get_multi_param('ClusterSecurityGroups.member') + if not cluster_security_groups: + cluster_security_groups = self._get_multi_param('ClusterSecurityGroups.ClusterSecurityGroupName') + return cluster_security_groups + + def _get_vpc_security_group_ids(self): + vpc_security_group_ids = self._get_multi_param('VpcSecurityGroupIds.member') + if not vpc_security_group_ids: + vpc_security_group_ids = self._get_multi_param('VpcSecurityGroupIds.VpcSecurityGroupId') + return vpc_security_group_ids + + def _get_subnet_ids(self): + subnet_ids = self._get_multi_param('SubnetIds.member') + if not subnet_ids: + subnet_ids = self._get_multi_param('SubnetIds.SubnetIdentifier') + return subnet_ids + def create_cluster(self): cluster_kwargs = { "cluster_identifier": self._get_param('ClusterIdentifier'), @@ -74,8 +92,8 @@ class RedshiftResponse(BaseResponse): "master_user_password": self._get_param('MasterUserPassword'), "db_name": self._get_param('DBName'), "cluster_type": self._get_param('ClusterType'), - "cluster_security_groups": self._get_multi_param('ClusterSecurityGroups.member'), - "vpc_security_group_ids": self._get_multi_param('VpcSecurityGroupIds.member'), + "cluster_security_groups": self._get_cluster_security_groups(), + "vpc_security_group_ids": self._get_vpc_security_group_ids(), "cluster_subnet_group_name": self._get_param('ClusterSubnetGroupName'), "availability_zone": self._get_param('AvailabilityZone'), "preferred_maintenance_window": self._get_param('PreferredMaintenanceWindow'), @@ -116,10 +134,8 @@ class RedshiftResponse(BaseResponse): "publicly_accessible": self._get_param("PubliclyAccessible"), "cluster_parameter_group_name": self._get_param( 'ClusterParameterGroupName'), - "cluster_security_groups": self._get_multi_param( - 'ClusterSecurityGroups.member'), - "vpc_security_group_ids": self._get_multi_param( - 'VpcSecurityGroupIds.member'), + "cluster_security_groups": self._get_cluster_security_groups(), + "vpc_security_group_ids": self._get_vpc_security_group_ids(), "preferred_maintenance_window": self._get_param( 'PreferredMaintenanceWindow'), "automated_snapshot_retention_period": self._get_int_param( @@ -161,8 +177,8 @@ class RedshiftResponse(BaseResponse): "node_type": self._get_param('NodeType'), "master_user_password": self._get_param('MasterUserPassword'), "cluster_type": self._get_param('ClusterType'), - "cluster_security_groups": self._get_multi_param('ClusterSecurityGroups.member'), - "vpc_security_group_ids": self._get_multi_param('VpcSecurityGroupIds.member'), + "cluster_security_groups": self._get_cluster_security_groups(), + "vpc_security_group_ids": self._get_vpc_security_group_ids(), "cluster_subnet_group_name": self._get_param('ClusterSubnetGroupName'), "preferred_maintenance_window": self._get_param('PreferredMaintenanceWindow'), "cluster_parameter_group_name": self._get_param('ClusterParameterGroupName'), @@ -173,12 +189,6 @@ class RedshiftResponse(BaseResponse): "publicly_accessible": self._get_param("PubliclyAccessible"), "encrypted": self._get_param("Encrypted"), } - # There's a bug in boto3 where the security group ids are not passed - # according to the AWS documentation - if not request_kwargs['vpc_security_group_ids']: - request_kwargs['vpc_security_group_ids'] = self._get_multi_param( - 'VpcSecurityGroupIds.VpcSecurityGroupId') - cluster_kwargs = {} # We only want parameters that were actually passed in, otherwise # we'll stomp all over our cluster metadata with None values. @@ -217,11 +227,7 @@ class RedshiftResponse(BaseResponse): def create_cluster_subnet_group(self): cluster_subnet_group_name = self._get_param('ClusterSubnetGroupName') description = self._get_param('Description') - subnet_ids = self._get_multi_param('SubnetIds.member') - # There's a bug in boto3 where the subnet ids are not passed - # according to the AWS documentation - if not subnet_ids: - subnet_ids = self._get_multi_param('SubnetIds.SubnetIdentifier') + subnet_ids = self._get_subnet_ids() tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value')) subnet_group = self.redshift_backend.create_cluster_subnet_group( diff --git a/tests/test_redshift/test_redshift.py b/tests/test_redshift/test_redshift.py index dca475374..cebaa3ec7 100644 --- a/tests/test_redshift/test_redshift.py +++ b/tests/test_redshift/test_redshift.py @@ -216,6 +216,33 @@ def test_create_cluster_with_security_group(): set(group_names).should.equal(set(["security_group1", "security_group2"])) +@mock_redshift +def test_create_cluster_with_security_group_boto3(): + client = boto3.client('redshift', region_name='us-east-1') + client.create_cluster_security_group( + ClusterSecurityGroupName="security_group1", + Description="This is my security group", + ) + client.create_cluster_security_group( + ClusterSecurityGroupName="security_group2", + Description="This is my security group", + ) + + cluster_identifier = 'my_cluster' + client.create_cluster( + ClusterIdentifier=cluster_identifier, + NodeType="dw.hs1.xlarge", + MasterUsername="username", + MasterUserPassword="password", + ClusterSecurityGroups=["security_group1", "security_group2"] + ) + response = client.describe_clusters(ClusterIdentifier=cluster_identifier) + cluster = response['Clusters'][0] + group_names = [group['ClusterSecurityGroupName'] + for group in cluster['ClusterSecurityGroups']] + set(group_names).should.equal({"security_group1", "security_group2"}) + + @mock_redshift_deprecated @mock_ec2_deprecated def test_create_cluster_with_vpc_security_groups(): @@ -242,6 +269,31 @@ def test_create_cluster_with_vpc_security_groups(): list(group_ids).should.equal([security_group.id]) +@mock_redshift +@mock_ec2 +def test_create_cluster_with_vpc_security_groups_boto3(): + ec2 = boto3.resource('ec2', region_name='us-east-1') + vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + client = boto3.client('redshift', region_name='us-east-1') + cluster_id = 'my_cluster' + security_group = ec2.create_security_group( + Description="vpc_security_group", + GroupName="a group", + VpcId=vpc.id) + client.create_cluster( + ClusterIdentifier=cluster_id, + NodeType="dw.hs1.xlarge", + MasterUsername="username", + MasterUserPassword="password", + VpcSecurityGroupIds=[security_group.id], + ) + response = client.describe_clusters(ClusterIdentifier=cluster_id) + cluster = response['Clusters'][0] + group_ids = [group['VpcSecurityGroupId'] + for group in cluster['VpcSecurityGroups']] + list(group_ids).should.equal([security_group.id]) + + @mock_redshift_deprecated def test_create_cluster_with_parameter_group(): conn = boto.connect_redshift()