From 6f3b250fc72a9614a6647741929c8ec57a452689 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Sun, 23 Oct 2022 13:26:55 +0000 Subject: [PATCH] TechDebt: MyPy Batch (#5592) --- moto/batch/models.py | 452 ++++++++++-------- moto/batch/responses.py | 63 +-- moto/batch/utils.py | 15 +- moto/batch_simple/models.py | 29 +- moto/batch_simple/responses.py | 4 +- moto/core/base_backend.py | 2 +- moto/ec2/models/instances.py | 4 +- moto/ec2/models/security_groups.py | 4 +- moto/ec2/models/subnets.py | 2 +- moto/ecs/models.py | 8 +- moto/iam/models.py | 4 +- moto/logs/models.py | 15 +- .../moto_api/_internal/managed_state_model.py | 5 +- moto/moto_api/_internal/state_manager.py | 7 +- moto/utilities/tagging_service.py | 2 +- setup.cfg | 2 +- tests/test_batch/test_batch_compute_envs.py | 36 ++ 17 files changed, 365 insertions(+), 289 deletions(-) diff --git a/moto/batch/models.py b/moto/batch/models.py index 5965d9f3a..cbe76af80 100644 --- a/moto/batch/models.py +++ b/moto/batch/models.py @@ -1,6 +1,7 @@ import re from itertools import cycle from time import sleep +from typing import Any, Dict, List, Tuple, Optional, Set import datetime import time import logging @@ -9,10 +10,11 @@ import dateutil.parser from sys import platform from moto.core import BaseBackend, BaseModel, CloudFormationModel -from moto.iam import iam_backends -from moto.ec2 import ec2_backends -from moto.ecs import ecs_backends -from moto.logs import logs_backends +from moto.iam.models import iam_backends, IAMBackend +from moto.ec2.models import ec2_backends, EC2Backend +from moto.ec2.models.instances import Instance +from moto.ecs.models import ecs_backends, EC2ContainerServiceBackend +from moto.logs.models import logs_backends, LogsBackend from moto.utilities.tagging_service import TaggingService from .exceptions import InvalidParameterValueException, ClientException, ValidationError @@ -39,7 +41,7 @@ COMPUTE_ENVIRONMENT_NAME_REGEX = re.compile( ) -def datetime2int_milliseconds(date): +def datetime2int_milliseconds(date: datetime.datetime) -> int: """ AWS returns timestamps in milliseconds We don't use milliseconds timestamps internally, @@ -48,20 +50,20 @@ def datetime2int_milliseconds(date): return int(date.timestamp() * 1000) -def datetime2int(date): +def datetime2int(date: datetime.datetime) -> int: return int(time.mktime(date.timetuple())) class ComputeEnvironment(CloudFormationModel): def __init__( self, - compute_environment_name, - _type, - state, - compute_resources, - service_role, - account_id, - region_name, + compute_environment_name: str, + _type: str, + state: str, + compute_resources: Dict[str, Any], + service_role: str, + account_id: str, + region_name: str, ): self.name = compute_environment_name self.env_type = _type @@ -72,34 +74,39 @@ class ComputeEnvironment(CloudFormationModel): account_id, compute_environment_name, region_name ) - self.instances = [] - self.ecs_arn = None - self.ecs_name = None + self.instances: List[Instance] = [] + self.ecs_arn = "" + self.ecs_name = "" - def add_instance(self, instance): + def add_instance(self, instance: Instance) -> None: self.instances.append(instance) - def set_ecs(self, arn, name): + def set_ecs(self, arn: str, name: str) -> None: self.ecs_arn = arn self.ecs_name = name @property - def physical_resource_id(self): + def physical_resource_id(self) -> str: return self.arn @staticmethod - def cloudformation_name_type(): + def cloudformation_name_type() -> str: return "ComputeEnvironmentName" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-batch-computeenvironment.html return "AWS::Batch::ComputeEnvironment" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Dict[str, Any], + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "ComputeEnvironment": backend = batch_backends[account_id][region_name] properties = cloudformation_json["Properties"] @@ -118,13 +125,13 @@ class ComputeEnvironment(CloudFormationModel): class JobQueue(CloudFormationModel): def __init__( self, - name, - priority, - state, - environments, - env_order_json, - backend, - tags=None, + name: str, + priority: str, + state: str, + environments: List[ComputeEnvironment], + env_order_json: List[Dict[str, Any]], + backend: "BatchBackend", + tags: Optional[Dict[str, str]] = None, ): """ :param name: Job queue name @@ -150,10 +157,10 @@ class JobQueue(CloudFormationModel): if tags: backend.tag_resource(self.arn, tags) - self.jobs = [] + self.jobs: List[Job] = [] - def describe(self): - result = { + def describe(self) -> Dict[str, Any]: + return { "computeEnvironmentOrder": self.env_order_json, "jobQueueArn": self.arn, "jobQueueName": self.name, @@ -163,25 +170,28 @@ class JobQueue(CloudFormationModel): "tags": self.backend.list_tags_for_resource(self.arn), } - return result - @property - def physical_resource_id(self): + def physical_resource_id(self) -> str: return self.arn @staticmethod - def cloudformation_name_type(): + def cloudformation_name_type() -> str: return "JobQueueName" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-batch-jobqueue.html return "AWS::Batch::JobQueue" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Dict[str, Any], + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "JobQueue": backend = batch_backends[account_id][region_name] properties = cloudformation_json["Properties"] @@ -206,25 +216,24 @@ class JobQueue(CloudFormationModel): class JobDefinition(CloudFormationModel): def __init__( self, - name, - parameters, - _type, - container_properties, - tags=None, - revision=0, - retry_strategy=0, - timeout=None, - backend=None, - platform_capabilities=None, - propagate_tags=None, + name: str, + parameters: Optional[Dict[str, Any]], + _type: str, + container_properties: Dict[str, Any], + tags: Dict[str, str], + retry_strategy: Dict[str, str], + timeout: Dict[str, int], + backend: "BatchBackend", + platform_capabilities: List[str], + propagate_tags: bool, + revision: Optional[int] = 0, ): self.name = name self.retry_strategy = retry_strategy self.type = _type - self.revision = revision + self.revision = revision or 0 self._region = backend.region_name self.container_properties = container_properties - self.arn = None self.status = "ACTIVE" self.parameters = parameters or {} self.timeout = timeout @@ -238,26 +247,23 @@ class JobDefinition(CloudFormationModel): self.container_properties["secrets"] = [] self._validate() - self._update_arn() - - tags = self._format_tags(tags or {}) - # Validate the tags before proceeding. - errmsg = self.backend.tagger.validate_tags(tags) - if errmsg: - raise ValidationError(errmsg) - - self.backend.tagger.tag_resource(self.arn, tags) - - def _format_tags(self, tags): - return [{"Key": k, "Value": v} for k, v in tags.items()] - - def _update_arn(self): self.revision += 1 self.arn = make_arn_for_task_def( self.backend.account_id, self.name, self.revision, self._region ) - def _get_resource_requirement(self, req_type, default=None): + tag_list = self._format_tags(tags or {}) + # Validate the tags before proceeding. + errmsg = self.backend.tagger.validate_tags(tag_list) + if errmsg: + raise ValidationError(errmsg) + + self.backend.tagger.tag_resource(self.arn, tag_list) + + def _format_tags(self, tags: Dict[str, str]) -> List[Dict[str, str]]: + return [{"Key": k, "Value": v} for k, v in tags.items()] + + def _get_resource_requirement(self, req_type: str, default: Any = None) -> Any: """ Get resource requirement from container properties. @@ -297,7 +303,7 @@ class JobDefinition(CloudFormationModel): else: return self.container_properties.get(req_type, default) - def _validate(self): + def _validate(self) -> None: # For future use when containers arnt the only thing in batch if self.type not in ("container",): raise ClientException('type must be one of "container"') @@ -320,12 +326,18 @@ class JobDefinition(CloudFormationModel): if vcpus <= 0: raise ClientException("container vcpus limit must be greater than 0") - def deregister(self): + def deregister(self) -> None: self.status = "INACTIVE" def update( - self, parameters, _type, container_properties, retry_strategy, tags, timeout - ): + self, + parameters: Optional[Dict[str, Any]], + _type: str, + container_properties: Dict[str, Any], + retry_strategy: Dict[str, Any], + tags: Dict[str, str], + timeout: Dict[str, int], + ) -> "JobDefinition": if self.status != "INACTIVE": if parameters is None: parameters = self.parameters @@ -353,7 +365,7 @@ class JobDefinition(CloudFormationModel): propagate_tags=self.propagate_tags, ) - def describe(self): + def describe(self) -> Dict[str, Any]: result = { "jobDefinitionArn": self.arn, "jobDefinitionName": self.name, @@ -374,22 +386,27 @@ class JobDefinition(CloudFormationModel): return result @property - def physical_resource_id(self): + def physical_resource_id(self) -> str: return self.arn @staticmethod - def cloudformation_name_type(): + def cloudformation_name_type() -> str: return "JobDefinitionName" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-batch-jobdefinition.html return "AWS::Batch::JobDefinition" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Dict[str, Any], + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "JobDefinition": backend = batch_backends[account_id][region_name] properties = cloudformation_json["Properties"] res = backend.register_job_definition( @@ -411,25 +428,15 @@ class JobDefinition(CloudFormationModel): class Job(threading.Thread, BaseModel, DockerModel, ManagedState): def __init__( self, - name, - job_def, - job_queue, - log_backend, - container_overrides, - depends_on, - all_jobs, - timeout, + name: str, + job_def: JobDefinition, + job_queue: JobQueue, + log_backend: LogsBackend, + container_overrides: Optional[Dict[str, Any]], + depends_on: Optional[List[Dict[str, str]]], + all_jobs: Dict[str, "Job"], + timeout: Optional[Dict[str, int]], ): - """ - Docker Job - - :param name: Job Name - :param job_def: Job definition - :type: job_def: JobDefinition - :param job_queue: Job Queue - :param log_backend: Log backend - :type log_backend: moto.logs.models.LogsBackend - """ threading.Thread.__init__(self) DockerModel.__init__(self) ManagedState.__init__( @@ -446,32 +453,32 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState): self.job_name = name self.job_id = str(mock_random.uuid4()) self.job_definition = job_def - self.container_overrides = container_overrides or {} + self.container_overrides: Dict[str, Any] = container_overrides or {} self.job_queue = job_queue self.job_queue.jobs.append(self) self.job_created_at = datetime.datetime.now() self.job_started_at = datetime.datetime(1970, 1, 1) self.job_stopped_at = datetime.datetime(1970, 1, 1) self.job_stopped = False - self.job_stopped_reason = None + self.job_stopped_reason: Optional[str] = None self.depends_on = depends_on self.timeout = timeout self.all_jobs = all_jobs self.stop = False - self.exit_code = None + self.exit_code: Optional[int] = None self.daemon = True self.name = "MOTO-BATCH-" + self.job_id self._log_backend = log_backend - self.log_stream_name = None + self.log_stream_name: Optional[str] = None - self.container_details = {} - self.attempts = [] - self.latest_attempt = None + self.container_details: Dict[str, Any] = {} + self.attempts: List[Dict[str, Any]] = [] + self.latest_attempt: Optional[Dict[str, Any]] = None - def describe_short(self): + def describe_short(self) -> Dict[str, Any]: result = { "jobId": self.job_id, "jobName": self.job_name, @@ -489,10 +496,10 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState): result["container"] = {"exitCode": self.exit_code} return result - def describe(self): + def describe(self) -> Dict[str, Any]: result = self.describe_short() result["jobQueue"] = self.job_queue.arn - result["dependsOn"] = self.depends_on if self.depends_on else [] + result["dependsOn"] = self.depends_on or [] result["container"] = self.container_details if self.job_stopped: result["stoppedAt"] = datetime2int_milliseconds(self.job_stopped_at) @@ -501,7 +508,7 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState): result["attempts"] = self.attempts return result - def _get_container_property(self, p, default): + def _get_container_property(self, p: str, default: Any) -> Any: if p == "environment": job_env = self.container_overrides.get(p, default) jd_env = self.job_definition.container_properties.get(p, default) @@ -526,14 +533,14 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState): p, self.job_definition.container_properties.get(p, default) ) - def _get_attempt_duration(self): + def _get_attempt_duration(self) -> Optional[int]: if self.timeout: return self.timeout["attemptDurationSeconds"] if self.job_definition.timeout: return self.job_definition.timeout["attemptDurationSeconds"] return None - def run(self): + def run(self) -> None: """ Run the container. @@ -672,7 +679,7 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState): if self._get_attempt_duration(): attempt_duration = self._get_attempt_duration() max_time = self.job_started_at + datetime.timedelta( - seconds=attempt_duration + seconds=attempt_duration # type: ignore[arg-type] ) while container.status == "running" and not self.stop: @@ -740,8 +747,9 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState): self.container_details["logStreamName"] = self.log_stream_name result = container.wait() or {} - self.exit_code = result.get("StatusCode", 0) - job_failed = self.stop or self.exit_code > 0 + exit_code = result.get("StatusCode", 0) + self.exit_code = exit_code + job_failed = self.stop or exit_code > 0 self._mark_stopped(success=not job_failed) except Exception as err: @@ -762,7 +770,7 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState): ) self._mark_stopped(success=False) - def _mark_stopped(self, success=True): + def _mark_stopped(self, success: bool = True) -> None: # Ensure that job_stopped/job_stopped_at-attributes are set first # The describe-method needs them immediately when status is set self.job_stopped = True @@ -770,7 +778,7 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState): self.status = "SUCCEEDED" if success else "FAILED" self._stop_attempt() - def _start_attempt(self): + def _start_attempt(self) -> None: self.latest_attempt = { "container": { "containerInstanceArn": "TBD", @@ -784,21 +792,21 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState): ) self.attempts.append(self.latest_attempt) - def _stop_attempt(self): + def _stop_attempt(self) -> None: if self.latest_attempt: self.latest_attempt["container"]["logStreamName"] = self.log_stream_name self.latest_attempt["stoppedAt"] = datetime2int_milliseconds( self.job_stopped_at ) - def terminate(self, reason): + def terminate(self, reason: str) -> None: if not self.stop: self.stop = True self.job_stopped_reason = reason - def _wait_for_dependencies(self): - dependent_ids = [dependency["jobId"] for dependency in self.depends_on] - successful_dependencies = set() + def _wait_for_dependencies(self) -> bool: + dependent_ids = [dependency["jobId"] for dependency in self.depends_on] # type: ignore[union-attr] + successful_dependencies: Set[str] = set() while len(successful_dependencies) != len(dependent_ids): for dependent_id in dependent_ids: if dependent_id in self.all_jobs: @@ -832,21 +840,21 @@ class BatchBackend(BaseBackend): With this decorator, jobs are simply marked as 'Success' without trying to execute any commands/scripts. """ - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) self.tagger = TaggingService() - self._compute_environments = {} - self._job_queues = {} - self._job_definitions = {} - self._jobs = {} + self._compute_environments: Dict[str, ComputeEnvironment] = {} + self._job_queues: Dict[str, JobQueue] = {} + self._job_definitions: Dict[str, JobDefinition] = {} + self._jobs: Dict[str, Job] = {} state_manager.register_default_transition( "batch::job", transition={"progression": "manual", "times": 1} ) @property - def iam_backend(self): + def iam_backend(self) -> IAMBackend: """ :return: IAM Backend :rtype: moto.iam.models.IAMBackend @@ -854,7 +862,7 @@ class BatchBackend(BaseBackend): return iam_backends[self.account_id]["global"] @property - def ec2_backend(self): + def ec2_backend(self) -> EC2Backend: """ :return: EC2 Backend :rtype: moto.ec2.models.EC2Backend @@ -862,7 +870,7 @@ class BatchBackend(BaseBackend): return ec2_backends[self.account_id][self.region_name] @property - def ecs_backend(self): + def ecs_backend(self) -> EC2ContainerServiceBackend: """ :return: ECS Backend :rtype: moto.ecs.models.EC2ContainerServiceBackend @@ -870,14 +878,14 @@ class BatchBackend(BaseBackend): return ecs_backends[self.account_id][self.region_name] @property - def logs_backend(self): + def logs_backend(self) -> LogsBackend: """ :return: ECS Backend :rtype: moto.logs.models.LogsBackend """ return logs_backends[self.account_id][self.region_name] - def reset(self): + def reset(self) -> None: for job in self._jobs.values(): if job.status not in ("FAILED", "SUCCEEDED"): job.stop = True @@ -886,16 +894,18 @@ class BatchBackend(BaseBackend): super().reset() - def get_compute_environment_by_arn(self, arn): + def get_compute_environment_by_arn(self, arn: str) -> Optional[ComputeEnvironment]: return self._compute_environments.get(arn) - def get_compute_environment_by_name(self, name): + def get_compute_environment_by_name( + self, name: str + ) -> Optional[ComputeEnvironment]: for comp_env in self._compute_environments.values(): if comp_env.name == name: return comp_env return None - def get_compute_environment(self, identifier): + def get_compute_environment(self, identifier: str) -> Optional[ComputeEnvironment]: """ Get compute environment by name or ARN :param identifier: Name or ARN @@ -904,21 +914,20 @@ class BatchBackend(BaseBackend): :return: Compute Environment or None :rtype: ComputeEnvironment or None """ - env = self.get_compute_environment_by_arn(identifier) - if env is None: - env = self.get_compute_environment_by_name(identifier) - return env + return self.get_compute_environment_by_arn( + identifier + ) or self.get_compute_environment_by_name(identifier) - def get_job_queue_by_arn(self, arn): + def get_job_queue_by_arn(self, arn: str) -> Optional[JobQueue]: return self._job_queues.get(arn) - def get_job_queue_by_name(self, name): + def get_job_queue_by_name(self, name: str) -> Optional[JobQueue]: for comp_env in self._job_queues.values(): if comp_env.name == name: return comp_env return None - def get_job_queue(self, identifier): + def get_job_queue(self, identifier: str) -> Optional[JobQueue]: """ Get job queue by name or ARN :param identifier: Name or ARN @@ -927,15 +936,14 @@ class BatchBackend(BaseBackend): :return: Job Queue or None :rtype: JobQueue or None """ - env = self.get_job_queue_by_arn(identifier) - if env is None: - env = self.get_job_queue_by_name(identifier) - return env + return self.get_job_queue_by_arn(identifier) or self.get_job_queue_by_name( + identifier + ) - def get_job_definition_by_arn(self, arn): + def get_job_definition_by_arn(self, arn: str) -> Optional[JobDefinition]: return self._job_definitions.get(arn) - def get_job_definition_by_name(self, name): + def get_job_definition_by_name(self, name: str) -> Optional[JobDefinition]: latest_revision = -1 latest_job = None for job_def in self._job_definitions.values(): @@ -944,13 +952,15 @@ class BatchBackend(BaseBackend): latest_revision = job_def.revision return latest_job - def get_job_definition_by_name_revision(self, name, revision): + def get_job_definition_by_name_revision( + self, name: str, revision: str + ) -> Optional[JobDefinition]: for job_def in self._job_definitions.values(): if job_def.name == name and job_def.revision == int(revision): return job_def return None - def get_job_definition(self, identifier): + def get_job_definition(self, identifier: str) -> Optional[JobDefinition]: """ Get job definitions by name or ARN :param identifier: Name or ARN @@ -969,7 +979,7 @@ class BatchBackend(BaseBackend): job_def = self.get_job_definition_by_name(identifier) return job_def - def get_job_definitions(self, identifier): + def get_job_definitions(self, identifier: str) -> List[JobDefinition]: """ Get job definitions by name or ARN :param identifier: Name or ARN @@ -989,21 +999,15 @@ class BatchBackend(BaseBackend): return result - def get_job_by_id(self, identifier): - """ - Get job by id - :param identifier: Job ID - :type identifier: str - - :return: Job - :rtype: Job - """ + def get_job_by_id(self, identifier: str) -> Optional[Job]: try: return self._jobs[identifier] except KeyError: return None - def describe_compute_environments(self, environments=None): + def describe_compute_environments( + self, environments: Optional[List[str]] = None + ) -> List[Dict[str, Any]]: """ Pagination is not yet implemented """ @@ -1017,7 +1021,7 @@ class BatchBackend(BaseBackend): if len(envs) > 0 and arn not in envs and environment.name not in envs: continue - json_part = { + json_part: Dict[str, Any] = { "computeEnvironmentArn": arn, "computeEnvironmentName": environment.name, "ecsClusterArn": environment.ecs_arn, @@ -1035,8 +1039,13 @@ class BatchBackend(BaseBackend): return result def create_compute_environment( - self, compute_environment_name, _type, state, compute_resources, service_role - ): + self, + compute_environment_name: str, + _type: str, + state: str, + compute_resources: Dict[str, Any], + service_role: str, + ) -> Tuple[str, str]: # Validate if COMPUTE_ENVIRONMENT_NAME_REGEX.match(compute_environment_name) is None: raise InvalidParameterValueException( @@ -1127,7 +1136,7 @@ class BatchBackend(BaseBackend): return compute_environment_name, new_comp_env.arn - def _validate_compute_resources(self, cr): + def _validate_compute_resources(self, cr: Dict[str, Any]) -> None: """ Checks contents of sub dictionary for managed clusters @@ -1195,16 +1204,15 @@ class BatchBackend(BaseBackend): ) @staticmethod - def find_min_instances_to_meet_vcpus(instance_types, target): + def find_min_instances_to_meet_vcpus( + instance_types: List[str], target: float + ) -> List[str]: """ Finds the minimum needed instances to meed a vcpu target :param instance_types: Instance types, like ['t2.medium', 't2.small'] - :type instance_types: list of str :param target: VCPU target - :type target: float :return: List of instance types - :rtype: list of str """ # vcpus = [ (vcpus, instance_type), (vcpus, instance_type), ... ] instance_vcpus = [] @@ -1253,7 +1261,7 @@ class BatchBackend(BaseBackend): return instances - def delete_compute_environment(self, compute_environment_name): + def delete_compute_environment(self, compute_environment_name: str) -> None: if compute_environment_name is None: raise InvalidParameterValueException("Missing computeEnvironment parameter") @@ -1273,8 +1281,12 @@ class BatchBackend(BaseBackend): self.ec2_backend.terminate_instances(instance_ids) def update_compute_environment( - self, compute_environment_name, state, compute_resources, service_role - ): + self, + compute_environment_name: str, + state: Optional[str], + compute_resources: Optional[Any], + service_role: Optional[str], + ) -> Tuple[str, str]: # Validate compute_env = self.get_compute_environment(compute_environment_name) if compute_env is None: @@ -1283,13 +1295,13 @@ class BatchBackend(BaseBackend): # Look for IAM role if service_role is not None: try: - role = self.iam_backend.get_role_by_arn(service_role) + self.iam_backend.get_role_by_arn(service_role) except IAMNotFoundException: raise InvalidParameterValueException( "Could not find IAM role {0}".format(service_role) ) - compute_env.service_role = role + compute_env.service_role = service_role if state is not None: if state not in ("ENABLED", "DISABLED"): @@ -1307,8 +1319,13 @@ class BatchBackend(BaseBackend): return compute_env.name, compute_env.arn def create_job_queue( - self, queue_name, priority, state, compute_env_order, tags=None - ): + self, + queue_name: str, + priority: str, + state: str, + compute_env_order: List[Dict[str, str]], + tags: Optional[Dict[str, str]] = None, + ) -> Tuple[str, str]: for variable, var_name in ( (queue_name, "jobQueueName"), (priority, "priority"), @@ -1359,7 +1376,9 @@ class BatchBackend(BaseBackend): return queue_name, queue.arn - def describe_job_queues(self, job_queues=None): + def describe_job_queues( + self, job_queues: Optional[List[str]] = None + ) -> List[Dict[str, Any]]: """ Pagination is not yet implemented """ @@ -1377,7 +1396,13 @@ class BatchBackend(BaseBackend): return result - def update_job_queue(self, queue_name, priority, state, compute_env_order): + def update_job_queue( + self, + queue_name: str, + priority: Optional[str], + state: Optional[str], + compute_env_order: Optional[List[Dict[str, Any]]], + ) -> Tuple[str, str]: if queue_name is None: raise ClientException("jobQueueName must be provided") @@ -1422,7 +1447,7 @@ class BatchBackend(BaseBackend): return queue_name, job_queue.arn - def delete_job_queue(self, queue_name): + def delete_job_queue(self, queue_name: str) -> None: job_queue = self.get_job_queue(queue_name) if job_queue is not None: @@ -1430,16 +1455,16 @@ class BatchBackend(BaseBackend): def register_job_definition( self, - def_name, - parameters, - _type, - tags, - retry_strategy, - container_properties, - timeout, - platform_capabilities, - propagate_tags, - ): + def_name: str, + parameters: Dict[str, Any], + _type: str, + tags: Dict[str, str], + retry_strategy: Dict[str, Any], + container_properties: Dict[str, Any], + timeout: Dict[str, int], + platform_capabilities: List[str], + propagate_tags: bool, + ) -> Tuple[str, str, int]: if def_name is None: raise ClientException("jobDefinitionName must be provided") @@ -1473,7 +1498,7 @@ class BatchBackend(BaseBackend): return def_name, job_def.arn, job_def.revision - def deregister_job_definition(self, def_name): + def deregister_job_definition(self, def_name: str) -> None: job_def = self.get_job_definition_by_arn(def_name) if job_def is None and ":" in def_name: name, revision = def_name.split(":", 1) @@ -1483,8 +1508,11 @@ class BatchBackend(BaseBackend): self._job_definitions[job_def.arn].deregister() def describe_job_definitions( - self, job_def_name=None, job_def_list=None, status=None - ): + self, + job_def_name: Optional[str] = None, + job_def_list: Optional[List[str]] = None, + status: Optional[str] = None, + ) -> List[JobDefinition]: """ Pagination is not yet implemented """ @@ -1496,8 +1524,8 @@ class BatchBackend(BaseBackend): if job_def is not None: jobs.extend(job_def) elif job_def_list is not None: - for job in job_def_list: - job_def = self.get_job_definitions(job) + for jdn in job_def_list: + job_def = self.get_job_definitions(jdn) if job_def is not None: jobs.extend(job_def) else: @@ -1512,13 +1540,13 @@ class BatchBackend(BaseBackend): def submit_job( self, - job_name, - job_def_id, - job_queue, - depends_on=None, - container_overrides=None, - timeout=None, - ): + job_name: str, + job_def_id: str, + job_queue: str, + depends_on: Optional[List[Dict[str, str]]] = None, + container_overrides: Optional[Dict[str, Any]] = None, + timeout: Optional[Dict[str, int]] = None, + ) -> Tuple[str, str]: """ Parameters RetryStrategy and Parameters are not yet implemented. """ @@ -1550,7 +1578,7 @@ class BatchBackend(BaseBackend): return job_name, job.job_id - def describe_jobs(self, jobs): + def describe_jobs(self, jobs: Optional[List[str]]) -> List[Dict[str, Any]]: job_filter = set() if jobs is not None: job_filter = set(jobs) @@ -1564,13 +1592,15 @@ class BatchBackend(BaseBackend): return result - def list_jobs(self, job_queue, job_status=None): + def list_jobs( + self, job_queue_name: str, job_status: Optional[str] = None + ) -> List[Job]: """ Pagination is not yet implemented """ jobs = [] - job_queue = self.get_job_queue(job_queue) + job_queue = self.get_job_queue(job_queue_name) if job_queue is None: raise ClientException("Job queue {0} does not exist".format(job_queue)) @@ -1595,7 +1625,7 @@ class BatchBackend(BaseBackend): return jobs - def cancel_job(self, job_id, reason): + def cancel_job(self, job_id: str, reason: str) -> None: if job_id == "": raise ClientException( "'reason' is a required field (cannot be an empty string)" @@ -1611,7 +1641,7 @@ class BatchBackend(BaseBackend): job.terminate(reason) # No-Op for jobs that have already started - user has to explicitly terminate those - def terminate_job(self, job_id, reason): + def terminate_job(self, job_id: str, reason: str) -> None: if job_id == "": raise ClientException( "'reason' is a required field (cannot be a empty string)" @@ -1625,14 +1655,14 @@ class BatchBackend(BaseBackend): if job is not None: job.terminate(reason) - def tag_resource(self, resource_arn, tags): - tags = self.tagger.convert_dict_to_tags_input(tags or {}) - self.tagger.tag_resource(resource_arn, tags) + def tag_resource(self, resource_arn: str, tags: Dict[str, str]) -> None: + tag_list = self.tagger.convert_dict_to_tags_input(tags or {}) + self.tagger.tag_resource(resource_arn, tag_list) - def list_tags_for_resource(self, resource_arn): + def list_tags_for_resource(self, resource_arn: str) -> Dict[str, str]: return self.tagger.get_tag_dict_for_resource(resource_arn) - def untag_resource(self, resource_arn, tag_keys): + def untag_resource(self, resource_arn: str, tag_keys: List[str]) -> None: self.tagger.untag_resource_using_names(resource_arn, tag_keys) diff --git a/moto/batch/responses.py b/moto/batch/responses.py index fa86458b4..be4375dd5 100644 --- a/moto/batch/responses.py +++ b/moto/batch/responses.py @@ -1,45 +1,28 @@ from moto.core.responses import BaseResponse -from .models import batch_backends +from .models import batch_backends, BatchBackend from urllib.parse import urlsplit, unquote import json class BatchResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="batch") - def _error(self, code, message): - return json.dumps({"__type": code, "message": message}), dict(status=400) - @property - def batch_backend(self): + def batch_backend(self) -> BatchBackend: """ :return: Batch Backend :rtype: moto.batch.models.BatchBackend """ return batch_backends[self.current_account][self.region] - @property - def json(self): - if self.body is None or self.body == "": - self._json = {} - elif not hasattr(self, "_json"): - self._json = json.loads(self.body) - return self._json - - def _get_param(self, param_name, if_none=None): - val = self.json.get(param_name) - if val is not None: - return val - return if_none - - def _get_action(self): + def _get_action(self) -> str: # Return element after the /v1/* return urlsplit(self.uri).path.lstrip("/").split("/")[1] # CreateComputeEnvironment - def createcomputeenvironment(self): + def createcomputeenvironment(self) -> str: compute_env_name = self._get_param("computeEnvironmentName") compute_resource = self._get_param("computeResources") service_role = self._get_param("serviceRole") @@ -59,7 +42,7 @@ class BatchResponse(BaseResponse): return json.dumps(result) # DescribeComputeEnvironments - def describecomputeenvironments(self): + def describecomputeenvironments(self) -> str: compute_environments = self._get_param("computeEnvironments") envs = self.batch_backend.describe_compute_environments(compute_environments) @@ -68,7 +51,7 @@ class BatchResponse(BaseResponse): return json.dumps(result) # DeleteComputeEnvironment - def deletecomputeenvironment(self): + def deletecomputeenvironment(self) -> str: compute_environment = self._get_param("computeEnvironment") self.batch_backend.delete_compute_environment(compute_environment) @@ -76,7 +59,7 @@ class BatchResponse(BaseResponse): return "" # UpdateComputeEnvironment - def updatecomputeenvironment(self): + def updatecomputeenvironment(self) -> str: compute_env_name = self._get_param("computeEnvironment") compute_resource = self._get_param("computeResources") service_role = self._get_param("serviceRole") @@ -94,7 +77,7 @@ class BatchResponse(BaseResponse): return json.dumps(result) # CreateJobQueue - def createjobqueue(self): + def createjobqueue(self) -> str: compute_env_order = self._get_param("computeEnvironmentOrder") queue_name = self._get_param("jobQueueName") priority = self._get_param("priority") @@ -114,7 +97,7 @@ class BatchResponse(BaseResponse): return json.dumps(result) # DescribeJobQueues - def describejobqueues(self): + def describejobqueues(self) -> str: job_queues = self._get_param("jobQueues") queues = self.batch_backend.describe_job_queues(job_queues) @@ -123,7 +106,7 @@ class BatchResponse(BaseResponse): return json.dumps(result) # UpdateJobQueue - def updatejobqueue(self): + def updatejobqueue(self) -> str: compute_env_order = self._get_param("computeEnvironmentOrder") queue_name = self._get_param("jobQueue") priority = self._get_param("priority") @@ -141,7 +124,7 @@ class BatchResponse(BaseResponse): return json.dumps(result) # DeleteJobQueue - def deletejobqueue(self): + def deletejobqueue(self) -> str: queue_name = self._get_param("jobQueue") self.batch_backend.delete_job_queue(queue_name) @@ -149,7 +132,7 @@ class BatchResponse(BaseResponse): return "" # RegisterJobDefinition - def registerjobdefinition(self): + def registerjobdefinition(self) -> str: container_properties = self._get_param("containerProperties") def_name = self._get_param("jobDefinitionName") parameters = self._get_param("parameters") @@ -180,7 +163,7 @@ class BatchResponse(BaseResponse): return json.dumps(result) # DeregisterJobDefinition - def deregisterjobdefinition(self): + def deregisterjobdefinition(self) -> str: queue_name = self._get_param("jobDefinition") self.batch_backend.deregister_job_definition(queue_name) @@ -188,7 +171,7 @@ class BatchResponse(BaseResponse): return "" # DescribeJobDefinitions - def describejobdefinitions(self): + def describejobdefinitions(self) -> str: job_def_name = self._get_param("jobDefinitionName") job_def_list = self._get_param("jobDefinitions") status = self._get_param("status") @@ -201,7 +184,7 @@ class BatchResponse(BaseResponse): return json.dumps(result) # SubmitJob - def submitjob(self): + def submitjob(self) -> str: container_overrides = self._get_param("containerOverrides") depends_on = self._get_param("dependsOn") job_def = self._get_param("jobDefinition") @@ -223,13 +206,13 @@ class BatchResponse(BaseResponse): return json.dumps(result) # DescribeJobs - def describejobs(self): + def describejobs(self) -> str: jobs = self._get_param("jobs") return json.dumps({"jobs": self.batch_backend.describe_jobs(jobs)}) # ListJobs - def listjobs(self): + def listjobs(self) -> str: job_queue = self._get_param("jobQueue") job_status = self._get_param("jobStatus") @@ -239,7 +222,7 @@ class BatchResponse(BaseResponse): return json.dumps(result) # TerminateJob - def terminatejob(self): + def terminatejob(self) -> str: job_id = self._get_param("jobId") reason = self._get_param("reason") @@ -248,22 +231,22 @@ class BatchResponse(BaseResponse): return "" # CancelJob - def canceljob(self): + def canceljob(self) -> str: job_id = self._get_param("jobId") reason = self._get_param("reason") self.batch_backend.cancel_job(job_id, reason) return "" - def tags(self): + def tags(self) -> str: resource_arn = unquote(self.path).split("/v1/tags/")[-1] tags = self._get_param("tags") if self.method == "POST": self.batch_backend.tag_resource(resource_arn, tags) - return "" if self.method == "GET": tags = self.batch_backend.list_tags_for_resource(resource_arn) return json.dumps({"tags": tags}) if self.method == "DELETE": tag_keys = self.querystring.get("tagKeys") - self.batch_backend.untag_resource(resource_arn, tag_keys) + self.batch_backend.untag_resource(resource_arn, tag_keys) # type: ignore[arg-type] + return "" diff --git a/moto/batch/utils.py b/moto/batch/utils.py index 820e3d0d8..e9ca56c65 100644 --- a/moto/batch/utils.py +++ b/moto/batch/utils.py @@ -1,21 +1,26 @@ -def make_arn_for_compute_env(account_id, name, region_name): +from typing import Any, Dict + + +def make_arn_for_compute_env(account_id: str, name: str, region_name: str) -> str: return "arn:aws:batch:{0}:{1}:compute-environment/{2}".format( region_name, account_id, name ) -def make_arn_for_job_queue(account_id, name, region_name): +def make_arn_for_job_queue(account_id: str, name: str, region_name: str) -> str: return "arn:aws:batch:{0}:{1}:job-queue/{2}".format(region_name, account_id, name) -def make_arn_for_task_def(account_id, name, revision, region_name): +def make_arn_for_task_def( + account_id: str, name: str, revision: int, region_name: str +) -> str: return "arn:aws:batch:{0}:{1}:job-definition/{2}:{3}".format( region_name, account_id, name, revision ) -def lowercase_first_key(some_dict): - new_dict = {} +def lowercase_first_key(some_dict: Dict[str, Any]) -> Dict[str, Any]: + new_dict: Dict[str, Any] = {} for key, value in some_dict.items(): new_key = key[0].lower() + key[1:] try: diff --git a/moto/batch_simple/models.py b/moto/batch_simple/models.py index e5afc85aa..1c752a7ad 100644 --- a/moto/batch_simple/models.py +++ b/moto/batch_simple/models.py @@ -1,7 +1,14 @@ -from ..batch.models import batch_backends, BaseBackend, Job, ClientException +from ..batch.models import ( + batch_backends, + BaseBackend, + Job, + ClientException, + BatchBackend, +) from ..core.utils import BackendDict import datetime +from typing import Any, Dict, List, Tuple, Optional class BatchSimpleBackend(BaseBackend): @@ -11,10 +18,10 @@ class BatchSimpleBackend(BaseBackend): """ @property - def backend(self): + def backend(self) -> BatchBackend: return batch_backends[self.account_id][self.region_name] - def __getattribute__(self, name): + def __getattribute__(self, name: str) -> Any: """ Magic part that makes this class behave like a wrapper around the regular batch_backend We intercept calls to `submit_job` and replace this with our own (non-Docker) implementation @@ -32,7 +39,7 @@ class BatchSimpleBackend(BaseBackend): return object.__getattribute__(self, name) if name in ["submit_job"]: - def newfunc(*args, **kwargs): + def newfunc(*args: Any, **kwargs: Any) -> Any: attr = object.__getattribute__(self, name) return attr(*args, **kwargs) @@ -42,13 +49,13 @@ class BatchSimpleBackend(BaseBackend): def submit_job( self, - job_name, - job_def_id, - job_queue, - depends_on=None, - container_overrides=None, - timeout=None, - ): + job_name: str, + job_def_id: str, + job_queue: str, + depends_on: Optional[List[Dict[str, str]]] = None, + container_overrides: Optional[Dict[str, Any]] = None, + timeout: Optional[Dict[str, int]] = None, + ) -> Tuple[str, str]: # Look for job definition job_def = self.get_job_definition(job_def_id) if job_def is None: diff --git a/moto/batch_simple/responses.py b/moto/batch_simple/responses.py index f882b77c3..b4e7b7adb 100644 --- a/moto/batch_simple/responses.py +++ b/moto/batch_simple/responses.py @@ -1,10 +1,10 @@ from ..batch.responses import BatchResponse -from .models import batch_simple_backends +from .models import batch_simple_backends, BatchBackend class BatchSimpleResponse(BatchResponse): @property - def batch_backend(self): + def batch_backend(self) -> BatchBackend: """ :return: Batch Backend :rtype: moto.batch.models.BatchBackend diff --git a/moto/core/base_backend.py b/moto/core/base_backend.py index 4e9fdc283..04f31d028 100644 --- a/moto/core/base_backend.py +++ b/moto/core/base_backend.py @@ -32,7 +32,7 @@ class BaseBackend: for model in models.values(): model.instances = [] - def reset(self): + def reset(self) -> None: region_name = self.region_name account_id = self.account_id self._reset_model_refs() diff --git a/moto/ec2/models/instances.py b/moto/ec2/models/instances.py index c159055cf..d6afd75a5 100644 --- a/moto/ec2/models/instances.py +++ b/moto/ec2/models/instances.py @@ -2,7 +2,7 @@ import copy import warnings from collections import OrderedDict from datetime import datetime -from typing import Any, List, Tuple +from typing import Any, List, Tuple, Optional from moto import settings from moto.core import CloudFormationModel @@ -597,7 +597,7 @@ class InstanceBackend: self, image_id: str, count: int, - user_data: str, + user_data: Optional[str], security_group_names: List[str], **kwargs: Any ) -> Reservation: diff --git a/moto/ec2/models/security_groups.py b/moto/ec2/models/security_groups.py index 352067fd5..75c9c1e49 100644 --- a/moto/ec2/models/security_groups.py +++ b/moto/ec2/models/security_groups.py @@ -2,7 +2,7 @@ import copy import itertools import json from collections import defaultdict - +from typing import Optional from moto.core import CloudFormationModel from moto.core.utils import aws_api_matches from ..exceptions import ( @@ -543,7 +543,7 @@ class SecurityGroupBackend: return self._delete_security_group(None, group.id) raise InvalidSecurityGroupNotFoundError(name) - def get_security_group_from_id(self, group_id): + def get_security_group_from_id(self, group_id: str) -> Optional[SecurityGroup]: # 2 levels of chaining necessary since it's a complex structure all_groups = itertools.chain.from_iterable( [x.copy().values() for x in self.groups.copy().values()] diff --git a/moto/ec2/models/subnets.py b/moto/ec2/models/subnets.py index f25f7db75..82f7b72be 100644 --- a/moto/ec2/models/subnets.py +++ b/moto/ec2/models/subnets.py @@ -229,7 +229,7 @@ class SubnetBackend: # maps availability zone to dict of (subnet_id, subnet) self.subnets = defaultdict(dict) - def get_subnet(self, subnet_id): + def get_subnet(self, subnet_id: str) -> Subnet: for subnets in self.subnets.values(): if subnet_id in subnets: return subnets[subnet_id] diff --git a/moto/ecs/models.py b/moto/ecs/models.py index bd6fcf975..dfd6d921d 100644 --- a/moto/ecs/models.py +++ b/moto/ecs/models.py @@ -1,7 +1,7 @@ import re from copy import copy from datetime import datetime - +from typing import Any import pytz from moto import settings @@ -850,7 +850,9 @@ class EC2ContainerServiceBackend(BaseBackend): else: raise Exception("{0} is not a task_definition".format(task_definition_name)) - def create_cluster(self, cluster_name, tags=None, cluster_settings=None): + def create_cluster( + self, cluster_name: str, tags: Any = None, cluster_settings: Any = None + ) -> Cluster: """ The following parameters are not yet implemented: configuration, capacityProviders, defaultCapacityProviderStrategy """ @@ -926,7 +928,7 @@ class EC2ContainerServiceBackend(BaseBackend): return list_clusters, failures - def delete_cluster(self, cluster_str): + def delete_cluster(self, cluster_str: str) -> Cluster: cluster = self._get_cluster(cluster_str) return self.clusters.pop(cluster.name) diff --git a/moto/iam/models.py b/moto/iam/models.py index 27fd89ec0..2f9d3cbf6 100644 --- a/moto/iam/models.py +++ b/moto/iam/models.py @@ -10,7 +10,7 @@ from cryptography import x509 from cryptography.hazmat.backends import default_backend from jinja2 import Template -from typing import Mapping +from typing import List, Mapping from urllib import parse from moto.core.exceptions import RESTError from moto.core import DEFAULT_ACCOUNT_ID, BaseBackend, BaseModel, CloudFormationModel @@ -2188,7 +2188,7 @@ class IAMBackend(BaseBackend): raise IAMNotFoundException("Instance profile {0} not found".format(profile_arn)) - def get_instance_profiles(self): + def get_instance_profiles(self) -> List[InstanceProfile]: return self.instance_profiles.values() def get_instance_profiles_for_role(self, role_name): diff --git a/moto/logs/models.py b/moto/logs/models.py index e5ce29d60..7518d3a99 100644 --- a/moto/logs/models.py +++ b/moto/logs/models.py @@ -1,5 +1,5 @@ from datetime import datetime, timedelta - +from typing import Any, Dict, List, Tuple, Optional from moto.core import BaseBackend, BaseModel from moto.core import CloudFormationModel from moto.core.utils import unix_time_millis, BackendDict @@ -627,7 +627,9 @@ class LogsBackend(BaseBackend): ) return self.groups[log_group_name] - def ensure_log_group(self, log_group_name, tags): + def ensure_log_group( + self, log_group_name: str, tags: Optional[Dict[str, str]] + ) -> None: if log_group_name in self.groups: return self.groups[log_group_name] = LogGroup( @@ -653,7 +655,7 @@ class LogsBackend(BaseBackend): return groups - def create_log_stream(self, log_group_name, log_stream_name): + def create_log_stream(self, log_group_name: str, log_stream_name: str) -> LogStream: if log_group_name not in self.groups: raise ResourceNotFoundException() log_group = self.groups[log_group_name] @@ -702,7 +704,12 @@ class LogsBackend(BaseBackend): order_by=order_by, ) - def put_log_events(self, log_group_name, log_stream_name, log_events): + def put_log_events( + self, + log_group_name: str, + log_stream_name: str, + log_events: List[Dict[str, Any]], + ) -> Tuple[str, Dict[str, Any]]: """ The SequenceToken-parameter is not yet implemented """ diff --git a/moto/moto_api/_internal/managed_state_model.py b/moto/moto_api/_internal/managed_state_model.py index fe43f7afe..1cf71ab08 100644 --- a/moto/moto_api/_internal/managed_state_model.py +++ b/moto/moto_api/_internal/managed_state_model.py @@ -1,5 +1,6 @@ from datetime import datetime, timedelta from moto.moto_api import state_manager +from typing import List, Tuple class ManagedState: @@ -7,7 +8,7 @@ class ManagedState: Subclass this class to configure state-transitions """ - def __init__(self, model_name, transitions): + def __init__(self, model_name: str, transitions: List[Tuple[str, str]]): # Indicate the possible transitions for this model # Example: [(initializing,queued), (queued, starting), (starting, ready)] self._transitions = transitions @@ -23,7 +24,7 @@ class ManagedState: # Name of this model. This will be used in the API self.model_name = model_name - def advance(self): + def advance(self) -> None: self._tick += 1 @property diff --git a/moto/moto_api/_internal/state_manager.py b/moto/moto_api/_internal/state_manager.py index 5c5361a08..f31a0bbe9 100644 --- a/moto/moto_api/_internal/state_manager.py +++ b/moto/moto_api/_internal/state_manager.py @@ -1,3 +1,6 @@ +from typing import Any, Dict + + DEFAULT_TRANSITION = {"progression": "immediate"} @@ -6,7 +9,9 @@ class StateManager: self._default_transitions = dict() self._transitions = dict() - def register_default_transition(self, model_name, transition): + def register_default_transition( + self, model_name: str, transition: Dict[str, Any] + ) -> None: """ Register the default transition for a specific model. This should only be called by Moto backends - use the `set_transition` method to override this default transition in your own tests. diff --git a/moto/utilities/tagging_service.py b/moto/utilities/tagging_service.py index a5b56aa60..fa5e128dc 100644 --- a/moto/utilities/tagging_service.py +++ b/moto/utilities/tagging_service.py @@ -106,7 +106,7 @@ class TaggingService: result[tag[self.key_name]] = None return result - def validate_tags(self, tags, limit=0): + def validate_tags(self, tags: List[Dict[str, str]], limit: int = 0) -> str: """Returns error message if tags in 'tags' list of dicts are invalid. The validation does not include a check for duplicate keys. diff --git a/setup.cfg b/setup.cfg index e42e449bf..96cf4a622 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,7 +18,7 @@ disable = W,C,R,E enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import [mypy] -files= moto/a*,moto/budgets +files= moto/a*,moto/b* show_column_numbers=True show_error_codes = True disable_error_code=abstract diff --git a/tests/test_batch/test_batch_compute_envs.py b/tests/test_batch/test_batch_compute_envs.py index 7ef1a2ff8..aef85092e 100644 --- a/tests/test_batch/test_batch_compute_envs.py +++ b/tests/test_batch/test_batch_compute_envs.py @@ -323,6 +323,42 @@ def test_update_unmanaged_compute_environment_state(): our_envs[0]["state"].should.equal("DISABLED") +@mock_ec2 +@mock_ecs +@mock_iam +@mock_batch +def test_update_iam_role(): + ec2_client, iam_client, _, _, batch_client = _get_clients() + _, _, _, iam_arn = _setup(ec2_client, iam_client) + iam_arn2 = iam_client.create_role(RoleName="r", AssumeRolePolicyDocument="sp")[ + "Role" + ]["Arn"] + + compute_name = str(uuid4()) + batch_client.create_compute_environment( + computeEnvironmentName=compute_name, + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, + ) + + batch_client.update_compute_environment( + computeEnvironment=compute_name, serviceRole=iam_arn2 + ) + + all_envs = batch_client.describe_compute_environments()["computeEnvironments"] + our_envs = [e for e in all_envs if e["computeEnvironmentName"] == compute_name] + our_envs.should.have.length_of(1) + our_envs[0]["serviceRole"].should.equal(iam_arn2) + + with pytest.raises(ClientError) as exc: + batch_client.update_compute_environment( + computeEnvironment=compute_name, serviceRole="unknown" + ) + err = exc.value.response["Error"] + err["Code"].should.equal("InvalidParameterValue") + + @pytest.mark.parametrize("compute_env_type", ["FARGATE", "FARGATE_SPOT"]) @mock_ec2 @mock_ecs