Techdebt: MyPy EMR (#6005)

This commit is contained in:
Bert Blommers 2023-03-03 18:42:11 -01:00 committed by GitHub
parent 0e24b281eb
commit a1a43e3f74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 293 additions and 238 deletions

View File

@ -218,7 +218,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
access_key_regex = re.compile(
r"AWS.*(?P<access_key>(?<![A-Z0-9])[A-Z0-9]{20}(?![A-Z0-9]))[:/]"
)
aws_service_spec = None
aws_service_spec: Optional["AWSServiceSpec"] = None
def __init__(self, service_name: Optional[str] = None):
super().__init__()

View File

@ -2,15 +2,15 @@ from moto.core.exceptions import JsonRESTError
class InvalidRequestException(JsonRESTError):
def __init__(self, message, **kwargs):
super().__init__("InvalidRequestException", message, **kwargs)
def __init__(self, message: str):
super().__init__("InvalidRequestException", message)
class ValidationException(JsonRESTError):
def __init__(self, message, **kwargs):
super().__init__("ValidationException", message, **kwargs)
def __init__(self, message: str):
super().__init__("ValidationException", message)
class ResourceNotFoundException(JsonRESTError):
def __init__(self, message, **kwargs):
super().__init__("ResourceNotFoundException", message, **kwargs)
def __init__(self, message: str):
super().__init__("ResourceNotFoundException", message)

View File

@ -1,5 +1,5 @@
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional, Tuple
import warnings
from dateutil.parser import parse as dtparse
@ -21,7 +21,9 @@ EXAMPLE_AMI_ID = "ami-12c6146b"
class FakeApplication(BaseModel):
def __init__(self, name, version, args=None, additional_info=None):
def __init__(
self, name: str, version: str, args: List[str], additional_info: Dict[str, str]
):
self.additional_info = additional_info or {}
self.args = args or []
self.name = name
@ -29,7 +31,7 @@ class FakeApplication(BaseModel):
class FakeBootstrapAction(BaseModel):
def __init__(self, args, name, script_path):
def __init__(self, args: List[str], name: str, script_path: str):
self.args = args or []
self.name = name
self.script_path = script_path
@ -37,7 +39,11 @@ class FakeBootstrapAction(BaseModel):
class FakeInstance(BaseModel):
def __init__(
self, ec2_instance_id, instance_group, instance_fleet_id=None, instance_id=None
self,
ec2_instance_id: str,
instance_group: "FakeInstanceGroup",
instance_fleet_id: Optional[str] = None,
instance_id: Optional[str] = None,
):
self.id = instance_id or random_instance_group_id()
self.ec2_instance_id = ec2_instance_id
@ -48,16 +54,16 @@ class FakeInstance(BaseModel):
class FakeInstanceGroup(BaseModel):
def __init__(
self,
cluster_id,
instance_count,
instance_role,
instance_type,
market="ON_DEMAND",
name=None,
instance_group_id=None,
bid_price=None,
ebs_configuration=None,
auto_scaling_policy=None,
cluster_id: str,
instance_count: int,
instance_role: str,
instance_type: str,
market: str = "ON_DEMAND",
name: Optional[str] = None,
instance_group_id: Optional[str] = None,
bid_price: Optional[str] = None,
ebs_configuration: Optional[Dict[str, Any]] = None,
auto_scaling_policy: Optional[Dict[str, Any]] = None,
):
self.id = instance_group_id or random_instance_group_id()
self.cluster_id = cluster_id
@ -83,15 +89,15 @@ class FakeInstanceGroup(BaseModel):
self.end_datetime = None
self.state = "RUNNING"
def set_instance_count(self, instance_count):
def set_instance_count(self, instance_count: int) -> None:
self.num_instances = instance_count
@property
def auto_scaling_policy(self):
def auto_scaling_policy(self) -> Any: # type: ignore[misc]
return self._auto_scaling_policy
@auto_scaling_policy.setter
def auto_scaling_policy(self, value):
def auto_scaling_policy(self, value: Any) -> None:
if value is None:
self._auto_scaling_policy = value
return
@ -118,12 +124,12 @@ class FakeInstanceGroup(BaseModel):
class FakeStep(BaseModel):
def __init__(
self,
state,
name="",
jar="",
args=None,
properties=None,
action_on_failure="TERMINATE_CLUSTER",
state: str,
name: str = "",
jar: str = "",
args: Optional[List[str]] = None,
properties: Optional[Dict[str, str]] = None,
action_on_failure: str = "TERMINATE_CLUSTER",
):
self.id = random_step_id()
@ -136,63 +142,63 @@ class FakeStep(BaseModel):
self.creation_datetime = datetime.now(timezone.utc)
self.end_datetime = None
self.ready_datetime = None
self.start_datetime = None
self.start_datetime: Optional[datetime] = None
self.state = state
def start(self):
def start(self) -> None:
self.start_datetime = datetime.now(timezone.utc)
class FakeCluster(BaseModel):
def __init__(
self,
emr_backend,
name,
log_uri,
job_flow_role,
service_role,
steps,
instance_attrs,
bootstrap_actions=None,
configurations=None,
cluster_id=None,
visible_to_all_users="false",
release_label=None,
requested_ami_version=None,
running_ami_version=None,
custom_ami_id=None,
step_concurrency_level=1,
security_configuration=None,
kerberos_attributes=None,
auto_scaling_role=None,
emr_backend: "ElasticMapReduceBackend",
name: str,
log_uri: str,
job_flow_role: str,
service_role: str,
steps: List[Dict[str, Any]],
instance_attrs: Dict[str, Any],
bootstrap_actions: Optional[List[Dict[str, Any]]] = None,
configurations: Optional[List[Dict[str, Any]]] = None,
cluster_id: Optional[str] = None,
visible_to_all_users: str = "false",
release_label: Optional[str] = None,
requested_ami_version: Optional[str] = None,
running_ami_version: Optional[str] = None,
custom_ami_id: Optional[str] = None,
step_concurrency_level: int = 1,
security_configuration: Optional[str] = None,
kerberos_attributes: Optional[Dict[str, str]] = None,
auto_scaling_role: Optional[str] = None,
):
self.id = cluster_id or random_cluster_id()
emr_backend.clusters[self.id] = self
self.emr_backend = emr_backend
self.applications = []
self.applications: List[FakeApplication] = []
self.bootstrap_actions = []
self.bootstrap_actions: List[FakeBootstrapAction] = []
for bootstrap_action in bootstrap_actions or []:
self.add_bootstrap_action(bootstrap_action)
self.configurations = configurations or []
self.tags = {}
self.tags: Dict[str, str] = {}
self.log_uri = log_uri
self.name = name
self.normalized_instance_hours = 0
self.steps = []
self.steps: List[FakeStep] = []
self.add_steps(steps)
self.set_visibility(visible_to_all_users)
self.instance_group_ids = []
self.instances = []
self.master_instance_group_id = None
self.core_instance_group_id = None
self.instance_group_ids: List[str] = []
self.instances: List[FakeInstance] = []
self.master_instance_group_id: Optional[str] = None
self.core_instance_group_id: Optional[str] = None
if (
"master_instance_type" in instance_attrs
and instance_attrs["master_instance_type"]
@ -259,10 +265,10 @@ class FakeCluster(BaseModel):
self.step_concurrency_level = step_concurrency_level
self.creation_datetime = datetime.now(timezone.utc)
self.start_datetime = None
self.ready_datetime = None
self.end_datetime = None
self.state = None
self.start_datetime: Optional[datetime] = None
self.ready_datetime: Optional[datetime] = None
self.end_datetime: Optional[datetime] = None
self.state: Optional[str] = None
self.start_cluster()
self.run_bootstrap_actions()
@ -275,30 +281,30 @@ class FakeCluster(BaseModel):
self.auto_scaling_role = auto_scaling_role
@property
def arn(self):
def arn(self) -> str:
return f"arn:aws:elasticmapreduce:{self.emr_backend.region_name}:{self.emr_backend.account_id}:cluster/{self.id}"
@property
def instance_groups(self):
def instance_groups(self) -> List[FakeInstanceGroup]:
return self.emr_backend.get_instance_groups(self.instance_group_ids)
@property
def master_instance_type(self):
return self.emr_backend.instance_groups[self.master_instance_group_id].type
def master_instance_type(self) -> str:
return self.emr_backend.instance_groups[self.master_instance_group_id].type # type: ignore
@property
def slave_instance_type(self):
return self.emr_backend.instance_groups[self.core_instance_group_id].type
def slave_instance_type(self) -> str:
return self.emr_backend.instance_groups[self.core_instance_group_id].type # type: ignore
@property
def instance_count(self):
def instance_count(self) -> int:
return sum(group.num_instances for group in self.instance_groups)
def start_cluster(self):
def start_cluster(self) -> None:
self.state = "STARTING"
self.start_datetime = datetime.now(timezone.utc)
def run_bootstrap_actions(self):
def run_bootstrap_actions(self) -> None:
self.state = "BOOTSTRAPPING"
self.ready_datetime = datetime.now(timezone.utc)
self.state = "WAITING"
@ -306,28 +312,28 @@ class FakeCluster(BaseModel):
if not self.keep_job_flow_alive_when_no_steps:
self.terminate()
def terminate(self):
def terminate(self) -> None:
self.state = "TERMINATING"
self.end_datetime = datetime.now(timezone.utc)
self.state = "TERMINATED"
def add_applications(self, applications):
def add_applications(self, applications: List[Dict[str, Any]]) -> None:
self.applications.extend(
[
FakeApplication(
name=app.get("name", ""),
version=app.get("version", ""),
args=app.get("args", []),
additional_info=app.get("additiona_info", {}),
additional_info=app.get("additional_info", {}),
)
for app in applications
]
)
def add_bootstrap_action(self, bootstrap_action):
def add_bootstrap_action(self, bootstrap_action: Dict[str, Any]) -> None:
self.bootstrap_actions.append(FakeBootstrapAction(**bootstrap_action))
def add_instance_group(self, instance_group):
def add_instance_group(self, instance_group: FakeInstanceGroup) -> None:
if instance_group.role == "MASTER":
if self.master_instance_group_id:
raise Exception("Cannot add another master instance group")
@ -347,10 +353,10 @@ class FakeCluster(BaseModel):
self.core_instance_group_id = instance_group.id
self.instance_group_ids.append(instance_group.id)
def add_instance(self, instance):
def add_instance(self, instance: FakeInstance) -> None:
self.instances.append(instance)
def add_steps(self, steps):
def add_steps(self, steps: List[Dict[str, Any]]) -> List[FakeStep]:
added_steps = []
for step in steps:
if self.steps:
@ -363,43 +369,45 @@ class FakeCluster(BaseModel):
self.state = "RUNNING"
return added_steps
def add_tags(self, tags):
def add_tags(self, tags: Dict[str, str]) -> None:
self.tags.update(tags)
def remove_tags(self, tag_keys):
def remove_tags(self, tag_keys: List[str]) -> None:
for key in tag_keys:
self.tags.pop(key, None)
def set_termination_protection(self, value):
def set_termination_protection(self, value: bool) -> None:
self.termination_protected = value
def set_visibility(self, visibility):
def set_visibility(self, visibility: str) -> None:
self.visible_to_all_users = visibility
class FakeSecurityConfiguration(BaseModel):
def __init__(self, name, security_configuration):
def __init__(self, name: str, security_configuration: str):
self.name = name
self.security_configuration = security_configuration
self.creation_date_time = datetime.now(timezone.utc)
class ElasticMapReduceBackend(BaseBackend):
def __init__(self, region_name, account_id):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.clusters = {}
self.instance_groups = {}
self.security_configurations = {}
self.clusters: Dict[str, FakeCluster] = {}
self.instance_groups: Dict[str, FakeInstanceGroup] = {}
self.security_configurations: Dict[str, FakeSecurityConfiguration] = {}
@staticmethod
def default_vpc_endpoint_service(service_region, zones):
def default_vpc_endpoint_service(
service_region: str, zones: List[str]
) -> List[Dict[str, str]]:
"""Default VPC endpoint service."""
return BaseBackend.default_vpc_endpoint_service_factory(
service_region, zones, "elasticmapreduce"
)
@property
def ec2_backend(self):
def ec2_backend(self) -> Any: # type: ignore[misc]
"""
:return: EC2 Backend
:rtype: moto.ec2.models.EC2Backend
@ -408,11 +416,15 @@ class ElasticMapReduceBackend(BaseBackend):
return ec2_backends[self.account_id][self.region_name]
def add_applications(self, cluster_id, applications):
def add_applications(
self, cluster_id: str, applications: List[Dict[str, Any]]
) -> None:
cluster = self.describe_cluster(cluster_id)
cluster.add_applications(applications)
def add_instance_groups(self, cluster_id, instance_groups):
def add_instance_groups(
self, cluster_id: str, instance_groups: List[Dict[str, Any]]
) -> List[FakeInstanceGroup]:
cluster = self.clusters[cluster_id]
result_groups = []
for instance_group in instance_groups:
@ -422,7 +434,12 @@ class ElasticMapReduceBackend(BaseBackend):
result_groups.append(group)
return result_groups
def add_instances(self, cluster_id, instances, instance_group):
def add_instances(
self,
cluster_id: str,
instances: Dict[str, Any],
instance_group: FakeInstanceGroup,
) -> None:
cluster = self.clusters[cluster_id]
instances["is_instance_type_default"] = not instances.get("instance_type")
response = self.ec2_backend.add_instances(
@ -434,23 +451,24 @@ class ElasticMapReduceBackend(BaseBackend):
)
cluster.add_instance(instance)
def add_job_flow_steps(self, job_flow_id, steps):
def add_job_flow_steps(
self, job_flow_id: str, steps: List[Dict[str, Any]]
) -> List[FakeStep]:
cluster = self.clusters[job_flow_id]
steps = cluster.add_steps(steps)
return steps
return cluster.add_steps(steps)
def add_tags(self, cluster_id, tags):
def add_tags(self, cluster_id: str, tags: Dict[str, str]) -> None:
cluster = self.describe_cluster(cluster_id)
cluster.add_tags(tags)
def describe_job_flows(
self,
job_flow_ids=None,
job_flow_states=None,
created_after=None,
created_before=None,
):
clusters = self.clusters.values()
job_flow_ids: Optional[List[str]] = None,
job_flow_states: Optional[List[str]] = None,
created_after: Optional[str] = None,
created_before: Optional[str] = None,
) -> List[FakeCluster]:
clusters = list(self.clusters.values())
within_two_month = datetime.now(timezone.utc) - timedelta(days=60)
clusters = [c for c in clusters if c.creation_datetime >= within_two_month]
@ -460,34 +478,41 @@ class ElasticMapReduceBackend(BaseBackend):
if job_flow_states:
clusters = [c for c in clusters if c.state in job_flow_states]
if created_after:
created_after = dtparse(created_after)
clusters = [c for c in clusters if c.creation_datetime > created_after]
clusters = [
c for c in clusters if c.creation_datetime > dtparse(created_after)
]
if created_before:
created_before = dtparse(created_before)
clusters = [c for c in clusters if c.creation_datetime < created_before]
clusters = [
c for c in clusters if c.creation_datetime < dtparse(created_before)
]
# Amazon EMR can return a maximum of 512 job flow descriptions
return sorted(clusters, key=lambda x: x.id)[:512]
def describe_step(self, cluster_id, step_id):
def describe_step(self, cluster_id: str, step_id: str) -> Optional[FakeStep]:
cluster = self.clusters[cluster_id]
for step in cluster.steps:
if step.id == step_id:
return step
return None
def describe_cluster(self, cluster_id):
def describe_cluster(self, cluster_id: str) -> FakeCluster:
if cluster_id in self.clusters:
return self.clusters[cluster_id]
raise ResourceNotFoundException("")
def get_instance_groups(self, instance_group_ids):
def get_instance_groups(
self, instance_group_ids: List[str]
) -> List[FakeInstanceGroup]:
return [
group
for group_id, group in self.instance_groups.items()
if group_id in instance_group_ids
]
def list_bootstrap_actions(self, cluster_id, marker=None):
def list_bootstrap_actions(
self, cluster_id: str, marker: Optional[str] = None
) -> Tuple[List[FakeBootstrapAction], Optional[str]]:
max_items = 50
actions = self.clusters[cluster_id].bootstrap_actions
start_idx = 0 if marker is None else int(marker)
@ -499,18 +524,24 @@ class ElasticMapReduceBackend(BaseBackend):
return actions[start_idx : start_idx + max_items], marker
def list_clusters(
self, cluster_states=None, created_after=None, created_before=None, marker=None
):
self,
cluster_states: Optional[List[str]] = None,
created_after: Optional[str] = None,
created_before: Optional[str] = None,
marker: Optional[str] = None,
) -> Tuple[List[FakeCluster], Optional[str]]:
max_items = 50
clusters = self.clusters.values()
clusters = list(self.clusters.values())
if cluster_states:
clusters = [c for c in clusters if c.state in cluster_states]
if created_after:
created_after = dtparse(created_after)
clusters = [c for c in clusters if c.creation_datetime > created_after]
clusters = [
c for c in clusters if c.creation_datetime > dtparse(created_after)
]
if created_before:
created_before = dtparse(created_before)
clusters = [c for c in clusters if c.creation_datetime < created_before]
clusters = [
c for c in clusters if c.creation_datetime < dtparse(created_before)
]
clusters = sorted(clusters, key=lambda x: x.id)
start_idx = 0 if marker is None else int(marker)
marker = (
@ -520,7 +551,9 @@ class ElasticMapReduceBackend(BaseBackend):
)
return clusters[start_idx : start_idx + max_items], marker
def list_instance_groups(self, cluster_id, marker=None):
def list_instance_groups(
self, cluster_id: str, marker: Optional[str] = None
) -> Tuple[List[FakeInstanceGroup], Optional[str]]:
max_items = 50
groups = sorted(self.clusters[cluster_id].instance_groups, key=lambda x: x.id)
start_idx = 0 if marker is None else int(marker)
@ -530,8 +563,12 @@ class ElasticMapReduceBackend(BaseBackend):
return groups[start_idx : start_idx + max_items], marker
def list_instances(
self, cluster_id, marker=None, instance_group_id=None, instance_group_types=None
):
self,
cluster_id: str,
marker: Optional[str] = None,
instance_group_id: Optional[str] = None,
instance_group_types: Optional[List[str]] = None,
) -> Tuple[List[FakeInstance], Optional[str]]:
max_items = 50
groups = sorted(self.clusters[cluster_id].instances, key=lambda x: x.id)
start_idx = 0 if marker is None else int(marker)
@ -545,10 +582,16 @@ class ElasticMapReduceBackend(BaseBackend):
g for g in groups if g.instance_group.role in instance_group_types
]
for g in groups:
g.details = self.ec2_backend.get_instance(g.ec2_instance_id)
g.details = self.ec2_backend.get_instance(g.ec2_instance_id) # type: ignore
return groups[start_idx : start_idx + max_items], marker
def list_steps(self, cluster_id, marker=None, step_ids=None, step_states=None):
def list_steps(
self,
cluster_id: str,
marker: Optional[str] = None,
step_ids: Optional[List[str]] = None,
step_states: Optional[List[str]] = None,
) -> Tuple[List[FakeStep], Optional[str]]:
max_items = 50
steps = sorted(
self.clusters[cluster_id].steps,
@ -565,30 +608,30 @@ class ElasticMapReduceBackend(BaseBackend):
)
return steps[start_idx : start_idx + max_items], marker
def modify_cluster(self, cluster_id, step_concurrency_level):
def modify_cluster(
self, cluster_id: str, step_concurrency_level: int
) -> FakeCluster:
cluster = self.clusters[cluster_id]
cluster.step_concurrency_level = step_concurrency_level
return cluster
def modify_instance_groups(self, instance_groups):
result_groups = []
def modify_instance_groups(self, instance_groups: List[Dict[str, Any]]) -> None:
for instance_group in instance_groups:
group = self.instance_groups[instance_group["instance_group_id"]]
group.set_instance_count(int(instance_group["instance_count"]))
return result_groups
def remove_tags(self, cluster_id, tag_keys):
def remove_tags(self, cluster_id: str, tag_keys: List[str]) -> None:
cluster = self.describe_cluster(cluster_id)
cluster.remove_tags(tag_keys)
def _manage_security_groups(
self,
ec2_subnet_id,
emr_managed_master_security_group,
emr_managed_slave_security_group,
service_access_security_group,
**_,
):
ec2_subnet_id: str,
emr_managed_master_security_group: str,
emr_managed_slave_security_group: str,
service_access_security_group: str,
**_: Any,
) -> Tuple[str, str, str]:
default_return_value = (
emr_managed_master_security_group,
emr_managed_slave_security_group,
@ -601,7 +644,7 @@ class ElasticMapReduceBackend(BaseBackend):
from moto.ec2.exceptions import InvalidSubnetIdError
try:
subnet = self.ec2_backend.get_subnet(ec2_subnet_id)
subnet = self.ec2_backend.get_subnet(ec2_subnet_id) # type: ignore
except InvalidSubnetIdError:
warnings.warn(
f"Could not find Subnet with id: {ec2_subnet_id}\n"
@ -620,7 +663,7 @@ class ElasticMapReduceBackend(BaseBackend):
)
return master.id, slave.id, service.id
def run_job_flow(self, **kwargs):
def run_job_flow(self, **kwargs: Any) -> FakeCluster:
(
kwargs["instance_attrs"]["emr_managed_master_security_group"],
kwargs["instance_attrs"]["emr_managed_slave_security_group"],
@ -628,17 +671,19 @@ class ElasticMapReduceBackend(BaseBackend):
) = self._manage_security_groups(**kwargs["instance_attrs"])
return FakeCluster(self, **kwargs)
def set_visible_to_all_users(self, job_flow_ids, visible_to_all_users):
def set_visible_to_all_users(
self, job_flow_ids: List[str], visible_to_all_users: str
) -> None:
for job_flow_id in job_flow_ids:
cluster = self.clusters[job_flow_id]
cluster.set_visibility(visible_to_all_users)
def set_termination_protection(self, job_flow_ids, value):
def set_termination_protection(self, job_flow_ids: List[str], value: bool) -> None:
for job_flow_id in job_flow_ids:
cluster = self.clusters[job_flow_id]
cluster.set_termination_protection(value)
def terminate_job_flows(self, job_flow_ids):
def terminate_job_flows(self, job_flow_ids: List[str]) -> List[FakeCluster]:
clusters_terminated = []
clusters_protected = []
for job_flow_id in job_flow_ids:
@ -654,7 +699,9 @@ class ElasticMapReduceBackend(BaseBackend):
)
return clusters_terminated
def put_auto_scaling_policy(self, instance_group_id, auto_scaling_policy):
def put_auto_scaling_policy(
self, instance_group_id: str, auto_scaling_policy: Optional[Dict[str, Any]]
) -> Optional[FakeInstanceGroup]:
instance_groups = self.get_instance_groups(
instance_group_ids=[instance_group_id]
)
@ -664,28 +711,30 @@ class ElasticMapReduceBackend(BaseBackend):
instance_group.auto_scaling_policy = auto_scaling_policy
return instance_group
def remove_auto_scaling_policy(self, instance_group_id):
def remove_auto_scaling_policy(self, instance_group_id: str) -> None:
self.put_auto_scaling_policy(instance_group_id, auto_scaling_policy=None)
def create_security_configuration(self, name, security_configuration):
def create_security_configuration(
self, name: str, security_configuration: str
) -> FakeSecurityConfiguration:
if name in self.security_configurations:
raise InvalidRequestException(
message=f"SecurityConfiguration with name '{name}' already exists."
)
security_configuration = FakeSecurityConfiguration(
config = FakeSecurityConfiguration(
name=name, security_configuration=security_configuration
)
self.security_configurations[name] = security_configuration
return security_configuration
self.security_configurations[name] = config
return config
def get_security_configuration(self, name):
def get_security_configuration(self, name: str) -> FakeSecurityConfiguration:
if name not in self.security_configurations:
raise InvalidRequestException(
message=f"Security configuration with name '{name}' does not exist."
)
return self.security_configurations[name]
def delete_security_configuration(self, name):
def delete_security_configuration(self, name: str) -> None:
if name not in self.security_configurations:
raise InvalidRequestException(
message=f"Security configuration with name '{name}' does not exist."

View File

@ -2,6 +2,7 @@ import json
import re
from datetime import datetime, timezone
from functools import wraps
from typing import Any, Callable, Dict, List, Pattern
from urllib.parse import urlparse
from moto.core.responses import AWSServiceSpec
@ -9,20 +10,27 @@ from moto.core.responses import BaseResponse
from moto.core.responses import xml_to_json_response
from moto.core.utils import tags_from_query_string
from .exceptions import ValidationException
from .models import emr_backends
from .models import emr_backends, ElasticMapReduceBackend
from .utils import steps_from_query_string, Unflattener, ReleaseLabel
def generate_boto3_response(operation):
def generate_boto3_response(
operation: str,
) -> Callable[
[Callable[["ElasticMapReduceResponse"], str]],
Callable[["ElasticMapReduceResponse"], str],
]:
"""The decorator to convert an XML response to JSON, if the request is
determined to be from boto3. Pass the API action as a parameter.
"""
def _boto3_request(method):
def _boto3_request(
method: Callable[["ElasticMapReduceResponse"], str]
) -> Callable[["ElasticMapReduceResponse"], str]:
@wraps(method)
def f(self, *args, **kwargs):
rendered = method(self, *args, **kwargs)
def f(self: "ElasticMapReduceResponse") -> str:
rendered = method(self)
if "json" in self.headers.get("Content-Type", []):
self.response_headers.update(
{
@ -46,30 +54,30 @@ class ElasticMapReduceResponse(BaseResponse):
# EMR end points are inconsistent in the placement of region name
# in the URL, so parsing it out needs to be handled differently
region_regex = [
emr_region_regex: List[Pattern[str]] = [
re.compile(r"elasticmapreduce\.(.+?)\.amazonaws\.com"),
re.compile(r"(.+?)\.elasticmapreduce\.amazonaws\.com"),
]
aws_service_spec = AWSServiceSpec("data/emr/2009-03-31/service-2.json")
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="emr")
def get_region_from_url(self, request, full_url):
def get_region_from_url(self, request: Any, full_url: str) -> str:
parsed = urlparse(full_url)
for regex in self.region_regex:
for regex in ElasticMapReduceResponse.emr_region_regex:
match = regex.search(parsed.netloc)
if match:
return match.group(1)
return self.default_region
@property
def backend(self):
def backend(self) -> ElasticMapReduceBackend:
return emr_backends[self.current_account][self.region]
@generate_boto3_response("AddInstanceGroups")
def add_instance_groups(self):
def add_instance_groups(self) -> str:
jobflow_id = self._get_param("JobFlowId")
instance_groups = self._get_list_prefix("InstanceGroups.member")
for item in instance_groups:
@ -78,12 +86,12 @@ class ElasticMapReduceResponse(BaseResponse):
self._parse_ebs_configuration(item)
# Adding support for auto_scaling_policy
Unflattener.unflatten_complex_params(item, "auto_scaling_policy")
instance_groups = self.backend.add_instance_groups(jobflow_id, instance_groups)
fake_groups = self.backend.add_instance_groups(jobflow_id, instance_groups)
template = self.response_template(ADD_INSTANCE_GROUPS_TEMPLATE)
return template.render(instance_groups=instance_groups)
return template.render(instance_groups=fake_groups)
@generate_boto3_response("AddJobFlowSteps")
def add_job_flow_steps(self):
def add_job_flow_steps(self) -> str:
job_flow_id = self._get_param("JobFlowId")
steps = self.backend.add_job_flow_steps(
job_flow_id, steps_from_query_string(self._get_list_prefix("Steps.member"))
@ -92,18 +100,15 @@ class ElasticMapReduceResponse(BaseResponse):
return template.render(steps=steps)
@generate_boto3_response("AddTags")
def add_tags(self):
def add_tags(self) -> str:
cluster_id = self._get_param("ResourceId")
tags = tags_from_query_string(self.querystring, prefix="Tags")
self.backend.add_tags(cluster_id, tags)
template = self.response_template(ADD_TAGS_TEMPLATE)
return template.render()
def cancel_steps(self):
raise NotImplementedError
@generate_boto3_response("CreateSecurityConfiguration")
def create_security_configuration(self):
def create_security_configuration(self) -> str:
name = self._get_param("Name")
security_configuration = self._get_param("SecurityConfiguration")
resp = self.backend.create_security_configuration(
@ -113,28 +118,28 @@ class ElasticMapReduceResponse(BaseResponse):
return template.render(name=name, creation_date_time=resp.creation_date_time)
@generate_boto3_response("DescribeSecurityConfiguration")
def describe_security_configuration(self):
def describe_security_configuration(self) -> str:
name = self._get_param("Name")
security_configuration = self.backend.get_security_configuration(name=name)
template = self.response_template(DESCRIBE_SECURITY_CONFIGURATION_TEMPLATE)
return template.render(security_configuration=security_configuration)
@generate_boto3_response("DeleteSecurityConfiguration")
def delete_security_configuration(self):
def delete_security_configuration(self) -> str:
name = self._get_param("Name")
self.backend.delete_security_configuration(name=name)
template = self.response_template(DELETE_SECURITY_CONFIGURATION_TEMPLATE)
return template.render()
@generate_boto3_response("DescribeCluster")
def describe_cluster(self):
def describe_cluster(self) -> str:
cluster_id = self._get_param("ClusterId")
cluster = self.backend.describe_cluster(cluster_id)
template = self.response_template(DESCRIBE_CLUSTER_TEMPLATE)
return template.render(cluster=cluster)
@generate_boto3_response("DescribeJobFlows")
def describe_job_flows(self):
def describe_job_flows(self) -> str:
created_after = self._get_param("CreatedAfter")
created_before = self._get_param("CreatedBefore")
job_flow_ids = self._get_multi_param("JobFlowIds.member")
@ -146,7 +151,7 @@ class ElasticMapReduceResponse(BaseResponse):
return template.render(clusters=clusters)
@generate_boto3_response("DescribeStep")
def describe_step(self):
def describe_step(self) -> str:
cluster_id = self._get_param("ClusterId")
step_id = self._get_param("StepId")
step = self.backend.describe_step(cluster_id, step_id)
@ -154,7 +159,7 @@ class ElasticMapReduceResponse(BaseResponse):
return template.render(step=step)
@generate_boto3_response("ListBootstrapActions")
def list_bootstrap_actions(self):
def list_bootstrap_actions(self) -> str:
cluster_id = self._get_param("ClusterId")
marker = self._get_param("Marker")
bootstrap_actions, marker = self.backend.list_bootstrap_actions(
@ -164,7 +169,7 @@ class ElasticMapReduceResponse(BaseResponse):
return template.render(bootstrap_actions=bootstrap_actions, marker=marker)
@generate_boto3_response("ListClusters")
def list_clusters(self):
def list_clusters(self) -> str:
cluster_states = self._get_multi_param("ClusterStates.member")
created_after = self._get_param("CreatedAfter")
created_before = self._get_param("CreatedBefore")
@ -176,7 +181,7 @@ class ElasticMapReduceResponse(BaseResponse):
return template.render(clusters=clusters, marker=marker)
@generate_boto3_response("ListInstanceGroups")
def list_instance_groups(self):
def list_instance_groups(self) -> str:
cluster_id = self._get_param("ClusterId")
marker = self._get_param("Marker")
instance_groups, marker = self.backend.list_instance_groups(
@ -186,7 +191,7 @@ class ElasticMapReduceResponse(BaseResponse):
return template.render(instance_groups=instance_groups, marker=marker)
@generate_boto3_response("ListInstances")
def list_instances(self):
def list_instances(self) -> str:
cluster_id = self._get_param("ClusterId")
marker = self._get_param("Marker")
instance_group_id = self._get_param("InstanceGroupId")
@ -201,7 +206,7 @@ class ElasticMapReduceResponse(BaseResponse):
return template.render(instances=instances, marker=marker)
@generate_boto3_response("ListSteps")
def list_steps(self):
def list_steps(self) -> str:
cluster_id = self._get_param("ClusterId")
marker = self._get_param("Marker")
step_ids = self._get_multi_param("StepIds.member")
@ -213,7 +218,7 @@ class ElasticMapReduceResponse(BaseResponse):
return template.render(steps=steps, marker=marker)
@generate_boto3_response("ModifyCluster")
def modify_cluster(self):
def modify_cluster(self) -> str:
cluster_id = self._get_param("ClusterId")
step_concurrency_level = self._get_param("StepConcurrencyLevel")
cluster = self.backend.modify_cluster(cluster_id, step_concurrency_level)
@ -221,16 +226,16 @@ class ElasticMapReduceResponse(BaseResponse):
return template.render(cluster=cluster)
@generate_boto3_response("ModifyInstanceGroups")
def modify_instance_groups(self):
def modify_instance_groups(self) -> str:
instance_groups = self._get_list_prefix("InstanceGroups.member")
for item in instance_groups:
item["instance_count"] = int(item["instance_count"])
instance_groups = self.backend.modify_instance_groups(instance_groups)
self.backend.modify_instance_groups(instance_groups)
template = self.response_template(MODIFY_INSTANCE_GROUPS_TEMPLATE)
return template.render(instance_groups=instance_groups)
return template.render()
@generate_boto3_response("RemoveTags")
def remove_tags(self):
def remove_tags(self) -> str:
cluster_id = self._get_param("ResourceId")
tag_keys = self._get_multi_param("TagKeys.member")
self.backend.remove_tags(cluster_id, tag_keys)
@ -238,7 +243,7 @@ class ElasticMapReduceResponse(BaseResponse):
return template.render()
@generate_boto3_response("RunJobFlow")
def run_job_flow(self):
def run_job_flow(self) -> str:
instance_attrs = dict(
master_instance_type=self._get_param("Instances.MasterInstanceType"),
slave_instance_type=self._get_param("Instances.SlaveInstanceType"),
@ -349,7 +354,7 @@ class ElasticMapReduceResponse(BaseResponse):
if security_configuration:
kwargs["security_configuration"] = security_configuration
kerberos_attributes = {}
kerberos_attributes: Dict[str, Any] = {}
kwargs["kerberos_attributes"] = kerberos_attributes
realm = self._get_param("KerberosAttributes.Realm")
@ -413,13 +418,13 @@ class ElasticMapReduceResponse(BaseResponse):
template = self.response_template(RUN_JOB_FLOW_TEMPLATE)
return template.render(cluster=cluster)
def _has_key_prefix(self, key_prefix, value):
def _has_key_prefix(self, key_prefix: str, value: Dict[str, Any]) -> bool:
for key in value: # iter on both keys and values
if key.startswith(key_prefix):
return True
return False
def _parse_ebs_configuration(self, instance_group):
def _parse_ebs_configuration(self, instance_group: Dict[str, Any]) -> None:
key_ebs_config = "ebs_configuration"
ebs_configuration = dict()
# Filter only EBS config keys
@ -456,7 +461,7 @@ class ElasticMapReduceResponse(BaseResponse):
vol_iops = vlespc_keyfmt.format(iops)
vol_type = vlespc_keyfmt.format(volume_type)
ebs_block = dict()
ebs_block: Dict[str, Any] = dict()
ebs_block[volume_specification] = dict()
if vol_size in ebs_configuration:
instance_group.pop(vol_size)
@ -491,7 +496,7 @@ class ElasticMapReduceResponse(BaseResponse):
instance_group[key_ebs_config] = ebs_configuration
@generate_boto3_response("SetTerminationProtection")
def set_termination_protection(self):
def set_termination_protection(self) -> str:
termination_protection = self._get_bool_param("TerminationProtected")
job_ids = self._get_multi_param("JobFlowIds.member")
self.backend.set_termination_protection(job_ids, termination_protection)
@ -499,7 +504,7 @@ class ElasticMapReduceResponse(BaseResponse):
return template.render()
@generate_boto3_response("SetVisibleToAllUsers")
def set_visible_to_all_users(self):
def set_visible_to_all_users(self) -> str:
visible_to_all_users = self._get_param("VisibleToAllUsers")
job_ids = self._get_multi_param("JobFlowIds.member")
self.backend.set_visible_to_all_users(job_ids, visible_to_all_users)
@ -507,14 +512,14 @@ class ElasticMapReduceResponse(BaseResponse):
return template.render()
@generate_boto3_response("TerminateJobFlows")
def terminate_job_flows(self):
def terminate_job_flows(self) -> str:
job_ids = self._get_multi_param("JobFlowIds.member.")
self.backend.terminate_job_flows(job_ids)
template = self.response_template(TERMINATE_JOB_FLOWS_TEMPLATE)
return template.render()
@generate_boto3_response("PutAutoScalingPolicy")
def put_auto_scaling_policy(self):
def put_auto_scaling_policy(self) -> str:
cluster_id = self._get_param("ClusterId")
cluster = self.backend.describe_cluster(cluster_id)
instance_group_id = self._get_param("InstanceGroupId")
@ -528,12 +533,11 @@ class ElasticMapReduceResponse(BaseResponse):
)
@generate_boto3_response("RemoveAutoScalingPolicy")
def remove_auto_scaling_policy(self):
cluster_id = self._get_param("ClusterId")
def remove_auto_scaling_policy(self) -> str:
instance_group_id = self._get_param("InstanceGroupId")
instance_group = self.backend.remove_auto_scaling_policy(instance_group_id)
self.backend.remove_auto_scaling_policy(instance_group_id)
template = self.response_template(REMOVE_AUTO_SCALING_POLICY)
return template.render(cluster_id=cluster_id, instance_group=instance_group)
return template.render()
ADD_INSTANCE_GROUPS_TEMPLATE = """<AddInstanceGroupsResponse xmlns="http://elasticmapreduce.amazonaws.com/doc/2009-03-31">

View File

@ -2,6 +2,7 @@ import copy
import datetime
import re
import string
from typing import Any, List, Dict, Tuple, Iterator
from moto.core.utils import (
camelcase_to_underscores,
iso_8601_datetime_with_milliseconds,
@ -9,24 +10,26 @@ from moto.core.utils import (
from moto.moto_api._internal import mock_random as random
def random_id(size=13):
def random_id(size: int = 13) -> str:
chars = list(range(10)) + list(string.ascii_uppercase)
return "".join(str(random.choice(chars)) for x in range(size))
def random_cluster_id():
def random_cluster_id() -> str:
return f"j-{random_id()}"
def random_step_id():
def random_step_id() -> str:
return f"s-{random_id()}"
def random_instance_group_id():
def random_instance_group_id() -> str:
return f"i-{random_id()}"
def steps_from_query_string(querystring_dict):
def steps_from_query_string(
querystring_dict: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
steps = []
for step in querystring_dict:
step["jar"] = step.pop("hadoop_jar_step._jar")
@ -45,7 +48,7 @@ def steps_from_query_string(querystring_dict):
class Unflattener:
@staticmethod
def unflatten_complex_params(input_dict, param_name):
def unflatten_complex_params(input_dict: Dict[str, Any], param_name: str) -> None: # type: ignore[misc]
"""Function to unflatten (portions of) dicts with complex keys. The moto request parser flattens the incoming
request bodies, which is generally helpful, but for nested dicts/lists can result in a hard-to-manage
parameter exposion. This function allows one to selectively unflatten a set of dict keys, replacing them
@ -68,7 +71,7 @@ class Unflattener:
Unflattener._set_deep(k, input_dict, items_to_process[k])
@staticmethod
def _set_deep(complex_key, container, value):
def _set_deep(complex_key: Any, container: Any, value: Any) -> None: # type: ignore[misc]
keys = complex_key.split(".")
keys.reverse()
@ -91,7 +94,7 @@ class Unflattener:
container = Unflattener._get_child(container, key)
@staticmethod
def _add_to_container(container, key, value):
def _add_to_container(container: Any, key: Any, value: Any) -> Any: # type: ignore[misc]
if type(container) is dict:
container[key] = value
elif type(container) is list:
@ -102,7 +105,7 @@ class Unflattener:
return value
@staticmethod
def _get_child(container, key):
def _get_child(container: Any, key: Any) -> Any: # type: ignore[misc]
if type(container) is dict:
return container[key]
elif type(container) is list:
@ -110,7 +113,7 @@ class Unflattener:
return container[i - 1]
@staticmethod
def _key_in_container(container, key):
def _key_in_container(container: Any, key: Any) -> bool: # type: ignore
if type(container) is dict:
return key in container
elif type(container) is list:
@ -122,7 +125,7 @@ class CamelToUnderscoresWalker:
"""A class to convert the keys in dict/list hierarchical data structures from CamelCase to snake_case (underscores)"""
@staticmethod
def parse(x):
def parse(x: Any) -> Any: # type: ignore[misc]
if isinstance(x, dict):
return CamelToUnderscoresWalker.parse_dict(x)
elif isinstance(x, list):
@ -131,29 +134,29 @@ class CamelToUnderscoresWalker:
return CamelToUnderscoresWalker.parse_scalar(x)
@staticmethod
def parse_dict(x):
def parse_dict(x: Dict[str, Any]) -> Dict[str, Any]: # type: ignore[misc]
temp = {}
for key in x.keys():
temp[camelcase_to_underscores(key)] = CamelToUnderscoresWalker.parse(x[key])
return temp
@staticmethod
def parse_list(x):
def parse_list(x: Any) -> Any: # type: ignore[misc]
temp = []
for i in x:
temp.append(CamelToUnderscoresWalker.parse(i))
return temp
@staticmethod
def parse_scalar(x):
def parse_scalar(x: Any) -> Any: # type: ignore[misc]
return x
class ReleaseLabel(object):
class ReleaseLabel:
version_re = re.compile(r"^emr-(\d+)\.(\d+)\.(\d+)$")
def __init__(self, release_label):
def __init__(self, release_label: str):
major, minor, patch = self.parse(release_label)
self.major = major
@ -161,7 +164,7 @@ class ReleaseLabel(object):
self.patch = patch
@classmethod
def parse(cls, release_label):
def parse(cls, release_label: str) -> Tuple[int, int, int]:
if not release_label:
raise ValueError(f"Invalid empty ReleaseLabel: {release_label}")
@ -171,23 +174,19 @@ class ReleaseLabel(object):
major, minor, patch = match.groups()
major = int(major)
minor = int(minor)
patch = int(patch)
return int(major), int(minor), int(patch)
return major, minor, patch
def __str__(self):
def __str__(self) -> str:
version = f"emr-{self.major}.{self.minor}.{self.patch}"
return version
def __repr__(self):
def __repr__(self) -> str:
return f"{self.__class__.__name__}({str(self)})"
def __iter__(self):
return iter((self.major, self.minor, self.patch))
def __iter__(self) -> Iterator[Tuple[int, int, int]]:
return iter((self.major, self.minor, self.patch)) # type: ignore
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
if not isinstance(other, self.__class__):
return NotImplemented
return (
@ -196,46 +195,46 @@ class ReleaseLabel(object):
and self.patch == other.patch
)
def __ne__(self, other):
def __ne__(self, other: Any) -> bool:
if not isinstance(other, self.__class__):
return NotImplemented
return tuple(self) != tuple(other)
def __lt__(self, other):
def __lt__(self, other: Any) -> bool:
if not isinstance(other, self.__class__):
return NotImplemented
return tuple(self) < tuple(other)
def __le__(self, other):
def __le__(self, other: Any) -> bool:
if not isinstance(other, self.__class__):
return NotImplemented
return tuple(self) <= tuple(other)
def __gt__(self, other):
def __gt__(self, other: Any) -> bool:
if not isinstance(other, self.__class__):
return NotImplemented
return tuple(self) > tuple(other)
def __ge__(self, other):
def __ge__(self, other: Any) -> bool:
if not isinstance(other, self.__class__):
return NotImplemented
return tuple(self) >= tuple(other)
class EmrManagedSecurityGroup(object):
class EmrManagedSecurityGroup:
class Kind:
MASTER = "Master"
SLAVE = "Slave"
SERVICE = "Service"
kind = None
kind = ""
group_name = ""
short_name = ""
desc_fmt = "{short_name} for Elastic MapReduce created on {created}"
@classmethod
def description(cls):
def description(cls) -> str:
created = iso_8601_datetime_with_milliseconds(datetime.datetime.now())
return cls.desc_fmt.format(short_name=cls.short_name, created=created)
@ -258,7 +257,7 @@ class EmrManagedServiceAccessSecurityGroup(EmrManagedSecurityGroup):
short_name = "Service access"
class EmrSecurityGroupManager(object):
class EmrSecurityGroupManager:
MANAGED_RULES_EGRESS = [
{
@ -383,13 +382,16 @@ class EmrSecurityGroupManager(object):
},
]
def __init__(self, ec2_backend, vpc_id):
def __init__(self, ec2_backend: Any, vpc_id: str):
self.ec2 = ec2_backend
self.vpc_id = vpc_id
def manage_security_groups(
self, master_security_group, slave_security_group, service_access_security_group
):
self,
master_security_group: str,
slave_security_group: str,
service_access_security_group: str,
) -> Tuple[Any, Any, Any]:
group_metadata = [
(
master_security_group,
@ -407,7 +409,7 @@ class EmrSecurityGroupManager(object):
EmrManagedServiceAccessSecurityGroup,
),
]
managed_groups = {}
managed_groups: Dict[str, Any] = {}
for name, kind, defaults in group_metadata:
managed_groups[kind] = self._get_or_create_sg(name, defaults)
self._add_rules_to(managed_groups)
@ -417,7 +419,7 @@ class EmrSecurityGroupManager(object):
managed_groups[EmrManagedSecurityGroup.Kind.SERVICE],
)
def _get_or_create_sg(self, sg_id, defaults):
def _get_or_create_sg(self, sg_id: str, defaults: Any) -> Any:
find_sg = self.ec2.get_security_group_by_name_or_id
create_sg = self.ec2.create_security_group
group_id_or_name = sg_id or defaults.group_name
@ -430,7 +432,7 @@ class EmrSecurityGroupManager(object):
group = create_sg(defaults.group_name, defaults.description(), self.vpc_id)
return group
def _add_rules_to(self, managed_groups):
def _add_rules_to(self, managed_groups: Dict[str, Any]) -> None:
rules_metadata = [
(self.MANAGED_RULES_EGRESS, self.ec2.authorize_security_group_egress),
(self.MANAGED_RULES_INGRESS, self.ec2.authorize_security_group_ingress),
@ -447,7 +449,7 @@ class EmrSecurityGroupManager(object):
pass
@staticmethod
def _render_rules(rules, managed_groups):
def _render_rules(rules: Any, managed_groups: Dict[str, Any]) -> List[Dict[str, Any]]: # type: ignore[misc]
rendered_rules = copy.deepcopy(rules)
for rule in rendered_rules:
rule["group_name_or_id"] = managed_groups[rule["group_name_or_id"]].id

View File

@ -229,7 +229,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy]
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/ebs/,moto/ec2,moto/ec2instanceconnect,moto/ecr,moto/ecs,moto/efs,moto/eks,moto/elasticache,moto/elasticbeanstalk,moto/elastictranscoder,moto/elb,moto/elbv2,moto/es,moto/moto_api,moto/neptune
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/ebs/,moto/ec2,moto/ec2instanceconnect,moto/ecr,moto/ecs,moto/efs,moto/eks,moto/elasticache,moto/elasticbeanstalk,moto/elastictranscoder,moto/elb,moto/elbv2,moto/emr,moto/es,moto/moto_api,moto/neptune
show_column_numbers=True
show_error_codes = True
disable_error_code=abstract