Techdebt: MyPy EMR (#6005)

This commit is contained in:
Bert Blommers 2023-03-03 18:42:11 -01:00 committed by GitHub
parent 0e24b281eb
commit a1a43e3f74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 293 additions and 238 deletions

View File

@ -218,7 +218,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
access_key_regex = re.compile( access_key_regex = re.compile(
r"AWS.*(?P<access_key>(?<![A-Z0-9])[A-Z0-9]{20}(?![A-Z0-9]))[:/]" r"AWS.*(?P<access_key>(?<![A-Z0-9])[A-Z0-9]{20}(?![A-Z0-9]))[:/]"
) )
aws_service_spec = None aws_service_spec: Optional["AWSServiceSpec"] = None
def __init__(self, service_name: Optional[str] = None): def __init__(self, service_name: Optional[str] = None):
super().__init__() super().__init__()

View File

@ -2,15 +2,15 @@ from moto.core.exceptions import JsonRESTError
class InvalidRequestException(JsonRESTError): class InvalidRequestException(JsonRESTError):
def __init__(self, message, **kwargs): def __init__(self, message: str):
super().__init__("InvalidRequestException", message, **kwargs) super().__init__("InvalidRequestException", message)
class ValidationException(JsonRESTError): class ValidationException(JsonRESTError):
def __init__(self, message, **kwargs): def __init__(self, message: str):
super().__init__("ValidationException", message, **kwargs) super().__init__("ValidationException", message)
class ResourceNotFoundException(JsonRESTError): class ResourceNotFoundException(JsonRESTError):
def __init__(self, message, **kwargs): def __init__(self, message: str):
super().__init__("ResourceNotFoundException", message, **kwargs) super().__init__("ResourceNotFoundException", message)

View File

@ -1,5 +1,5 @@
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional, Tuple
import warnings import warnings
from dateutil.parser import parse as dtparse from dateutil.parser import parse as dtparse
@ -21,7 +21,9 @@ EXAMPLE_AMI_ID = "ami-12c6146b"
class FakeApplication(BaseModel): class FakeApplication(BaseModel):
def __init__(self, name, version, args=None, additional_info=None): def __init__(
self, name: str, version: str, args: List[str], additional_info: Dict[str, str]
):
self.additional_info = additional_info or {} self.additional_info = additional_info or {}
self.args = args or [] self.args = args or []
self.name = name self.name = name
@ -29,7 +31,7 @@ class FakeApplication(BaseModel):
class FakeBootstrapAction(BaseModel): class FakeBootstrapAction(BaseModel):
def __init__(self, args, name, script_path): def __init__(self, args: List[str], name: str, script_path: str):
self.args = args or [] self.args = args or []
self.name = name self.name = name
self.script_path = script_path self.script_path = script_path
@ -37,7 +39,11 @@ class FakeBootstrapAction(BaseModel):
class FakeInstance(BaseModel): class FakeInstance(BaseModel):
def __init__( def __init__(
self, ec2_instance_id, instance_group, instance_fleet_id=None, instance_id=None self,
ec2_instance_id: str,
instance_group: "FakeInstanceGroup",
instance_fleet_id: Optional[str] = None,
instance_id: Optional[str] = None,
): ):
self.id = instance_id or random_instance_group_id() self.id = instance_id or random_instance_group_id()
self.ec2_instance_id = ec2_instance_id self.ec2_instance_id = ec2_instance_id
@ -48,16 +54,16 @@ class FakeInstance(BaseModel):
class FakeInstanceGroup(BaseModel): class FakeInstanceGroup(BaseModel):
def __init__( def __init__(
self, self,
cluster_id, cluster_id: str,
instance_count, instance_count: int,
instance_role, instance_role: str,
instance_type, instance_type: str,
market="ON_DEMAND", market: str = "ON_DEMAND",
name=None, name: Optional[str] = None,
instance_group_id=None, instance_group_id: Optional[str] = None,
bid_price=None, bid_price: Optional[str] = None,
ebs_configuration=None, ebs_configuration: Optional[Dict[str, Any]] = None,
auto_scaling_policy=None, auto_scaling_policy: Optional[Dict[str, Any]] = None,
): ):
self.id = instance_group_id or random_instance_group_id() self.id = instance_group_id or random_instance_group_id()
self.cluster_id = cluster_id self.cluster_id = cluster_id
@ -83,15 +89,15 @@ class FakeInstanceGroup(BaseModel):
self.end_datetime = None self.end_datetime = None
self.state = "RUNNING" self.state = "RUNNING"
def set_instance_count(self, instance_count): def set_instance_count(self, instance_count: int) -> None:
self.num_instances = instance_count self.num_instances = instance_count
@property @property
def auto_scaling_policy(self): def auto_scaling_policy(self) -> Any: # type: ignore[misc]
return self._auto_scaling_policy return self._auto_scaling_policy
@auto_scaling_policy.setter @auto_scaling_policy.setter
def auto_scaling_policy(self, value): def auto_scaling_policy(self, value: Any) -> None:
if value is None: if value is None:
self._auto_scaling_policy = value self._auto_scaling_policy = value
return return
@ -118,12 +124,12 @@ class FakeInstanceGroup(BaseModel):
class FakeStep(BaseModel): class FakeStep(BaseModel):
def __init__( def __init__(
self, self,
state, state: str,
name="", name: str = "",
jar="", jar: str = "",
args=None, args: Optional[List[str]] = None,
properties=None, properties: Optional[Dict[str, str]] = None,
action_on_failure="TERMINATE_CLUSTER", action_on_failure: str = "TERMINATE_CLUSTER",
): ):
self.id = random_step_id() self.id = random_step_id()
@ -136,63 +142,63 @@ class FakeStep(BaseModel):
self.creation_datetime = datetime.now(timezone.utc) self.creation_datetime = datetime.now(timezone.utc)
self.end_datetime = None self.end_datetime = None
self.ready_datetime = None self.ready_datetime = None
self.start_datetime = None self.start_datetime: Optional[datetime] = None
self.state = state self.state = state
def start(self): def start(self) -> None:
self.start_datetime = datetime.now(timezone.utc) self.start_datetime = datetime.now(timezone.utc)
class FakeCluster(BaseModel): class FakeCluster(BaseModel):
def __init__( def __init__(
self, self,
emr_backend, emr_backend: "ElasticMapReduceBackend",
name, name: str,
log_uri, log_uri: str,
job_flow_role, job_flow_role: str,
service_role, service_role: str,
steps, steps: List[Dict[str, Any]],
instance_attrs, instance_attrs: Dict[str, Any],
bootstrap_actions=None, bootstrap_actions: Optional[List[Dict[str, Any]]] = None,
configurations=None, configurations: Optional[List[Dict[str, Any]]] = None,
cluster_id=None, cluster_id: Optional[str] = None,
visible_to_all_users="false", visible_to_all_users: str = "false",
release_label=None, release_label: Optional[str] = None,
requested_ami_version=None, requested_ami_version: Optional[str] = None,
running_ami_version=None, running_ami_version: Optional[str] = None,
custom_ami_id=None, custom_ami_id: Optional[str] = None,
step_concurrency_level=1, step_concurrency_level: int = 1,
security_configuration=None, security_configuration: Optional[str] = None,
kerberos_attributes=None, kerberos_attributes: Optional[Dict[str, str]] = None,
auto_scaling_role=None, auto_scaling_role: Optional[str] = None,
): ):
self.id = cluster_id or random_cluster_id() self.id = cluster_id or random_cluster_id()
emr_backend.clusters[self.id] = self emr_backend.clusters[self.id] = self
self.emr_backend = emr_backend self.emr_backend = emr_backend
self.applications = [] self.applications: List[FakeApplication] = []
self.bootstrap_actions = [] self.bootstrap_actions: List[FakeBootstrapAction] = []
for bootstrap_action in bootstrap_actions or []: for bootstrap_action in bootstrap_actions or []:
self.add_bootstrap_action(bootstrap_action) self.add_bootstrap_action(bootstrap_action)
self.configurations = configurations or [] self.configurations = configurations or []
self.tags = {} self.tags: Dict[str, str] = {}
self.log_uri = log_uri self.log_uri = log_uri
self.name = name self.name = name
self.normalized_instance_hours = 0 self.normalized_instance_hours = 0
self.steps = [] self.steps: List[FakeStep] = []
self.add_steps(steps) self.add_steps(steps)
self.set_visibility(visible_to_all_users) self.set_visibility(visible_to_all_users)
self.instance_group_ids = [] self.instance_group_ids: List[str] = []
self.instances = [] self.instances: List[FakeInstance] = []
self.master_instance_group_id = None self.master_instance_group_id: Optional[str] = None
self.core_instance_group_id = None self.core_instance_group_id: Optional[str] = None
if ( if (
"master_instance_type" in instance_attrs "master_instance_type" in instance_attrs
and instance_attrs["master_instance_type"] and instance_attrs["master_instance_type"]
@ -259,10 +265,10 @@ class FakeCluster(BaseModel):
self.step_concurrency_level = step_concurrency_level self.step_concurrency_level = step_concurrency_level
self.creation_datetime = datetime.now(timezone.utc) self.creation_datetime = datetime.now(timezone.utc)
self.start_datetime = None self.start_datetime: Optional[datetime] = None
self.ready_datetime = None self.ready_datetime: Optional[datetime] = None
self.end_datetime = None self.end_datetime: Optional[datetime] = None
self.state = None self.state: Optional[str] = None
self.start_cluster() self.start_cluster()
self.run_bootstrap_actions() self.run_bootstrap_actions()
@ -275,30 +281,30 @@ class FakeCluster(BaseModel):
self.auto_scaling_role = auto_scaling_role self.auto_scaling_role = auto_scaling_role
@property @property
def arn(self): def arn(self) -> str:
return f"arn:aws:elasticmapreduce:{self.emr_backend.region_name}:{self.emr_backend.account_id}:cluster/{self.id}" return f"arn:aws:elasticmapreduce:{self.emr_backend.region_name}:{self.emr_backend.account_id}:cluster/{self.id}"
@property @property
def instance_groups(self): def instance_groups(self) -> List[FakeInstanceGroup]:
return self.emr_backend.get_instance_groups(self.instance_group_ids) return self.emr_backend.get_instance_groups(self.instance_group_ids)
@property @property
def master_instance_type(self): def master_instance_type(self) -> str:
return self.emr_backend.instance_groups[self.master_instance_group_id].type return self.emr_backend.instance_groups[self.master_instance_group_id].type # type: ignore
@property @property
def slave_instance_type(self): def slave_instance_type(self) -> str:
return self.emr_backend.instance_groups[self.core_instance_group_id].type return self.emr_backend.instance_groups[self.core_instance_group_id].type # type: ignore
@property @property
def instance_count(self): def instance_count(self) -> int:
return sum(group.num_instances for group in self.instance_groups) return sum(group.num_instances for group in self.instance_groups)
def start_cluster(self): def start_cluster(self) -> None:
self.state = "STARTING" self.state = "STARTING"
self.start_datetime = datetime.now(timezone.utc) self.start_datetime = datetime.now(timezone.utc)
def run_bootstrap_actions(self): def run_bootstrap_actions(self) -> None:
self.state = "BOOTSTRAPPING" self.state = "BOOTSTRAPPING"
self.ready_datetime = datetime.now(timezone.utc) self.ready_datetime = datetime.now(timezone.utc)
self.state = "WAITING" self.state = "WAITING"
@ -306,28 +312,28 @@ class FakeCluster(BaseModel):
if not self.keep_job_flow_alive_when_no_steps: if not self.keep_job_flow_alive_when_no_steps:
self.terminate() self.terminate()
def terminate(self): def terminate(self) -> None:
self.state = "TERMINATING" self.state = "TERMINATING"
self.end_datetime = datetime.now(timezone.utc) self.end_datetime = datetime.now(timezone.utc)
self.state = "TERMINATED" self.state = "TERMINATED"
def add_applications(self, applications): def add_applications(self, applications: List[Dict[str, Any]]) -> None:
self.applications.extend( self.applications.extend(
[ [
FakeApplication( FakeApplication(
name=app.get("name", ""), name=app.get("name", ""),
version=app.get("version", ""), version=app.get("version", ""),
args=app.get("args", []), args=app.get("args", []),
additional_info=app.get("additiona_info", {}), additional_info=app.get("additional_info", {}),
) )
for app in applications for app in applications
] ]
) )
def add_bootstrap_action(self, bootstrap_action): def add_bootstrap_action(self, bootstrap_action: Dict[str, Any]) -> None:
self.bootstrap_actions.append(FakeBootstrapAction(**bootstrap_action)) self.bootstrap_actions.append(FakeBootstrapAction(**bootstrap_action))
def add_instance_group(self, instance_group): def add_instance_group(self, instance_group: FakeInstanceGroup) -> None:
if instance_group.role == "MASTER": if instance_group.role == "MASTER":
if self.master_instance_group_id: if self.master_instance_group_id:
raise Exception("Cannot add another master instance group") raise Exception("Cannot add another master instance group")
@ -347,10 +353,10 @@ class FakeCluster(BaseModel):
self.core_instance_group_id = instance_group.id self.core_instance_group_id = instance_group.id
self.instance_group_ids.append(instance_group.id) self.instance_group_ids.append(instance_group.id)
def add_instance(self, instance): def add_instance(self, instance: FakeInstance) -> None:
self.instances.append(instance) self.instances.append(instance)
def add_steps(self, steps): def add_steps(self, steps: List[Dict[str, Any]]) -> List[FakeStep]:
added_steps = [] added_steps = []
for step in steps: for step in steps:
if self.steps: if self.steps:
@ -363,43 +369,45 @@ class FakeCluster(BaseModel):
self.state = "RUNNING" self.state = "RUNNING"
return added_steps return added_steps
def add_tags(self, tags): def add_tags(self, tags: Dict[str, str]) -> None:
self.tags.update(tags) self.tags.update(tags)
def remove_tags(self, tag_keys): def remove_tags(self, tag_keys: List[str]) -> None:
for key in tag_keys: for key in tag_keys:
self.tags.pop(key, None) self.tags.pop(key, None)
def set_termination_protection(self, value): def set_termination_protection(self, value: bool) -> None:
self.termination_protected = value self.termination_protected = value
def set_visibility(self, visibility): def set_visibility(self, visibility: str) -> None:
self.visible_to_all_users = visibility self.visible_to_all_users = visibility
class FakeSecurityConfiguration(BaseModel): class FakeSecurityConfiguration(BaseModel):
def __init__(self, name, security_configuration): def __init__(self, name: str, security_configuration: str):
self.name = name self.name = name
self.security_configuration = security_configuration self.security_configuration = security_configuration
self.creation_date_time = datetime.now(timezone.utc) self.creation_date_time = datetime.now(timezone.utc)
class ElasticMapReduceBackend(BaseBackend): class ElasticMapReduceBackend(BaseBackend):
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.clusters = {} self.clusters: Dict[str, FakeCluster] = {}
self.instance_groups = {} self.instance_groups: Dict[str, FakeInstanceGroup] = {}
self.security_configurations = {} self.security_configurations: Dict[str, FakeSecurityConfiguration] = {}
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(
service_region: str, zones: List[str]
) -> List[Dict[str, str]]:
"""Default VPC endpoint service.""" """Default VPC endpoint service."""
return BaseBackend.default_vpc_endpoint_service_factory( return BaseBackend.default_vpc_endpoint_service_factory(
service_region, zones, "elasticmapreduce" service_region, zones, "elasticmapreduce"
) )
@property @property
def ec2_backend(self): def ec2_backend(self) -> Any: # type: ignore[misc]
""" """
:return: EC2 Backend :return: EC2 Backend
:rtype: moto.ec2.models.EC2Backend :rtype: moto.ec2.models.EC2Backend
@ -408,11 +416,15 @@ class ElasticMapReduceBackend(BaseBackend):
return ec2_backends[self.account_id][self.region_name] return ec2_backends[self.account_id][self.region_name]
def add_applications(self, cluster_id, applications): def add_applications(
self, cluster_id: str, applications: List[Dict[str, Any]]
) -> None:
cluster = self.describe_cluster(cluster_id) cluster = self.describe_cluster(cluster_id)
cluster.add_applications(applications) cluster.add_applications(applications)
def add_instance_groups(self, cluster_id, instance_groups): def add_instance_groups(
self, cluster_id: str, instance_groups: List[Dict[str, Any]]
) -> List[FakeInstanceGroup]:
cluster = self.clusters[cluster_id] cluster = self.clusters[cluster_id]
result_groups = [] result_groups = []
for instance_group in instance_groups: for instance_group in instance_groups:
@ -422,7 +434,12 @@ class ElasticMapReduceBackend(BaseBackend):
result_groups.append(group) result_groups.append(group)
return result_groups return result_groups
def add_instances(self, cluster_id, instances, instance_group): def add_instances(
self,
cluster_id: str,
instances: Dict[str, Any],
instance_group: FakeInstanceGroup,
) -> None:
cluster = self.clusters[cluster_id] cluster = self.clusters[cluster_id]
instances["is_instance_type_default"] = not instances.get("instance_type") instances["is_instance_type_default"] = not instances.get("instance_type")
response = self.ec2_backend.add_instances( response = self.ec2_backend.add_instances(
@ -434,23 +451,24 @@ class ElasticMapReduceBackend(BaseBackend):
) )
cluster.add_instance(instance) cluster.add_instance(instance)
def add_job_flow_steps(self, job_flow_id, steps): def add_job_flow_steps(
self, job_flow_id: str, steps: List[Dict[str, Any]]
) -> List[FakeStep]:
cluster = self.clusters[job_flow_id] cluster = self.clusters[job_flow_id]
steps = cluster.add_steps(steps) return cluster.add_steps(steps)
return steps
def add_tags(self, cluster_id, tags): def add_tags(self, cluster_id: str, tags: Dict[str, str]) -> None:
cluster = self.describe_cluster(cluster_id) cluster = self.describe_cluster(cluster_id)
cluster.add_tags(tags) cluster.add_tags(tags)
def describe_job_flows( def describe_job_flows(
self, self,
job_flow_ids=None, job_flow_ids: Optional[List[str]] = None,
job_flow_states=None, job_flow_states: Optional[List[str]] = None,
created_after=None, created_after: Optional[str] = None,
created_before=None, created_before: Optional[str] = None,
): ) -> List[FakeCluster]:
clusters = self.clusters.values() clusters = list(self.clusters.values())
within_two_month = datetime.now(timezone.utc) - timedelta(days=60) within_two_month = datetime.now(timezone.utc) - timedelta(days=60)
clusters = [c for c in clusters if c.creation_datetime >= within_two_month] clusters = [c for c in clusters if c.creation_datetime >= within_two_month]
@ -460,34 +478,41 @@ class ElasticMapReduceBackend(BaseBackend):
if job_flow_states: if job_flow_states:
clusters = [c for c in clusters if c.state in job_flow_states] clusters = [c for c in clusters if c.state in job_flow_states]
if created_after: if created_after:
created_after = dtparse(created_after) clusters = [
clusters = [c for c in clusters if c.creation_datetime > created_after] c for c in clusters if c.creation_datetime > dtparse(created_after)
]
if created_before: if created_before:
created_before = dtparse(created_before) clusters = [
clusters = [c for c in clusters if c.creation_datetime < created_before] c for c in clusters if c.creation_datetime < dtparse(created_before)
]
# Amazon EMR can return a maximum of 512 job flow descriptions # Amazon EMR can return a maximum of 512 job flow descriptions
return sorted(clusters, key=lambda x: x.id)[:512] return sorted(clusters, key=lambda x: x.id)[:512]
def describe_step(self, cluster_id, step_id): def describe_step(self, cluster_id: str, step_id: str) -> Optional[FakeStep]:
cluster = self.clusters[cluster_id] cluster = self.clusters[cluster_id]
for step in cluster.steps: for step in cluster.steps:
if step.id == step_id: if step.id == step_id:
return step return step
return None
def describe_cluster(self, cluster_id): def describe_cluster(self, cluster_id: str) -> FakeCluster:
if cluster_id in self.clusters: if cluster_id in self.clusters:
return self.clusters[cluster_id] return self.clusters[cluster_id]
raise ResourceNotFoundException("") raise ResourceNotFoundException("")
def get_instance_groups(self, instance_group_ids): def get_instance_groups(
self, instance_group_ids: List[str]
) -> List[FakeInstanceGroup]:
return [ return [
group group
for group_id, group in self.instance_groups.items() for group_id, group in self.instance_groups.items()
if group_id in instance_group_ids if group_id in instance_group_ids
] ]
def list_bootstrap_actions(self, cluster_id, marker=None): def list_bootstrap_actions(
self, cluster_id: str, marker: Optional[str] = None
) -> Tuple[List[FakeBootstrapAction], Optional[str]]:
max_items = 50 max_items = 50
actions = self.clusters[cluster_id].bootstrap_actions actions = self.clusters[cluster_id].bootstrap_actions
start_idx = 0 if marker is None else int(marker) start_idx = 0 if marker is None else int(marker)
@ -499,18 +524,24 @@ class ElasticMapReduceBackend(BaseBackend):
return actions[start_idx : start_idx + max_items], marker return actions[start_idx : start_idx + max_items], marker
def list_clusters( def list_clusters(
self, cluster_states=None, created_after=None, created_before=None, marker=None self,
): cluster_states: Optional[List[str]] = None,
created_after: Optional[str] = None,
created_before: Optional[str] = None,
marker: Optional[str] = None,
) -> Tuple[List[FakeCluster], Optional[str]]:
max_items = 50 max_items = 50
clusters = self.clusters.values() clusters = list(self.clusters.values())
if cluster_states: if cluster_states:
clusters = [c for c in clusters if c.state in cluster_states] clusters = [c for c in clusters if c.state in cluster_states]
if created_after: if created_after:
created_after = dtparse(created_after) clusters = [
clusters = [c for c in clusters if c.creation_datetime > created_after] c for c in clusters if c.creation_datetime > dtparse(created_after)
]
if created_before: if created_before:
created_before = dtparse(created_before) clusters = [
clusters = [c for c in clusters if c.creation_datetime < created_before] c for c in clusters if c.creation_datetime < dtparse(created_before)
]
clusters = sorted(clusters, key=lambda x: x.id) clusters = sorted(clusters, key=lambda x: x.id)
start_idx = 0 if marker is None else int(marker) start_idx = 0 if marker is None else int(marker)
marker = ( marker = (
@ -520,7 +551,9 @@ class ElasticMapReduceBackend(BaseBackend):
) )
return clusters[start_idx : start_idx + max_items], marker return clusters[start_idx : start_idx + max_items], marker
def list_instance_groups(self, cluster_id, marker=None): def list_instance_groups(
self, cluster_id: str, marker: Optional[str] = None
) -> Tuple[List[FakeInstanceGroup], Optional[str]]:
max_items = 50 max_items = 50
groups = sorted(self.clusters[cluster_id].instance_groups, key=lambda x: x.id) groups = sorted(self.clusters[cluster_id].instance_groups, key=lambda x: x.id)
start_idx = 0 if marker is None else int(marker) start_idx = 0 if marker is None else int(marker)
@ -530,8 +563,12 @@ class ElasticMapReduceBackend(BaseBackend):
return groups[start_idx : start_idx + max_items], marker return groups[start_idx : start_idx + max_items], marker
def list_instances( def list_instances(
self, cluster_id, marker=None, instance_group_id=None, instance_group_types=None self,
): cluster_id: str,
marker: Optional[str] = None,
instance_group_id: Optional[str] = None,
instance_group_types: Optional[List[str]] = None,
) -> Tuple[List[FakeInstance], Optional[str]]:
max_items = 50 max_items = 50
groups = sorted(self.clusters[cluster_id].instances, key=lambda x: x.id) groups = sorted(self.clusters[cluster_id].instances, key=lambda x: x.id)
start_idx = 0 if marker is None else int(marker) start_idx = 0 if marker is None else int(marker)
@ -545,10 +582,16 @@ class ElasticMapReduceBackend(BaseBackend):
g for g in groups if g.instance_group.role in instance_group_types g for g in groups if g.instance_group.role in instance_group_types
] ]
for g in groups: for g in groups:
g.details = self.ec2_backend.get_instance(g.ec2_instance_id) g.details = self.ec2_backend.get_instance(g.ec2_instance_id) # type: ignore
return groups[start_idx : start_idx + max_items], marker return groups[start_idx : start_idx + max_items], marker
def list_steps(self, cluster_id, marker=None, step_ids=None, step_states=None): def list_steps(
self,
cluster_id: str,
marker: Optional[str] = None,
step_ids: Optional[List[str]] = None,
step_states: Optional[List[str]] = None,
) -> Tuple[List[FakeStep], Optional[str]]:
max_items = 50 max_items = 50
steps = sorted( steps = sorted(
self.clusters[cluster_id].steps, self.clusters[cluster_id].steps,
@ -565,30 +608,30 @@ class ElasticMapReduceBackend(BaseBackend):
) )
return steps[start_idx : start_idx + max_items], marker return steps[start_idx : start_idx + max_items], marker
def modify_cluster(self, cluster_id, step_concurrency_level): def modify_cluster(
self, cluster_id: str, step_concurrency_level: int
) -> FakeCluster:
cluster = self.clusters[cluster_id] cluster = self.clusters[cluster_id]
cluster.step_concurrency_level = step_concurrency_level cluster.step_concurrency_level = step_concurrency_level
return cluster return cluster
def modify_instance_groups(self, instance_groups): def modify_instance_groups(self, instance_groups: List[Dict[str, Any]]) -> None:
result_groups = []
for instance_group in instance_groups: for instance_group in instance_groups:
group = self.instance_groups[instance_group["instance_group_id"]] group = self.instance_groups[instance_group["instance_group_id"]]
group.set_instance_count(int(instance_group["instance_count"])) group.set_instance_count(int(instance_group["instance_count"]))
return result_groups
def remove_tags(self, cluster_id, tag_keys): def remove_tags(self, cluster_id: str, tag_keys: List[str]) -> None:
cluster = self.describe_cluster(cluster_id) cluster = self.describe_cluster(cluster_id)
cluster.remove_tags(tag_keys) cluster.remove_tags(tag_keys)
def _manage_security_groups( def _manage_security_groups(
self, self,
ec2_subnet_id, ec2_subnet_id: str,
emr_managed_master_security_group, emr_managed_master_security_group: str,
emr_managed_slave_security_group, emr_managed_slave_security_group: str,
service_access_security_group, service_access_security_group: str,
**_, **_: Any,
): ) -> Tuple[str, str, str]:
default_return_value = ( default_return_value = (
emr_managed_master_security_group, emr_managed_master_security_group,
emr_managed_slave_security_group, emr_managed_slave_security_group,
@ -601,7 +644,7 @@ class ElasticMapReduceBackend(BaseBackend):
from moto.ec2.exceptions import InvalidSubnetIdError from moto.ec2.exceptions import InvalidSubnetIdError
try: try:
subnet = self.ec2_backend.get_subnet(ec2_subnet_id) subnet = self.ec2_backend.get_subnet(ec2_subnet_id) # type: ignore
except InvalidSubnetIdError: except InvalidSubnetIdError:
warnings.warn( warnings.warn(
f"Could not find Subnet with id: {ec2_subnet_id}\n" f"Could not find Subnet with id: {ec2_subnet_id}\n"
@ -620,7 +663,7 @@ class ElasticMapReduceBackend(BaseBackend):
) )
return master.id, slave.id, service.id return master.id, slave.id, service.id
def run_job_flow(self, **kwargs): def run_job_flow(self, **kwargs: Any) -> FakeCluster:
( (
kwargs["instance_attrs"]["emr_managed_master_security_group"], kwargs["instance_attrs"]["emr_managed_master_security_group"],
kwargs["instance_attrs"]["emr_managed_slave_security_group"], kwargs["instance_attrs"]["emr_managed_slave_security_group"],
@ -628,17 +671,19 @@ class ElasticMapReduceBackend(BaseBackend):
) = self._manage_security_groups(**kwargs["instance_attrs"]) ) = self._manage_security_groups(**kwargs["instance_attrs"])
return FakeCluster(self, **kwargs) return FakeCluster(self, **kwargs)
def set_visible_to_all_users(self, job_flow_ids, visible_to_all_users): def set_visible_to_all_users(
self, job_flow_ids: List[str], visible_to_all_users: str
) -> None:
for job_flow_id in job_flow_ids: for job_flow_id in job_flow_ids:
cluster = self.clusters[job_flow_id] cluster = self.clusters[job_flow_id]
cluster.set_visibility(visible_to_all_users) cluster.set_visibility(visible_to_all_users)
def set_termination_protection(self, job_flow_ids, value): def set_termination_protection(self, job_flow_ids: List[str], value: bool) -> None:
for job_flow_id in job_flow_ids: for job_flow_id in job_flow_ids:
cluster = self.clusters[job_flow_id] cluster = self.clusters[job_flow_id]
cluster.set_termination_protection(value) cluster.set_termination_protection(value)
def terminate_job_flows(self, job_flow_ids): def terminate_job_flows(self, job_flow_ids: List[str]) -> List[FakeCluster]:
clusters_terminated = [] clusters_terminated = []
clusters_protected = [] clusters_protected = []
for job_flow_id in job_flow_ids: for job_flow_id in job_flow_ids:
@ -654,7 +699,9 @@ class ElasticMapReduceBackend(BaseBackend):
) )
return clusters_terminated return clusters_terminated
def put_auto_scaling_policy(self, instance_group_id, auto_scaling_policy): def put_auto_scaling_policy(
self, instance_group_id: str, auto_scaling_policy: Optional[Dict[str, Any]]
) -> Optional[FakeInstanceGroup]:
instance_groups = self.get_instance_groups( instance_groups = self.get_instance_groups(
instance_group_ids=[instance_group_id] instance_group_ids=[instance_group_id]
) )
@ -664,28 +711,30 @@ class ElasticMapReduceBackend(BaseBackend):
instance_group.auto_scaling_policy = auto_scaling_policy instance_group.auto_scaling_policy = auto_scaling_policy
return instance_group return instance_group
def remove_auto_scaling_policy(self, instance_group_id): def remove_auto_scaling_policy(self, instance_group_id: str) -> None:
self.put_auto_scaling_policy(instance_group_id, auto_scaling_policy=None) self.put_auto_scaling_policy(instance_group_id, auto_scaling_policy=None)
def create_security_configuration(self, name, security_configuration): def create_security_configuration(
self, name: str, security_configuration: str
) -> FakeSecurityConfiguration:
if name in self.security_configurations: if name in self.security_configurations:
raise InvalidRequestException( raise InvalidRequestException(
message=f"SecurityConfiguration with name '{name}' already exists." message=f"SecurityConfiguration with name '{name}' already exists."
) )
security_configuration = FakeSecurityConfiguration( config = FakeSecurityConfiguration(
name=name, security_configuration=security_configuration name=name, security_configuration=security_configuration
) )
self.security_configurations[name] = security_configuration self.security_configurations[name] = config
return security_configuration return config
def get_security_configuration(self, name): def get_security_configuration(self, name: str) -> FakeSecurityConfiguration:
if name not in self.security_configurations: if name not in self.security_configurations:
raise InvalidRequestException( raise InvalidRequestException(
message=f"Security configuration with name '{name}' does not exist." message=f"Security configuration with name '{name}' does not exist."
) )
return self.security_configurations[name] return self.security_configurations[name]
def delete_security_configuration(self, name): def delete_security_configuration(self, name: str) -> None:
if name not in self.security_configurations: if name not in self.security_configurations:
raise InvalidRequestException( raise InvalidRequestException(
message=f"Security configuration with name '{name}' does not exist." message=f"Security configuration with name '{name}' does not exist."

View File

@ -2,6 +2,7 @@ import json
import re import re
from datetime import datetime, timezone from datetime import datetime, timezone
from functools import wraps from functools import wraps
from typing import Any, Callable, Dict, List, Pattern
from urllib.parse import urlparse from urllib.parse import urlparse
from moto.core.responses import AWSServiceSpec from moto.core.responses import AWSServiceSpec
@ -9,20 +10,27 @@ from moto.core.responses import BaseResponse
from moto.core.responses import xml_to_json_response from moto.core.responses import xml_to_json_response
from moto.core.utils import tags_from_query_string from moto.core.utils import tags_from_query_string
from .exceptions import ValidationException from .exceptions import ValidationException
from .models import emr_backends from .models import emr_backends, ElasticMapReduceBackend
from .utils import steps_from_query_string, Unflattener, ReleaseLabel from .utils import steps_from_query_string, Unflattener, ReleaseLabel
def generate_boto3_response(operation): def generate_boto3_response(
operation: str,
) -> Callable[
[Callable[["ElasticMapReduceResponse"], str]],
Callable[["ElasticMapReduceResponse"], str],
]:
"""The decorator to convert an XML response to JSON, if the request is """The decorator to convert an XML response to JSON, if the request is
determined to be from boto3. Pass the API action as a parameter. determined to be from boto3. Pass the API action as a parameter.
""" """
def _boto3_request(method): def _boto3_request(
method: Callable[["ElasticMapReduceResponse"], str]
) -> Callable[["ElasticMapReduceResponse"], str]:
@wraps(method) @wraps(method)
def f(self, *args, **kwargs): def f(self: "ElasticMapReduceResponse") -> str:
rendered = method(self, *args, **kwargs) rendered = method(self)
if "json" in self.headers.get("Content-Type", []): if "json" in self.headers.get("Content-Type", []):
self.response_headers.update( self.response_headers.update(
{ {
@ -46,30 +54,30 @@ class ElasticMapReduceResponse(BaseResponse):
# EMR end points are inconsistent in the placement of region name # EMR end points are inconsistent in the placement of region name
# in the URL, so parsing it out needs to be handled differently # in the URL, so parsing it out needs to be handled differently
region_regex = [ emr_region_regex: List[Pattern[str]] = [
re.compile(r"elasticmapreduce\.(.+?)\.amazonaws\.com"), re.compile(r"elasticmapreduce\.(.+?)\.amazonaws\.com"),
re.compile(r"(.+?)\.elasticmapreduce\.amazonaws\.com"), re.compile(r"(.+?)\.elasticmapreduce\.amazonaws\.com"),
] ]
aws_service_spec = AWSServiceSpec("data/emr/2009-03-31/service-2.json") aws_service_spec = AWSServiceSpec("data/emr/2009-03-31/service-2.json")
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="emr") super().__init__(service_name="emr")
def get_region_from_url(self, request, full_url): def get_region_from_url(self, request: Any, full_url: str) -> str:
parsed = urlparse(full_url) parsed = urlparse(full_url)
for regex in self.region_regex: for regex in ElasticMapReduceResponse.emr_region_regex:
match = regex.search(parsed.netloc) match = regex.search(parsed.netloc)
if match: if match:
return match.group(1) return match.group(1)
return self.default_region return self.default_region
@property @property
def backend(self): def backend(self) -> ElasticMapReduceBackend:
return emr_backends[self.current_account][self.region] return emr_backends[self.current_account][self.region]
@generate_boto3_response("AddInstanceGroups") @generate_boto3_response("AddInstanceGroups")
def add_instance_groups(self): def add_instance_groups(self) -> str:
jobflow_id = self._get_param("JobFlowId") jobflow_id = self._get_param("JobFlowId")
instance_groups = self._get_list_prefix("InstanceGroups.member") instance_groups = self._get_list_prefix("InstanceGroups.member")
for item in instance_groups: for item in instance_groups:
@ -78,12 +86,12 @@ class ElasticMapReduceResponse(BaseResponse):
self._parse_ebs_configuration(item) self._parse_ebs_configuration(item)
# Adding support for auto_scaling_policy # Adding support for auto_scaling_policy
Unflattener.unflatten_complex_params(item, "auto_scaling_policy") Unflattener.unflatten_complex_params(item, "auto_scaling_policy")
instance_groups = self.backend.add_instance_groups(jobflow_id, instance_groups) fake_groups = self.backend.add_instance_groups(jobflow_id, instance_groups)
template = self.response_template(ADD_INSTANCE_GROUPS_TEMPLATE) template = self.response_template(ADD_INSTANCE_GROUPS_TEMPLATE)
return template.render(instance_groups=instance_groups) return template.render(instance_groups=fake_groups)
@generate_boto3_response("AddJobFlowSteps") @generate_boto3_response("AddJobFlowSteps")
def add_job_flow_steps(self): def add_job_flow_steps(self) -> str:
job_flow_id = self._get_param("JobFlowId") job_flow_id = self._get_param("JobFlowId")
steps = self.backend.add_job_flow_steps( steps = self.backend.add_job_flow_steps(
job_flow_id, steps_from_query_string(self._get_list_prefix("Steps.member")) job_flow_id, steps_from_query_string(self._get_list_prefix("Steps.member"))
@ -92,18 +100,15 @@ class ElasticMapReduceResponse(BaseResponse):
return template.render(steps=steps) return template.render(steps=steps)
@generate_boto3_response("AddTags") @generate_boto3_response("AddTags")
def add_tags(self): def add_tags(self) -> str:
cluster_id = self._get_param("ResourceId") cluster_id = self._get_param("ResourceId")
tags = tags_from_query_string(self.querystring, prefix="Tags") tags = tags_from_query_string(self.querystring, prefix="Tags")
self.backend.add_tags(cluster_id, tags) self.backend.add_tags(cluster_id, tags)
template = self.response_template(ADD_TAGS_TEMPLATE) template = self.response_template(ADD_TAGS_TEMPLATE)
return template.render() return template.render()
def cancel_steps(self):
raise NotImplementedError
@generate_boto3_response("CreateSecurityConfiguration") @generate_boto3_response("CreateSecurityConfiguration")
def create_security_configuration(self): def create_security_configuration(self) -> str:
name = self._get_param("Name") name = self._get_param("Name")
security_configuration = self._get_param("SecurityConfiguration") security_configuration = self._get_param("SecurityConfiguration")
resp = self.backend.create_security_configuration( resp = self.backend.create_security_configuration(
@ -113,28 +118,28 @@ class ElasticMapReduceResponse(BaseResponse):
return template.render(name=name, creation_date_time=resp.creation_date_time) return template.render(name=name, creation_date_time=resp.creation_date_time)
@generate_boto3_response("DescribeSecurityConfiguration") @generate_boto3_response("DescribeSecurityConfiguration")
def describe_security_configuration(self): def describe_security_configuration(self) -> str:
name = self._get_param("Name") name = self._get_param("Name")
security_configuration = self.backend.get_security_configuration(name=name) security_configuration = self.backend.get_security_configuration(name=name)
template = self.response_template(DESCRIBE_SECURITY_CONFIGURATION_TEMPLATE) template = self.response_template(DESCRIBE_SECURITY_CONFIGURATION_TEMPLATE)
return template.render(security_configuration=security_configuration) return template.render(security_configuration=security_configuration)
@generate_boto3_response("DeleteSecurityConfiguration") @generate_boto3_response("DeleteSecurityConfiguration")
def delete_security_configuration(self): def delete_security_configuration(self) -> str:
name = self._get_param("Name") name = self._get_param("Name")
self.backend.delete_security_configuration(name=name) self.backend.delete_security_configuration(name=name)
template = self.response_template(DELETE_SECURITY_CONFIGURATION_TEMPLATE) template = self.response_template(DELETE_SECURITY_CONFIGURATION_TEMPLATE)
return template.render() return template.render()
@generate_boto3_response("DescribeCluster") @generate_boto3_response("DescribeCluster")
def describe_cluster(self): def describe_cluster(self) -> str:
cluster_id = self._get_param("ClusterId") cluster_id = self._get_param("ClusterId")
cluster = self.backend.describe_cluster(cluster_id) cluster = self.backend.describe_cluster(cluster_id)
template = self.response_template(DESCRIBE_CLUSTER_TEMPLATE) template = self.response_template(DESCRIBE_CLUSTER_TEMPLATE)
return template.render(cluster=cluster) return template.render(cluster=cluster)
@generate_boto3_response("DescribeJobFlows") @generate_boto3_response("DescribeJobFlows")
def describe_job_flows(self): def describe_job_flows(self) -> str:
created_after = self._get_param("CreatedAfter") created_after = self._get_param("CreatedAfter")
created_before = self._get_param("CreatedBefore") created_before = self._get_param("CreatedBefore")
job_flow_ids = self._get_multi_param("JobFlowIds.member") job_flow_ids = self._get_multi_param("JobFlowIds.member")
@ -146,7 +151,7 @@ class ElasticMapReduceResponse(BaseResponse):
return template.render(clusters=clusters) return template.render(clusters=clusters)
@generate_boto3_response("DescribeStep") @generate_boto3_response("DescribeStep")
def describe_step(self): def describe_step(self) -> str:
cluster_id = self._get_param("ClusterId") cluster_id = self._get_param("ClusterId")
step_id = self._get_param("StepId") step_id = self._get_param("StepId")
step = self.backend.describe_step(cluster_id, step_id) step = self.backend.describe_step(cluster_id, step_id)
@ -154,7 +159,7 @@ class ElasticMapReduceResponse(BaseResponse):
return template.render(step=step) return template.render(step=step)
@generate_boto3_response("ListBootstrapActions") @generate_boto3_response("ListBootstrapActions")
def list_bootstrap_actions(self): def list_bootstrap_actions(self) -> str:
cluster_id = self._get_param("ClusterId") cluster_id = self._get_param("ClusterId")
marker = self._get_param("Marker") marker = self._get_param("Marker")
bootstrap_actions, marker = self.backend.list_bootstrap_actions( bootstrap_actions, marker = self.backend.list_bootstrap_actions(
@ -164,7 +169,7 @@ class ElasticMapReduceResponse(BaseResponse):
return template.render(bootstrap_actions=bootstrap_actions, marker=marker) return template.render(bootstrap_actions=bootstrap_actions, marker=marker)
@generate_boto3_response("ListClusters") @generate_boto3_response("ListClusters")
def list_clusters(self): def list_clusters(self) -> str:
cluster_states = self._get_multi_param("ClusterStates.member") cluster_states = self._get_multi_param("ClusterStates.member")
created_after = self._get_param("CreatedAfter") created_after = self._get_param("CreatedAfter")
created_before = self._get_param("CreatedBefore") created_before = self._get_param("CreatedBefore")
@ -176,7 +181,7 @@ class ElasticMapReduceResponse(BaseResponse):
return template.render(clusters=clusters, marker=marker) return template.render(clusters=clusters, marker=marker)
@generate_boto3_response("ListInstanceGroups") @generate_boto3_response("ListInstanceGroups")
def list_instance_groups(self): def list_instance_groups(self) -> str:
cluster_id = self._get_param("ClusterId") cluster_id = self._get_param("ClusterId")
marker = self._get_param("Marker") marker = self._get_param("Marker")
instance_groups, marker = self.backend.list_instance_groups( instance_groups, marker = self.backend.list_instance_groups(
@ -186,7 +191,7 @@ class ElasticMapReduceResponse(BaseResponse):
return template.render(instance_groups=instance_groups, marker=marker) return template.render(instance_groups=instance_groups, marker=marker)
@generate_boto3_response("ListInstances") @generate_boto3_response("ListInstances")
def list_instances(self): def list_instances(self) -> str:
cluster_id = self._get_param("ClusterId") cluster_id = self._get_param("ClusterId")
marker = self._get_param("Marker") marker = self._get_param("Marker")
instance_group_id = self._get_param("InstanceGroupId") instance_group_id = self._get_param("InstanceGroupId")
@ -201,7 +206,7 @@ class ElasticMapReduceResponse(BaseResponse):
return template.render(instances=instances, marker=marker) return template.render(instances=instances, marker=marker)
@generate_boto3_response("ListSteps") @generate_boto3_response("ListSteps")
def list_steps(self): def list_steps(self) -> str:
cluster_id = self._get_param("ClusterId") cluster_id = self._get_param("ClusterId")
marker = self._get_param("Marker") marker = self._get_param("Marker")
step_ids = self._get_multi_param("StepIds.member") step_ids = self._get_multi_param("StepIds.member")
@ -213,7 +218,7 @@ class ElasticMapReduceResponse(BaseResponse):
return template.render(steps=steps, marker=marker) return template.render(steps=steps, marker=marker)
@generate_boto3_response("ModifyCluster") @generate_boto3_response("ModifyCluster")
def modify_cluster(self): def modify_cluster(self) -> str:
cluster_id = self._get_param("ClusterId") cluster_id = self._get_param("ClusterId")
step_concurrency_level = self._get_param("StepConcurrencyLevel") step_concurrency_level = self._get_param("StepConcurrencyLevel")
cluster = self.backend.modify_cluster(cluster_id, step_concurrency_level) cluster = self.backend.modify_cluster(cluster_id, step_concurrency_level)
@ -221,16 +226,16 @@ class ElasticMapReduceResponse(BaseResponse):
return template.render(cluster=cluster) return template.render(cluster=cluster)
@generate_boto3_response("ModifyInstanceGroups") @generate_boto3_response("ModifyInstanceGroups")
def modify_instance_groups(self): def modify_instance_groups(self) -> str:
instance_groups = self._get_list_prefix("InstanceGroups.member") instance_groups = self._get_list_prefix("InstanceGroups.member")
for item in instance_groups: for item in instance_groups:
item["instance_count"] = int(item["instance_count"]) item["instance_count"] = int(item["instance_count"])
instance_groups = self.backend.modify_instance_groups(instance_groups) self.backend.modify_instance_groups(instance_groups)
template = self.response_template(MODIFY_INSTANCE_GROUPS_TEMPLATE) template = self.response_template(MODIFY_INSTANCE_GROUPS_TEMPLATE)
return template.render(instance_groups=instance_groups) return template.render()
@generate_boto3_response("RemoveTags") @generate_boto3_response("RemoveTags")
def remove_tags(self): def remove_tags(self) -> str:
cluster_id = self._get_param("ResourceId") cluster_id = self._get_param("ResourceId")
tag_keys = self._get_multi_param("TagKeys.member") tag_keys = self._get_multi_param("TagKeys.member")
self.backend.remove_tags(cluster_id, tag_keys) self.backend.remove_tags(cluster_id, tag_keys)
@ -238,7 +243,7 @@ class ElasticMapReduceResponse(BaseResponse):
return template.render() return template.render()
@generate_boto3_response("RunJobFlow") @generate_boto3_response("RunJobFlow")
def run_job_flow(self): def run_job_flow(self) -> str:
instance_attrs = dict( instance_attrs = dict(
master_instance_type=self._get_param("Instances.MasterInstanceType"), master_instance_type=self._get_param("Instances.MasterInstanceType"),
slave_instance_type=self._get_param("Instances.SlaveInstanceType"), slave_instance_type=self._get_param("Instances.SlaveInstanceType"),
@ -349,7 +354,7 @@ class ElasticMapReduceResponse(BaseResponse):
if security_configuration: if security_configuration:
kwargs["security_configuration"] = security_configuration kwargs["security_configuration"] = security_configuration
kerberos_attributes = {} kerberos_attributes: Dict[str, Any] = {}
kwargs["kerberos_attributes"] = kerberos_attributes kwargs["kerberos_attributes"] = kerberos_attributes
realm = self._get_param("KerberosAttributes.Realm") realm = self._get_param("KerberosAttributes.Realm")
@ -413,13 +418,13 @@ class ElasticMapReduceResponse(BaseResponse):
template = self.response_template(RUN_JOB_FLOW_TEMPLATE) template = self.response_template(RUN_JOB_FLOW_TEMPLATE)
return template.render(cluster=cluster) return template.render(cluster=cluster)
def _has_key_prefix(self, key_prefix, value): def _has_key_prefix(self, key_prefix: str, value: Dict[str, Any]) -> bool:
for key in value: # iter on both keys and values for key in value: # iter on both keys and values
if key.startswith(key_prefix): if key.startswith(key_prefix):
return True return True
return False return False
def _parse_ebs_configuration(self, instance_group): def _parse_ebs_configuration(self, instance_group: Dict[str, Any]) -> None:
key_ebs_config = "ebs_configuration" key_ebs_config = "ebs_configuration"
ebs_configuration = dict() ebs_configuration = dict()
# Filter only EBS config keys # Filter only EBS config keys
@ -456,7 +461,7 @@ class ElasticMapReduceResponse(BaseResponse):
vol_iops = vlespc_keyfmt.format(iops) vol_iops = vlespc_keyfmt.format(iops)
vol_type = vlespc_keyfmt.format(volume_type) vol_type = vlespc_keyfmt.format(volume_type)
ebs_block = dict() ebs_block: Dict[str, Any] = dict()
ebs_block[volume_specification] = dict() ebs_block[volume_specification] = dict()
if vol_size in ebs_configuration: if vol_size in ebs_configuration:
instance_group.pop(vol_size) instance_group.pop(vol_size)
@ -491,7 +496,7 @@ class ElasticMapReduceResponse(BaseResponse):
instance_group[key_ebs_config] = ebs_configuration instance_group[key_ebs_config] = ebs_configuration
@generate_boto3_response("SetTerminationProtection") @generate_boto3_response("SetTerminationProtection")
def set_termination_protection(self): def set_termination_protection(self) -> str:
termination_protection = self._get_bool_param("TerminationProtected") termination_protection = self._get_bool_param("TerminationProtected")
job_ids = self._get_multi_param("JobFlowIds.member") job_ids = self._get_multi_param("JobFlowIds.member")
self.backend.set_termination_protection(job_ids, termination_protection) self.backend.set_termination_protection(job_ids, termination_protection)
@ -499,7 +504,7 @@ class ElasticMapReduceResponse(BaseResponse):
return template.render() return template.render()
@generate_boto3_response("SetVisibleToAllUsers") @generate_boto3_response("SetVisibleToAllUsers")
def set_visible_to_all_users(self): def set_visible_to_all_users(self) -> str:
visible_to_all_users = self._get_param("VisibleToAllUsers") visible_to_all_users = self._get_param("VisibleToAllUsers")
job_ids = self._get_multi_param("JobFlowIds.member") job_ids = self._get_multi_param("JobFlowIds.member")
self.backend.set_visible_to_all_users(job_ids, visible_to_all_users) self.backend.set_visible_to_all_users(job_ids, visible_to_all_users)
@ -507,14 +512,14 @@ class ElasticMapReduceResponse(BaseResponse):
return template.render() return template.render()
@generate_boto3_response("TerminateJobFlows") @generate_boto3_response("TerminateJobFlows")
def terminate_job_flows(self): def terminate_job_flows(self) -> str:
job_ids = self._get_multi_param("JobFlowIds.member.") job_ids = self._get_multi_param("JobFlowIds.member.")
self.backend.terminate_job_flows(job_ids) self.backend.terminate_job_flows(job_ids)
template = self.response_template(TERMINATE_JOB_FLOWS_TEMPLATE) template = self.response_template(TERMINATE_JOB_FLOWS_TEMPLATE)
return template.render() return template.render()
@generate_boto3_response("PutAutoScalingPolicy") @generate_boto3_response("PutAutoScalingPolicy")
def put_auto_scaling_policy(self): def put_auto_scaling_policy(self) -> str:
cluster_id = self._get_param("ClusterId") cluster_id = self._get_param("ClusterId")
cluster = self.backend.describe_cluster(cluster_id) cluster = self.backend.describe_cluster(cluster_id)
instance_group_id = self._get_param("InstanceGroupId") instance_group_id = self._get_param("InstanceGroupId")
@ -528,12 +533,11 @@ class ElasticMapReduceResponse(BaseResponse):
) )
@generate_boto3_response("RemoveAutoScalingPolicy") @generate_boto3_response("RemoveAutoScalingPolicy")
def remove_auto_scaling_policy(self): def remove_auto_scaling_policy(self) -> str:
cluster_id = self._get_param("ClusterId")
instance_group_id = self._get_param("InstanceGroupId") instance_group_id = self._get_param("InstanceGroupId")
instance_group = self.backend.remove_auto_scaling_policy(instance_group_id) self.backend.remove_auto_scaling_policy(instance_group_id)
template = self.response_template(REMOVE_AUTO_SCALING_POLICY) template = self.response_template(REMOVE_AUTO_SCALING_POLICY)
return template.render(cluster_id=cluster_id, instance_group=instance_group) return template.render()
ADD_INSTANCE_GROUPS_TEMPLATE = """<AddInstanceGroupsResponse xmlns="http://elasticmapreduce.amazonaws.com/doc/2009-03-31"> ADD_INSTANCE_GROUPS_TEMPLATE = """<AddInstanceGroupsResponse xmlns="http://elasticmapreduce.amazonaws.com/doc/2009-03-31">

View File

@ -2,6 +2,7 @@ import copy
import datetime import datetime
import re import re
import string import string
from typing import Any, List, Dict, Tuple, Iterator
from moto.core.utils import ( from moto.core.utils import (
camelcase_to_underscores, camelcase_to_underscores,
iso_8601_datetime_with_milliseconds, iso_8601_datetime_with_milliseconds,
@ -9,24 +10,26 @@ from moto.core.utils import (
from moto.moto_api._internal import mock_random as random from moto.moto_api._internal import mock_random as random
def random_id(size=13): def random_id(size: int = 13) -> str:
chars = list(range(10)) + list(string.ascii_uppercase) chars = list(range(10)) + list(string.ascii_uppercase)
return "".join(str(random.choice(chars)) for x in range(size)) return "".join(str(random.choice(chars)) for x in range(size))
def random_cluster_id(): def random_cluster_id() -> str:
return f"j-{random_id()}" return f"j-{random_id()}"
def random_step_id(): def random_step_id() -> str:
return f"s-{random_id()}" return f"s-{random_id()}"
def random_instance_group_id(): def random_instance_group_id() -> str:
return f"i-{random_id()}" return f"i-{random_id()}"
def steps_from_query_string(querystring_dict): def steps_from_query_string(
querystring_dict: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
steps = [] steps = []
for step in querystring_dict: for step in querystring_dict:
step["jar"] = step.pop("hadoop_jar_step._jar") step["jar"] = step.pop("hadoop_jar_step._jar")
@ -45,7 +48,7 @@ def steps_from_query_string(querystring_dict):
class Unflattener: class Unflattener:
@staticmethod @staticmethod
def unflatten_complex_params(input_dict, param_name): def unflatten_complex_params(input_dict: Dict[str, Any], param_name: str) -> None: # type: ignore[misc]
"""Function to unflatten (portions of) dicts with complex keys. The moto request parser flattens the incoming """Function to unflatten (portions of) dicts with complex keys. The moto request parser flattens the incoming
request bodies, which is generally helpful, but for nested dicts/lists can result in a hard-to-manage request bodies, which is generally helpful, but for nested dicts/lists can result in a hard-to-manage
parameter exposion. This function allows one to selectively unflatten a set of dict keys, replacing them parameter exposion. This function allows one to selectively unflatten a set of dict keys, replacing them
@ -68,7 +71,7 @@ class Unflattener:
Unflattener._set_deep(k, input_dict, items_to_process[k]) Unflattener._set_deep(k, input_dict, items_to_process[k])
@staticmethod @staticmethod
def _set_deep(complex_key, container, value): def _set_deep(complex_key: Any, container: Any, value: Any) -> None: # type: ignore[misc]
keys = complex_key.split(".") keys = complex_key.split(".")
keys.reverse() keys.reverse()
@ -91,7 +94,7 @@ class Unflattener:
container = Unflattener._get_child(container, key) container = Unflattener._get_child(container, key)
@staticmethod @staticmethod
def _add_to_container(container, key, value): def _add_to_container(container: Any, key: Any, value: Any) -> Any: # type: ignore[misc]
if type(container) is dict: if type(container) is dict:
container[key] = value container[key] = value
elif type(container) is list: elif type(container) is list:
@ -102,7 +105,7 @@ class Unflattener:
return value return value
@staticmethod @staticmethod
def _get_child(container, key): def _get_child(container: Any, key: Any) -> Any: # type: ignore[misc]
if type(container) is dict: if type(container) is dict:
return container[key] return container[key]
elif type(container) is list: elif type(container) is list:
@ -110,7 +113,7 @@ class Unflattener:
return container[i - 1] return container[i - 1]
@staticmethod @staticmethod
def _key_in_container(container, key): def _key_in_container(container: Any, key: Any) -> bool: # type: ignore
if type(container) is dict: if type(container) is dict:
return key in container return key in container
elif type(container) is list: elif type(container) is list:
@ -122,7 +125,7 @@ class CamelToUnderscoresWalker:
"""A class to convert the keys in dict/list hierarchical data structures from CamelCase to snake_case (underscores)""" """A class to convert the keys in dict/list hierarchical data structures from CamelCase to snake_case (underscores)"""
@staticmethod @staticmethod
def parse(x): def parse(x: Any) -> Any: # type: ignore[misc]
if isinstance(x, dict): if isinstance(x, dict):
return CamelToUnderscoresWalker.parse_dict(x) return CamelToUnderscoresWalker.parse_dict(x)
elif isinstance(x, list): elif isinstance(x, list):
@ -131,29 +134,29 @@ class CamelToUnderscoresWalker:
return CamelToUnderscoresWalker.parse_scalar(x) return CamelToUnderscoresWalker.parse_scalar(x)
@staticmethod @staticmethod
def parse_dict(x): def parse_dict(x: Dict[str, Any]) -> Dict[str, Any]: # type: ignore[misc]
temp = {} temp = {}
for key in x.keys(): for key in x.keys():
temp[camelcase_to_underscores(key)] = CamelToUnderscoresWalker.parse(x[key]) temp[camelcase_to_underscores(key)] = CamelToUnderscoresWalker.parse(x[key])
return temp return temp
@staticmethod @staticmethod
def parse_list(x): def parse_list(x: Any) -> Any: # type: ignore[misc]
temp = [] temp = []
for i in x: for i in x:
temp.append(CamelToUnderscoresWalker.parse(i)) temp.append(CamelToUnderscoresWalker.parse(i))
return temp return temp
@staticmethod @staticmethod
def parse_scalar(x): def parse_scalar(x: Any) -> Any: # type: ignore[misc]
return x return x
class ReleaseLabel(object): class ReleaseLabel:
version_re = re.compile(r"^emr-(\d+)\.(\d+)\.(\d+)$") version_re = re.compile(r"^emr-(\d+)\.(\d+)\.(\d+)$")
def __init__(self, release_label): def __init__(self, release_label: str):
major, minor, patch = self.parse(release_label) major, minor, patch = self.parse(release_label)
self.major = major self.major = major
@ -161,7 +164,7 @@ class ReleaseLabel(object):
self.patch = patch self.patch = patch
@classmethod @classmethod
def parse(cls, release_label): def parse(cls, release_label: str) -> Tuple[int, int, int]:
if not release_label: if not release_label:
raise ValueError(f"Invalid empty ReleaseLabel: {release_label}") raise ValueError(f"Invalid empty ReleaseLabel: {release_label}")
@ -171,23 +174,19 @@ class ReleaseLabel(object):
major, minor, patch = match.groups() major, minor, patch = match.groups()
major = int(major) return int(major), int(minor), int(patch)
minor = int(minor)
patch = int(patch)
return major, minor, patch def __str__(self) -> str:
def __str__(self):
version = f"emr-{self.major}.{self.minor}.{self.patch}" version = f"emr-{self.major}.{self.minor}.{self.patch}"
return version return version
def __repr__(self): def __repr__(self) -> str:
return f"{self.__class__.__name__}({str(self)})" return f"{self.__class__.__name__}({str(self)})"
def __iter__(self): def __iter__(self) -> Iterator[Tuple[int, int, int]]:
return iter((self.major, self.minor, self.patch)) return iter((self.major, self.minor, self.patch)) # type: ignore
def __eq__(self, other): def __eq__(self, other: Any) -> bool:
if not isinstance(other, self.__class__): if not isinstance(other, self.__class__):
return NotImplemented return NotImplemented
return ( return (
@ -196,46 +195,46 @@ class ReleaseLabel(object):
and self.patch == other.patch and self.patch == other.patch
) )
def __ne__(self, other): def __ne__(self, other: Any) -> bool:
if not isinstance(other, self.__class__): if not isinstance(other, self.__class__):
return NotImplemented return NotImplemented
return tuple(self) != tuple(other) return tuple(self) != tuple(other)
def __lt__(self, other): def __lt__(self, other: Any) -> bool:
if not isinstance(other, self.__class__): if not isinstance(other, self.__class__):
return NotImplemented return NotImplemented
return tuple(self) < tuple(other) return tuple(self) < tuple(other)
def __le__(self, other): def __le__(self, other: Any) -> bool:
if not isinstance(other, self.__class__): if not isinstance(other, self.__class__):
return NotImplemented return NotImplemented
return tuple(self) <= tuple(other) return tuple(self) <= tuple(other)
def __gt__(self, other): def __gt__(self, other: Any) -> bool:
if not isinstance(other, self.__class__): if not isinstance(other, self.__class__):
return NotImplemented return NotImplemented
return tuple(self) > tuple(other) return tuple(self) > tuple(other)
def __ge__(self, other): def __ge__(self, other: Any) -> bool:
if not isinstance(other, self.__class__): if not isinstance(other, self.__class__):
return NotImplemented return NotImplemented
return tuple(self) >= tuple(other) return tuple(self) >= tuple(other)
class EmrManagedSecurityGroup(object): class EmrManagedSecurityGroup:
class Kind: class Kind:
MASTER = "Master" MASTER = "Master"
SLAVE = "Slave" SLAVE = "Slave"
SERVICE = "Service" SERVICE = "Service"
kind = None kind = ""
group_name = "" group_name = ""
short_name = "" short_name = ""
desc_fmt = "{short_name} for Elastic MapReduce created on {created}" desc_fmt = "{short_name} for Elastic MapReduce created on {created}"
@classmethod @classmethod
def description(cls): def description(cls) -> str:
created = iso_8601_datetime_with_milliseconds(datetime.datetime.now()) created = iso_8601_datetime_with_milliseconds(datetime.datetime.now())
return cls.desc_fmt.format(short_name=cls.short_name, created=created) return cls.desc_fmt.format(short_name=cls.short_name, created=created)
@ -258,7 +257,7 @@ class EmrManagedServiceAccessSecurityGroup(EmrManagedSecurityGroup):
short_name = "Service access" short_name = "Service access"
class EmrSecurityGroupManager(object): class EmrSecurityGroupManager:
MANAGED_RULES_EGRESS = [ MANAGED_RULES_EGRESS = [
{ {
@ -383,13 +382,16 @@ class EmrSecurityGroupManager(object):
}, },
] ]
def __init__(self, ec2_backend, vpc_id): def __init__(self, ec2_backend: Any, vpc_id: str):
self.ec2 = ec2_backend self.ec2 = ec2_backend
self.vpc_id = vpc_id self.vpc_id = vpc_id
def manage_security_groups( def manage_security_groups(
self, master_security_group, slave_security_group, service_access_security_group self,
): master_security_group: str,
slave_security_group: str,
service_access_security_group: str,
) -> Tuple[Any, Any, Any]:
group_metadata = [ group_metadata = [
( (
master_security_group, master_security_group,
@ -407,7 +409,7 @@ class EmrSecurityGroupManager(object):
EmrManagedServiceAccessSecurityGroup, EmrManagedServiceAccessSecurityGroup,
), ),
] ]
managed_groups = {} managed_groups: Dict[str, Any] = {}
for name, kind, defaults in group_metadata: for name, kind, defaults in group_metadata:
managed_groups[kind] = self._get_or_create_sg(name, defaults) managed_groups[kind] = self._get_or_create_sg(name, defaults)
self._add_rules_to(managed_groups) self._add_rules_to(managed_groups)
@ -417,7 +419,7 @@ class EmrSecurityGroupManager(object):
managed_groups[EmrManagedSecurityGroup.Kind.SERVICE], managed_groups[EmrManagedSecurityGroup.Kind.SERVICE],
) )
def _get_or_create_sg(self, sg_id, defaults): def _get_or_create_sg(self, sg_id: str, defaults: Any) -> Any:
find_sg = self.ec2.get_security_group_by_name_or_id find_sg = self.ec2.get_security_group_by_name_or_id
create_sg = self.ec2.create_security_group create_sg = self.ec2.create_security_group
group_id_or_name = sg_id or defaults.group_name group_id_or_name = sg_id or defaults.group_name
@ -430,7 +432,7 @@ class EmrSecurityGroupManager(object):
group = create_sg(defaults.group_name, defaults.description(), self.vpc_id) group = create_sg(defaults.group_name, defaults.description(), self.vpc_id)
return group return group
def _add_rules_to(self, managed_groups): def _add_rules_to(self, managed_groups: Dict[str, Any]) -> None:
rules_metadata = [ rules_metadata = [
(self.MANAGED_RULES_EGRESS, self.ec2.authorize_security_group_egress), (self.MANAGED_RULES_EGRESS, self.ec2.authorize_security_group_egress),
(self.MANAGED_RULES_INGRESS, self.ec2.authorize_security_group_ingress), (self.MANAGED_RULES_INGRESS, self.ec2.authorize_security_group_ingress),
@ -447,7 +449,7 @@ class EmrSecurityGroupManager(object):
pass pass
@staticmethod @staticmethod
def _render_rules(rules, managed_groups): def _render_rules(rules: Any, managed_groups: Dict[str, Any]) -> List[Dict[str, Any]]: # type: ignore[misc]
rendered_rules = copy.deepcopy(rules) rendered_rules = copy.deepcopy(rules)
for rule in rendered_rules: for rule in rendered_rules:
rule["group_name_or_id"] = managed_groups[rule["group_name_or_id"]].id rule["group_name_or_id"] = managed_groups[rule["group_name_or_id"]].id

View File

@ -229,7 +229,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy] [mypy]
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/ebs/,moto/ec2,moto/ec2instanceconnect,moto/ecr,moto/ecs,moto/efs,moto/eks,moto/elasticache,moto/elasticbeanstalk,moto/elastictranscoder,moto/elb,moto/elbv2,moto/es,moto/moto_api,moto/neptune files= moto/a*,moto/b*,moto/c*,moto/d*,moto/ebs/,moto/ec2,moto/ec2instanceconnect,moto/ecr,moto/ecs,moto/efs,moto/eks,moto/elasticache,moto/elasticbeanstalk,moto/elastictranscoder,moto/elb,moto/elbv2,moto/emr,moto/es,moto/moto_api,moto/neptune
show_column_numbers=True show_column_numbers=True
show_error_codes = True show_error_codes = True
disable_error_code=abstract disable_error_code=abstract