210 lines
		
	
	
		
			7.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			210 lines
		
	
	
		
			7.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import boto3
 | 
						|
import pytest
 | 
						|
import sure  # noqa # pylint: disable=unused-import
 | 
						|
 | 
						|
from moto import settings
 | 
						|
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
 | 
						|
from moto.ec2 import mock_ec2, ec2_backends
 | 
						|
from moto.emr import mock_emr
 | 
						|
from moto.emr.utils import EmrSecurityGroupManager
 | 
						|
 | 
						|
 | 
						|
ec2_backend = ec2_backends[ACCOUNT_ID]["us-east-1"]
 | 
						|
 | 
						|
 | 
						|
@mock_emr
 | 
						|
@mock_ec2
 | 
						|
def test_default_emr_security_groups_get_created_on_first_job_flow():
 | 
						|
    ec2 = boto3.resource("ec2", region_name="us-east-1")
 | 
						|
    ec2_client = boto3.client("ec2", region_name="us-east-1")
 | 
						|
 | 
						|
    vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16")
 | 
						|
    subnet = ec2.create_subnet(
 | 
						|
        VpcId=vpc.id, CidrBlock="10.0.0.0/24", AvailabilityZone="us-east-1a"
 | 
						|
    )
 | 
						|
 | 
						|
    def _get_default_security_groups():
 | 
						|
        group_resp = ec2_client.describe_security_groups(
 | 
						|
            Filters=[
 | 
						|
                {"Name": "vpc-id", "Values": [vpc.id]},
 | 
						|
                {
 | 
						|
                    "Name": "group-name",
 | 
						|
                    "Values": [
 | 
						|
                        "ElasticMapReduce-Master-Private",
 | 
						|
                        "ElasticMapReduce-Slave-Private",
 | 
						|
                        "ElasticMapReduce-ServiceAccess",
 | 
						|
                    ],
 | 
						|
                },
 | 
						|
            ]
 | 
						|
        )
 | 
						|
        return group_resp.get("SecurityGroups", [])
 | 
						|
 | 
						|
    assert len(_get_default_security_groups()) == 0
 | 
						|
 | 
						|
    client = boto3.client("emr", region_name="us-east-1")
 | 
						|
    run_job_flow_params = dict(
 | 
						|
        ReleaseLabel="emr-5.29.0",
 | 
						|
        Instances={
 | 
						|
            "KeepJobFlowAliveWhenNoSteps": True,
 | 
						|
            "Ec2SubnetId": subnet.id,
 | 
						|
            "InstanceGroups": [
 | 
						|
                {
 | 
						|
                    "Name": "Master",
 | 
						|
                    "Market": "ON_DEMAND",
 | 
						|
                    "InstanceRole": "MASTER",
 | 
						|
                    "InstanceType": "m5.xlarge",
 | 
						|
                    "InstanceCount": 3,
 | 
						|
                },
 | 
						|
                {
 | 
						|
                    "Name": "Core",
 | 
						|
                    "Market": "ON_DEMAND",
 | 
						|
                    "InstanceRole": "CORE",
 | 
						|
                    "InstanceType": "m5.xlarge",
 | 
						|
                    "InstanceCount": 2,
 | 
						|
                },
 | 
						|
            ],
 | 
						|
        },
 | 
						|
        JobFlowRole="EMR_EC2_DefaultRole",
 | 
						|
        Name="test-emr-cluster-security-groups",
 | 
						|
        ServiceRole="EMR_DefaultRole",
 | 
						|
        VisibleToAllUsers=True,
 | 
						|
    )
 | 
						|
    cluster_id = client.run_job_flow(**run_job_flow_params)["JobFlowId"]
 | 
						|
 | 
						|
    # Default security groups should have been created.
 | 
						|
    default_security_groups = _get_default_security_groups()
 | 
						|
    default_security_group_ids = [sg["GroupId"] for sg in default_security_groups]
 | 
						|
    assert len(default_security_group_ids) == 3
 | 
						|
 | 
						|
    resp = client.describe_cluster(ClusterId=cluster_id)
 | 
						|
    ec2_attrs = resp["Cluster"]["Ec2InstanceAttributes"]
 | 
						|
    assert ec2_attrs["Ec2SubnetId"] == subnet.id
 | 
						|
    cluster_security_group_ids = [
 | 
						|
        ec2_attrs["EmrManagedMasterSecurityGroup"],
 | 
						|
        ec2_attrs["EmrManagedSlaveSecurityGroup"],
 | 
						|
        ec2_attrs["ServiceAccessSecurityGroup"],
 | 
						|
    ]
 | 
						|
    assert set(cluster_security_group_ids) == set(default_security_group_ids)
 | 
						|
 | 
						|
 | 
						|
@pytest.mark.skipif(
 | 
						|
    settings.TEST_SERVER_MODE, reason="Can't modify backend directly in server mode."
 | 
						|
)
 | 
						|
class TestEmrSecurityGroupManager(object):
 | 
						|
 | 
						|
    mocks = []
 | 
						|
 | 
						|
    def setup_method(self):
 | 
						|
        self.mocks = [mock_ec2()]
 | 
						|
        for mock in self.mocks:
 | 
						|
            mock.start()
 | 
						|
        ec2_client = boto3.client("ec2", region_name="us-east-1")
 | 
						|
        ec2 = boto3.resource("ec2", region_name="us-east-1")
 | 
						|
        vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16")
 | 
						|
        self.vpc_id = vpc.id
 | 
						|
        self.ec2 = ec2
 | 
						|
        self.ec2_client = ec2_client
 | 
						|
 | 
						|
    def teardown_method(self):
 | 
						|
        for mock in self.mocks:
 | 
						|
            mock.stop()
 | 
						|
 | 
						|
    def _create_default_client_supplied_security_groups(self):
 | 
						|
        master = self.ec2.create_security_group(
 | 
						|
            GroupName="master", Description="master", VpcId=self.vpc_id
 | 
						|
        )
 | 
						|
        slave = self.ec2.create_security_group(
 | 
						|
            GroupName="slave", Description="slave", VpcId=self.vpc_id
 | 
						|
        )
 | 
						|
        service = self.ec2.create_security_group(
 | 
						|
            GroupName="service", Description="service", VpcId=self.vpc_id
 | 
						|
        )
 | 
						|
        return master, slave, service
 | 
						|
 | 
						|
    def _describe_security_groups(self, group_names):
 | 
						|
        resp = self.ec2_client.describe_security_groups(
 | 
						|
            Filters=[
 | 
						|
                {"Name": "vpc-id", "Values": [self.vpc_id]},
 | 
						|
                {"Name": "group-name", "Values": group_names},
 | 
						|
            ]
 | 
						|
        )
 | 
						|
        return resp.get("SecurityGroups", [])
 | 
						|
 | 
						|
    def _default_emr_security_groups(self):
 | 
						|
        group_names = [
 | 
						|
            "ElasticMapReduce-Master-Private",
 | 
						|
            "ElasticMapReduce-Slave-Private",
 | 
						|
            "ElasticMapReduce-ServiceAccess",
 | 
						|
        ]
 | 
						|
        return self._describe_security_groups(group_names)
 | 
						|
 | 
						|
    def test_emr_security_groups_get_created_if_non_existent(self):
 | 
						|
        manager = EmrSecurityGroupManager(ec2_backend, self.vpc_id)
 | 
						|
        assert len(self._default_emr_security_groups()) == 0
 | 
						|
        manager.manage_security_groups(None, None, None)
 | 
						|
        assert len(self._default_emr_security_groups()) == 3
 | 
						|
 | 
						|
    def test_emr_security_groups_do_not_get_created_if_already_exist(self):
 | 
						|
        manager = EmrSecurityGroupManager(ec2_backend, self.vpc_id)
 | 
						|
        assert len(self._default_emr_security_groups()) == 0
 | 
						|
        manager.manage_security_groups(None, None, None)
 | 
						|
        emr_security_groups = self._default_emr_security_groups()
 | 
						|
        assert len(emr_security_groups) == 3
 | 
						|
        # Run again.  Group count should still be 3.
 | 
						|
        emr_sg_ids_expected = [sg["GroupId"] for sg in emr_security_groups]
 | 
						|
        manager.manage_security_groups(None, None, None)
 | 
						|
        emr_security_groups = self._default_emr_security_groups()
 | 
						|
        assert len(emr_security_groups) == 3
 | 
						|
        emr_sg_ids_actual = [sg["GroupId"] for sg in emr_security_groups]
 | 
						|
        assert emr_sg_ids_actual == emr_sg_ids_expected
 | 
						|
 | 
						|
    def test_emr_security_groups_do_not_get_created_if_client_supplied(self):
 | 
						|
        (
 | 
						|
            client_master,
 | 
						|
            client_slave,
 | 
						|
            client_service,
 | 
						|
        ) = self._create_default_client_supplied_security_groups()
 | 
						|
        manager = EmrSecurityGroupManager(ec2_backend, self.vpc_id)
 | 
						|
        manager.manage_security_groups(
 | 
						|
            client_master.id, client_slave.id, client_service.id
 | 
						|
        )
 | 
						|
        client_group_names = [
 | 
						|
            client_master.group_name,
 | 
						|
            client_slave.group_name,
 | 
						|
            client_service.group_name,
 | 
						|
        ]
 | 
						|
        assert len(self._describe_security_groups(client_group_names)) == 3
 | 
						|
        assert len(self._default_emr_security_groups()) == 0
 | 
						|
 | 
						|
    def test_client_supplied_invalid_security_group_identifier_raises_error(self):
 | 
						|
        manager = EmrSecurityGroupManager(ec2_backend, self.vpc_id)
 | 
						|
        args_bad = [
 | 
						|
            ("sg-invalid", None, None),
 | 
						|
            (None, "sg-invalid", None),
 | 
						|
            (None, None, "sg-invalid"),
 | 
						|
        ]
 | 
						|
        for args in args_bad:
 | 
						|
            with pytest.raises(ValueError) as exc:
 | 
						|
                manager.manage_security_groups(*args)
 | 
						|
            assert str(exc.value) == "The security group 'sg-invalid' does not exist"
 | 
						|
 | 
						|
    def test_client_supplied_security_groups_have_rules_added(self):
 | 
						|
        (
 | 
						|
            client_master,
 | 
						|
            client_slave,
 | 
						|
            client_service,
 | 
						|
        ) = self._create_default_client_supplied_security_groups()
 | 
						|
        manager = EmrSecurityGroupManager(ec2_backend, self.vpc_id)
 | 
						|
        manager.manage_security_groups(
 | 
						|
            client_master.id, client_slave.id, client_service.id
 | 
						|
        )
 | 
						|
        client_group_names = [
 | 
						|
            client_master.group_name,
 | 
						|
            client_slave.group_name,
 | 
						|
            client_service.group_name,
 | 
						|
        ]
 | 
						|
        security_groups = self._describe_security_groups(client_group_names)
 | 
						|
        for security_group in security_groups:
 | 
						|
            assert len(security_group["IpPermissions"]) > 0
 | 
						|
            assert len(security_group["IpPermissionsEgress"]) > 0
 |