diff --git a/moto/emr/models.py b/moto/emr/models.py index b37ebf034..74050fed7 100644 --- a/moto/emr/models.py +++ b/moto/emr/models.py @@ -2,6 +2,8 @@ from __future__ import unicode_literals from datetime import datetime from datetime import timedelta +import warnings + import pytz from boto3 import Session from dateutil.parser import parse as dtparse @@ -12,6 +14,7 @@ from .utils import ( random_cluster_id, random_step_id, CamelToUnderscoresWalker, + EmrSecurityGroupManager, ) @@ -363,6 +366,16 @@ class ElasticMapReduceBackend(BaseBackend): self.__dict__ = {} self.__init__(region_name) + @property + def ec2_backend(self): + """ + :return: EC2 Backend + :rtype: moto.ec2.models.EC2Backend + """ + from moto.ec2 import ec2_backends + + return ec2_backends[self.region_name] + def add_applications(self, cluster_id, applications): cluster = self.get_cluster(cluster_id) cluster.add_applications(applications) @@ -501,7 +514,51 @@ class ElasticMapReduceBackend(BaseBackend): cluster = self.get_cluster(cluster_id) cluster.remove_tags(tag_keys) + def _manage_security_groups( + self, + ec2_subnet_id, + emr_managed_master_security_group, + emr_managed_slave_security_group, + service_access_security_group, + **_ + ): + default_return_value = ( + emr_managed_master_security_group, + emr_managed_slave_security_group, + service_access_security_group, + ) + if not ec2_subnet_id: + # TODO: Set up Security Groups in Default VPC. + return default_return_value + + from moto.ec2.exceptions import InvalidSubnetIdError + + try: + subnet = self.ec2_backend.get_subnet(ec2_subnet_id) + except InvalidSubnetIdError: + warnings.warn( + "Could not find Subnet with id: {0}\n" + "In the near future, this will raise an error.\n" + "Use ec2.describe_subnets() to find a suitable id " + "for your test.".format(ec2_subnet_id), + PendingDeprecationWarning, + ) + return default_return_value + + manager = EmrSecurityGroupManager(self.ec2_backend, subnet.vpc_id) + master, slave, service = manager.manage_security_groups( + emr_managed_master_security_group, + emr_managed_slave_security_group, + service_access_security_group, + ) + return master.id, slave.id, service.id + def run_job_flow(self, **kwargs): + ( + kwargs["instance_attrs"]["emr_managed_master_security_group"], + kwargs["instance_attrs"]["emr_managed_slave_security_group"], + kwargs["instance_attrs"]["service_access_security_group"], + ) = self._manage_security_groups(**kwargs["instance_attrs"]) return FakeCluster(self, **kwargs) def set_visible_to_all_users(self, job_flow_ids, visible_to_all_users): diff --git a/moto/emr/utils.py b/moto/emr/utils.py index 506201c1c..99cc2ad12 100644 --- a/moto/emr/utils.py +++ b/moto/emr/utils.py @@ -1,8 +1,13 @@ from __future__ import unicode_literals +import copy +import datetime import random import re import string -from moto.core.utils import camelcase_to_underscores +from moto.core.utils import ( + camelcase_to_underscores, + iso_8601_datetime_with_milliseconds, +) import six @@ -218,3 +223,238 @@ class ReleaseLabel(object): if not isinstance(other, self.__class__): return NotImplemented return tuple(self) >= tuple(other) + + +class EmrManagedSecurityGroup(object): + class Kind: + MASTER = "Master" + SLAVE = "Slave" + SERVICE = "Service" + + kind = None + + group_name = "" + short_name = "" + desc_fmt = "{short_name} for Elastic MapReduce created on {created}" + + @classmethod + def description(cls): + created = iso_8601_datetime_with_milliseconds(datetime.datetime.now()) + return cls.desc_fmt.format(short_name=cls.short_name, created=created) + + +class EmrManagedMasterSecurityGroup(EmrManagedSecurityGroup): + kind = EmrManagedSecurityGroup.Kind.MASTER + group_name = "ElasticMapReduce-Master-Private" + short_name = "Master" + + +class EmrManagedSlaveSecurityGroup(EmrManagedSecurityGroup): + kind = EmrManagedSecurityGroup.Kind.SLAVE + group_name = "ElasticMapReduce-Slave-Private" + short_name = "Slave" + + +class EmrManagedServiceAccessSecurityGroup(EmrManagedSecurityGroup): + kind = EmrManagedSecurityGroup.Kind.SERVICE + group_name = "ElasticMapReduce-ServiceAccess" + short_name = "Service access" + + +class EmrSecurityGroupManager(object): + + MANAGED_RULES_EGRESS = [ + { + "group_name_or_id": EmrManagedSecurityGroup.Kind.MASTER, + "from_port": None, + "ip_protocol": "-1", + "ip_ranges": [{"CidrIp": "0.0.0.0/0"}], + "to_port": None, + "source_group_ids": [], + }, + { + "group_name_or_id": EmrManagedSecurityGroup.Kind.SLAVE, + "from_port": None, + "ip_protocol": "-1", + "ip_ranges": [{"CidrIp": "0.0.0.0/0"}], + "to_port": None, + "source_group_ids": [], + }, + { + "group_name_or_id": EmrManagedSecurityGroup.Kind.SERVICE, + "from_port": 8443, + "ip_protocol": "tcp", + "ip_ranges": [], + "to_port": 8443, + "source_group_ids": [ + EmrManagedSecurityGroup.Kind.MASTER, + EmrManagedSecurityGroup.Kind.SLAVE, + ], + }, + ] + + MANAGED_RULES_INGRESS = [ + { + "group_name_or_id": EmrManagedSecurityGroup.Kind.MASTER, + "from_port": 0, + "ip_protocol": "tcp", + "ip_ranges": [], + "to_port": 65535, + "source_group_ids": [ + EmrManagedSecurityGroup.Kind.MASTER, + EmrManagedSecurityGroup.Kind.SLAVE, + ], + }, + { + "group_name_or_id": EmrManagedSecurityGroup.Kind.MASTER, + "from_port": 8443, + "ip_protocol": "tcp", + "ip_ranges": [], + "to_port": 8443, + "source_group_ids": [EmrManagedSecurityGroup.Kind.SERVICE], + }, + { + "group_name_or_id": EmrManagedSecurityGroup.Kind.MASTER, + "from_port": 0, + "ip_protocol": "udp", + "ip_ranges": [], + "to_port": 65535, + "source_group_ids": [ + EmrManagedSecurityGroup.Kind.MASTER, + EmrManagedSecurityGroup.Kind.SLAVE, + ], + }, + { + "group_name_or_id": EmrManagedSecurityGroup.Kind.MASTER, + "from_port": -1, + "ip_protocol": "icmp", + "ip_ranges": [], + "to_port": -1, + "source_group_ids": [ + EmrManagedSecurityGroup.Kind.MASTER, + EmrManagedSecurityGroup.Kind.SLAVE, + ], + }, + { + "group_name_or_id": EmrManagedSecurityGroup.Kind.SLAVE, + "from_port": 0, + "ip_protocol": "tcp", + "ip_ranges": [], + "to_port": 65535, + "source_group_ids": [ + EmrManagedSecurityGroup.Kind.SLAVE, + EmrManagedSecurityGroup.Kind.MASTER, + ], + }, + { + "group_name_or_id": EmrManagedSecurityGroup.Kind.SLAVE, + "from_port": 8443, + "ip_protocol": "tcp", + "ip_ranges": [], + "to_port": 8443, + "source_group_ids": [EmrManagedSecurityGroup.Kind.SERVICE], + }, + { + "group_name_or_id": EmrManagedSecurityGroup.Kind.SLAVE, + "from_port": 0, + "ip_protocol": "udp", + "ip_ranges": [], + "to_port": 65535, + "source_group_ids": [ + EmrManagedSecurityGroup.Kind.MASTER, + EmrManagedSecurityGroup.Kind.SLAVE, + ], + }, + { + "group_name_or_id": EmrManagedSecurityGroup.Kind.SLAVE, + "from_port": -1, + "ip_protocol": "icmp", + "ip_ranges": [], + "to_port": -1, + "source_group_ids": [ + EmrManagedSecurityGroup.Kind.MASTER, + EmrManagedSecurityGroup.Kind.SLAVE, + ], + }, + { + "group_name_or_id": EmrManagedSecurityGroup.Kind.SERVICE, + "from_port": 9443, + "ip_protocol": "tcp", + "ip_ranges": [], + "to_port": 9443, + "source_group_ids": [EmrManagedSecurityGroup.Kind.MASTER], + }, + ] + + def __init__(self, ec2_backend, vpc_id): + self.ec2 = ec2_backend + self.vpc_id = vpc_id + + def manage_security_groups( + self, master_security_group, slave_security_group, service_access_security_group + ): + group_metadata = [ + ( + master_security_group, + EmrManagedSecurityGroup.Kind.MASTER, + EmrManagedMasterSecurityGroup, + ), + ( + slave_security_group, + EmrManagedSecurityGroup.Kind.SLAVE, + EmrManagedSlaveSecurityGroup, + ), + ( + service_access_security_group, + EmrManagedSecurityGroup.Kind.SERVICE, + EmrManagedServiceAccessSecurityGroup, + ), + ] + managed_groups = {} + for name, kind, defaults in group_metadata: + managed_groups[kind] = self._get_or_create_sg(name, defaults) + self._add_rules_to(managed_groups) + return ( + managed_groups[EmrManagedSecurityGroup.Kind.MASTER], + managed_groups[EmrManagedSecurityGroup.Kind.SLAVE], + managed_groups[EmrManagedSecurityGroup.Kind.SERVICE], + ) + + def _get_or_create_sg(self, sg_id, defaults): + find_sg = self.ec2.get_security_group_by_name_or_id + create_sg = self.ec2.create_security_group + group_id_or_name = sg_id or defaults.group_name + group = find_sg(group_id_or_name, self.vpc_id) + if group is None: + if group_id_or_name != defaults.group_name: + raise ValueError( + "The security group '{}' does not exist".format(group_id_or_name) + ) + group = create_sg(defaults.group_name, defaults.description(), self.vpc_id) + return group + + def _add_rules_to(self, managed_groups): + rules_metadata = [ + (self.MANAGED_RULES_EGRESS, self.ec2.authorize_security_group_egress), + (self.MANAGED_RULES_INGRESS, self.ec2.authorize_security_group_ingress), + ] + for rules, add_rule in rules_metadata: + rendered_rules = self._render_rules(rules, managed_groups) + for rule in rendered_rules: + from moto.ec2.exceptions import InvalidPermissionDuplicateError + + try: + add_rule(vpc_id=self.vpc_id, **rule) + except InvalidPermissionDuplicateError: + # If the rule already exists, we can just move on. + pass + + @staticmethod + def _render_rules(rules, managed_groups): + rendered_rules = copy.deepcopy(rules) + for rule in rendered_rules: + rule["group_name_or_id"] = managed_groups[rule["group_name_or_id"]].id + rule["source_group_ids"] = [ + managed_groups[group].id for group in rule["source_group_ids"] + ] + return rendered_rules diff --git a/tests/test_emr/test_emr_integration.py b/tests/test_emr/test_emr_integration.py new file mode 100644 index 000000000..16b499476 --- /dev/null +++ b/tests/test_emr/test_emr_integration.py @@ -0,0 +1,207 @@ +from __future__ import unicode_literals + +import boto3 +import pytest +import sure # noqa + +from moto import settings +from moto.ec2 import mock_ec2, ec2_backend +from moto.emr import mock_emr +from moto.emr.utils import EmrSecurityGroupManager + + +@mock_emr +@mock_ec2 +def test_default_emr_security_groups_get_created_on_first_job_flow(): + ec2 = boto3.resource("ec2", region_name="us-east-1") + ec2_client = boto3.client("ec2", region_name="us-east-1") + + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") + subnet = ec2.create_subnet( + VpcId=vpc.id, CidrBlock="10.0.0.0/24", AvailabilityZone="us-east-1a" + ) + + def _get_default_security_groups(): + group_resp = ec2_client.describe_security_groups( + Filters=[ + {"Name": "vpc-id", "Values": [vpc.id]}, + { + "Name": "group-name", + "Values": [ + "ElasticMapReduce-Master-Private", + "ElasticMapReduce-Slave-Private", + "ElasticMapReduce-ServiceAccess", + ], + }, + ] + ) + return group_resp.get("SecurityGroups", []) + + assert len(_get_default_security_groups()) == 0 + + client = boto3.client("emr", region_name="us-east-1") + run_job_flow_params = dict( + ReleaseLabel="emr-5.29.0", + Instances={ + "KeepJobFlowAliveWhenNoSteps": True, + "Ec2SubnetId": subnet.id, + "InstanceGroups": [ + { + "Name": "Master", + "Market": "ON_DEMAND", + "InstanceRole": "MASTER", + "InstanceType": "m5.xlarge", + "InstanceCount": 3, + }, + { + "Name": "Core", + "Market": "ON_DEMAND", + "InstanceRole": "CORE", + "InstanceType": "m5.xlarge", + "InstanceCount": 2, + }, + ], + }, + JobFlowRole="EMR_EC2_DefaultRole", + Name="test-emr-cluster-security-groups", + ServiceRole="EMR_DefaultRole", + VisibleToAllUsers=True, + ) + cluster_id = client.run_job_flow(**run_job_flow_params)["JobFlowId"] + + # Default security groups should have been created. + default_security_groups = _get_default_security_groups() + default_security_group_ids = [sg["GroupId"] for sg in default_security_groups] + assert len(default_security_group_ids) == 3 + + resp = client.describe_cluster(ClusterId=cluster_id) + ec2_attrs = resp["Cluster"]["Ec2InstanceAttributes"] + assert ec2_attrs["Ec2SubnetId"] == subnet.id + cluster_security_group_ids = [ + ec2_attrs["EmrManagedMasterSecurityGroup"], + ec2_attrs["EmrManagedSlaveSecurityGroup"], + ec2_attrs["ServiceAccessSecurityGroup"], + ] + assert set(cluster_security_group_ids) == set(default_security_group_ids) + + +@pytest.mark.skipif( + settings.TEST_SERVER_MODE, reason="Can't modify backend directly in server mode." +) +class TestEmrSecurityGroupManager(object): + + mocks = [] + + def setup(self): + self.mocks = [mock_ec2()] + for mock in self.mocks: + mock.start() + ec2_client = boto3.client("ec2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") + self.vpc_id = vpc.id + self.ec2 = ec2 + self.ec2_client = ec2_client + + def teardown(self): + for mock in self.mocks: + mock.stop() + + def _create_default_client_supplied_security_groups(self): + master = self.ec2.create_security_group( + GroupName="master", Description="master", VpcId=self.vpc_id + ) + slave = self.ec2.create_security_group( + GroupName="slave", Description="slave", VpcId=self.vpc_id + ) + service = self.ec2.create_security_group( + GroupName="service", Description="service", VpcId=self.vpc_id + ) + return master, slave, service + + def _describe_security_groups(self, group_names): + resp = self.ec2_client.describe_security_groups( + Filters=[ + {"Name": "vpc-id", "Values": [self.vpc_id]}, + {"Name": "group-name", "Values": group_names}, + ] + ) + return resp.get("SecurityGroups", []) + + def _default_emr_security_groups(self): + group_names = [ + "ElasticMapReduce-Master-Private", + "ElasticMapReduce-Slave-Private", + "ElasticMapReduce-ServiceAccess", + ] + return self._describe_security_groups(group_names) + + def test_emr_security_groups_get_created_if_non_existent(self): + manager = EmrSecurityGroupManager(ec2_backend, self.vpc_id) + assert len(self._default_emr_security_groups()) == 0 + manager.manage_security_groups(None, None, None) + assert len(self._default_emr_security_groups()) == 3 + + def test_emr_security_groups_do_not_get_created_if_already_exist(self): + manager = EmrSecurityGroupManager(ec2_backend, self.vpc_id) + assert len(self._default_emr_security_groups()) == 0 + manager.manage_security_groups(None, None, None) + emr_security_groups = self._default_emr_security_groups() + assert len(emr_security_groups) == 3 + # Run again. Group count should still be 3. + emr_sg_ids_expected = [sg["GroupId"] for sg in emr_security_groups] + manager.manage_security_groups(None, None, None) + emr_security_groups = self._default_emr_security_groups() + assert len(emr_security_groups) == 3 + emr_sg_ids_actual = [sg["GroupId"] for sg in emr_security_groups] + assert emr_sg_ids_actual == emr_sg_ids_expected + + def test_emr_security_groups_do_not_get_created_if_client_supplied(self): + ( + client_master, + client_slave, + client_service, + ) = self._create_default_client_supplied_security_groups() + manager = EmrSecurityGroupManager(ec2_backend, self.vpc_id) + manager.manage_security_groups( + client_master.id, client_slave.id, client_service.id + ) + client_group_names = [ + client_master.group_name, + client_slave.group_name, + client_service.group_name, + ] + assert len(self._describe_security_groups(client_group_names)) == 3 + assert len(self._default_emr_security_groups()) == 0 + + def test_client_supplied_invalid_security_group_identifier_raises_error(self): + manager = EmrSecurityGroupManager(ec2_backend, self.vpc_id) + args_bad = [ + ("sg-invalid", None, None), + (None, "sg-invalid", None), + (None, None, "sg-invalid"), + ] + for args in args_bad: + with pytest.raises(ValueError) as exc: + manager.manage_security_groups(*args) + assert str(exc.value) == "The security group 'sg-invalid' does not exist" + + def test_client_supplied_security_groups_have_rules_added(self): + ( + client_master, + client_slave, + client_service, + ) = self._create_default_client_supplied_security_groups() + manager = EmrSecurityGroupManager(ec2_backend, self.vpc_id) + manager.manage_security_groups( + client_master.id, client_slave.id, client_service.id + ) + client_group_names = [ + client_master.group_name, + client_slave.group_name, + client_service.group_name, + ] + security_groups = self._describe_security_groups(client_group_names) + for security_group in security_groups: + assert len(security_group["IpPermissions"]) > 0 + assert len(security_group["IpPermissionsEgress"]) > 0