TechDebt: MyPy Batch (#5592)

This commit is contained in:
Bert Blommers 2022-10-23 13:26:55 +00:00 committed by GitHub
parent a470ea748b
commit 6f3b250fc7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 365 additions and 289 deletions

View File

@ -1,6 +1,7 @@
import re
from itertools import cycle
from time import sleep
from typing import Any, Dict, List, Tuple, Optional, Set
import datetime
import time
import logging
@ -9,10 +10,11 @@ import dateutil.parser
from sys import platform
from moto.core import BaseBackend, BaseModel, CloudFormationModel
from moto.iam import iam_backends
from moto.ec2 import ec2_backends
from moto.ecs import ecs_backends
from moto.logs import logs_backends
from moto.iam.models import iam_backends, IAMBackend
from moto.ec2.models import ec2_backends, EC2Backend
from moto.ec2.models.instances import Instance
from moto.ecs.models import ecs_backends, EC2ContainerServiceBackend
from moto.logs.models import logs_backends, LogsBackend
from moto.utilities.tagging_service import TaggingService
from .exceptions import InvalidParameterValueException, ClientException, ValidationError
@ -39,7 +41,7 @@ COMPUTE_ENVIRONMENT_NAME_REGEX = re.compile(
)
def datetime2int_milliseconds(date):
def datetime2int_milliseconds(date: datetime.datetime) -> int:
"""
AWS returns timestamps in milliseconds
We don't use milliseconds timestamps internally,
@ -48,20 +50,20 @@ def datetime2int_milliseconds(date):
return int(date.timestamp() * 1000)
def datetime2int(date):
def datetime2int(date: datetime.datetime) -> int:
return int(time.mktime(date.timetuple()))
class ComputeEnvironment(CloudFormationModel):
def __init__(
self,
compute_environment_name,
_type,
state,
compute_resources,
service_role,
account_id,
region_name,
compute_environment_name: str,
_type: str,
state: str,
compute_resources: Dict[str, Any],
service_role: str,
account_id: str,
region_name: str,
):
self.name = compute_environment_name
self.env_type = _type
@ -72,34 +74,39 @@ class ComputeEnvironment(CloudFormationModel):
account_id, compute_environment_name, region_name
)
self.instances = []
self.ecs_arn = None
self.ecs_name = None
self.instances: List[Instance] = []
self.ecs_arn = ""
self.ecs_name = ""
def add_instance(self, instance):
def add_instance(self, instance: Instance) -> None:
self.instances.append(instance)
def set_ecs(self, arn, name):
def set_ecs(self, arn: str, name: str) -> None:
self.ecs_arn = arn
self.ecs_name = name
@property
def physical_resource_id(self):
def physical_resource_id(self) -> str:
return self.arn
@staticmethod
def cloudformation_name_type():
def cloudformation_name_type() -> str:
return "ComputeEnvironmentName"
@staticmethod
def cloudformation_type():
def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-batch-computeenvironment.html
return "AWS::Batch::ComputeEnvironment"
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs
):
def create_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Dict[str, Any],
account_id: str,
region_name: str,
**kwargs: Any,
) -> "ComputeEnvironment":
backend = batch_backends[account_id][region_name]
properties = cloudformation_json["Properties"]
@ -118,13 +125,13 @@ class ComputeEnvironment(CloudFormationModel):
class JobQueue(CloudFormationModel):
def __init__(
self,
name,
priority,
state,
environments,
env_order_json,
backend,
tags=None,
name: str,
priority: str,
state: str,
environments: List[ComputeEnvironment],
env_order_json: List[Dict[str, Any]],
backend: "BatchBackend",
tags: Optional[Dict[str, str]] = None,
):
"""
:param name: Job queue name
@ -150,10 +157,10 @@ class JobQueue(CloudFormationModel):
if tags:
backend.tag_resource(self.arn, tags)
self.jobs = []
self.jobs: List[Job] = []
def describe(self):
result = {
def describe(self) -> Dict[str, Any]:
return {
"computeEnvironmentOrder": self.env_order_json,
"jobQueueArn": self.arn,
"jobQueueName": self.name,
@ -163,25 +170,28 @@ class JobQueue(CloudFormationModel):
"tags": self.backend.list_tags_for_resource(self.arn),
}
return result
@property
def physical_resource_id(self):
def physical_resource_id(self) -> str:
return self.arn
@staticmethod
def cloudformation_name_type():
def cloudformation_name_type() -> str:
return "JobQueueName"
@staticmethod
def cloudformation_type():
def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-batch-jobqueue.html
return "AWS::Batch::JobQueue"
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs
):
def create_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Dict[str, Any],
account_id: str,
region_name: str,
**kwargs: Any,
) -> "JobQueue":
backend = batch_backends[account_id][region_name]
properties = cloudformation_json["Properties"]
@ -206,25 +216,24 @@ class JobQueue(CloudFormationModel):
class JobDefinition(CloudFormationModel):
def __init__(
self,
name,
parameters,
_type,
container_properties,
tags=None,
revision=0,
retry_strategy=0,
timeout=None,
backend=None,
platform_capabilities=None,
propagate_tags=None,
name: str,
parameters: Optional[Dict[str, Any]],
_type: str,
container_properties: Dict[str, Any],
tags: Dict[str, str],
retry_strategy: Dict[str, str],
timeout: Dict[str, int],
backend: "BatchBackend",
platform_capabilities: List[str],
propagate_tags: bool,
revision: Optional[int] = 0,
):
self.name = name
self.retry_strategy = retry_strategy
self.type = _type
self.revision = revision
self.revision = revision or 0
self._region = backend.region_name
self.container_properties = container_properties
self.arn = None
self.status = "ACTIVE"
self.parameters = parameters or {}
self.timeout = timeout
@ -238,26 +247,23 @@ class JobDefinition(CloudFormationModel):
self.container_properties["secrets"] = []
self._validate()
self._update_arn()
tags = self._format_tags(tags or {})
# Validate the tags before proceeding.
errmsg = self.backend.tagger.validate_tags(tags)
if errmsg:
raise ValidationError(errmsg)
self.backend.tagger.tag_resource(self.arn, tags)
def _format_tags(self, tags):
return [{"Key": k, "Value": v} for k, v in tags.items()]
def _update_arn(self):
self.revision += 1
self.arn = make_arn_for_task_def(
self.backend.account_id, self.name, self.revision, self._region
)
def _get_resource_requirement(self, req_type, default=None):
tag_list = self._format_tags(tags or {})
# Validate the tags before proceeding.
errmsg = self.backend.tagger.validate_tags(tag_list)
if errmsg:
raise ValidationError(errmsg)
self.backend.tagger.tag_resource(self.arn, tag_list)
def _format_tags(self, tags: Dict[str, str]) -> List[Dict[str, str]]:
return [{"Key": k, "Value": v} for k, v in tags.items()]
def _get_resource_requirement(self, req_type: str, default: Any = None) -> Any:
"""
Get resource requirement from container properties.
@ -297,7 +303,7 @@ class JobDefinition(CloudFormationModel):
else:
return self.container_properties.get(req_type, default)
def _validate(self):
def _validate(self) -> None:
# For future use when containers arnt the only thing in batch
if self.type not in ("container",):
raise ClientException('type must be one of "container"')
@ -320,12 +326,18 @@ class JobDefinition(CloudFormationModel):
if vcpus <= 0:
raise ClientException("container vcpus limit must be greater than 0")
def deregister(self):
def deregister(self) -> None:
self.status = "INACTIVE"
def update(
self, parameters, _type, container_properties, retry_strategy, tags, timeout
):
self,
parameters: Optional[Dict[str, Any]],
_type: str,
container_properties: Dict[str, Any],
retry_strategy: Dict[str, Any],
tags: Dict[str, str],
timeout: Dict[str, int],
) -> "JobDefinition":
if self.status != "INACTIVE":
if parameters is None:
parameters = self.parameters
@ -353,7 +365,7 @@ class JobDefinition(CloudFormationModel):
propagate_tags=self.propagate_tags,
)
def describe(self):
def describe(self) -> Dict[str, Any]:
result = {
"jobDefinitionArn": self.arn,
"jobDefinitionName": self.name,
@ -374,22 +386,27 @@ class JobDefinition(CloudFormationModel):
return result
@property
def physical_resource_id(self):
def physical_resource_id(self) -> str:
return self.arn
@staticmethod
def cloudformation_name_type():
def cloudformation_name_type() -> str:
return "JobDefinitionName"
@staticmethod
def cloudformation_type():
def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-batch-jobdefinition.html
return "AWS::Batch::JobDefinition"
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs
):
def create_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Dict[str, Any],
account_id: str,
region_name: str,
**kwargs: Any,
) -> "JobDefinition":
backend = batch_backends[account_id][region_name]
properties = cloudformation_json["Properties"]
res = backend.register_job_definition(
@ -411,25 +428,15 @@ class JobDefinition(CloudFormationModel):
class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
def __init__(
self,
name,
job_def,
job_queue,
log_backend,
container_overrides,
depends_on,
all_jobs,
timeout,
name: str,
job_def: JobDefinition,
job_queue: JobQueue,
log_backend: LogsBackend,
container_overrides: Optional[Dict[str, Any]],
depends_on: Optional[List[Dict[str, str]]],
all_jobs: Dict[str, "Job"],
timeout: Optional[Dict[str, int]],
):
"""
Docker Job
:param name: Job Name
:param job_def: Job definition
:type: job_def: JobDefinition
:param job_queue: Job Queue
:param log_backend: Log backend
:type log_backend: moto.logs.models.LogsBackend
"""
threading.Thread.__init__(self)
DockerModel.__init__(self)
ManagedState.__init__(
@ -446,32 +453,32 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
self.job_name = name
self.job_id = str(mock_random.uuid4())
self.job_definition = job_def
self.container_overrides = container_overrides or {}
self.container_overrides: Dict[str, Any] = container_overrides or {}
self.job_queue = job_queue
self.job_queue.jobs.append(self)
self.job_created_at = datetime.datetime.now()
self.job_started_at = datetime.datetime(1970, 1, 1)
self.job_stopped_at = datetime.datetime(1970, 1, 1)
self.job_stopped = False
self.job_stopped_reason = None
self.job_stopped_reason: Optional[str] = None
self.depends_on = depends_on
self.timeout = timeout
self.all_jobs = all_jobs
self.stop = False
self.exit_code = None
self.exit_code: Optional[int] = None
self.daemon = True
self.name = "MOTO-BATCH-" + self.job_id
self._log_backend = log_backend
self.log_stream_name = None
self.log_stream_name: Optional[str] = None
self.container_details = {}
self.attempts = []
self.latest_attempt = None
self.container_details: Dict[str, Any] = {}
self.attempts: List[Dict[str, Any]] = []
self.latest_attempt: Optional[Dict[str, Any]] = None
def describe_short(self):
def describe_short(self) -> Dict[str, Any]:
result = {
"jobId": self.job_id,
"jobName": self.job_name,
@ -489,10 +496,10 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
result["container"] = {"exitCode": self.exit_code}
return result
def describe(self):
def describe(self) -> Dict[str, Any]:
result = self.describe_short()
result["jobQueue"] = self.job_queue.arn
result["dependsOn"] = self.depends_on if self.depends_on else []
result["dependsOn"] = self.depends_on or []
result["container"] = self.container_details
if self.job_stopped:
result["stoppedAt"] = datetime2int_milliseconds(self.job_stopped_at)
@ -501,7 +508,7 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
result["attempts"] = self.attempts
return result
def _get_container_property(self, p, default):
def _get_container_property(self, p: str, default: Any) -> Any:
if p == "environment":
job_env = self.container_overrides.get(p, default)
jd_env = self.job_definition.container_properties.get(p, default)
@ -526,14 +533,14 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
p, self.job_definition.container_properties.get(p, default)
)
def _get_attempt_duration(self):
def _get_attempt_duration(self) -> Optional[int]:
if self.timeout:
return self.timeout["attemptDurationSeconds"]
if self.job_definition.timeout:
return self.job_definition.timeout["attemptDurationSeconds"]
return None
def run(self):
def run(self) -> None:
"""
Run the container.
@ -672,7 +679,7 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
if self._get_attempt_duration():
attempt_duration = self._get_attempt_duration()
max_time = self.job_started_at + datetime.timedelta(
seconds=attempt_duration
seconds=attempt_duration # type: ignore[arg-type]
)
while container.status == "running" and not self.stop:
@ -740,8 +747,9 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
self.container_details["logStreamName"] = self.log_stream_name
result = container.wait() or {}
self.exit_code = result.get("StatusCode", 0)
job_failed = self.stop or self.exit_code > 0
exit_code = result.get("StatusCode", 0)
self.exit_code = exit_code
job_failed = self.stop or exit_code > 0
self._mark_stopped(success=not job_failed)
except Exception as err:
@ -762,7 +770,7 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
)
self._mark_stopped(success=False)
def _mark_stopped(self, success=True):
def _mark_stopped(self, success: bool = True) -> None:
# Ensure that job_stopped/job_stopped_at-attributes are set first
# The describe-method needs them immediately when status is set
self.job_stopped = True
@ -770,7 +778,7 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
self.status = "SUCCEEDED" if success else "FAILED"
self._stop_attempt()
def _start_attempt(self):
def _start_attempt(self) -> None:
self.latest_attempt = {
"container": {
"containerInstanceArn": "TBD",
@ -784,21 +792,21 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
)
self.attempts.append(self.latest_attempt)
def _stop_attempt(self):
def _stop_attempt(self) -> None:
if self.latest_attempt:
self.latest_attempt["container"]["logStreamName"] = self.log_stream_name
self.latest_attempt["stoppedAt"] = datetime2int_milliseconds(
self.job_stopped_at
)
def terminate(self, reason):
def terminate(self, reason: str) -> None:
if not self.stop:
self.stop = True
self.job_stopped_reason = reason
def _wait_for_dependencies(self):
dependent_ids = [dependency["jobId"] for dependency in self.depends_on]
successful_dependencies = set()
def _wait_for_dependencies(self) -> bool:
dependent_ids = [dependency["jobId"] for dependency in self.depends_on] # type: ignore[union-attr]
successful_dependencies: Set[str] = set()
while len(successful_dependencies) != len(dependent_ids):
for dependent_id in dependent_ids:
if dependent_id in self.all_jobs:
@ -832,21 +840,21 @@ class BatchBackend(BaseBackend):
With this decorator, jobs are simply marked as 'Success' without trying to execute any commands/scripts.
"""
def __init__(self, region_name, account_id):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.tagger = TaggingService()
self._compute_environments = {}
self._job_queues = {}
self._job_definitions = {}
self._jobs = {}
self._compute_environments: Dict[str, ComputeEnvironment] = {}
self._job_queues: Dict[str, JobQueue] = {}
self._job_definitions: Dict[str, JobDefinition] = {}
self._jobs: Dict[str, Job] = {}
state_manager.register_default_transition(
"batch::job", transition={"progression": "manual", "times": 1}
)
@property
def iam_backend(self):
def iam_backend(self) -> IAMBackend:
"""
:return: IAM Backend
:rtype: moto.iam.models.IAMBackend
@ -854,7 +862,7 @@ class BatchBackend(BaseBackend):
return iam_backends[self.account_id]["global"]
@property
def ec2_backend(self):
def ec2_backend(self) -> EC2Backend:
"""
:return: EC2 Backend
:rtype: moto.ec2.models.EC2Backend
@ -862,7 +870,7 @@ class BatchBackend(BaseBackend):
return ec2_backends[self.account_id][self.region_name]
@property
def ecs_backend(self):
def ecs_backend(self) -> EC2ContainerServiceBackend:
"""
:return: ECS Backend
:rtype: moto.ecs.models.EC2ContainerServiceBackend
@ -870,14 +878,14 @@ class BatchBackend(BaseBackend):
return ecs_backends[self.account_id][self.region_name]
@property
def logs_backend(self):
def logs_backend(self) -> LogsBackend:
"""
:return: ECS Backend
:rtype: moto.logs.models.LogsBackend
"""
return logs_backends[self.account_id][self.region_name]
def reset(self):
def reset(self) -> None:
for job in self._jobs.values():
if job.status not in ("FAILED", "SUCCEEDED"):
job.stop = True
@ -886,16 +894,18 @@ class BatchBackend(BaseBackend):
super().reset()
def get_compute_environment_by_arn(self, arn):
def get_compute_environment_by_arn(self, arn: str) -> Optional[ComputeEnvironment]:
return self._compute_environments.get(arn)
def get_compute_environment_by_name(self, name):
def get_compute_environment_by_name(
self, name: str
) -> Optional[ComputeEnvironment]:
for comp_env in self._compute_environments.values():
if comp_env.name == name:
return comp_env
return None
def get_compute_environment(self, identifier):
def get_compute_environment(self, identifier: str) -> Optional[ComputeEnvironment]:
"""
Get compute environment by name or ARN
:param identifier: Name or ARN
@ -904,21 +914,20 @@ class BatchBackend(BaseBackend):
:return: Compute Environment or None
:rtype: ComputeEnvironment or None
"""
env = self.get_compute_environment_by_arn(identifier)
if env is None:
env = self.get_compute_environment_by_name(identifier)
return env
return self.get_compute_environment_by_arn(
identifier
) or self.get_compute_environment_by_name(identifier)
def get_job_queue_by_arn(self, arn):
def get_job_queue_by_arn(self, arn: str) -> Optional[JobQueue]:
return self._job_queues.get(arn)
def get_job_queue_by_name(self, name):
def get_job_queue_by_name(self, name: str) -> Optional[JobQueue]:
for comp_env in self._job_queues.values():
if comp_env.name == name:
return comp_env
return None
def get_job_queue(self, identifier):
def get_job_queue(self, identifier: str) -> Optional[JobQueue]:
"""
Get job queue by name or ARN
:param identifier: Name or ARN
@ -927,15 +936,14 @@ class BatchBackend(BaseBackend):
:return: Job Queue or None
:rtype: JobQueue or None
"""
env = self.get_job_queue_by_arn(identifier)
if env is None:
env = self.get_job_queue_by_name(identifier)
return env
return self.get_job_queue_by_arn(identifier) or self.get_job_queue_by_name(
identifier
)
def get_job_definition_by_arn(self, arn):
def get_job_definition_by_arn(self, arn: str) -> Optional[JobDefinition]:
return self._job_definitions.get(arn)
def get_job_definition_by_name(self, name):
def get_job_definition_by_name(self, name: str) -> Optional[JobDefinition]:
latest_revision = -1
latest_job = None
for job_def in self._job_definitions.values():
@ -944,13 +952,15 @@ class BatchBackend(BaseBackend):
latest_revision = job_def.revision
return latest_job
def get_job_definition_by_name_revision(self, name, revision):
def get_job_definition_by_name_revision(
self, name: str, revision: str
) -> Optional[JobDefinition]:
for job_def in self._job_definitions.values():
if job_def.name == name and job_def.revision == int(revision):
return job_def
return None
def get_job_definition(self, identifier):
def get_job_definition(self, identifier: str) -> Optional[JobDefinition]:
"""
Get job definitions by name or ARN
:param identifier: Name or ARN
@ -969,7 +979,7 @@ class BatchBackend(BaseBackend):
job_def = self.get_job_definition_by_name(identifier)
return job_def
def get_job_definitions(self, identifier):
def get_job_definitions(self, identifier: str) -> List[JobDefinition]:
"""
Get job definitions by name or ARN
:param identifier: Name or ARN
@ -989,21 +999,15 @@ class BatchBackend(BaseBackend):
return result
def get_job_by_id(self, identifier):
"""
Get job by id
:param identifier: Job ID
:type identifier: str
:return: Job
:rtype: Job
"""
def get_job_by_id(self, identifier: str) -> Optional[Job]:
try:
return self._jobs[identifier]
except KeyError:
return None
def describe_compute_environments(self, environments=None):
def describe_compute_environments(
self, environments: Optional[List[str]] = None
) -> List[Dict[str, Any]]:
"""
Pagination is not yet implemented
"""
@ -1017,7 +1021,7 @@ class BatchBackend(BaseBackend):
if len(envs) > 0 and arn not in envs and environment.name not in envs:
continue
json_part = {
json_part: Dict[str, Any] = {
"computeEnvironmentArn": arn,
"computeEnvironmentName": environment.name,
"ecsClusterArn": environment.ecs_arn,
@ -1035,8 +1039,13 @@ class BatchBackend(BaseBackend):
return result
def create_compute_environment(
self, compute_environment_name, _type, state, compute_resources, service_role
):
self,
compute_environment_name: str,
_type: str,
state: str,
compute_resources: Dict[str, Any],
service_role: str,
) -> Tuple[str, str]:
# Validate
if COMPUTE_ENVIRONMENT_NAME_REGEX.match(compute_environment_name) is None:
raise InvalidParameterValueException(
@ -1127,7 +1136,7 @@ class BatchBackend(BaseBackend):
return compute_environment_name, new_comp_env.arn
def _validate_compute_resources(self, cr):
def _validate_compute_resources(self, cr: Dict[str, Any]) -> None:
"""
Checks contents of sub dictionary for managed clusters
@ -1195,16 +1204,15 @@ class BatchBackend(BaseBackend):
)
@staticmethod
def find_min_instances_to_meet_vcpus(instance_types, target):
def find_min_instances_to_meet_vcpus(
instance_types: List[str], target: float
) -> List[str]:
"""
Finds the minimum needed instances to meed a vcpu target
:param instance_types: Instance types, like ['t2.medium', 't2.small']
:type instance_types: list of str
:param target: VCPU target
:type target: float
:return: List of instance types
:rtype: list of str
"""
# vcpus = [ (vcpus, instance_type), (vcpus, instance_type), ... ]
instance_vcpus = []
@ -1253,7 +1261,7 @@ class BatchBackend(BaseBackend):
return instances
def delete_compute_environment(self, compute_environment_name):
def delete_compute_environment(self, compute_environment_name: str) -> None:
if compute_environment_name is None:
raise InvalidParameterValueException("Missing computeEnvironment parameter")
@ -1273,8 +1281,12 @@ class BatchBackend(BaseBackend):
self.ec2_backend.terminate_instances(instance_ids)
def update_compute_environment(
self, compute_environment_name, state, compute_resources, service_role
):
self,
compute_environment_name: str,
state: Optional[str],
compute_resources: Optional[Any],
service_role: Optional[str],
) -> Tuple[str, str]:
# Validate
compute_env = self.get_compute_environment(compute_environment_name)
if compute_env is None:
@ -1283,13 +1295,13 @@ class BatchBackend(BaseBackend):
# Look for IAM role
if service_role is not None:
try:
role = self.iam_backend.get_role_by_arn(service_role)
self.iam_backend.get_role_by_arn(service_role)
except IAMNotFoundException:
raise InvalidParameterValueException(
"Could not find IAM role {0}".format(service_role)
)
compute_env.service_role = role
compute_env.service_role = service_role
if state is not None:
if state not in ("ENABLED", "DISABLED"):
@ -1307,8 +1319,13 @@ class BatchBackend(BaseBackend):
return compute_env.name, compute_env.arn
def create_job_queue(
self, queue_name, priority, state, compute_env_order, tags=None
):
self,
queue_name: str,
priority: str,
state: str,
compute_env_order: List[Dict[str, str]],
tags: Optional[Dict[str, str]] = None,
) -> Tuple[str, str]:
for variable, var_name in (
(queue_name, "jobQueueName"),
(priority, "priority"),
@ -1359,7 +1376,9 @@ class BatchBackend(BaseBackend):
return queue_name, queue.arn
def describe_job_queues(self, job_queues=None):
def describe_job_queues(
self, job_queues: Optional[List[str]] = None
) -> List[Dict[str, Any]]:
"""
Pagination is not yet implemented
"""
@ -1377,7 +1396,13 @@ class BatchBackend(BaseBackend):
return result
def update_job_queue(self, queue_name, priority, state, compute_env_order):
def update_job_queue(
self,
queue_name: str,
priority: Optional[str],
state: Optional[str],
compute_env_order: Optional[List[Dict[str, Any]]],
) -> Tuple[str, str]:
if queue_name is None:
raise ClientException("jobQueueName must be provided")
@ -1422,7 +1447,7 @@ class BatchBackend(BaseBackend):
return queue_name, job_queue.arn
def delete_job_queue(self, queue_name):
def delete_job_queue(self, queue_name: str) -> None:
job_queue = self.get_job_queue(queue_name)
if job_queue is not None:
@ -1430,16 +1455,16 @@ class BatchBackend(BaseBackend):
def register_job_definition(
self,
def_name,
parameters,
_type,
tags,
retry_strategy,
container_properties,
timeout,
platform_capabilities,
propagate_tags,
):
def_name: str,
parameters: Dict[str, Any],
_type: str,
tags: Dict[str, str],
retry_strategy: Dict[str, Any],
container_properties: Dict[str, Any],
timeout: Dict[str, int],
platform_capabilities: List[str],
propagate_tags: bool,
) -> Tuple[str, str, int]:
if def_name is None:
raise ClientException("jobDefinitionName must be provided")
@ -1473,7 +1498,7 @@ class BatchBackend(BaseBackend):
return def_name, job_def.arn, job_def.revision
def deregister_job_definition(self, def_name):
def deregister_job_definition(self, def_name: str) -> None:
job_def = self.get_job_definition_by_arn(def_name)
if job_def is None and ":" in def_name:
name, revision = def_name.split(":", 1)
@ -1483,8 +1508,11 @@ class BatchBackend(BaseBackend):
self._job_definitions[job_def.arn].deregister()
def describe_job_definitions(
self, job_def_name=None, job_def_list=None, status=None
):
self,
job_def_name: Optional[str] = None,
job_def_list: Optional[List[str]] = None,
status: Optional[str] = None,
) -> List[JobDefinition]:
"""
Pagination is not yet implemented
"""
@ -1496,8 +1524,8 @@ class BatchBackend(BaseBackend):
if job_def is not None:
jobs.extend(job_def)
elif job_def_list is not None:
for job in job_def_list:
job_def = self.get_job_definitions(job)
for jdn in job_def_list:
job_def = self.get_job_definitions(jdn)
if job_def is not None:
jobs.extend(job_def)
else:
@ -1512,13 +1540,13 @@ class BatchBackend(BaseBackend):
def submit_job(
self,
job_name,
job_def_id,
job_queue,
depends_on=None,
container_overrides=None,
timeout=None,
):
job_name: str,
job_def_id: str,
job_queue: str,
depends_on: Optional[List[Dict[str, str]]] = None,
container_overrides: Optional[Dict[str, Any]] = None,
timeout: Optional[Dict[str, int]] = None,
) -> Tuple[str, str]:
"""
Parameters RetryStrategy and Parameters are not yet implemented.
"""
@ -1550,7 +1578,7 @@ class BatchBackend(BaseBackend):
return job_name, job.job_id
def describe_jobs(self, jobs):
def describe_jobs(self, jobs: Optional[List[str]]) -> List[Dict[str, Any]]:
job_filter = set()
if jobs is not None:
job_filter = set(jobs)
@ -1564,13 +1592,15 @@ class BatchBackend(BaseBackend):
return result
def list_jobs(self, job_queue, job_status=None):
def list_jobs(
self, job_queue_name: str, job_status: Optional[str] = None
) -> List[Job]:
"""
Pagination is not yet implemented
"""
jobs = []
job_queue = self.get_job_queue(job_queue)
job_queue = self.get_job_queue(job_queue_name)
if job_queue is None:
raise ClientException("Job queue {0} does not exist".format(job_queue))
@ -1595,7 +1625,7 @@ class BatchBackend(BaseBackend):
return jobs
def cancel_job(self, job_id, reason):
def cancel_job(self, job_id: str, reason: str) -> None:
if job_id == "":
raise ClientException(
"'reason' is a required field (cannot be an empty string)"
@ -1611,7 +1641,7 @@ class BatchBackend(BaseBackend):
job.terminate(reason)
# No-Op for jobs that have already started - user has to explicitly terminate those
def terminate_job(self, job_id, reason):
def terminate_job(self, job_id: str, reason: str) -> None:
if job_id == "":
raise ClientException(
"'reason' is a required field (cannot be a empty string)"
@ -1625,14 +1655,14 @@ class BatchBackend(BaseBackend):
if job is not None:
job.terminate(reason)
def tag_resource(self, resource_arn, tags):
tags = self.tagger.convert_dict_to_tags_input(tags or {})
self.tagger.tag_resource(resource_arn, tags)
def tag_resource(self, resource_arn: str, tags: Dict[str, str]) -> None:
tag_list = self.tagger.convert_dict_to_tags_input(tags or {})
self.tagger.tag_resource(resource_arn, tag_list)
def list_tags_for_resource(self, resource_arn):
def list_tags_for_resource(self, resource_arn: str) -> Dict[str, str]:
return self.tagger.get_tag_dict_for_resource(resource_arn)
def untag_resource(self, resource_arn, tag_keys):
def untag_resource(self, resource_arn: str, tag_keys: List[str]) -> None:
self.tagger.untag_resource_using_names(resource_arn, tag_keys)

View File

@ -1,45 +1,28 @@
from moto.core.responses import BaseResponse
from .models import batch_backends
from .models import batch_backends, BatchBackend
from urllib.parse import urlsplit, unquote
import json
class BatchResponse(BaseResponse):
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="batch")
def _error(self, code, message):
return json.dumps({"__type": code, "message": message}), dict(status=400)
@property
def batch_backend(self):
def batch_backend(self) -> BatchBackend:
"""
:return: Batch Backend
:rtype: moto.batch.models.BatchBackend
"""
return batch_backends[self.current_account][self.region]
@property
def json(self):
if self.body is None or self.body == "":
self._json = {}
elif not hasattr(self, "_json"):
self._json = json.loads(self.body)
return self._json
def _get_param(self, param_name, if_none=None):
val = self.json.get(param_name)
if val is not None:
return val
return if_none
def _get_action(self):
def _get_action(self) -> str:
# Return element after the /v1/*
return urlsplit(self.uri).path.lstrip("/").split("/")[1]
# CreateComputeEnvironment
def createcomputeenvironment(self):
def createcomputeenvironment(self) -> str:
compute_env_name = self._get_param("computeEnvironmentName")
compute_resource = self._get_param("computeResources")
service_role = self._get_param("serviceRole")
@ -59,7 +42,7 @@ class BatchResponse(BaseResponse):
return json.dumps(result)
# DescribeComputeEnvironments
def describecomputeenvironments(self):
def describecomputeenvironments(self) -> str:
compute_environments = self._get_param("computeEnvironments")
envs = self.batch_backend.describe_compute_environments(compute_environments)
@ -68,7 +51,7 @@ class BatchResponse(BaseResponse):
return json.dumps(result)
# DeleteComputeEnvironment
def deletecomputeenvironment(self):
def deletecomputeenvironment(self) -> str:
compute_environment = self._get_param("computeEnvironment")
self.batch_backend.delete_compute_environment(compute_environment)
@ -76,7 +59,7 @@ class BatchResponse(BaseResponse):
return ""
# UpdateComputeEnvironment
def updatecomputeenvironment(self):
def updatecomputeenvironment(self) -> str:
compute_env_name = self._get_param("computeEnvironment")
compute_resource = self._get_param("computeResources")
service_role = self._get_param("serviceRole")
@ -94,7 +77,7 @@ class BatchResponse(BaseResponse):
return json.dumps(result)
# CreateJobQueue
def createjobqueue(self):
def createjobqueue(self) -> str:
compute_env_order = self._get_param("computeEnvironmentOrder")
queue_name = self._get_param("jobQueueName")
priority = self._get_param("priority")
@ -114,7 +97,7 @@ class BatchResponse(BaseResponse):
return json.dumps(result)
# DescribeJobQueues
def describejobqueues(self):
def describejobqueues(self) -> str:
job_queues = self._get_param("jobQueues")
queues = self.batch_backend.describe_job_queues(job_queues)
@ -123,7 +106,7 @@ class BatchResponse(BaseResponse):
return json.dumps(result)
# UpdateJobQueue
def updatejobqueue(self):
def updatejobqueue(self) -> str:
compute_env_order = self._get_param("computeEnvironmentOrder")
queue_name = self._get_param("jobQueue")
priority = self._get_param("priority")
@ -141,7 +124,7 @@ class BatchResponse(BaseResponse):
return json.dumps(result)
# DeleteJobQueue
def deletejobqueue(self):
def deletejobqueue(self) -> str:
queue_name = self._get_param("jobQueue")
self.batch_backend.delete_job_queue(queue_name)
@ -149,7 +132,7 @@ class BatchResponse(BaseResponse):
return ""
# RegisterJobDefinition
def registerjobdefinition(self):
def registerjobdefinition(self) -> str:
container_properties = self._get_param("containerProperties")
def_name = self._get_param("jobDefinitionName")
parameters = self._get_param("parameters")
@ -180,7 +163,7 @@ class BatchResponse(BaseResponse):
return json.dumps(result)
# DeregisterJobDefinition
def deregisterjobdefinition(self):
def deregisterjobdefinition(self) -> str:
queue_name = self._get_param("jobDefinition")
self.batch_backend.deregister_job_definition(queue_name)
@ -188,7 +171,7 @@ class BatchResponse(BaseResponse):
return ""
# DescribeJobDefinitions
def describejobdefinitions(self):
def describejobdefinitions(self) -> str:
job_def_name = self._get_param("jobDefinitionName")
job_def_list = self._get_param("jobDefinitions")
status = self._get_param("status")
@ -201,7 +184,7 @@ class BatchResponse(BaseResponse):
return json.dumps(result)
# SubmitJob
def submitjob(self):
def submitjob(self) -> str:
container_overrides = self._get_param("containerOverrides")
depends_on = self._get_param("dependsOn")
job_def = self._get_param("jobDefinition")
@ -223,13 +206,13 @@ class BatchResponse(BaseResponse):
return json.dumps(result)
# DescribeJobs
def describejobs(self):
def describejobs(self) -> str:
jobs = self._get_param("jobs")
return json.dumps({"jobs": self.batch_backend.describe_jobs(jobs)})
# ListJobs
def listjobs(self):
def listjobs(self) -> str:
job_queue = self._get_param("jobQueue")
job_status = self._get_param("jobStatus")
@ -239,7 +222,7 @@ class BatchResponse(BaseResponse):
return json.dumps(result)
# TerminateJob
def terminatejob(self):
def terminatejob(self) -> str:
job_id = self._get_param("jobId")
reason = self._get_param("reason")
@ -248,22 +231,22 @@ class BatchResponse(BaseResponse):
return ""
# CancelJob
def canceljob(self):
def canceljob(self) -> str:
job_id = self._get_param("jobId")
reason = self._get_param("reason")
self.batch_backend.cancel_job(job_id, reason)
return ""
def tags(self):
def tags(self) -> str:
resource_arn = unquote(self.path).split("/v1/tags/")[-1]
tags = self._get_param("tags")
if self.method == "POST":
self.batch_backend.tag_resource(resource_arn, tags)
return ""
if self.method == "GET":
tags = self.batch_backend.list_tags_for_resource(resource_arn)
return json.dumps({"tags": tags})
if self.method == "DELETE":
tag_keys = self.querystring.get("tagKeys")
self.batch_backend.untag_resource(resource_arn, tag_keys)
self.batch_backend.untag_resource(resource_arn, tag_keys) # type: ignore[arg-type]
return ""

View File

@ -1,21 +1,26 @@
def make_arn_for_compute_env(account_id, name, region_name):
from typing import Any, Dict
def make_arn_for_compute_env(account_id: str, name: str, region_name: str) -> str:
return "arn:aws:batch:{0}:{1}:compute-environment/{2}".format(
region_name, account_id, name
)
def make_arn_for_job_queue(account_id, name, region_name):
def make_arn_for_job_queue(account_id: str, name: str, region_name: str) -> str:
return "arn:aws:batch:{0}:{1}:job-queue/{2}".format(region_name, account_id, name)
def make_arn_for_task_def(account_id, name, revision, region_name):
def make_arn_for_task_def(
account_id: str, name: str, revision: int, region_name: str
) -> str:
return "arn:aws:batch:{0}:{1}:job-definition/{2}:{3}".format(
region_name, account_id, name, revision
)
def lowercase_first_key(some_dict):
new_dict = {}
def lowercase_first_key(some_dict: Dict[str, Any]) -> Dict[str, Any]:
new_dict: Dict[str, Any] = {}
for key, value in some_dict.items():
new_key = key[0].lower() + key[1:]
try:

View File

@ -1,7 +1,14 @@
from ..batch.models import batch_backends, BaseBackend, Job, ClientException
from ..batch.models import (
batch_backends,
BaseBackend,
Job,
ClientException,
BatchBackend,
)
from ..core.utils import BackendDict
import datetime
from typing import Any, Dict, List, Tuple, Optional
class BatchSimpleBackend(BaseBackend):
@ -11,10 +18,10 @@ class BatchSimpleBackend(BaseBackend):
"""
@property
def backend(self):
def backend(self) -> BatchBackend:
return batch_backends[self.account_id][self.region_name]
def __getattribute__(self, name):
def __getattribute__(self, name: str) -> Any:
"""
Magic part that makes this class behave like a wrapper around the regular batch_backend
We intercept calls to `submit_job` and replace this with our own (non-Docker) implementation
@ -32,7 +39,7 @@ class BatchSimpleBackend(BaseBackend):
return object.__getattribute__(self, name)
if name in ["submit_job"]:
def newfunc(*args, **kwargs):
def newfunc(*args: Any, **kwargs: Any) -> Any:
attr = object.__getattribute__(self, name)
return attr(*args, **kwargs)
@ -42,13 +49,13 @@ class BatchSimpleBackend(BaseBackend):
def submit_job(
self,
job_name,
job_def_id,
job_queue,
depends_on=None,
container_overrides=None,
timeout=None,
):
job_name: str,
job_def_id: str,
job_queue: str,
depends_on: Optional[List[Dict[str, str]]] = None,
container_overrides: Optional[Dict[str, Any]] = None,
timeout: Optional[Dict[str, int]] = None,
) -> Tuple[str, str]:
# Look for job definition
job_def = self.get_job_definition(job_def_id)
if job_def is None:

View File

@ -1,10 +1,10 @@
from ..batch.responses import BatchResponse
from .models import batch_simple_backends
from .models import batch_simple_backends, BatchBackend
class BatchSimpleResponse(BatchResponse):
@property
def batch_backend(self):
def batch_backend(self) -> BatchBackend:
"""
:return: Batch Backend
:rtype: moto.batch.models.BatchBackend

View File

@ -32,7 +32,7 @@ class BaseBackend:
for model in models.values():
model.instances = []
def reset(self):
def reset(self) -> None:
region_name = self.region_name
account_id = self.account_id
self._reset_model_refs()

View File

@ -2,7 +2,7 @@ import copy
import warnings
from collections import OrderedDict
from datetime import datetime
from typing import Any, List, Tuple
from typing import Any, List, Tuple, Optional
from moto import settings
from moto.core import CloudFormationModel
@ -597,7 +597,7 @@ class InstanceBackend:
self,
image_id: str,
count: int,
user_data: str,
user_data: Optional[str],
security_group_names: List[str],
**kwargs: Any
) -> Reservation:

View File

@ -2,7 +2,7 @@ import copy
import itertools
import json
from collections import defaultdict
from typing import Optional
from moto.core import CloudFormationModel
from moto.core.utils import aws_api_matches
from ..exceptions import (
@ -543,7 +543,7 @@ class SecurityGroupBackend:
return self._delete_security_group(None, group.id)
raise InvalidSecurityGroupNotFoundError(name)
def get_security_group_from_id(self, group_id):
def get_security_group_from_id(self, group_id: str) -> Optional[SecurityGroup]:
# 2 levels of chaining necessary since it's a complex structure
all_groups = itertools.chain.from_iterable(
[x.copy().values() for x in self.groups.copy().values()]

View File

@ -229,7 +229,7 @@ class SubnetBackend:
# maps availability zone to dict of (subnet_id, subnet)
self.subnets = defaultdict(dict)
def get_subnet(self, subnet_id):
def get_subnet(self, subnet_id: str) -> Subnet:
for subnets in self.subnets.values():
if subnet_id in subnets:
return subnets[subnet_id]

View File

@ -1,7 +1,7 @@
import re
from copy import copy
from datetime import datetime
from typing import Any
import pytz
from moto import settings
@ -850,7 +850,9 @@ class EC2ContainerServiceBackend(BaseBackend):
else:
raise Exception("{0} is not a task_definition".format(task_definition_name))
def create_cluster(self, cluster_name, tags=None, cluster_settings=None):
def create_cluster(
self, cluster_name: str, tags: Any = None, cluster_settings: Any = None
) -> Cluster:
"""
The following parameters are not yet implemented: configuration, capacityProviders, defaultCapacityProviderStrategy
"""
@ -926,7 +928,7 @@ class EC2ContainerServiceBackend(BaseBackend):
return list_clusters, failures
def delete_cluster(self, cluster_str):
def delete_cluster(self, cluster_str: str) -> Cluster:
cluster = self._get_cluster(cluster_str)
return self.clusters.pop(cluster.name)

View File

@ -10,7 +10,7 @@ from cryptography import x509
from cryptography.hazmat.backends import default_backend
from jinja2 import Template
from typing import Mapping
from typing import List, Mapping
from urllib import parse
from moto.core.exceptions import RESTError
from moto.core import DEFAULT_ACCOUNT_ID, BaseBackend, BaseModel, CloudFormationModel
@ -2188,7 +2188,7 @@ class IAMBackend(BaseBackend):
raise IAMNotFoundException("Instance profile {0} not found".format(profile_arn))
def get_instance_profiles(self):
def get_instance_profiles(self) -> List[InstanceProfile]:
return self.instance_profiles.values()
def get_instance_profiles_for_role(self, role_name):

View File

@ -1,5 +1,5 @@
from datetime import datetime, timedelta
from typing import Any, Dict, List, Tuple, Optional
from moto.core import BaseBackend, BaseModel
from moto.core import CloudFormationModel
from moto.core.utils import unix_time_millis, BackendDict
@ -627,7 +627,9 @@ class LogsBackend(BaseBackend):
)
return self.groups[log_group_name]
def ensure_log_group(self, log_group_name, tags):
def ensure_log_group(
self, log_group_name: str, tags: Optional[Dict[str, str]]
) -> None:
if log_group_name in self.groups:
return
self.groups[log_group_name] = LogGroup(
@ -653,7 +655,7 @@ class LogsBackend(BaseBackend):
return groups
def create_log_stream(self, log_group_name, log_stream_name):
def create_log_stream(self, log_group_name: str, log_stream_name: str) -> LogStream:
if log_group_name not in self.groups:
raise ResourceNotFoundException()
log_group = self.groups[log_group_name]
@ -702,7 +704,12 @@ class LogsBackend(BaseBackend):
order_by=order_by,
)
def put_log_events(self, log_group_name, log_stream_name, log_events):
def put_log_events(
self,
log_group_name: str,
log_stream_name: str,
log_events: List[Dict[str, Any]],
) -> Tuple[str, Dict[str, Any]]:
"""
The SequenceToken-parameter is not yet implemented
"""

View File

@ -1,5 +1,6 @@
from datetime import datetime, timedelta
from moto.moto_api import state_manager
from typing import List, Tuple
class ManagedState:
@ -7,7 +8,7 @@ class ManagedState:
Subclass this class to configure state-transitions
"""
def __init__(self, model_name, transitions):
def __init__(self, model_name: str, transitions: List[Tuple[str, str]]):
# Indicate the possible transitions for this model
# Example: [(initializing,queued), (queued, starting), (starting, ready)]
self._transitions = transitions
@ -23,7 +24,7 @@ class ManagedState:
# Name of this model. This will be used in the API
self.model_name = model_name
def advance(self):
def advance(self) -> None:
self._tick += 1
@property

View File

@ -1,3 +1,6 @@
from typing import Any, Dict
DEFAULT_TRANSITION = {"progression": "immediate"}
@ -6,7 +9,9 @@ class StateManager:
self._default_transitions = dict()
self._transitions = dict()
def register_default_transition(self, model_name, transition):
def register_default_transition(
self, model_name: str, transition: Dict[str, Any]
) -> None:
"""
Register the default transition for a specific model.
This should only be called by Moto backends - use the `set_transition` method to override this default transition in your own tests.

View File

@ -106,7 +106,7 @@ class TaggingService:
result[tag[self.key_name]] = None
return result
def validate_tags(self, tags, limit=0):
def validate_tags(self, tags: List[Dict[str, str]], limit: int = 0) -> str:
"""Returns error message if tags in 'tags' list of dicts are invalid.
The validation does not include a check for duplicate keys.

View File

@ -18,7 +18,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy]
files= moto/a*,moto/budgets
files= moto/a*,moto/b*
show_column_numbers=True
show_error_codes = True
disable_error_code=abstract

View File

@ -323,6 +323,42 @@ def test_update_unmanaged_compute_environment_state():
our_envs[0]["state"].should.equal("DISABLED")
@mock_ec2
@mock_ecs
@mock_iam
@mock_batch
def test_update_iam_role():
ec2_client, iam_client, _, _, batch_client = _get_clients()
_, _, _, iam_arn = _setup(ec2_client, iam_client)
iam_arn2 = iam_client.create_role(RoleName="r", AssumeRolePolicyDocument="sp")[
"Role"
]["Arn"]
compute_name = str(uuid4())
batch_client.create_compute_environment(
computeEnvironmentName=compute_name,
type="UNMANAGED",
state="ENABLED",
serviceRole=iam_arn,
)
batch_client.update_compute_environment(
computeEnvironment=compute_name, serviceRole=iam_arn2
)
all_envs = batch_client.describe_compute_environments()["computeEnvironments"]
our_envs = [e for e in all_envs if e["computeEnvironmentName"] == compute_name]
our_envs.should.have.length_of(1)
our_envs[0]["serviceRole"].should.equal(iam_arn2)
with pytest.raises(ClientError) as exc:
batch_client.update_compute_environment(
computeEnvironment=compute_name, serviceRole="unknown"
)
err = exc.value.response["Error"]
err["Code"].should.equal("InvalidParameterValue")
@pytest.mark.parametrize("compute_env_type", ["FARGATE", "FARGATE_SPOT"])
@mock_ec2
@mock_ecs