TechDebt: MyPy Batch (#5592)

This commit is contained in:
Bert Blommers 2022-10-23 13:26:55 +00:00 committed by GitHub
parent a470ea748b
commit 6f3b250fc7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 365 additions and 289 deletions

View File

@ -1,6 +1,7 @@
import re import re
from itertools import cycle from itertools import cycle
from time import sleep from time import sleep
from typing import Any, Dict, List, Tuple, Optional, Set
import datetime import datetime
import time import time
import logging import logging
@ -9,10 +10,11 @@ import dateutil.parser
from sys import platform from sys import platform
from moto.core import BaseBackend, BaseModel, CloudFormationModel from moto.core import BaseBackend, BaseModel, CloudFormationModel
from moto.iam import iam_backends from moto.iam.models import iam_backends, IAMBackend
from moto.ec2 import ec2_backends from moto.ec2.models import ec2_backends, EC2Backend
from moto.ecs import ecs_backends from moto.ec2.models.instances import Instance
from moto.logs import logs_backends from moto.ecs.models import ecs_backends, EC2ContainerServiceBackend
from moto.logs.models import logs_backends, LogsBackend
from moto.utilities.tagging_service import TaggingService from moto.utilities.tagging_service import TaggingService
from .exceptions import InvalidParameterValueException, ClientException, ValidationError from .exceptions import InvalidParameterValueException, ClientException, ValidationError
@ -39,7 +41,7 @@ COMPUTE_ENVIRONMENT_NAME_REGEX = re.compile(
) )
def datetime2int_milliseconds(date): def datetime2int_milliseconds(date: datetime.datetime) -> int:
""" """
AWS returns timestamps in milliseconds AWS returns timestamps in milliseconds
We don't use milliseconds timestamps internally, We don't use milliseconds timestamps internally,
@ -48,20 +50,20 @@ def datetime2int_milliseconds(date):
return int(date.timestamp() * 1000) return int(date.timestamp() * 1000)
def datetime2int(date): def datetime2int(date: datetime.datetime) -> int:
return int(time.mktime(date.timetuple())) return int(time.mktime(date.timetuple()))
class ComputeEnvironment(CloudFormationModel): class ComputeEnvironment(CloudFormationModel):
def __init__( def __init__(
self, self,
compute_environment_name, compute_environment_name: str,
_type, _type: str,
state, state: str,
compute_resources, compute_resources: Dict[str, Any],
service_role, service_role: str,
account_id, account_id: str,
region_name, region_name: str,
): ):
self.name = compute_environment_name self.name = compute_environment_name
self.env_type = _type self.env_type = _type
@ -72,34 +74,39 @@ class ComputeEnvironment(CloudFormationModel):
account_id, compute_environment_name, region_name account_id, compute_environment_name, region_name
) )
self.instances = [] self.instances: List[Instance] = []
self.ecs_arn = None self.ecs_arn = ""
self.ecs_name = None self.ecs_name = ""
def add_instance(self, instance): def add_instance(self, instance: Instance) -> None:
self.instances.append(instance) self.instances.append(instance)
def set_ecs(self, arn, name): def set_ecs(self, arn: str, name: str) -> None:
self.ecs_arn = arn self.ecs_arn = arn
self.ecs_name = name self.ecs_name = name
@property @property
def physical_resource_id(self): def physical_resource_id(self) -> str:
return self.arn return self.arn
@staticmethod @staticmethod
def cloudformation_name_type(): def cloudformation_name_type() -> str:
return "ComputeEnvironmentName" return "ComputeEnvironmentName"
@staticmethod @staticmethod
def cloudformation_type(): def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-batch-computeenvironment.html # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-batch-computeenvironment.html
return "AWS::Batch::ComputeEnvironment" return "AWS::Batch::ComputeEnvironment"
@classmethod @classmethod
def create_from_cloudformation_json( def create_from_cloudformation_json( # type: ignore[misc]
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs cls,
): resource_name: str,
cloudformation_json: Dict[str, Any],
account_id: str,
region_name: str,
**kwargs: Any,
) -> "ComputeEnvironment":
backend = batch_backends[account_id][region_name] backend = batch_backends[account_id][region_name]
properties = cloudformation_json["Properties"] properties = cloudformation_json["Properties"]
@ -118,13 +125,13 @@ class ComputeEnvironment(CloudFormationModel):
class JobQueue(CloudFormationModel): class JobQueue(CloudFormationModel):
def __init__( def __init__(
self, self,
name, name: str,
priority, priority: str,
state, state: str,
environments, environments: List[ComputeEnvironment],
env_order_json, env_order_json: List[Dict[str, Any]],
backend, backend: "BatchBackend",
tags=None, tags: Optional[Dict[str, str]] = None,
): ):
""" """
:param name: Job queue name :param name: Job queue name
@ -150,10 +157,10 @@ class JobQueue(CloudFormationModel):
if tags: if tags:
backend.tag_resource(self.arn, tags) backend.tag_resource(self.arn, tags)
self.jobs = [] self.jobs: List[Job] = []
def describe(self): def describe(self) -> Dict[str, Any]:
result = { return {
"computeEnvironmentOrder": self.env_order_json, "computeEnvironmentOrder": self.env_order_json,
"jobQueueArn": self.arn, "jobQueueArn": self.arn,
"jobQueueName": self.name, "jobQueueName": self.name,
@ -163,25 +170,28 @@ class JobQueue(CloudFormationModel):
"tags": self.backend.list_tags_for_resource(self.arn), "tags": self.backend.list_tags_for_resource(self.arn),
} }
return result
@property @property
def physical_resource_id(self): def physical_resource_id(self) -> str:
return self.arn return self.arn
@staticmethod @staticmethod
def cloudformation_name_type(): def cloudformation_name_type() -> str:
return "JobQueueName" return "JobQueueName"
@staticmethod @staticmethod
def cloudformation_type(): def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-batch-jobqueue.html # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-batch-jobqueue.html
return "AWS::Batch::JobQueue" return "AWS::Batch::JobQueue"
@classmethod @classmethod
def create_from_cloudformation_json( def create_from_cloudformation_json( # type: ignore[misc]
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs cls,
): resource_name: str,
cloudformation_json: Dict[str, Any],
account_id: str,
region_name: str,
**kwargs: Any,
) -> "JobQueue":
backend = batch_backends[account_id][region_name] backend = batch_backends[account_id][region_name]
properties = cloudformation_json["Properties"] properties = cloudformation_json["Properties"]
@ -206,25 +216,24 @@ class JobQueue(CloudFormationModel):
class JobDefinition(CloudFormationModel): class JobDefinition(CloudFormationModel):
def __init__( def __init__(
self, self,
name, name: str,
parameters, parameters: Optional[Dict[str, Any]],
_type, _type: str,
container_properties, container_properties: Dict[str, Any],
tags=None, tags: Dict[str, str],
revision=0, retry_strategy: Dict[str, str],
retry_strategy=0, timeout: Dict[str, int],
timeout=None, backend: "BatchBackend",
backend=None, platform_capabilities: List[str],
platform_capabilities=None, propagate_tags: bool,
propagate_tags=None, revision: Optional[int] = 0,
): ):
self.name = name self.name = name
self.retry_strategy = retry_strategy self.retry_strategy = retry_strategy
self.type = _type self.type = _type
self.revision = revision self.revision = revision or 0
self._region = backend.region_name self._region = backend.region_name
self.container_properties = container_properties self.container_properties = container_properties
self.arn = None
self.status = "ACTIVE" self.status = "ACTIVE"
self.parameters = parameters or {} self.parameters = parameters or {}
self.timeout = timeout self.timeout = timeout
@ -238,26 +247,23 @@ class JobDefinition(CloudFormationModel):
self.container_properties["secrets"] = [] self.container_properties["secrets"] = []
self._validate() self._validate()
self._update_arn()
tags = self._format_tags(tags or {})
# Validate the tags before proceeding.
errmsg = self.backend.tagger.validate_tags(tags)
if errmsg:
raise ValidationError(errmsg)
self.backend.tagger.tag_resource(self.arn, tags)
def _format_tags(self, tags):
return [{"Key": k, "Value": v} for k, v in tags.items()]
def _update_arn(self):
self.revision += 1 self.revision += 1
self.arn = make_arn_for_task_def( self.arn = make_arn_for_task_def(
self.backend.account_id, self.name, self.revision, self._region self.backend.account_id, self.name, self.revision, self._region
) )
def _get_resource_requirement(self, req_type, default=None): tag_list = self._format_tags(tags or {})
# Validate the tags before proceeding.
errmsg = self.backend.tagger.validate_tags(tag_list)
if errmsg:
raise ValidationError(errmsg)
self.backend.tagger.tag_resource(self.arn, tag_list)
def _format_tags(self, tags: Dict[str, str]) -> List[Dict[str, str]]:
return [{"Key": k, "Value": v} for k, v in tags.items()]
def _get_resource_requirement(self, req_type: str, default: Any = None) -> Any:
""" """
Get resource requirement from container properties. Get resource requirement from container properties.
@ -297,7 +303,7 @@ class JobDefinition(CloudFormationModel):
else: else:
return self.container_properties.get(req_type, default) return self.container_properties.get(req_type, default)
def _validate(self): def _validate(self) -> None:
# For future use when containers arnt the only thing in batch # For future use when containers arnt the only thing in batch
if self.type not in ("container",): if self.type not in ("container",):
raise ClientException('type must be one of "container"') raise ClientException('type must be one of "container"')
@ -320,12 +326,18 @@ class JobDefinition(CloudFormationModel):
if vcpus <= 0: if vcpus <= 0:
raise ClientException("container vcpus limit must be greater than 0") raise ClientException("container vcpus limit must be greater than 0")
def deregister(self): def deregister(self) -> None:
self.status = "INACTIVE" self.status = "INACTIVE"
def update( def update(
self, parameters, _type, container_properties, retry_strategy, tags, timeout self,
): parameters: Optional[Dict[str, Any]],
_type: str,
container_properties: Dict[str, Any],
retry_strategy: Dict[str, Any],
tags: Dict[str, str],
timeout: Dict[str, int],
) -> "JobDefinition":
if self.status != "INACTIVE": if self.status != "INACTIVE":
if parameters is None: if parameters is None:
parameters = self.parameters parameters = self.parameters
@ -353,7 +365,7 @@ class JobDefinition(CloudFormationModel):
propagate_tags=self.propagate_tags, propagate_tags=self.propagate_tags,
) )
def describe(self): def describe(self) -> Dict[str, Any]:
result = { result = {
"jobDefinitionArn": self.arn, "jobDefinitionArn": self.arn,
"jobDefinitionName": self.name, "jobDefinitionName": self.name,
@ -374,22 +386,27 @@ class JobDefinition(CloudFormationModel):
return result return result
@property @property
def physical_resource_id(self): def physical_resource_id(self) -> str:
return self.arn return self.arn
@staticmethod @staticmethod
def cloudformation_name_type(): def cloudformation_name_type() -> str:
return "JobDefinitionName" return "JobDefinitionName"
@staticmethod @staticmethod
def cloudformation_type(): def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-batch-jobdefinition.html # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-batch-jobdefinition.html
return "AWS::Batch::JobDefinition" return "AWS::Batch::JobDefinition"
@classmethod @classmethod
def create_from_cloudformation_json( def create_from_cloudformation_json( # type: ignore[misc]
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs cls,
): resource_name: str,
cloudformation_json: Dict[str, Any],
account_id: str,
region_name: str,
**kwargs: Any,
) -> "JobDefinition":
backend = batch_backends[account_id][region_name] backend = batch_backends[account_id][region_name]
properties = cloudformation_json["Properties"] properties = cloudformation_json["Properties"]
res = backend.register_job_definition( res = backend.register_job_definition(
@ -411,25 +428,15 @@ class JobDefinition(CloudFormationModel):
class Job(threading.Thread, BaseModel, DockerModel, ManagedState): class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
def __init__( def __init__(
self, self,
name, name: str,
job_def, job_def: JobDefinition,
job_queue, job_queue: JobQueue,
log_backend, log_backend: LogsBackend,
container_overrides, container_overrides: Optional[Dict[str, Any]],
depends_on, depends_on: Optional[List[Dict[str, str]]],
all_jobs, all_jobs: Dict[str, "Job"],
timeout, timeout: Optional[Dict[str, int]],
): ):
"""
Docker Job
:param name: Job Name
:param job_def: Job definition
:type: job_def: JobDefinition
:param job_queue: Job Queue
:param log_backend: Log backend
:type log_backend: moto.logs.models.LogsBackend
"""
threading.Thread.__init__(self) threading.Thread.__init__(self)
DockerModel.__init__(self) DockerModel.__init__(self)
ManagedState.__init__( ManagedState.__init__(
@ -446,32 +453,32 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
self.job_name = name self.job_name = name
self.job_id = str(mock_random.uuid4()) self.job_id = str(mock_random.uuid4())
self.job_definition = job_def self.job_definition = job_def
self.container_overrides = container_overrides or {} self.container_overrides: Dict[str, Any] = container_overrides or {}
self.job_queue = job_queue self.job_queue = job_queue
self.job_queue.jobs.append(self) self.job_queue.jobs.append(self)
self.job_created_at = datetime.datetime.now() self.job_created_at = datetime.datetime.now()
self.job_started_at = datetime.datetime(1970, 1, 1) self.job_started_at = datetime.datetime(1970, 1, 1)
self.job_stopped_at = datetime.datetime(1970, 1, 1) self.job_stopped_at = datetime.datetime(1970, 1, 1)
self.job_stopped = False self.job_stopped = False
self.job_stopped_reason = None self.job_stopped_reason: Optional[str] = None
self.depends_on = depends_on self.depends_on = depends_on
self.timeout = timeout self.timeout = timeout
self.all_jobs = all_jobs self.all_jobs = all_jobs
self.stop = False self.stop = False
self.exit_code = None self.exit_code: Optional[int] = None
self.daemon = True self.daemon = True
self.name = "MOTO-BATCH-" + self.job_id self.name = "MOTO-BATCH-" + self.job_id
self._log_backend = log_backend self._log_backend = log_backend
self.log_stream_name = None self.log_stream_name: Optional[str] = None
self.container_details = {} self.container_details: Dict[str, Any] = {}
self.attempts = [] self.attempts: List[Dict[str, Any]] = []
self.latest_attempt = None self.latest_attempt: Optional[Dict[str, Any]] = None
def describe_short(self): def describe_short(self) -> Dict[str, Any]:
result = { result = {
"jobId": self.job_id, "jobId": self.job_id,
"jobName": self.job_name, "jobName": self.job_name,
@ -489,10 +496,10 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
result["container"] = {"exitCode": self.exit_code} result["container"] = {"exitCode": self.exit_code}
return result return result
def describe(self): def describe(self) -> Dict[str, Any]:
result = self.describe_short() result = self.describe_short()
result["jobQueue"] = self.job_queue.arn result["jobQueue"] = self.job_queue.arn
result["dependsOn"] = self.depends_on if self.depends_on else [] result["dependsOn"] = self.depends_on or []
result["container"] = self.container_details result["container"] = self.container_details
if self.job_stopped: if self.job_stopped:
result["stoppedAt"] = datetime2int_milliseconds(self.job_stopped_at) result["stoppedAt"] = datetime2int_milliseconds(self.job_stopped_at)
@ -501,7 +508,7 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
result["attempts"] = self.attempts result["attempts"] = self.attempts
return result return result
def _get_container_property(self, p, default): def _get_container_property(self, p: str, default: Any) -> Any:
if p == "environment": if p == "environment":
job_env = self.container_overrides.get(p, default) job_env = self.container_overrides.get(p, default)
jd_env = self.job_definition.container_properties.get(p, default) jd_env = self.job_definition.container_properties.get(p, default)
@ -526,14 +533,14 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
p, self.job_definition.container_properties.get(p, default) p, self.job_definition.container_properties.get(p, default)
) )
def _get_attempt_duration(self): def _get_attempt_duration(self) -> Optional[int]:
if self.timeout: if self.timeout:
return self.timeout["attemptDurationSeconds"] return self.timeout["attemptDurationSeconds"]
if self.job_definition.timeout: if self.job_definition.timeout:
return self.job_definition.timeout["attemptDurationSeconds"] return self.job_definition.timeout["attemptDurationSeconds"]
return None return None
def run(self): def run(self) -> None:
""" """
Run the container. Run the container.
@ -672,7 +679,7 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
if self._get_attempt_duration(): if self._get_attempt_duration():
attempt_duration = self._get_attempt_duration() attempt_duration = self._get_attempt_duration()
max_time = self.job_started_at + datetime.timedelta( max_time = self.job_started_at + datetime.timedelta(
seconds=attempt_duration seconds=attempt_duration # type: ignore[arg-type]
) )
while container.status == "running" and not self.stop: while container.status == "running" and not self.stop:
@ -740,8 +747,9 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
self.container_details["logStreamName"] = self.log_stream_name self.container_details["logStreamName"] = self.log_stream_name
result = container.wait() or {} result = container.wait() or {}
self.exit_code = result.get("StatusCode", 0) exit_code = result.get("StatusCode", 0)
job_failed = self.stop or self.exit_code > 0 self.exit_code = exit_code
job_failed = self.stop or exit_code > 0
self._mark_stopped(success=not job_failed) self._mark_stopped(success=not job_failed)
except Exception as err: except Exception as err:
@ -762,7 +770,7 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
) )
self._mark_stopped(success=False) self._mark_stopped(success=False)
def _mark_stopped(self, success=True): def _mark_stopped(self, success: bool = True) -> None:
# Ensure that job_stopped/job_stopped_at-attributes are set first # Ensure that job_stopped/job_stopped_at-attributes are set first
# The describe-method needs them immediately when status is set # The describe-method needs them immediately when status is set
self.job_stopped = True self.job_stopped = True
@ -770,7 +778,7 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
self.status = "SUCCEEDED" if success else "FAILED" self.status = "SUCCEEDED" if success else "FAILED"
self._stop_attempt() self._stop_attempt()
def _start_attempt(self): def _start_attempt(self) -> None:
self.latest_attempt = { self.latest_attempt = {
"container": { "container": {
"containerInstanceArn": "TBD", "containerInstanceArn": "TBD",
@ -784,21 +792,21 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
) )
self.attempts.append(self.latest_attempt) self.attempts.append(self.latest_attempt)
def _stop_attempt(self): def _stop_attempt(self) -> None:
if self.latest_attempt: if self.latest_attempt:
self.latest_attempt["container"]["logStreamName"] = self.log_stream_name self.latest_attempt["container"]["logStreamName"] = self.log_stream_name
self.latest_attempt["stoppedAt"] = datetime2int_milliseconds( self.latest_attempt["stoppedAt"] = datetime2int_milliseconds(
self.job_stopped_at self.job_stopped_at
) )
def terminate(self, reason): def terminate(self, reason: str) -> None:
if not self.stop: if not self.stop:
self.stop = True self.stop = True
self.job_stopped_reason = reason self.job_stopped_reason = reason
def _wait_for_dependencies(self): def _wait_for_dependencies(self) -> bool:
dependent_ids = [dependency["jobId"] for dependency in self.depends_on] dependent_ids = [dependency["jobId"] for dependency in self.depends_on] # type: ignore[union-attr]
successful_dependencies = set() successful_dependencies: Set[str] = set()
while len(successful_dependencies) != len(dependent_ids): while len(successful_dependencies) != len(dependent_ids):
for dependent_id in dependent_ids: for dependent_id in dependent_ids:
if dependent_id in self.all_jobs: if dependent_id in self.all_jobs:
@ -832,21 +840,21 @@ class BatchBackend(BaseBackend):
With this decorator, jobs are simply marked as 'Success' without trying to execute any commands/scripts. With this decorator, jobs are simply marked as 'Success' without trying to execute any commands/scripts.
""" """
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.tagger = TaggingService() self.tagger = TaggingService()
self._compute_environments = {} self._compute_environments: Dict[str, ComputeEnvironment] = {}
self._job_queues = {} self._job_queues: Dict[str, JobQueue] = {}
self._job_definitions = {} self._job_definitions: Dict[str, JobDefinition] = {}
self._jobs = {} self._jobs: Dict[str, Job] = {}
state_manager.register_default_transition( state_manager.register_default_transition(
"batch::job", transition={"progression": "manual", "times": 1} "batch::job", transition={"progression": "manual", "times": 1}
) )
@property @property
def iam_backend(self): def iam_backend(self) -> IAMBackend:
""" """
:return: IAM Backend :return: IAM Backend
:rtype: moto.iam.models.IAMBackend :rtype: moto.iam.models.IAMBackend
@ -854,7 +862,7 @@ class BatchBackend(BaseBackend):
return iam_backends[self.account_id]["global"] return iam_backends[self.account_id]["global"]
@property @property
def ec2_backend(self): def ec2_backend(self) -> EC2Backend:
""" """
:return: EC2 Backend :return: EC2 Backend
:rtype: moto.ec2.models.EC2Backend :rtype: moto.ec2.models.EC2Backend
@ -862,7 +870,7 @@ class BatchBackend(BaseBackend):
return ec2_backends[self.account_id][self.region_name] return ec2_backends[self.account_id][self.region_name]
@property @property
def ecs_backend(self): def ecs_backend(self) -> EC2ContainerServiceBackend:
""" """
:return: ECS Backend :return: ECS Backend
:rtype: moto.ecs.models.EC2ContainerServiceBackend :rtype: moto.ecs.models.EC2ContainerServiceBackend
@ -870,14 +878,14 @@ class BatchBackend(BaseBackend):
return ecs_backends[self.account_id][self.region_name] return ecs_backends[self.account_id][self.region_name]
@property @property
def logs_backend(self): def logs_backend(self) -> LogsBackend:
""" """
:return: ECS Backend :return: ECS Backend
:rtype: moto.logs.models.LogsBackend :rtype: moto.logs.models.LogsBackend
""" """
return logs_backends[self.account_id][self.region_name] return logs_backends[self.account_id][self.region_name]
def reset(self): def reset(self) -> None:
for job in self._jobs.values(): for job in self._jobs.values():
if job.status not in ("FAILED", "SUCCEEDED"): if job.status not in ("FAILED", "SUCCEEDED"):
job.stop = True job.stop = True
@ -886,16 +894,18 @@ class BatchBackend(BaseBackend):
super().reset() super().reset()
def get_compute_environment_by_arn(self, arn): def get_compute_environment_by_arn(self, arn: str) -> Optional[ComputeEnvironment]:
return self._compute_environments.get(arn) return self._compute_environments.get(arn)
def get_compute_environment_by_name(self, name): def get_compute_environment_by_name(
self, name: str
) -> Optional[ComputeEnvironment]:
for comp_env in self._compute_environments.values(): for comp_env in self._compute_environments.values():
if comp_env.name == name: if comp_env.name == name:
return comp_env return comp_env
return None return None
def get_compute_environment(self, identifier): def get_compute_environment(self, identifier: str) -> Optional[ComputeEnvironment]:
""" """
Get compute environment by name or ARN Get compute environment by name or ARN
:param identifier: Name or ARN :param identifier: Name or ARN
@ -904,21 +914,20 @@ class BatchBackend(BaseBackend):
:return: Compute Environment or None :return: Compute Environment or None
:rtype: ComputeEnvironment or None :rtype: ComputeEnvironment or None
""" """
env = self.get_compute_environment_by_arn(identifier) return self.get_compute_environment_by_arn(
if env is None: identifier
env = self.get_compute_environment_by_name(identifier) ) or self.get_compute_environment_by_name(identifier)
return env
def get_job_queue_by_arn(self, arn): def get_job_queue_by_arn(self, arn: str) -> Optional[JobQueue]:
return self._job_queues.get(arn) return self._job_queues.get(arn)
def get_job_queue_by_name(self, name): def get_job_queue_by_name(self, name: str) -> Optional[JobQueue]:
for comp_env in self._job_queues.values(): for comp_env in self._job_queues.values():
if comp_env.name == name: if comp_env.name == name:
return comp_env return comp_env
return None return None
def get_job_queue(self, identifier): def get_job_queue(self, identifier: str) -> Optional[JobQueue]:
""" """
Get job queue by name or ARN Get job queue by name or ARN
:param identifier: Name or ARN :param identifier: Name or ARN
@ -927,15 +936,14 @@ class BatchBackend(BaseBackend):
:return: Job Queue or None :return: Job Queue or None
:rtype: JobQueue or None :rtype: JobQueue or None
""" """
env = self.get_job_queue_by_arn(identifier) return self.get_job_queue_by_arn(identifier) or self.get_job_queue_by_name(
if env is None: identifier
env = self.get_job_queue_by_name(identifier) )
return env
def get_job_definition_by_arn(self, arn): def get_job_definition_by_arn(self, arn: str) -> Optional[JobDefinition]:
return self._job_definitions.get(arn) return self._job_definitions.get(arn)
def get_job_definition_by_name(self, name): def get_job_definition_by_name(self, name: str) -> Optional[JobDefinition]:
latest_revision = -1 latest_revision = -1
latest_job = None latest_job = None
for job_def in self._job_definitions.values(): for job_def in self._job_definitions.values():
@ -944,13 +952,15 @@ class BatchBackend(BaseBackend):
latest_revision = job_def.revision latest_revision = job_def.revision
return latest_job return latest_job
def get_job_definition_by_name_revision(self, name, revision): def get_job_definition_by_name_revision(
self, name: str, revision: str
) -> Optional[JobDefinition]:
for job_def in self._job_definitions.values(): for job_def in self._job_definitions.values():
if job_def.name == name and job_def.revision == int(revision): if job_def.name == name and job_def.revision == int(revision):
return job_def return job_def
return None return None
def get_job_definition(self, identifier): def get_job_definition(self, identifier: str) -> Optional[JobDefinition]:
""" """
Get job definitions by name or ARN Get job definitions by name or ARN
:param identifier: Name or ARN :param identifier: Name or ARN
@ -969,7 +979,7 @@ class BatchBackend(BaseBackend):
job_def = self.get_job_definition_by_name(identifier) job_def = self.get_job_definition_by_name(identifier)
return job_def return job_def
def get_job_definitions(self, identifier): def get_job_definitions(self, identifier: str) -> List[JobDefinition]:
""" """
Get job definitions by name or ARN Get job definitions by name or ARN
:param identifier: Name or ARN :param identifier: Name or ARN
@ -989,21 +999,15 @@ class BatchBackend(BaseBackend):
return result return result
def get_job_by_id(self, identifier): def get_job_by_id(self, identifier: str) -> Optional[Job]:
"""
Get job by id
:param identifier: Job ID
:type identifier: str
:return: Job
:rtype: Job
"""
try: try:
return self._jobs[identifier] return self._jobs[identifier]
except KeyError: except KeyError:
return None return None
def describe_compute_environments(self, environments=None): def describe_compute_environments(
self, environments: Optional[List[str]] = None
) -> List[Dict[str, Any]]:
""" """
Pagination is not yet implemented Pagination is not yet implemented
""" """
@ -1017,7 +1021,7 @@ class BatchBackend(BaseBackend):
if len(envs) > 0 and arn not in envs and environment.name not in envs: if len(envs) > 0 and arn not in envs and environment.name not in envs:
continue continue
json_part = { json_part: Dict[str, Any] = {
"computeEnvironmentArn": arn, "computeEnvironmentArn": arn,
"computeEnvironmentName": environment.name, "computeEnvironmentName": environment.name,
"ecsClusterArn": environment.ecs_arn, "ecsClusterArn": environment.ecs_arn,
@ -1035,8 +1039,13 @@ class BatchBackend(BaseBackend):
return result return result
def create_compute_environment( def create_compute_environment(
self, compute_environment_name, _type, state, compute_resources, service_role self,
): compute_environment_name: str,
_type: str,
state: str,
compute_resources: Dict[str, Any],
service_role: str,
) -> Tuple[str, str]:
# Validate # Validate
if COMPUTE_ENVIRONMENT_NAME_REGEX.match(compute_environment_name) is None: if COMPUTE_ENVIRONMENT_NAME_REGEX.match(compute_environment_name) is None:
raise InvalidParameterValueException( raise InvalidParameterValueException(
@ -1127,7 +1136,7 @@ class BatchBackend(BaseBackend):
return compute_environment_name, new_comp_env.arn return compute_environment_name, new_comp_env.arn
def _validate_compute_resources(self, cr): def _validate_compute_resources(self, cr: Dict[str, Any]) -> None:
""" """
Checks contents of sub dictionary for managed clusters Checks contents of sub dictionary for managed clusters
@ -1195,16 +1204,15 @@ class BatchBackend(BaseBackend):
) )
@staticmethod @staticmethod
def find_min_instances_to_meet_vcpus(instance_types, target): def find_min_instances_to_meet_vcpus(
instance_types: List[str], target: float
) -> List[str]:
""" """
Finds the minimum needed instances to meed a vcpu target Finds the minimum needed instances to meed a vcpu target
:param instance_types: Instance types, like ['t2.medium', 't2.small'] :param instance_types: Instance types, like ['t2.medium', 't2.small']
:type instance_types: list of str
:param target: VCPU target :param target: VCPU target
:type target: float
:return: List of instance types :return: List of instance types
:rtype: list of str
""" """
# vcpus = [ (vcpus, instance_type), (vcpus, instance_type), ... ] # vcpus = [ (vcpus, instance_type), (vcpus, instance_type), ... ]
instance_vcpus = [] instance_vcpus = []
@ -1253,7 +1261,7 @@ class BatchBackend(BaseBackend):
return instances return instances
def delete_compute_environment(self, compute_environment_name): def delete_compute_environment(self, compute_environment_name: str) -> None:
if compute_environment_name is None: if compute_environment_name is None:
raise InvalidParameterValueException("Missing computeEnvironment parameter") raise InvalidParameterValueException("Missing computeEnvironment parameter")
@ -1273,8 +1281,12 @@ class BatchBackend(BaseBackend):
self.ec2_backend.terminate_instances(instance_ids) self.ec2_backend.terminate_instances(instance_ids)
def update_compute_environment( def update_compute_environment(
self, compute_environment_name, state, compute_resources, service_role self,
): compute_environment_name: str,
state: Optional[str],
compute_resources: Optional[Any],
service_role: Optional[str],
) -> Tuple[str, str]:
# Validate # Validate
compute_env = self.get_compute_environment(compute_environment_name) compute_env = self.get_compute_environment(compute_environment_name)
if compute_env is None: if compute_env is None:
@ -1283,13 +1295,13 @@ class BatchBackend(BaseBackend):
# Look for IAM role # Look for IAM role
if service_role is not None: if service_role is not None:
try: try:
role = self.iam_backend.get_role_by_arn(service_role) self.iam_backend.get_role_by_arn(service_role)
except IAMNotFoundException: except IAMNotFoundException:
raise InvalidParameterValueException( raise InvalidParameterValueException(
"Could not find IAM role {0}".format(service_role) "Could not find IAM role {0}".format(service_role)
) )
compute_env.service_role = role compute_env.service_role = service_role
if state is not None: if state is not None:
if state not in ("ENABLED", "DISABLED"): if state not in ("ENABLED", "DISABLED"):
@ -1307,8 +1319,13 @@ class BatchBackend(BaseBackend):
return compute_env.name, compute_env.arn return compute_env.name, compute_env.arn
def create_job_queue( def create_job_queue(
self, queue_name, priority, state, compute_env_order, tags=None self,
): queue_name: str,
priority: str,
state: str,
compute_env_order: List[Dict[str, str]],
tags: Optional[Dict[str, str]] = None,
) -> Tuple[str, str]:
for variable, var_name in ( for variable, var_name in (
(queue_name, "jobQueueName"), (queue_name, "jobQueueName"),
(priority, "priority"), (priority, "priority"),
@ -1359,7 +1376,9 @@ class BatchBackend(BaseBackend):
return queue_name, queue.arn return queue_name, queue.arn
def describe_job_queues(self, job_queues=None): def describe_job_queues(
self, job_queues: Optional[List[str]] = None
) -> List[Dict[str, Any]]:
""" """
Pagination is not yet implemented Pagination is not yet implemented
""" """
@ -1377,7 +1396,13 @@ class BatchBackend(BaseBackend):
return result return result
def update_job_queue(self, queue_name, priority, state, compute_env_order): def update_job_queue(
self,
queue_name: str,
priority: Optional[str],
state: Optional[str],
compute_env_order: Optional[List[Dict[str, Any]]],
) -> Tuple[str, str]:
if queue_name is None: if queue_name is None:
raise ClientException("jobQueueName must be provided") raise ClientException("jobQueueName must be provided")
@ -1422,7 +1447,7 @@ class BatchBackend(BaseBackend):
return queue_name, job_queue.arn return queue_name, job_queue.arn
def delete_job_queue(self, queue_name): def delete_job_queue(self, queue_name: str) -> None:
job_queue = self.get_job_queue(queue_name) job_queue = self.get_job_queue(queue_name)
if job_queue is not None: if job_queue is not None:
@ -1430,16 +1455,16 @@ class BatchBackend(BaseBackend):
def register_job_definition( def register_job_definition(
self, self,
def_name, def_name: str,
parameters, parameters: Dict[str, Any],
_type, _type: str,
tags, tags: Dict[str, str],
retry_strategy, retry_strategy: Dict[str, Any],
container_properties, container_properties: Dict[str, Any],
timeout, timeout: Dict[str, int],
platform_capabilities, platform_capabilities: List[str],
propagate_tags, propagate_tags: bool,
): ) -> Tuple[str, str, int]:
if def_name is None: if def_name is None:
raise ClientException("jobDefinitionName must be provided") raise ClientException("jobDefinitionName must be provided")
@ -1473,7 +1498,7 @@ class BatchBackend(BaseBackend):
return def_name, job_def.arn, job_def.revision return def_name, job_def.arn, job_def.revision
def deregister_job_definition(self, def_name): def deregister_job_definition(self, def_name: str) -> None:
job_def = self.get_job_definition_by_arn(def_name) job_def = self.get_job_definition_by_arn(def_name)
if job_def is None and ":" in def_name: if job_def is None and ":" in def_name:
name, revision = def_name.split(":", 1) name, revision = def_name.split(":", 1)
@ -1483,8 +1508,11 @@ class BatchBackend(BaseBackend):
self._job_definitions[job_def.arn].deregister() self._job_definitions[job_def.arn].deregister()
def describe_job_definitions( def describe_job_definitions(
self, job_def_name=None, job_def_list=None, status=None self,
): job_def_name: Optional[str] = None,
job_def_list: Optional[List[str]] = None,
status: Optional[str] = None,
) -> List[JobDefinition]:
""" """
Pagination is not yet implemented Pagination is not yet implemented
""" """
@ -1496,8 +1524,8 @@ class BatchBackend(BaseBackend):
if job_def is not None: if job_def is not None:
jobs.extend(job_def) jobs.extend(job_def)
elif job_def_list is not None: elif job_def_list is not None:
for job in job_def_list: for jdn in job_def_list:
job_def = self.get_job_definitions(job) job_def = self.get_job_definitions(jdn)
if job_def is not None: if job_def is not None:
jobs.extend(job_def) jobs.extend(job_def)
else: else:
@ -1512,13 +1540,13 @@ class BatchBackend(BaseBackend):
def submit_job( def submit_job(
self, self,
job_name, job_name: str,
job_def_id, job_def_id: str,
job_queue, job_queue: str,
depends_on=None, depends_on: Optional[List[Dict[str, str]]] = None,
container_overrides=None, container_overrides: Optional[Dict[str, Any]] = None,
timeout=None, timeout: Optional[Dict[str, int]] = None,
): ) -> Tuple[str, str]:
""" """
Parameters RetryStrategy and Parameters are not yet implemented. Parameters RetryStrategy and Parameters are not yet implemented.
""" """
@ -1550,7 +1578,7 @@ class BatchBackend(BaseBackend):
return job_name, job.job_id return job_name, job.job_id
def describe_jobs(self, jobs): def describe_jobs(self, jobs: Optional[List[str]]) -> List[Dict[str, Any]]:
job_filter = set() job_filter = set()
if jobs is not None: if jobs is not None:
job_filter = set(jobs) job_filter = set(jobs)
@ -1564,13 +1592,15 @@ class BatchBackend(BaseBackend):
return result return result
def list_jobs(self, job_queue, job_status=None): def list_jobs(
self, job_queue_name: str, job_status: Optional[str] = None
) -> List[Job]:
""" """
Pagination is not yet implemented Pagination is not yet implemented
""" """
jobs = [] jobs = []
job_queue = self.get_job_queue(job_queue) job_queue = self.get_job_queue(job_queue_name)
if job_queue is None: if job_queue is None:
raise ClientException("Job queue {0} does not exist".format(job_queue)) raise ClientException("Job queue {0} does not exist".format(job_queue))
@ -1595,7 +1625,7 @@ class BatchBackend(BaseBackend):
return jobs return jobs
def cancel_job(self, job_id, reason): def cancel_job(self, job_id: str, reason: str) -> None:
if job_id == "": if job_id == "":
raise ClientException( raise ClientException(
"'reason' is a required field (cannot be an empty string)" "'reason' is a required field (cannot be an empty string)"
@ -1611,7 +1641,7 @@ class BatchBackend(BaseBackend):
job.terminate(reason) job.terminate(reason)
# No-Op for jobs that have already started - user has to explicitly terminate those # No-Op for jobs that have already started - user has to explicitly terminate those
def terminate_job(self, job_id, reason): def terminate_job(self, job_id: str, reason: str) -> None:
if job_id == "": if job_id == "":
raise ClientException( raise ClientException(
"'reason' is a required field (cannot be a empty string)" "'reason' is a required field (cannot be a empty string)"
@ -1625,14 +1655,14 @@ class BatchBackend(BaseBackend):
if job is not None: if job is not None:
job.terminate(reason) job.terminate(reason)
def tag_resource(self, resource_arn, tags): def tag_resource(self, resource_arn: str, tags: Dict[str, str]) -> None:
tags = self.tagger.convert_dict_to_tags_input(tags or {}) tag_list = self.tagger.convert_dict_to_tags_input(tags or {})
self.tagger.tag_resource(resource_arn, tags) self.tagger.tag_resource(resource_arn, tag_list)
def list_tags_for_resource(self, resource_arn): def list_tags_for_resource(self, resource_arn: str) -> Dict[str, str]:
return self.tagger.get_tag_dict_for_resource(resource_arn) return self.tagger.get_tag_dict_for_resource(resource_arn)
def untag_resource(self, resource_arn, tag_keys): def untag_resource(self, resource_arn: str, tag_keys: List[str]) -> None:
self.tagger.untag_resource_using_names(resource_arn, tag_keys) self.tagger.untag_resource_using_names(resource_arn, tag_keys)

View File

@ -1,45 +1,28 @@
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import batch_backends from .models import batch_backends, BatchBackend
from urllib.parse import urlsplit, unquote from urllib.parse import urlsplit, unquote
import json import json
class BatchResponse(BaseResponse): class BatchResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="batch") super().__init__(service_name="batch")
def _error(self, code, message):
return json.dumps({"__type": code, "message": message}), dict(status=400)
@property @property
def batch_backend(self): def batch_backend(self) -> BatchBackend:
""" """
:return: Batch Backend :return: Batch Backend
:rtype: moto.batch.models.BatchBackend :rtype: moto.batch.models.BatchBackend
""" """
return batch_backends[self.current_account][self.region] return batch_backends[self.current_account][self.region]
@property def _get_action(self) -> str:
def json(self):
if self.body is None or self.body == "":
self._json = {}
elif not hasattr(self, "_json"):
self._json = json.loads(self.body)
return self._json
def _get_param(self, param_name, if_none=None):
val = self.json.get(param_name)
if val is not None:
return val
return if_none
def _get_action(self):
# Return element after the /v1/* # Return element after the /v1/*
return urlsplit(self.uri).path.lstrip("/").split("/")[1] return urlsplit(self.uri).path.lstrip("/").split("/")[1]
# CreateComputeEnvironment # CreateComputeEnvironment
def createcomputeenvironment(self): def createcomputeenvironment(self) -> str:
compute_env_name = self._get_param("computeEnvironmentName") compute_env_name = self._get_param("computeEnvironmentName")
compute_resource = self._get_param("computeResources") compute_resource = self._get_param("computeResources")
service_role = self._get_param("serviceRole") service_role = self._get_param("serviceRole")
@ -59,7 +42,7 @@ class BatchResponse(BaseResponse):
return json.dumps(result) return json.dumps(result)
# DescribeComputeEnvironments # DescribeComputeEnvironments
def describecomputeenvironments(self): def describecomputeenvironments(self) -> str:
compute_environments = self._get_param("computeEnvironments") compute_environments = self._get_param("computeEnvironments")
envs = self.batch_backend.describe_compute_environments(compute_environments) envs = self.batch_backend.describe_compute_environments(compute_environments)
@ -68,7 +51,7 @@ class BatchResponse(BaseResponse):
return json.dumps(result) return json.dumps(result)
# DeleteComputeEnvironment # DeleteComputeEnvironment
def deletecomputeenvironment(self): def deletecomputeenvironment(self) -> str:
compute_environment = self._get_param("computeEnvironment") compute_environment = self._get_param("computeEnvironment")
self.batch_backend.delete_compute_environment(compute_environment) self.batch_backend.delete_compute_environment(compute_environment)
@ -76,7 +59,7 @@ class BatchResponse(BaseResponse):
return "" return ""
# UpdateComputeEnvironment # UpdateComputeEnvironment
def updatecomputeenvironment(self): def updatecomputeenvironment(self) -> str:
compute_env_name = self._get_param("computeEnvironment") compute_env_name = self._get_param("computeEnvironment")
compute_resource = self._get_param("computeResources") compute_resource = self._get_param("computeResources")
service_role = self._get_param("serviceRole") service_role = self._get_param("serviceRole")
@ -94,7 +77,7 @@ class BatchResponse(BaseResponse):
return json.dumps(result) return json.dumps(result)
# CreateJobQueue # CreateJobQueue
def createjobqueue(self): def createjobqueue(self) -> str:
compute_env_order = self._get_param("computeEnvironmentOrder") compute_env_order = self._get_param("computeEnvironmentOrder")
queue_name = self._get_param("jobQueueName") queue_name = self._get_param("jobQueueName")
priority = self._get_param("priority") priority = self._get_param("priority")
@ -114,7 +97,7 @@ class BatchResponse(BaseResponse):
return json.dumps(result) return json.dumps(result)
# DescribeJobQueues # DescribeJobQueues
def describejobqueues(self): def describejobqueues(self) -> str:
job_queues = self._get_param("jobQueues") job_queues = self._get_param("jobQueues")
queues = self.batch_backend.describe_job_queues(job_queues) queues = self.batch_backend.describe_job_queues(job_queues)
@ -123,7 +106,7 @@ class BatchResponse(BaseResponse):
return json.dumps(result) return json.dumps(result)
# UpdateJobQueue # UpdateJobQueue
def updatejobqueue(self): def updatejobqueue(self) -> str:
compute_env_order = self._get_param("computeEnvironmentOrder") compute_env_order = self._get_param("computeEnvironmentOrder")
queue_name = self._get_param("jobQueue") queue_name = self._get_param("jobQueue")
priority = self._get_param("priority") priority = self._get_param("priority")
@ -141,7 +124,7 @@ class BatchResponse(BaseResponse):
return json.dumps(result) return json.dumps(result)
# DeleteJobQueue # DeleteJobQueue
def deletejobqueue(self): def deletejobqueue(self) -> str:
queue_name = self._get_param("jobQueue") queue_name = self._get_param("jobQueue")
self.batch_backend.delete_job_queue(queue_name) self.batch_backend.delete_job_queue(queue_name)
@ -149,7 +132,7 @@ class BatchResponse(BaseResponse):
return "" return ""
# RegisterJobDefinition # RegisterJobDefinition
def registerjobdefinition(self): def registerjobdefinition(self) -> str:
container_properties = self._get_param("containerProperties") container_properties = self._get_param("containerProperties")
def_name = self._get_param("jobDefinitionName") def_name = self._get_param("jobDefinitionName")
parameters = self._get_param("parameters") parameters = self._get_param("parameters")
@ -180,7 +163,7 @@ class BatchResponse(BaseResponse):
return json.dumps(result) return json.dumps(result)
# DeregisterJobDefinition # DeregisterJobDefinition
def deregisterjobdefinition(self): def deregisterjobdefinition(self) -> str:
queue_name = self._get_param("jobDefinition") queue_name = self._get_param("jobDefinition")
self.batch_backend.deregister_job_definition(queue_name) self.batch_backend.deregister_job_definition(queue_name)
@ -188,7 +171,7 @@ class BatchResponse(BaseResponse):
return "" return ""
# DescribeJobDefinitions # DescribeJobDefinitions
def describejobdefinitions(self): def describejobdefinitions(self) -> str:
job_def_name = self._get_param("jobDefinitionName") job_def_name = self._get_param("jobDefinitionName")
job_def_list = self._get_param("jobDefinitions") job_def_list = self._get_param("jobDefinitions")
status = self._get_param("status") status = self._get_param("status")
@ -201,7 +184,7 @@ class BatchResponse(BaseResponse):
return json.dumps(result) return json.dumps(result)
# SubmitJob # SubmitJob
def submitjob(self): def submitjob(self) -> str:
container_overrides = self._get_param("containerOverrides") container_overrides = self._get_param("containerOverrides")
depends_on = self._get_param("dependsOn") depends_on = self._get_param("dependsOn")
job_def = self._get_param("jobDefinition") job_def = self._get_param("jobDefinition")
@ -223,13 +206,13 @@ class BatchResponse(BaseResponse):
return json.dumps(result) return json.dumps(result)
# DescribeJobs # DescribeJobs
def describejobs(self): def describejobs(self) -> str:
jobs = self._get_param("jobs") jobs = self._get_param("jobs")
return json.dumps({"jobs": self.batch_backend.describe_jobs(jobs)}) return json.dumps({"jobs": self.batch_backend.describe_jobs(jobs)})
# ListJobs # ListJobs
def listjobs(self): def listjobs(self) -> str:
job_queue = self._get_param("jobQueue") job_queue = self._get_param("jobQueue")
job_status = self._get_param("jobStatus") job_status = self._get_param("jobStatus")
@ -239,7 +222,7 @@ class BatchResponse(BaseResponse):
return json.dumps(result) return json.dumps(result)
# TerminateJob # TerminateJob
def terminatejob(self): def terminatejob(self) -> str:
job_id = self._get_param("jobId") job_id = self._get_param("jobId")
reason = self._get_param("reason") reason = self._get_param("reason")
@ -248,22 +231,22 @@ class BatchResponse(BaseResponse):
return "" return ""
# CancelJob # CancelJob
def canceljob(self): def canceljob(self) -> str:
job_id = self._get_param("jobId") job_id = self._get_param("jobId")
reason = self._get_param("reason") reason = self._get_param("reason")
self.batch_backend.cancel_job(job_id, reason) self.batch_backend.cancel_job(job_id, reason)
return "" return ""
def tags(self): def tags(self) -> str:
resource_arn = unquote(self.path).split("/v1/tags/")[-1] resource_arn = unquote(self.path).split("/v1/tags/")[-1]
tags = self._get_param("tags") tags = self._get_param("tags")
if self.method == "POST": if self.method == "POST":
self.batch_backend.tag_resource(resource_arn, tags) self.batch_backend.tag_resource(resource_arn, tags)
return ""
if self.method == "GET": if self.method == "GET":
tags = self.batch_backend.list_tags_for_resource(resource_arn) tags = self.batch_backend.list_tags_for_resource(resource_arn)
return json.dumps({"tags": tags}) return json.dumps({"tags": tags})
if self.method == "DELETE": if self.method == "DELETE":
tag_keys = self.querystring.get("tagKeys") tag_keys = self.querystring.get("tagKeys")
self.batch_backend.untag_resource(resource_arn, tag_keys) self.batch_backend.untag_resource(resource_arn, tag_keys) # type: ignore[arg-type]
return ""

View File

@ -1,21 +1,26 @@
def make_arn_for_compute_env(account_id, name, region_name): from typing import Any, Dict
def make_arn_for_compute_env(account_id: str, name: str, region_name: str) -> str:
return "arn:aws:batch:{0}:{1}:compute-environment/{2}".format( return "arn:aws:batch:{0}:{1}:compute-environment/{2}".format(
region_name, account_id, name region_name, account_id, name
) )
def make_arn_for_job_queue(account_id, name, region_name): def make_arn_for_job_queue(account_id: str, name: str, region_name: str) -> str:
return "arn:aws:batch:{0}:{1}:job-queue/{2}".format(region_name, account_id, name) return "arn:aws:batch:{0}:{1}:job-queue/{2}".format(region_name, account_id, name)
def make_arn_for_task_def(account_id, name, revision, region_name): def make_arn_for_task_def(
account_id: str, name: str, revision: int, region_name: str
) -> str:
return "arn:aws:batch:{0}:{1}:job-definition/{2}:{3}".format( return "arn:aws:batch:{0}:{1}:job-definition/{2}:{3}".format(
region_name, account_id, name, revision region_name, account_id, name, revision
) )
def lowercase_first_key(some_dict): def lowercase_first_key(some_dict: Dict[str, Any]) -> Dict[str, Any]:
new_dict = {} new_dict: Dict[str, Any] = {}
for key, value in some_dict.items(): for key, value in some_dict.items():
new_key = key[0].lower() + key[1:] new_key = key[0].lower() + key[1:]
try: try:

View File

@ -1,7 +1,14 @@
from ..batch.models import batch_backends, BaseBackend, Job, ClientException from ..batch.models import (
batch_backends,
BaseBackend,
Job,
ClientException,
BatchBackend,
)
from ..core.utils import BackendDict from ..core.utils import BackendDict
import datetime import datetime
from typing import Any, Dict, List, Tuple, Optional
class BatchSimpleBackend(BaseBackend): class BatchSimpleBackend(BaseBackend):
@ -11,10 +18,10 @@ class BatchSimpleBackend(BaseBackend):
""" """
@property @property
def backend(self): def backend(self) -> BatchBackend:
return batch_backends[self.account_id][self.region_name] return batch_backends[self.account_id][self.region_name]
def __getattribute__(self, name): def __getattribute__(self, name: str) -> Any:
""" """
Magic part that makes this class behave like a wrapper around the regular batch_backend Magic part that makes this class behave like a wrapper around the regular batch_backend
We intercept calls to `submit_job` and replace this with our own (non-Docker) implementation We intercept calls to `submit_job` and replace this with our own (non-Docker) implementation
@ -32,7 +39,7 @@ class BatchSimpleBackend(BaseBackend):
return object.__getattribute__(self, name) return object.__getattribute__(self, name)
if name in ["submit_job"]: if name in ["submit_job"]:
def newfunc(*args, **kwargs): def newfunc(*args: Any, **kwargs: Any) -> Any:
attr = object.__getattribute__(self, name) attr = object.__getattribute__(self, name)
return attr(*args, **kwargs) return attr(*args, **kwargs)
@ -42,13 +49,13 @@ class BatchSimpleBackend(BaseBackend):
def submit_job( def submit_job(
self, self,
job_name, job_name: str,
job_def_id, job_def_id: str,
job_queue, job_queue: str,
depends_on=None, depends_on: Optional[List[Dict[str, str]]] = None,
container_overrides=None, container_overrides: Optional[Dict[str, Any]] = None,
timeout=None, timeout: Optional[Dict[str, int]] = None,
): ) -> Tuple[str, str]:
# Look for job definition # Look for job definition
job_def = self.get_job_definition(job_def_id) job_def = self.get_job_definition(job_def_id)
if job_def is None: if job_def is None:

View File

@ -1,10 +1,10 @@
from ..batch.responses import BatchResponse from ..batch.responses import BatchResponse
from .models import batch_simple_backends from .models import batch_simple_backends, BatchBackend
class BatchSimpleResponse(BatchResponse): class BatchSimpleResponse(BatchResponse):
@property @property
def batch_backend(self): def batch_backend(self) -> BatchBackend:
""" """
:return: Batch Backend :return: Batch Backend
:rtype: moto.batch.models.BatchBackend :rtype: moto.batch.models.BatchBackend

View File

@ -32,7 +32,7 @@ class BaseBackend:
for model in models.values(): for model in models.values():
model.instances = [] model.instances = []
def reset(self): def reset(self) -> None:
region_name = self.region_name region_name = self.region_name
account_id = self.account_id account_id = self.account_id
self._reset_model_refs() self._reset_model_refs()

View File

@ -2,7 +2,7 @@ import copy
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from datetime import datetime from datetime import datetime
from typing import Any, List, Tuple from typing import Any, List, Tuple, Optional
from moto import settings from moto import settings
from moto.core import CloudFormationModel from moto.core import CloudFormationModel
@ -597,7 +597,7 @@ class InstanceBackend:
self, self,
image_id: str, image_id: str,
count: int, count: int,
user_data: str, user_data: Optional[str],
security_group_names: List[str], security_group_names: List[str],
**kwargs: Any **kwargs: Any
) -> Reservation: ) -> Reservation:

View File

@ -2,7 +2,7 @@ import copy
import itertools import itertools
import json import json
from collections import defaultdict from collections import defaultdict
from typing import Optional
from moto.core import CloudFormationModel from moto.core import CloudFormationModel
from moto.core.utils import aws_api_matches from moto.core.utils import aws_api_matches
from ..exceptions import ( from ..exceptions import (
@ -543,7 +543,7 @@ class SecurityGroupBackend:
return self._delete_security_group(None, group.id) return self._delete_security_group(None, group.id)
raise InvalidSecurityGroupNotFoundError(name) raise InvalidSecurityGroupNotFoundError(name)
def get_security_group_from_id(self, group_id): def get_security_group_from_id(self, group_id: str) -> Optional[SecurityGroup]:
# 2 levels of chaining necessary since it's a complex structure # 2 levels of chaining necessary since it's a complex structure
all_groups = itertools.chain.from_iterable( all_groups = itertools.chain.from_iterable(
[x.copy().values() for x in self.groups.copy().values()] [x.copy().values() for x in self.groups.copy().values()]

View File

@ -229,7 +229,7 @@ class SubnetBackend:
# maps availability zone to dict of (subnet_id, subnet) # maps availability zone to dict of (subnet_id, subnet)
self.subnets = defaultdict(dict) self.subnets = defaultdict(dict)
def get_subnet(self, subnet_id): def get_subnet(self, subnet_id: str) -> Subnet:
for subnets in self.subnets.values(): for subnets in self.subnets.values():
if subnet_id in subnets: if subnet_id in subnets:
return subnets[subnet_id] return subnets[subnet_id]

View File

@ -1,7 +1,7 @@
import re import re
from copy import copy from copy import copy
from datetime import datetime from datetime import datetime
from typing import Any
import pytz import pytz
from moto import settings from moto import settings
@ -850,7 +850,9 @@ class EC2ContainerServiceBackend(BaseBackend):
else: else:
raise Exception("{0} is not a task_definition".format(task_definition_name)) raise Exception("{0} is not a task_definition".format(task_definition_name))
def create_cluster(self, cluster_name, tags=None, cluster_settings=None): def create_cluster(
self, cluster_name: str, tags: Any = None, cluster_settings: Any = None
) -> Cluster:
""" """
The following parameters are not yet implemented: configuration, capacityProviders, defaultCapacityProviderStrategy The following parameters are not yet implemented: configuration, capacityProviders, defaultCapacityProviderStrategy
""" """
@ -926,7 +928,7 @@ class EC2ContainerServiceBackend(BaseBackend):
return list_clusters, failures return list_clusters, failures
def delete_cluster(self, cluster_str): def delete_cluster(self, cluster_str: str) -> Cluster:
cluster = self._get_cluster(cluster_str) cluster = self._get_cluster(cluster_str)
return self.clusters.pop(cluster.name) return self.clusters.pop(cluster.name)

View File

@ -10,7 +10,7 @@ from cryptography import x509
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from jinja2 import Template from jinja2 import Template
from typing import Mapping from typing import List, Mapping
from urllib import parse from urllib import parse
from moto.core.exceptions import RESTError from moto.core.exceptions import RESTError
from moto.core import DEFAULT_ACCOUNT_ID, BaseBackend, BaseModel, CloudFormationModel from moto.core import DEFAULT_ACCOUNT_ID, BaseBackend, BaseModel, CloudFormationModel
@ -2188,7 +2188,7 @@ class IAMBackend(BaseBackend):
raise IAMNotFoundException("Instance profile {0} not found".format(profile_arn)) raise IAMNotFoundException("Instance profile {0} not found".format(profile_arn))
def get_instance_profiles(self): def get_instance_profiles(self) -> List[InstanceProfile]:
return self.instance_profiles.values() return self.instance_profiles.values()
def get_instance_profiles_for_role(self, role_name): def get_instance_profiles_for_role(self, role_name):

View File

@ -1,5 +1,5 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Dict, List, Tuple, Optional
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core import CloudFormationModel from moto.core import CloudFormationModel
from moto.core.utils import unix_time_millis, BackendDict from moto.core.utils import unix_time_millis, BackendDict
@ -627,7 +627,9 @@ class LogsBackend(BaseBackend):
) )
return self.groups[log_group_name] return self.groups[log_group_name]
def ensure_log_group(self, log_group_name, tags): def ensure_log_group(
self, log_group_name: str, tags: Optional[Dict[str, str]]
) -> None:
if log_group_name in self.groups: if log_group_name in self.groups:
return return
self.groups[log_group_name] = LogGroup( self.groups[log_group_name] = LogGroup(
@ -653,7 +655,7 @@ class LogsBackend(BaseBackend):
return groups return groups
def create_log_stream(self, log_group_name, log_stream_name): def create_log_stream(self, log_group_name: str, log_stream_name: str) -> LogStream:
if log_group_name not in self.groups: if log_group_name not in self.groups:
raise ResourceNotFoundException() raise ResourceNotFoundException()
log_group = self.groups[log_group_name] log_group = self.groups[log_group_name]
@ -702,7 +704,12 @@ class LogsBackend(BaseBackend):
order_by=order_by, order_by=order_by,
) )
def put_log_events(self, log_group_name, log_stream_name, log_events): def put_log_events(
self,
log_group_name: str,
log_stream_name: str,
log_events: List[Dict[str, Any]],
) -> Tuple[str, Dict[str, Any]]:
""" """
The SequenceToken-parameter is not yet implemented The SequenceToken-parameter is not yet implemented
""" """

View File

@ -1,5 +1,6 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from moto.moto_api import state_manager from moto.moto_api import state_manager
from typing import List, Tuple
class ManagedState: class ManagedState:
@ -7,7 +8,7 @@ class ManagedState:
Subclass this class to configure state-transitions Subclass this class to configure state-transitions
""" """
def __init__(self, model_name, transitions): def __init__(self, model_name: str, transitions: List[Tuple[str, str]]):
# Indicate the possible transitions for this model # Indicate the possible transitions for this model
# Example: [(initializing,queued), (queued, starting), (starting, ready)] # Example: [(initializing,queued), (queued, starting), (starting, ready)]
self._transitions = transitions self._transitions = transitions
@ -23,7 +24,7 @@ class ManagedState:
# Name of this model. This will be used in the API # Name of this model. This will be used in the API
self.model_name = model_name self.model_name = model_name
def advance(self): def advance(self) -> None:
self._tick += 1 self._tick += 1
@property @property

View File

@ -1,3 +1,6 @@
from typing import Any, Dict
DEFAULT_TRANSITION = {"progression": "immediate"} DEFAULT_TRANSITION = {"progression": "immediate"}
@ -6,7 +9,9 @@ class StateManager:
self._default_transitions = dict() self._default_transitions = dict()
self._transitions = dict() self._transitions = dict()
def register_default_transition(self, model_name, transition): def register_default_transition(
self, model_name: str, transition: Dict[str, Any]
) -> None:
""" """
Register the default transition for a specific model. Register the default transition for a specific model.
This should only be called by Moto backends - use the `set_transition` method to override this default transition in your own tests. This should only be called by Moto backends - use the `set_transition` method to override this default transition in your own tests.

View File

@ -106,7 +106,7 @@ class TaggingService:
result[tag[self.key_name]] = None result[tag[self.key_name]] = None
return result return result
def validate_tags(self, tags, limit=0): def validate_tags(self, tags: List[Dict[str, str]], limit: int = 0) -> str:
"""Returns error message if tags in 'tags' list of dicts are invalid. """Returns error message if tags in 'tags' list of dicts are invalid.
The validation does not include a check for duplicate keys. The validation does not include a check for duplicate keys.

View File

@ -18,7 +18,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy] [mypy]
files= moto/a*,moto/budgets files= moto/a*,moto/b*
show_column_numbers=True show_column_numbers=True
show_error_codes = True show_error_codes = True
disable_error_code=abstract disable_error_code=abstract

View File

@ -323,6 +323,42 @@ def test_update_unmanaged_compute_environment_state():
our_envs[0]["state"].should.equal("DISABLED") our_envs[0]["state"].should.equal("DISABLED")
@mock_ec2
@mock_ecs
@mock_iam
@mock_batch
def test_update_iam_role():
ec2_client, iam_client, _, _, batch_client = _get_clients()
_, _, _, iam_arn = _setup(ec2_client, iam_client)
iam_arn2 = iam_client.create_role(RoleName="r", AssumeRolePolicyDocument="sp")[
"Role"
]["Arn"]
compute_name = str(uuid4())
batch_client.create_compute_environment(
computeEnvironmentName=compute_name,
type="UNMANAGED",
state="ENABLED",
serviceRole=iam_arn,
)
batch_client.update_compute_environment(
computeEnvironment=compute_name, serviceRole=iam_arn2
)
all_envs = batch_client.describe_compute_environments()["computeEnvironments"]
our_envs = [e for e in all_envs if e["computeEnvironmentName"] == compute_name]
our_envs.should.have.length_of(1)
our_envs[0]["serviceRole"].should.equal(iam_arn2)
with pytest.raises(ClientError) as exc:
batch_client.update_compute_environment(
computeEnvironment=compute_name, serviceRole="unknown"
)
err = exc.value.response["Error"]
err["Code"].should.equal("InvalidParameterValue")
@pytest.mark.parametrize("compute_env_type", ["FARGATE", "FARGATE_SPOT"]) @pytest.mark.parametrize("compute_env_type", ["FARGATE", "FARGATE_SPOT"])
@mock_ec2 @mock_ec2
@mock_ecs @mock_ecs