Add support for EMR-Managed Security Groups (#3824)

* Add support for EMR-Managed Security Groups

This covers the base case for EMR Clusters provisioned in a private subnet.

Ref: https://docs.aws.amazon.com/emr/latest/ManagementGuide/emr-man-sec-groups.html

* Address PR comments

* Address PR comments
This commit is contained in:
Brian Pandola 2021-04-02 07:34:02 -07:00 committed by GitHub
parent e90858b2e8
commit ac4a26f289
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 505 additions and 1 deletions

View File

@ -2,6 +2,8 @@ from __future__ import unicode_literals
from datetime import datetime
from datetime import timedelta
import warnings
import pytz
from boto3 import Session
from dateutil.parser import parse as dtparse
@ -12,6 +14,7 @@ from .utils import (
random_cluster_id,
random_step_id,
CamelToUnderscoresWalker,
EmrSecurityGroupManager,
)
@ -363,6 +366,16 @@ class ElasticMapReduceBackend(BaseBackend):
self.__dict__ = {}
self.__init__(region_name)
@property
def ec2_backend(self):
"""
:return: EC2 Backend
:rtype: moto.ec2.models.EC2Backend
"""
from moto.ec2 import ec2_backends
return ec2_backends[self.region_name]
def add_applications(self, cluster_id, applications):
cluster = self.get_cluster(cluster_id)
cluster.add_applications(applications)
@ -501,7 +514,51 @@ class ElasticMapReduceBackend(BaseBackend):
cluster = self.get_cluster(cluster_id)
cluster.remove_tags(tag_keys)
def _manage_security_groups(
self,
ec2_subnet_id,
emr_managed_master_security_group,
emr_managed_slave_security_group,
service_access_security_group,
**_
):
default_return_value = (
emr_managed_master_security_group,
emr_managed_slave_security_group,
service_access_security_group,
)
if not ec2_subnet_id:
# TODO: Set up Security Groups in Default VPC.
return default_return_value
from moto.ec2.exceptions import InvalidSubnetIdError
try:
subnet = self.ec2_backend.get_subnet(ec2_subnet_id)
except InvalidSubnetIdError:
warnings.warn(
"Could not find Subnet with id: {0}\n"
"In the near future, this will raise an error.\n"
"Use ec2.describe_subnets() to find a suitable id "
"for your test.".format(ec2_subnet_id),
PendingDeprecationWarning,
)
return default_return_value
manager = EmrSecurityGroupManager(self.ec2_backend, subnet.vpc_id)
master, slave, service = manager.manage_security_groups(
emr_managed_master_security_group,
emr_managed_slave_security_group,
service_access_security_group,
)
return master.id, slave.id, service.id
def run_job_flow(self, **kwargs):
(
kwargs["instance_attrs"]["emr_managed_master_security_group"],
kwargs["instance_attrs"]["emr_managed_slave_security_group"],
kwargs["instance_attrs"]["service_access_security_group"],
) = self._manage_security_groups(**kwargs["instance_attrs"])
return FakeCluster(self, **kwargs)
def set_visible_to_all_users(self, job_flow_ids, visible_to_all_users):

View File

@ -1,8 +1,13 @@
from __future__ import unicode_literals
import copy
import datetime
import random
import re
import string
from moto.core.utils import camelcase_to_underscores
from moto.core.utils import (
camelcase_to_underscores,
iso_8601_datetime_with_milliseconds,
)
import six
@ -218,3 +223,238 @@ class ReleaseLabel(object):
if not isinstance(other, self.__class__):
return NotImplemented
return tuple(self) >= tuple(other)
class EmrManagedSecurityGroup(object):
class Kind:
MASTER = "Master"
SLAVE = "Slave"
SERVICE = "Service"
kind = None
group_name = ""
short_name = ""
desc_fmt = "{short_name} for Elastic MapReduce created on {created}"
@classmethod
def description(cls):
created = iso_8601_datetime_with_milliseconds(datetime.datetime.now())
return cls.desc_fmt.format(short_name=cls.short_name, created=created)
class EmrManagedMasterSecurityGroup(EmrManagedSecurityGroup):
kind = EmrManagedSecurityGroup.Kind.MASTER
group_name = "ElasticMapReduce-Master-Private"
short_name = "Master"
class EmrManagedSlaveSecurityGroup(EmrManagedSecurityGroup):
kind = EmrManagedSecurityGroup.Kind.SLAVE
group_name = "ElasticMapReduce-Slave-Private"
short_name = "Slave"
class EmrManagedServiceAccessSecurityGroup(EmrManagedSecurityGroup):
kind = EmrManagedSecurityGroup.Kind.SERVICE
group_name = "ElasticMapReduce-ServiceAccess"
short_name = "Service access"
class EmrSecurityGroupManager(object):
MANAGED_RULES_EGRESS = [
{
"group_name_or_id": EmrManagedSecurityGroup.Kind.MASTER,
"from_port": None,
"ip_protocol": "-1",
"ip_ranges": [{"CidrIp": "0.0.0.0/0"}],
"to_port": None,
"source_group_ids": [],
},
{
"group_name_or_id": EmrManagedSecurityGroup.Kind.SLAVE,
"from_port": None,
"ip_protocol": "-1",
"ip_ranges": [{"CidrIp": "0.0.0.0/0"}],
"to_port": None,
"source_group_ids": [],
},
{
"group_name_or_id": EmrManagedSecurityGroup.Kind.SERVICE,
"from_port": 8443,
"ip_protocol": "tcp",
"ip_ranges": [],
"to_port": 8443,
"source_group_ids": [
EmrManagedSecurityGroup.Kind.MASTER,
EmrManagedSecurityGroup.Kind.SLAVE,
],
},
]
MANAGED_RULES_INGRESS = [
{
"group_name_or_id": EmrManagedSecurityGroup.Kind.MASTER,
"from_port": 0,
"ip_protocol": "tcp",
"ip_ranges": [],
"to_port": 65535,
"source_group_ids": [
EmrManagedSecurityGroup.Kind.MASTER,
EmrManagedSecurityGroup.Kind.SLAVE,
],
},
{
"group_name_or_id": EmrManagedSecurityGroup.Kind.MASTER,
"from_port": 8443,
"ip_protocol": "tcp",
"ip_ranges": [],
"to_port": 8443,
"source_group_ids": [EmrManagedSecurityGroup.Kind.SERVICE],
},
{
"group_name_or_id": EmrManagedSecurityGroup.Kind.MASTER,
"from_port": 0,
"ip_protocol": "udp",
"ip_ranges": [],
"to_port": 65535,
"source_group_ids": [
EmrManagedSecurityGroup.Kind.MASTER,
EmrManagedSecurityGroup.Kind.SLAVE,
],
},
{
"group_name_or_id": EmrManagedSecurityGroup.Kind.MASTER,
"from_port": -1,
"ip_protocol": "icmp",
"ip_ranges": [],
"to_port": -1,
"source_group_ids": [
EmrManagedSecurityGroup.Kind.MASTER,
EmrManagedSecurityGroup.Kind.SLAVE,
],
},
{
"group_name_or_id": EmrManagedSecurityGroup.Kind.SLAVE,
"from_port": 0,
"ip_protocol": "tcp",
"ip_ranges": [],
"to_port": 65535,
"source_group_ids": [
EmrManagedSecurityGroup.Kind.SLAVE,
EmrManagedSecurityGroup.Kind.MASTER,
],
},
{
"group_name_or_id": EmrManagedSecurityGroup.Kind.SLAVE,
"from_port": 8443,
"ip_protocol": "tcp",
"ip_ranges": [],
"to_port": 8443,
"source_group_ids": [EmrManagedSecurityGroup.Kind.SERVICE],
},
{
"group_name_or_id": EmrManagedSecurityGroup.Kind.SLAVE,
"from_port": 0,
"ip_protocol": "udp",
"ip_ranges": [],
"to_port": 65535,
"source_group_ids": [
EmrManagedSecurityGroup.Kind.MASTER,
EmrManagedSecurityGroup.Kind.SLAVE,
],
},
{
"group_name_or_id": EmrManagedSecurityGroup.Kind.SLAVE,
"from_port": -1,
"ip_protocol": "icmp",
"ip_ranges": [],
"to_port": -1,
"source_group_ids": [
EmrManagedSecurityGroup.Kind.MASTER,
EmrManagedSecurityGroup.Kind.SLAVE,
],
},
{
"group_name_or_id": EmrManagedSecurityGroup.Kind.SERVICE,
"from_port": 9443,
"ip_protocol": "tcp",
"ip_ranges": [],
"to_port": 9443,
"source_group_ids": [EmrManagedSecurityGroup.Kind.MASTER],
},
]
def __init__(self, ec2_backend, vpc_id):
self.ec2 = ec2_backend
self.vpc_id = vpc_id
def manage_security_groups(
self, master_security_group, slave_security_group, service_access_security_group
):
group_metadata = [
(
master_security_group,
EmrManagedSecurityGroup.Kind.MASTER,
EmrManagedMasterSecurityGroup,
),
(
slave_security_group,
EmrManagedSecurityGroup.Kind.SLAVE,
EmrManagedSlaveSecurityGroup,
),
(
service_access_security_group,
EmrManagedSecurityGroup.Kind.SERVICE,
EmrManagedServiceAccessSecurityGroup,
),
]
managed_groups = {}
for name, kind, defaults in group_metadata:
managed_groups[kind] = self._get_or_create_sg(name, defaults)
self._add_rules_to(managed_groups)
return (
managed_groups[EmrManagedSecurityGroup.Kind.MASTER],
managed_groups[EmrManagedSecurityGroup.Kind.SLAVE],
managed_groups[EmrManagedSecurityGroup.Kind.SERVICE],
)
def _get_or_create_sg(self, sg_id, defaults):
find_sg = self.ec2.get_security_group_by_name_or_id
create_sg = self.ec2.create_security_group
group_id_or_name = sg_id or defaults.group_name
group = find_sg(group_id_or_name, self.vpc_id)
if group is None:
if group_id_or_name != defaults.group_name:
raise ValueError(
"The security group '{}' does not exist".format(group_id_or_name)
)
group = create_sg(defaults.group_name, defaults.description(), self.vpc_id)
return group
def _add_rules_to(self, managed_groups):
rules_metadata = [
(self.MANAGED_RULES_EGRESS, self.ec2.authorize_security_group_egress),
(self.MANAGED_RULES_INGRESS, self.ec2.authorize_security_group_ingress),
]
for rules, add_rule in rules_metadata:
rendered_rules = self._render_rules(rules, managed_groups)
for rule in rendered_rules:
from moto.ec2.exceptions import InvalidPermissionDuplicateError
try:
add_rule(vpc_id=self.vpc_id, **rule)
except InvalidPermissionDuplicateError:
# If the rule already exists, we can just move on.
pass
@staticmethod
def _render_rules(rules, managed_groups):
rendered_rules = copy.deepcopy(rules)
for rule in rendered_rules:
rule["group_name_or_id"] = managed_groups[rule["group_name_or_id"]].id
rule["source_group_ids"] = [
managed_groups[group].id for group in rule["source_group_ids"]
]
return rendered_rules

View File

@ -0,0 +1,207 @@
from __future__ import unicode_literals
import boto3
import pytest
import sure # noqa
from moto import settings
from moto.ec2 import mock_ec2, ec2_backend
from moto.emr import mock_emr
from moto.emr.utils import EmrSecurityGroupManager
@mock_emr
@mock_ec2
def test_default_emr_security_groups_get_created_on_first_job_flow():
ec2 = boto3.resource("ec2", region_name="us-east-1")
ec2_client = boto3.client("ec2", region_name="us-east-1")
vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16")
subnet = ec2.create_subnet(
VpcId=vpc.id, CidrBlock="10.0.0.0/24", AvailabilityZone="us-east-1a"
)
def _get_default_security_groups():
group_resp = ec2_client.describe_security_groups(
Filters=[
{"Name": "vpc-id", "Values": [vpc.id]},
{
"Name": "group-name",
"Values": [
"ElasticMapReduce-Master-Private",
"ElasticMapReduce-Slave-Private",
"ElasticMapReduce-ServiceAccess",
],
},
]
)
return group_resp.get("SecurityGroups", [])
assert len(_get_default_security_groups()) == 0
client = boto3.client("emr", region_name="us-east-1")
run_job_flow_params = dict(
ReleaseLabel="emr-5.29.0",
Instances={
"KeepJobFlowAliveWhenNoSteps": True,
"Ec2SubnetId": subnet.id,
"InstanceGroups": [
{
"Name": "Master",
"Market": "ON_DEMAND",
"InstanceRole": "MASTER",
"InstanceType": "m5.xlarge",
"InstanceCount": 3,
},
{
"Name": "Core",
"Market": "ON_DEMAND",
"InstanceRole": "CORE",
"InstanceType": "m5.xlarge",
"InstanceCount": 2,
},
],
},
JobFlowRole="EMR_EC2_DefaultRole",
Name="test-emr-cluster-security-groups",
ServiceRole="EMR_DefaultRole",
VisibleToAllUsers=True,
)
cluster_id = client.run_job_flow(**run_job_flow_params)["JobFlowId"]
# Default security groups should have been created.
default_security_groups = _get_default_security_groups()
default_security_group_ids = [sg["GroupId"] for sg in default_security_groups]
assert len(default_security_group_ids) == 3
resp = client.describe_cluster(ClusterId=cluster_id)
ec2_attrs = resp["Cluster"]["Ec2InstanceAttributes"]
assert ec2_attrs["Ec2SubnetId"] == subnet.id
cluster_security_group_ids = [
ec2_attrs["EmrManagedMasterSecurityGroup"],
ec2_attrs["EmrManagedSlaveSecurityGroup"],
ec2_attrs["ServiceAccessSecurityGroup"],
]
assert set(cluster_security_group_ids) == set(default_security_group_ids)
@pytest.mark.skipif(
settings.TEST_SERVER_MODE, reason="Can't modify backend directly in server mode."
)
class TestEmrSecurityGroupManager(object):
mocks = []
def setup(self):
self.mocks = [mock_ec2()]
for mock in self.mocks:
mock.start()
ec2_client = boto3.client("ec2", region_name="us-east-1")
ec2 = boto3.resource("ec2", region_name="us-east-1")
vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16")
self.vpc_id = vpc.id
self.ec2 = ec2
self.ec2_client = ec2_client
def teardown(self):
for mock in self.mocks:
mock.stop()
def _create_default_client_supplied_security_groups(self):
master = self.ec2.create_security_group(
GroupName="master", Description="master", VpcId=self.vpc_id
)
slave = self.ec2.create_security_group(
GroupName="slave", Description="slave", VpcId=self.vpc_id
)
service = self.ec2.create_security_group(
GroupName="service", Description="service", VpcId=self.vpc_id
)
return master, slave, service
def _describe_security_groups(self, group_names):
resp = self.ec2_client.describe_security_groups(
Filters=[
{"Name": "vpc-id", "Values": [self.vpc_id]},
{"Name": "group-name", "Values": group_names},
]
)
return resp.get("SecurityGroups", [])
def _default_emr_security_groups(self):
group_names = [
"ElasticMapReduce-Master-Private",
"ElasticMapReduce-Slave-Private",
"ElasticMapReduce-ServiceAccess",
]
return self._describe_security_groups(group_names)
def test_emr_security_groups_get_created_if_non_existent(self):
manager = EmrSecurityGroupManager(ec2_backend, self.vpc_id)
assert len(self._default_emr_security_groups()) == 0
manager.manage_security_groups(None, None, None)
assert len(self._default_emr_security_groups()) == 3
def test_emr_security_groups_do_not_get_created_if_already_exist(self):
manager = EmrSecurityGroupManager(ec2_backend, self.vpc_id)
assert len(self._default_emr_security_groups()) == 0
manager.manage_security_groups(None, None, None)
emr_security_groups = self._default_emr_security_groups()
assert len(emr_security_groups) == 3
# Run again. Group count should still be 3.
emr_sg_ids_expected = [sg["GroupId"] for sg in emr_security_groups]
manager.manage_security_groups(None, None, None)
emr_security_groups = self._default_emr_security_groups()
assert len(emr_security_groups) == 3
emr_sg_ids_actual = [sg["GroupId"] for sg in emr_security_groups]
assert emr_sg_ids_actual == emr_sg_ids_expected
def test_emr_security_groups_do_not_get_created_if_client_supplied(self):
(
client_master,
client_slave,
client_service,
) = self._create_default_client_supplied_security_groups()
manager = EmrSecurityGroupManager(ec2_backend, self.vpc_id)
manager.manage_security_groups(
client_master.id, client_slave.id, client_service.id
)
client_group_names = [
client_master.group_name,
client_slave.group_name,
client_service.group_name,
]
assert len(self._describe_security_groups(client_group_names)) == 3
assert len(self._default_emr_security_groups()) == 0
def test_client_supplied_invalid_security_group_identifier_raises_error(self):
manager = EmrSecurityGroupManager(ec2_backend, self.vpc_id)
args_bad = [
("sg-invalid", None, None),
(None, "sg-invalid", None),
(None, None, "sg-invalid"),
]
for args in args_bad:
with pytest.raises(ValueError) as exc:
manager.manage_security_groups(*args)
assert str(exc.value) == "The security group 'sg-invalid' does not exist"
def test_client_supplied_security_groups_have_rules_added(self):
(
client_master,
client_slave,
client_service,
) = self._create_default_client_supplied_security_groups()
manager = EmrSecurityGroupManager(ec2_backend, self.vpc_id)
manager.manage_security_groups(
client_master.id, client_slave.id, client_service.id
)
client_group_names = [
client_master.group_name,
client_slave.group_name,
client_service.group_name,
]
security_groups = self._describe_security_groups(client_group_names)
for security_group in security_groups:
assert len(security_group["IpPermissions"]) > 0
assert len(security_group["IpPermissionsEgress"]) > 0