Add support for EMR-Managed Security Groups (#3824)

* Add support for EMR-Managed Security Groups

This covers the base case for EMR Clusters provisioned in a private subnet.

Ref: https://docs.aws.amazon.com/emr/latest/ManagementGuide/emr-man-sec-groups.html

* Address PR comments

* Address PR comments
This commit is contained in:
Brian Pandola 2021-04-02 07:34:02 -07:00 committed by GitHub
parent e90858b2e8
commit ac4a26f289
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 505 additions and 1 deletions

View File

@ -2,6 +2,8 @@ from __future__ import unicode_literals
from datetime import datetime from datetime import datetime
from datetime import timedelta from datetime import timedelta
import warnings
import pytz import pytz
from boto3 import Session from boto3 import Session
from dateutil.parser import parse as dtparse from dateutil.parser import parse as dtparse
@ -12,6 +14,7 @@ from .utils import (
random_cluster_id, random_cluster_id,
random_step_id, random_step_id,
CamelToUnderscoresWalker, CamelToUnderscoresWalker,
EmrSecurityGroupManager,
) )
@ -363,6 +366,16 @@ class ElasticMapReduceBackend(BaseBackend):
self.__dict__ = {} self.__dict__ = {}
self.__init__(region_name) self.__init__(region_name)
@property
def ec2_backend(self):
"""
:return: EC2 Backend
:rtype: moto.ec2.models.EC2Backend
"""
from moto.ec2 import ec2_backends
return ec2_backends[self.region_name]
def add_applications(self, cluster_id, applications): def add_applications(self, cluster_id, applications):
cluster = self.get_cluster(cluster_id) cluster = self.get_cluster(cluster_id)
cluster.add_applications(applications) cluster.add_applications(applications)
@ -501,7 +514,51 @@ class ElasticMapReduceBackend(BaseBackend):
cluster = self.get_cluster(cluster_id) cluster = self.get_cluster(cluster_id)
cluster.remove_tags(tag_keys) cluster.remove_tags(tag_keys)
def _manage_security_groups(
self,
ec2_subnet_id,
emr_managed_master_security_group,
emr_managed_slave_security_group,
service_access_security_group,
**_
):
default_return_value = (
emr_managed_master_security_group,
emr_managed_slave_security_group,
service_access_security_group,
)
if not ec2_subnet_id:
# TODO: Set up Security Groups in Default VPC.
return default_return_value
from moto.ec2.exceptions import InvalidSubnetIdError
try:
subnet = self.ec2_backend.get_subnet(ec2_subnet_id)
except InvalidSubnetIdError:
warnings.warn(
"Could not find Subnet with id: {0}\n"
"In the near future, this will raise an error.\n"
"Use ec2.describe_subnets() to find a suitable id "
"for your test.".format(ec2_subnet_id),
PendingDeprecationWarning,
)
return default_return_value
manager = EmrSecurityGroupManager(self.ec2_backend, subnet.vpc_id)
master, slave, service = manager.manage_security_groups(
emr_managed_master_security_group,
emr_managed_slave_security_group,
service_access_security_group,
)
return master.id, slave.id, service.id
def run_job_flow(self, **kwargs): def run_job_flow(self, **kwargs):
(
kwargs["instance_attrs"]["emr_managed_master_security_group"],
kwargs["instance_attrs"]["emr_managed_slave_security_group"],
kwargs["instance_attrs"]["service_access_security_group"],
) = self._manage_security_groups(**kwargs["instance_attrs"])
return FakeCluster(self, **kwargs) return FakeCluster(self, **kwargs)
def set_visible_to_all_users(self, job_flow_ids, visible_to_all_users): def set_visible_to_all_users(self, job_flow_ids, visible_to_all_users):

View File

@ -1,8 +1,13 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import copy
import datetime
import random import random
import re import re
import string import string
from moto.core.utils import camelcase_to_underscores from moto.core.utils import (
camelcase_to_underscores,
iso_8601_datetime_with_milliseconds,
)
import six import six
@ -218,3 +223,238 @@ class ReleaseLabel(object):
if not isinstance(other, self.__class__): if not isinstance(other, self.__class__):
return NotImplemented return NotImplemented
return tuple(self) >= tuple(other) return tuple(self) >= tuple(other)
class EmrManagedSecurityGroup(object):
class Kind:
MASTER = "Master"
SLAVE = "Slave"
SERVICE = "Service"
kind = None
group_name = ""
short_name = ""
desc_fmt = "{short_name} for Elastic MapReduce created on {created}"
@classmethod
def description(cls):
created = iso_8601_datetime_with_milliseconds(datetime.datetime.now())
return cls.desc_fmt.format(short_name=cls.short_name, created=created)
class EmrManagedMasterSecurityGroup(EmrManagedSecurityGroup):
kind = EmrManagedSecurityGroup.Kind.MASTER
group_name = "ElasticMapReduce-Master-Private"
short_name = "Master"
class EmrManagedSlaveSecurityGroup(EmrManagedSecurityGroup):
kind = EmrManagedSecurityGroup.Kind.SLAVE
group_name = "ElasticMapReduce-Slave-Private"
short_name = "Slave"
class EmrManagedServiceAccessSecurityGroup(EmrManagedSecurityGroup):
kind = EmrManagedSecurityGroup.Kind.SERVICE
group_name = "ElasticMapReduce-ServiceAccess"
short_name = "Service access"
class EmrSecurityGroupManager(object):
MANAGED_RULES_EGRESS = [
{
"group_name_or_id": EmrManagedSecurityGroup.Kind.MASTER,
"from_port": None,
"ip_protocol": "-1",
"ip_ranges": [{"CidrIp": "0.0.0.0/0"}],
"to_port": None,
"source_group_ids": [],
},
{
"group_name_or_id": EmrManagedSecurityGroup.Kind.SLAVE,
"from_port": None,
"ip_protocol": "-1",
"ip_ranges": [{"CidrIp": "0.0.0.0/0"}],
"to_port": None,
"source_group_ids": [],
},
{
"group_name_or_id": EmrManagedSecurityGroup.Kind.SERVICE,
"from_port": 8443,
"ip_protocol": "tcp",
"ip_ranges": [],
"to_port": 8443,
"source_group_ids": [
EmrManagedSecurityGroup.Kind.MASTER,
EmrManagedSecurityGroup.Kind.SLAVE,
],
},
]
MANAGED_RULES_INGRESS = [
{
"group_name_or_id": EmrManagedSecurityGroup.Kind.MASTER,
"from_port": 0,
"ip_protocol": "tcp",
"ip_ranges": [],
"to_port": 65535,
"source_group_ids": [
EmrManagedSecurityGroup.Kind.MASTER,
EmrManagedSecurityGroup.Kind.SLAVE,
],
},
{
"group_name_or_id": EmrManagedSecurityGroup.Kind.MASTER,
"from_port": 8443,
"ip_protocol": "tcp",
"ip_ranges": [],
"to_port": 8443,
"source_group_ids": [EmrManagedSecurityGroup.Kind.SERVICE],
},
{
"group_name_or_id": EmrManagedSecurityGroup.Kind.MASTER,
"from_port": 0,
"ip_protocol": "udp",
"ip_ranges": [],
"to_port": 65535,
"source_group_ids": [
EmrManagedSecurityGroup.Kind.MASTER,
EmrManagedSecurityGroup.Kind.SLAVE,
],
},
{
"group_name_or_id": EmrManagedSecurityGroup.Kind.MASTER,
"from_port": -1,
"ip_protocol": "icmp",
"ip_ranges": [],
"to_port": -1,
"source_group_ids": [
EmrManagedSecurityGroup.Kind.MASTER,
EmrManagedSecurityGroup.Kind.SLAVE,
],
},
{
"group_name_or_id": EmrManagedSecurityGroup.Kind.SLAVE,
"from_port": 0,
"ip_protocol": "tcp",
"ip_ranges": [],
"to_port": 65535,
"source_group_ids": [
EmrManagedSecurityGroup.Kind.SLAVE,
EmrManagedSecurityGroup.Kind.MASTER,
],
},
{
"group_name_or_id": EmrManagedSecurityGroup.Kind.SLAVE,
"from_port": 8443,
"ip_protocol": "tcp",
"ip_ranges": [],
"to_port": 8443,
"source_group_ids": [EmrManagedSecurityGroup.Kind.SERVICE],
},
{
"group_name_or_id": EmrManagedSecurityGroup.Kind.SLAVE,
"from_port": 0,
"ip_protocol": "udp",
"ip_ranges": [],
"to_port": 65535,
"source_group_ids": [
EmrManagedSecurityGroup.Kind.MASTER,
EmrManagedSecurityGroup.Kind.SLAVE,
],
},
{
"group_name_or_id": EmrManagedSecurityGroup.Kind.SLAVE,
"from_port": -1,
"ip_protocol": "icmp",
"ip_ranges": [],
"to_port": -1,
"source_group_ids": [
EmrManagedSecurityGroup.Kind.MASTER,
EmrManagedSecurityGroup.Kind.SLAVE,
],
},
{
"group_name_or_id": EmrManagedSecurityGroup.Kind.SERVICE,
"from_port": 9443,
"ip_protocol": "tcp",
"ip_ranges": [],
"to_port": 9443,
"source_group_ids": [EmrManagedSecurityGroup.Kind.MASTER],
},
]
def __init__(self, ec2_backend, vpc_id):
self.ec2 = ec2_backend
self.vpc_id = vpc_id
def manage_security_groups(
self, master_security_group, slave_security_group, service_access_security_group
):
group_metadata = [
(
master_security_group,
EmrManagedSecurityGroup.Kind.MASTER,
EmrManagedMasterSecurityGroup,
),
(
slave_security_group,
EmrManagedSecurityGroup.Kind.SLAVE,
EmrManagedSlaveSecurityGroup,
),
(
service_access_security_group,
EmrManagedSecurityGroup.Kind.SERVICE,
EmrManagedServiceAccessSecurityGroup,
),
]
managed_groups = {}
for name, kind, defaults in group_metadata:
managed_groups[kind] = self._get_or_create_sg(name, defaults)
self._add_rules_to(managed_groups)
return (
managed_groups[EmrManagedSecurityGroup.Kind.MASTER],
managed_groups[EmrManagedSecurityGroup.Kind.SLAVE],
managed_groups[EmrManagedSecurityGroup.Kind.SERVICE],
)
def _get_or_create_sg(self, sg_id, defaults):
find_sg = self.ec2.get_security_group_by_name_or_id
create_sg = self.ec2.create_security_group
group_id_or_name = sg_id or defaults.group_name
group = find_sg(group_id_or_name, self.vpc_id)
if group is None:
if group_id_or_name != defaults.group_name:
raise ValueError(
"The security group '{}' does not exist".format(group_id_or_name)
)
group = create_sg(defaults.group_name, defaults.description(), self.vpc_id)
return group
def _add_rules_to(self, managed_groups):
rules_metadata = [
(self.MANAGED_RULES_EGRESS, self.ec2.authorize_security_group_egress),
(self.MANAGED_RULES_INGRESS, self.ec2.authorize_security_group_ingress),
]
for rules, add_rule in rules_metadata:
rendered_rules = self._render_rules(rules, managed_groups)
for rule in rendered_rules:
from moto.ec2.exceptions import InvalidPermissionDuplicateError
try:
add_rule(vpc_id=self.vpc_id, **rule)
except InvalidPermissionDuplicateError:
# If the rule already exists, we can just move on.
pass
@staticmethod
def _render_rules(rules, managed_groups):
rendered_rules = copy.deepcopy(rules)
for rule in rendered_rules:
rule["group_name_or_id"] = managed_groups[rule["group_name_or_id"]].id
rule["source_group_ids"] = [
managed_groups[group].id for group in rule["source_group_ids"]
]
return rendered_rules

View File

@ -0,0 +1,207 @@
from __future__ import unicode_literals
import boto3
import pytest
import sure # noqa
from moto import settings
from moto.ec2 import mock_ec2, ec2_backend
from moto.emr import mock_emr
from moto.emr.utils import EmrSecurityGroupManager
@mock_emr
@mock_ec2
def test_default_emr_security_groups_get_created_on_first_job_flow():
ec2 = boto3.resource("ec2", region_name="us-east-1")
ec2_client = boto3.client("ec2", region_name="us-east-1")
vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16")
subnet = ec2.create_subnet(
VpcId=vpc.id, CidrBlock="10.0.0.0/24", AvailabilityZone="us-east-1a"
)
def _get_default_security_groups():
group_resp = ec2_client.describe_security_groups(
Filters=[
{"Name": "vpc-id", "Values": [vpc.id]},
{
"Name": "group-name",
"Values": [
"ElasticMapReduce-Master-Private",
"ElasticMapReduce-Slave-Private",
"ElasticMapReduce-ServiceAccess",
],
},
]
)
return group_resp.get("SecurityGroups", [])
assert len(_get_default_security_groups()) == 0
client = boto3.client("emr", region_name="us-east-1")
run_job_flow_params = dict(
ReleaseLabel="emr-5.29.0",
Instances={
"KeepJobFlowAliveWhenNoSteps": True,
"Ec2SubnetId": subnet.id,
"InstanceGroups": [
{
"Name": "Master",
"Market": "ON_DEMAND",
"InstanceRole": "MASTER",
"InstanceType": "m5.xlarge",
"InstanceCount": 3,
},
{
"Name": "Core",
"Market": "ON_DEMAND",
"InstanceRole": "CORE",
"InstanceType": "m5.xlarge",
"InstanceCount": 2,
},
],
},
JobFlowRole="EMR_EC2_DefaultRole",
Name="test-emr-cluster-security-groups",
ServiceRole="EMR_DefaultRole",
VisibleToAllUsers=True,
)
cluster_id = client.run_job_flow(**run_job_flow_params)["JobFlowId"]
# Default security groups should have been created.
default_security_groups = _get_default_security_groups()
default_security_group_ids = [sg["GroupId"] for sg in default_security_groups]
assert len(default_security_group_ids) == 3
resp = client.describe_cluster(ClusterId=cluster_id)
ec2_attrs = resp["Cluster"]["Ec2InstanceAttributes"]
assert ec2_attrs["Ec2SubnetId"] == subnet.id
cluster_security_group_ids = [
ec2_attrs["EmrManagedMasterSecurityGroup"],
ec2_attrs["EmrManagedSlaveSecurityGroup"],
ec2_attrs["ServiceAccessSecurityGroup"],
]
assert set(cluster_security_group_ids) == set(default_security_group_ids)
@pytest.mark.skipif(
settings.TEST_SERVER_MODE, reason="Can't modify backend directly in server mode."
)
class TestEmrSecurityGroupManager(object):
mocks = []
def setup(self):
self.mocks = [mock_ec2()]
for mock in self.mocks:
mock.start()
ec2_client = boto3.client("ec2", region_name="us-east-1")
ec2 = boto3.resource("ec2", region_name="us-east-1")
vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16")
self.vpc_id = vpc.id
self.ec2 = ec2
self.ec2_client = ec2_client
def teardown(self):
for mock in self.mocks:
mock.stop()
def _create_default_client_supplied_security_groups(self):
master = self.ec2.create_security_group(
GroupName="master", Description="master", VpcId=self.vpc_id
)
slave = self.ec2.create_security_group(
GroupName="slave", Description="slave", VpcId=self.vpc_id
)
service = self.ec2.create_security_group(
GroupName="service", Description="service", VpcId=self.vpc_id
)
return master, slave, service
def _describe_security_groups(self, group_names):
resp = self.ec2_client.describe_security_groups(
Filters=[
{"Name": "vpc-id", "Values": [self.vpc_id]},
{"Name": "group-name", "Values": group_names},
]
)
return resp.get("SecurityGroups", [])
def _default_emr_security_groups(self):
group_names = [
"ElasticMapReduce-Master-Private",
"ElasticMapReduce-Slave-Private",
"ElasticMapReduce-ServiceAccess",
]
return self._describe_security_groups(group_names)
def test_emr_security_groups_get_created_if_non_existent(self):
manager = EmrSecurityGroupManager(ec2_backend, self.vpc_id)
assert len(self._default_emr_security_groups()) == 0
manager.manage_security_groups(None, None, None)
assert len(self._default_emr_security_groups()) == 3
def test_emr_security_groups_do_not_get_created_if_already_exist(self):
manager = EmrSecurityGroupManager(ec2_backend, self.vpc_id)
assert len(self._default_emr_security_groups()) == 0
manager.manage_security_groups(None, None, None)
emr_security_groups = self._default_emr_security_groups()
assert len(emr_security_groups) == 3
# Run again. Group count should still be 3.
emr_sg_ids_expected = [sg["GroupId"] for sg in emr_security_groups]
manager.manage_security_groups(None, None, None)
emr_security_groups = self._default_emr_security_groups()
assert len(emr_security_groups) == 3
emr_sg_ids_actual = [sg["GroupId"] for sg in emr_security_groups]
assert emr_sg_ids_actual == emr_sg_ids_expected
def test_emr_security_groups_do_not_get_created_if_client_supplied(self):
(
client_master,
client_slave,
client_service,
) = self._create_default_client_supplied_security_groups()
manager = EmrSecurityGroupManager(ec2_backend, self.vpc_id)
manager.manage_security_groups(
client_master.id, client_slave.id, client_service.id
)
client_group_names = [
client_master.group_name,
client_slave.group_name,
client_service.group_name,
]
assert len(self._describe_security_groups(client_group_names)) == 3
assert len(self._default_emr_security_groups()) == 0
def test_client_supplied_invalid_security_group_identifier_raises_error(self):
manager = EmrSecurityGroupManager(ec2_backend, self.vpc_id)
args_bad = [
("sg-invalid", None, None),
(None, "sg-invalid", None),
(None, None, "sg-invalid"),
]
for args in args_bad:
with pytest.raises(ValueError) as exc:
manager.manage_security_groups(*args)
assert str(exc.value) == "The security group 'sg-invalid' does not exist"
def test_client_supplied_security_groups_have_rules_added(self):
(
client_master,
client_slave,
client_service,
) = self._create_default_client_supplied_security_groups()
manager = EmrSecurityGroupManager(ec2_backend, self.vpc_id)
manager.manage_security_groups(
client_master.id, client_slave.id, client_service.id
)
client_group_names = [
client_master.group_name,
client_slave.group_name,
client_service.group_name,
]
security_groups = self._describe_security_groups(client_group_names)
for security_group in security_groups:
assert len(security_group["IpPermissions"]) > 0
assert len(security_group["IpPermissionsEgress"]) > 0