diff --git a/IMPLEMENTATION_COVERAGE.md b/IMPLEMENTATION_COVERAGE.md index b4e062acb..f23117bb6 100644 --- a/IMPLEMENTATION_COVERAGE.md +++ b/IMPLEMENTATION_COVERAGE.md @@ -404,7 +404,7 @@ ## batch
-66% implemented +79% implemented - [X] cancel_job - [X] create_compute_environment @@ -421,12 +421,12 @@ - [ ] describe_scheduling_policies - [X] list_jobs - [ ] list_scheduling_policies -- [ ] list_tags_for_resource +- [X] list_tags_for_resource - [X] register_job_definition - [X] submit_job -- [ ] tag_resource +- [X] tag_resource - [X] terminate_job -- [ ] untag_resource +- [X] untag_resource - [X] update_compute_environment - [X] update_job_queue - [ ] update_scheduling_policy diff --git a/docs/docs/services/batch.rst b/docs/docs/services/batch.rst index 037f54a52..26fa1b98b 100644 --- a/docs/docs/services/batch.rst +++ b/docs/docs/services/batch.rst @@ -55,12 +55,12 @@ batch - [ ] describe_scheduling_policies - [X] list_jobs - [ ] list_scheduling_policies -- [ ] list_tags_for_resource +- [X] list_tags_for_resource - [X] register_job_definition - [X] submit_job -- [ ] tag_resource +- [X] tag_resource - [X] terminate_job -- [ ] untag_resource +- [X] untag_resource - [X] update_compute_environment - [X] update_job_queue diff --git a/moto/batch/models.py b/moto/batch/models.py index b194973b0..60d8ac2d3 100644 --- a/moto/batch/models.py +++ b/moto/batch/models.py @@ -14,6 +14,7 @@ from moto.iam import iam_backends from moto.ec2 import ec2_backends from moto.ecs import ecs_backends from moto.logs import logs_backends +from moto.utilities.tagging_service import TaggingService from .exceptions import InvalidParameterValueException, ClientException, ValidationError from .utils import ( @@ -24,12 +25,12 @@ from .utils import ( ) from moto.ec2.exceptions import InvalidSubnetIdError from moto.ec2.models import INSTANCE_TYPES as EC2_INSTANCE_TYPES +from moto.ec2.models import INSTANCE_FAMILIES as EC2_INSTANCE_FAMILIES from moto.iam.exceptions import IAMNotFoundException from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID from moto.core.utils import unix_time_millis, BackendDict from moto.utilities.docker_utilities import DockerModel from moto import settings -from ..utilities.tagging_service import TaggingService logger = logging.getLogger(__name__) COMPUTE_ENVIRONMENT_NAME_REGEX = re.compile( @@ -114,7 +115,15 @@ class ComputeEnvironment(CloudFormationModel): class JobQueue(CloudFormationModel): def __init__( - self, name, priority, state, environments, env_order_json, region_name + self, + name, + priority, + state, + environments, + env_order_json, + region_name, + backend, + tags=None, ): """ :param name: Job queue name @@ -137,6 +146,10 @@ class JobQueue(CloudFormationModel): self.env_order_json = env_order_json self.arn = make_arn_for_job_queue(DEFAULT_ACCOUNT_ID, name, region_name) self.status = "VALID" + self.backend = backend + + if tags: + backend.tag_resource(self.arn, tags) self.jobs = [] @@ -148,6 +161,7 @@ class JobQueue(CloudFormationModel): "priority": self.priority, "state": self.state, "status": self.status, + "tags": self.backend.list_tags_for_resource(self.arn), } return result @@ -202,29 +216,39 @@ class JobDefinition(CloudFormationModel): revision=0, retry_strategy=0, timeout=None, + backend=None, + platform_capabilities=None, + propagate_tags=None, ): self.name = name - self.retries = retry_strategy + self.retry_strategy = retry_strategy self.type = _type self.revision = revision self._region = region_name self.container_properties = container_properties self.arn = None self.status = "ACTIVE" - self.tagger = TaggingService() self.parameters = parameters or {} self.timeout = timeout + self.backend = backend + self.platform_capabilities = platform_capabilities + self.propagate_tags = propagate_tags + + if "resourceRequirements" not in self.container_properties: + self.container_properties["resourceRequirements"] = [] + if "secrets" not in self.container_properties: + self.container_properties["secrets"] = [] self._validate() self._update_arn() tags = self._format_tags(tags or {}) # Validate the tags before proceeding. - errmsg = self.tagger.validate_tags(tags) + errmsg = self.backend.tagger.validate_tags(tags) if errmsg: raise ValidationError(errmsg) - self.tagger.tag_resource(self.arn, tags) + self.backend.tagger.tag_resource(self.arn, tags) def _format_tags(self, tags): return [{"Key": k, "Value": v} for k, v in tags.items()] @@ -298,20 +322,24 @@ class JobDefinition(CloudFormationModel): if vcpus <= 0: raise ClientException("container vcpus limit must be greater than 0") + def deregister(self): + self.status = "INACTIVE" + def update( self, parameters, _type, container_properties, retry_strategy, tags, timeout ): - if parameters is None: - parameters = self.parameters + if self.status != "INACTIVE": + if parameters is None: + parameters = self.parameters - if _type is None: - _type = self.type + if _type is None: + _type = self.type - if container_properties is None: - container_properties = self.container_properties + if container_properties is None: + container_properties = self.container_properties - if retry_strategy is None: - retry_strategy = self.retries + if retry_strategy is None: + retry_strategy = self.retry_strategy return JobDefinition( self.name, @@ -323,6 +351,9 @@ class JobDefinition(CloudFormationModel): retry_strategy=retry_strategy, tags=tags, timeout=timeout, + backend=self.backend, + platform_capabilities=self.platform_capabilities, + propagate_tags=self.propagate_tags, ) def describe(self): @@ -333,12 +364,13 @@ class JobDefinition(CloudFormationModel): "revision": self.revision, "status": self.status, "type": self.type, - "tags": self.tagger.get_tag_dict_for_resource(self.arn), + "tags": self.backend.tagger.get_tag_dict_for_resource(self.arn), + "platformCapabilities": self.platform_capabilities, + "retryStrategy": self.retry_strategy, + "propagateTags": self.propagate_tags, } if self.container_properties is not None: result["containerProperties"] = self.container_properties - if self.retries is not None and self.retries > 0: - result["retryStrategy"] = {"attempts": self.retries} if self.timeout: result["timeout"] = self.timeout @@ -371,6 +403,8 @@ class JobDefinition(CloudFormationModel): retry_strategy=lowercase_first_key(properties["RetryStrategy"]), container_properties=lowercase_first_key(properties["ContainerProperties"]), timeout=lowercase_first_key(properties.get("timeout", {})), + platform_capabilities=None, + propagate_tags=None, ) arn = res[1] @@ -427,6 +461,9 @@ class Job(threading.Thread, BaseModel, DockerModel): self._log_backend = log_backend self.log_stream_name = None + self.attempts = [] + self.latest_attempt = None + def describe_short(self): result = { "jobId": self.job_id, @@ -469,6 +506,7 @@ class Job(threading.Thread, BaseModel, DockerModel): result["container"]["logStreamName"] = self.log_stream_name if self.timeout: result["timeout"] = self.timeout + result["attempts"] = self.attempts return result def _get_container_property(self, p, default): @@ -556,6 +594,7 @@ class Job(threading.Thread, BaseModel, DockerModel): # TODO setup ecs container instance self.job_started_at = datetime.datetime.now() + self._start_attempt() # add host.docker.internal host on linux to emulate Mac + Windows behavior # for communication with other mock AWS services running on localhost @@ -695,6 +734,27 @@ class Job(threading.Thread, BaseModel, DockerModel): self.job_stopped = True self.job_stopped_at = datetime.datetime.now() self.job_state = "SUCCEEDED" if success else "FAILED" + self._stop_attempt() + + def _start_attempt(self): + self.latest_attempt = { + "container": { + "containerInstanceArn": "TBD", + "logStreamName": self.log_stream_name, + "networkInterfaces": [], + "taskArn": self.job_definition.arn, + } + } + self.latest_attempt["startedAt"] = datetime2int_milliseconds( + self.job_started_at + ) + self.attempts.append(self.latest_attempt) + + def _stop_attempt(self): + self.latest_attempt["container"]["logStreamName"] = self.log_stream_name + self.latest_attempt["stoppedAt"] = datetime2int_milliseconds( + self.job_stopped_at + ) def terminate(self, reason): if not self.stop: @@ -732,6 +792,7 @@ class BatchBackend(BaseBackend): def __init__(self, region_name=None): super().__init__() self.region_name = region_name + self.tagger = TaggingService() self._compute_environments = {} self._job_queues = {} @@ -1054,7 +1115,10 @@ class BatchBackend(BaseBackend): for instance_type in cr["instanceTypes"]: if instance_type == "optimal": pass # Optimal should pick from latest of current gen - elif instance_type not in EC2_INSTANCE_TYPES: + elif ( + instance_type not in EC2_INSTANCE_TYPES + and instance_type not in EC2_INSTANCE_FAMILIES + ): raise InvalidParameterValueException( "Instance type {0} does not exist".format(instance_type) ) @@ -1104,6 +1168,12 @@ class BatchBackend(BaseBackend): if instance_type == "optimal": instance_type = "m4.4xlarge" + if "." not in instance_type: + # instance_type can be a family of instance types (c2, t3, etc) + # We'll just use the first instance_type in this family + instance_type = [ + i for i in EC2_INSTANCE_TYPES.keys() if i.startswith(instance_type) + ][0] instance_vcpus.append( ( EC2_INSTANCE_TYPES[instance_type]["VCpuInfo"]["DefaultVCpus"], @@ -1190,7 +1260,9 @@ class BatchBackend(BaseBackend): return compute_env.name, compute_env.arn - def create_job_queue(self, queue_name, priority, state, compute_env_order): + def create_job_queue( + self, queue_name, priority, state, compute_env_order, tags=None + ): """ Create a job queue @@ -1249,6 +1321,8 @@ class BatchBackend(BaseBackend): env_objects, compute_env_order, self.region_name, + backend=self, + tags=tags, ) self._job_queues[queue.arn] = queue @@ -1343,16 +1417,17 @@ class BatchBackend(BaseBackend): retry_strategy, container_properties, timeout, + platform_capabilities, + propagate_tags, ): if def_name is None: raise ClientException("jobDefinitionName must be provided") job_def = self.get_job_definition_by_name(def_name) - if retry_strategy is not None: - try: - retry_strategy = retry_strategy["attempts"] - except Exception: - raise ClientException("retryStrategy is malformed") + if retry_strategy is not None and "evaluateOnExit" in retry_strategy: + for strat in retry_strategy["evaluateOnExit"]: + if "action" in strat: + strat["action"] = strat["action"].lower() if not tags: tags = {} if job_def is None: @@ -1365,6 +1440,9 @@ class BatchBackend(BaseBackend): region_name=self.region_name, retry_strategy=retry_strategy, timeout=timeout, + backend=self, + platform_capabilities=platform_capabilities, + propagate_tags=propagate_tags, ) else: # Make new jobdef @@ -1383,7 +1461,7 @@ class BatchBackend(BaseBackend): job_def = self.get_job_definition_by_name_revision(name, revision) if job_def is not None: - del self._job_definitions[job_def.arn] + self._job_definitions[job_def.arn].deregister() def describe_job_definitions( self, @@ -1516,5 +1594,15 @@ class BatchBackend(BaseBackend): job.terminate(reason) + def tag_resource(self, resource_arn, tags): + tags = self.tagger.convert_dict_to_tags_input(tags or {}) + self.tagger.tag_resource(resource_arn, tags) + + def list_tags_for_resource(self, resource_arn): + return self.tagger.get_tag_dict_for_resource(resource_arn) + + def untag_resource(self, resource_arn, tag_keys): + self.tagger.untag_resource_using_names(resource_arn, tag_keys) + batch_backends = BackendDict(BatchBackend, "batch") diff --git a/moto/batch/responses.py b/moto/batch/responses.py index 1f86598d7..92b36b3d7 100644 --- a/moto/batch/responses.py +++ b/moto/batch/responses.py @@ -1,6 +1,6 @@ from moto.core.responses import BaseResponse from .models import batch_backends -from urllib.parse import urlsplit +from urllib.parse import urlsplit, unquote from .exceptions import AWSError @@ -114,6 +114,7 @@ class BatchResponse(BaseResponse): queue_name = self._get_param("jobQueueName") priority = self._get_param("priority") state = self._get_param("state") + tags = self._get_param("tags") try: name, arn = self.batch_backend.create_job_queue( @@ -121,6 +122,7 @@ class BatchResponse(BaseResponse): priority=priority, state=state, compute_env_order=compute_env_order, + tags=tags, ) except AWSError as err: return err.response() @@ -180,6 +182,8 @@ class BatchResponse(BaseResponse): retry_strategy = self._get_param("retryStrategy") _type = self._get_param("type") timeout = self._get_param("timeout") + platform_capabilities = self._get_param("platformCapabilities") + propagate_tags = self._get_param("propagateTags") try: name, arn, revision = self.batch_backend.register_job_definition( def_name=def_name, @@ -189,6 +193,8 @@ class BatchResponse(BaseResponse): retry_strategy=retry_strategy, container_properties=container_properties, timeout=timeout, + platform_capabilities=platform_capabilities, + propagate_tags=propagate_tags, ) except AWSError as err: return err.response() @@ -298,3 +304,16 @@ class BatchResponse(BaseResponse): self.batch_backend.cancel_job(job_id, reason) return "" + + def tags(self): + resource_arn = unquote(self.path).split("/v1/tags/")[-1] + tags = self._get_param("tags") + if self.method == "POST": + self.batch_backend.tag_resource(resource_arn, tags) + return "" + if self.method == "GET": + tags = self.batch_backend.list_tags_for_resource(resource_arn) + return json.dumps({"tags": tags}) + if self.method == "DELETE": + tag_keys = self.querystring.get("tagKeys") + self.batch_backend.untag_resource(resource_arn, tag_keys) diff --git a/moto/batch/urls.py b/moto/batch/urls.py index fee031e5a..41f44ce7e 100644 --- a/moto/batch/urls.py +++ b/moto/batch/urls.py @@ -19,4 +19,6 @@ url_paths = { "{0}/v1/listjobs": BatchResponse.dispatch, "{0}/v1/terminatejob": BatchResponse.dispatch, "{0}/v1/canceljob": BatchResponse.dispatch, + "{0}/v1/tags/(?P[^/]+)/(?P[^/]+)/?$": BatchResponse.dispatch, + "{0}/v1/tags/(?P[^/]+)/?$": BatchResponse.dispatch, } diff --git a/moto/ec2/models.py b/moto/ec2/models.py index eac8f95f5..e91b0f6ac 100644 --- a/moto/ec2/models.py +++ b/moto/ec2/models.py @@ -192,6 +192,7 @@ from .utils import ( ) INSTANCE_TYPES = load_resource(__name__, "resources/instance_types.json") +INSTANCE_FAMILIES = list(set([i.split(".")[0] for i in INSTANCE_TYPES.keys()])) root = pathlib.Path(__file__).parent offerings_path = "resources/instance_type_offerings" diff --git a/tests/terraform-tests.success.txt b/tests/terraform-tests.success.txt index fdb32be77..6c0b11bb8 100644 --- a/tests/terraform-tests.success.txt +++ b/tests/terraform-tests.success.txt @@ -9,6 +9,8 @@ TestAccAWSAPIGatewayV2VpcLink TestAccAWSAppsyncApiKey TestAccAWSAppsyncGraphqlApi TestAccAWSAvailabilityZones +TestAccAWSBatchJobDefinition +TestAccAWSBatchJobQueue TestAccAWSBillingServiceAccount TestAccAWSCallerIdentity TestAccAWSCloudTrail diff --git a/tests/test_batch/test_batch_compute_envs.py b/tests/test_batch/test_batch_compute_envs.py index d3a852626..7ef1a2ff8 100644 --- a/tests/test_batch/test_batch_compute_envs.py +++ b/tests/test_batch/test_batch_compute_envs.py @@ -1,6 +1,7 @@ from . import _get_clients, _setup import pytest import sure # noqa # pylint: disable=unused-import +from botocore.exceptions import ClientError from moto import mock_batch, mock_iam, mock_ec2, mock_ecs, settings from uuid import uuid4 @@ -55,6 +56,93 @@ def test_create_managed_compute_environment(): all_clusters.should.contain(our_env["ecsClusterArn"]) +@mock_ec2 +@mock_ecs +@mock_iam +@mock_batch +def test_create_managed_compute_environment_with_instance_family(): + """ + The InstanceType parameter can have multiple values: + instance_type t2.small + instance_family t2 <-- What we're testing here + 'optimal' + unknown value + """ + ec2_client, iam_client, _, _, batch_client = _get_clients() + _, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) + + compute_name = str(uuid4()) + batch_client.create_compute_environment( + computeEnvironmentName=compute_name, + type="MANAGED", + state="ENABLED", + computeResources={ + "type": "EC2", + "minvCpus": 5, + "maxvCpus": 10, + "desiredvCpus": 5, + "instanceTypes": ["t2"], + "imageId": "some_image_id", + "subnets": [subnet_id], + "securityGroupIds": [sg_id], + "ec2KeyPair": "string", + "instanceRole": iam_arn.replace("role", "instance-profile"), + "tags": {"string": "string"}, + "bidPercentage": 123, + "spotIamFleetRole": "string", + }, + serviceRole=iam_arn, + ) + + our_env = batch_client.describe_compute_environments( + computeEnvironments=[compute_name] + )["computeEnvironments"][0] + our_env["computeResources"]["instanceTypes"].should.equal(["t2"]) + + +@mock_ec2 +@mock_ecs +@mock_iam +@mock_batch +def test_create_managed_compute_environment_with_unknown_instance_type(): + """ + The InstanceType parameter can have multiple values: + instance_type t2.small + instance_family t2 + 'optimal' + unknown value <-- What we're testing here + """ + ec2_client, iam_client, _, _, batch_client = _get_clients() + _, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) + + compute_name = str(uuid4()) + with pytest.raises(ClientError) as exc: + batch_client.create_compute_environment( + computeEnvironmentName=compute_name, + type="MANAGED", + state="ENABLED", + computeResources={ + "type": "EC2", + "minvCpus": 5, + "maxvCpus": 10, + "desiredvCpus": 5, + "instanceTypes": ["unknown"], + "imageId": "some_image_id", + "subnets": [subnet_id], + "securityGroupIds": [sg_id], + "ec2KeyPair": "string", + "instanceRole": iam_arn.replace("role", "instance-profile"), + "tags": {"string": "string"}, + "bidPercentage": 123, + "spotIamFleetRole": "string", + }, + serviceRole=iam_arn, + ) + err = exc.value.response["Error"] + err["Code"].should.equal("InvalidParameterValue") + err["Message"].should.equal("Instance type unknown does not exist") + + @mock_ec2 @mock_ecs @mock_iam diff --git a/tests/test_batch/test_batch_jobs.py b/tests/test_batch/test_batch_jobs.py index 7dd976e71..3272cbe39 100644 --- a/tests/test_batch/test_batch_jobs.py +++ b/tests/test_batch/test_batch_jobs.py @@ -127,15 +127,28 @@ def test_submit_job(): # Test that describe_jobs() returns timestamps in milliseconds # github.com/spulec/moto/issues/4364 - resp = batch_client.describe_jobs(jobs=[job_id]) - created_at = resp["jobs"][0]["createdAt"] - started_at = resp["jobs"][0]["startedAt"] - stopped_at = resp["jobs"][0]["stoppedAt"] + job = batch_client.describe_jobs(jobs=[job_id])["jobs"][0] + created_at = job["createdAt"] + started_at = job["startedAt"] + stopped_at = job["stoppedAt"] created_at.should.be.greater_than(start_time_milliseconds) started_at.should.be.greater_than(start_time_milliseconds) stopped_at.should.be.greater_than(start_time_milliseconds) + # Verify we track attempts + job.should.have.key("attempts").length_of(1) + attempt = job["attempts"][0] + attempt.should.have.key("container") + attempt["container"].should.have.key("containerInstanceArn") + attempt["container"].should.have.key("logStreamName").equals( + job["container"]["logStreamName"] + ) + attempt["container"].should.have.key("networkInterfaces") + attempt["container"].should.have.key("taskArn") + attempt.should.have.key("startedAt").equals(started_at) + attempt.should.have.key("stoppedAt").equals(stopped_at) + @mock_logs @mock_ec2 diff --git a/tests/test_batch/test_batch_tags_job_definition.py b/tests/test_batch/test_batch_tags_job_definition.py new file mode 100644 index 000000000..b82d58c67 --- /dev/null +++ b/tests/test_batch/test_batch_tags_job_definition.py @@ -0,0 +1,67 @@ +from . import _get_clients + +import sure # noqa # pylint: disable=unused-import +from moto import mock_batch +from uuid import uuid4 + +container_properties = { + "image": "busybox", + "command": ["sleep", "1"], + "memory": 128, + "vcpus": 1, +} + + +@mock_batch +def test_list_tags_with_job_definition(): + _, _, _, _, batch_client = _get_clients() + + definition_name = str(uuid4())[0:6] + + job_def_arn = batch_client.register_job_definition( + jobDefinitionName=definition_name, + type="container", + containerProperties=container_properties, + tags={"foo": "123", "bar": "456"}, + )["jobDefinitionArn"] + + my_queue = batch_client.list_tags_for_resource(resourceArn=job_def_arn) + my_queue.should.have.key("tags").equals({"foo": "123", "bar": "456"}) + + +@mock_batch +def test_tag_job_definition(): + _, _, _, _, batch_client = _get_clients() + + definition_name = str(uuid4())[0:6] + + job_def_arn = batch_client.register_job_definition( + jobDefinitionName=definition_name, + type="container", + containerProperties=container_properties, + )["jobDefinitionArn"] + + batch_client.tag_resource(resourceArn=job_def_arn, tags={"k1": "v1", "k2": "v2"}) + + my_queue = batch_client.list_tags_for_resource(resourceArn=job_def_arn) + my_queue.should.have.key("tags").equals({"k1": "v1", "k2": "v2"}) + + +@mock_batch +def test_untag_job_queue(): + _, _, _, _, batch_client = _get_clients() + + definition_name = str(uuid4())[0:6] + + job_def_arn = batch_client.register_job_definition( + jobDefinitionName=definition_name, + type="container", + containerProperties=container_properties, + tags={"k1": "v1", "k2": "v2"}, + )["jobDefinitionArn"] + + batch_client.tag_resource(resourceArn=job_def_arn, tags={"k3": "v3"}) + batch_client.untag_resource(resourceArn=job_def_arn, tagKeys=["k2"]) + + my_queue = batch_client.list_tags_for_resource(resourceArn=job_def_arn) + my_queue.should.have.key("tags").equals({"k1": "v1", "k3": "v3"}) diff --git a/tests/test_batch/test_batch_tags_job_queue.py b/tests/test_batch/test_batch_tags_job_queue.py new file mode 100644 index 000000000..d720ec6fd --- /dev/null +++ b/tests/test_batch/test_batch_tags_job_queue.py @@ -0,0 +1,137 @@ +from . import _get_clients, _setup + +import sure # noqa # pylint: disable=unused-import +from moto import mock_batch, mock_iam, mock_ec2, mock_ecs +from uuid import uuid4 + + +@mock_ec2 +@mock_ecs +@mock_iam +@mock_batch +def test_create_job_queue_with_tags(): + ec2_client, iam_client, _, _, batch_client = _get_clients() + _, _, _, iam_arn = _setup(ec2_client, iam_client) + + compute_name = str(uuid4()) + resp = batch_client.create_compute_environment( + computeEnvironmentName=compute_name, + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, + ) + arn = resp["computeEnvironmentArn"] + + jq_name = str(uuid4())[0:6] + resp = batch_client.create_job_queue( + jobQueueName=jq_name, + state="ENABLED", + priority=123, + computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}], + tags={"k1": "v1", "k2": "v2"}, + ) + resp.should.contain("jobQueueArn") + resp.should.contain("jobQueueName") + queue_arn = resp["jobQueueArn"] + + my_queue = batch_client.describe_job_queues(jobQueues=[queue_arn])["jobQueues"][0] + my_queue.should.have.key("tags").equals({"k1": "v1", "k2": "v2"}) + + +@mock_ec2 +@mock_ecs +@mock_iam +@mock_batch +def test_list_tags(): + ec2_client, iam_client, _, _, batch_client = _get_clients() + _, _, _, iam_arn = _setup(ec2_client, iam_client) + + compute_name = str(uuid4()) + resp = batch_client.create_compute_environment( + computeEnvironmentName=compute_name, + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, + ) + arn = resp["computeEnvironmentArn"] + + jq_name = str(uuid4())[0:6] + resp = batch_client.create_job_queue( + jobQueueName=jq_name, + state="ENABLED", + priority=123, + computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}], + tags={"k1": "v1", "k2": "v2"}, + ) + resp.should.contain("jobQueueArn") + resp.should.contain("jobQueueName") + queue_arn = resp["jobQueueArn"] + + my_queue = batch_client.list_tags_for_resource(resourceArn=queue_arn) + my_queue.should.have.key("tags").equals({"k1": "v1", "k2": "v2"}) + + +@mock_ec2 +@mock_ecs +@mock_iam +@mock_batch +def test_tag_job_queue(): + ec2_client, iam_client, _, _, batch_client = _get_clients() + _, _, _, iam_arn = _setup(ec2_client, iam_client) + + compute_name = str(uuid4()) + resp = batch_client.create_compute_environment( + computeEnvironmentName=compute_name, + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, + ) + arn = resp["computeEnvironmentArn"] + + jq_name = str(uuid4())[0:6] + resp = batch_client.create_job_queue( + jobQueueName=jq_name, + state="ENABLED", + priority=123, + computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}], + ) + queue_arn = resp["jobQueueArn"] + + batch_client.tag_resource(resourceArn=queue_arn, tags={"k1": "v1", "k2": "v2"}) + + my_queue = batch_client.list_tags_for_resource(resourceArn=queue_arn) + my_queue.should.have.key("tags").equals({"k1": "v1", "k2": "v2"}) + + +@mock_ec2 +@mock_ecs +@mock_iam +@mock_batch +def test_untag_job_queue(): + ec2_client, iam_client, _, _, batch_client = _get_clients() + _, _, _, iam_arn = _setup(ec2_client, iam_client) + + compute_name = str(uuid4()) + resp = batch_client.create_compute_environment( + computeEnvironmentName=compute_name, + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, + ) + arn = resp["computeEnvironmentArn"] + + jq_name = str(uuid4())[0:6] + resp = batch_client.create_job_queue( + jobQueueName=jq_name, + state="ENABLED", + priority=123, + computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}], + tags={"k1": "v1", "k2": "v2"}, + ) + queue_arn = resp["jobQueueArn"] + + batch_client.tag_resource(resourceArn=queue_arn, tags={"k3": "v3"}) + batch_client.untag_resource(resourceArn=queue_arn, tagKeys=["k2"]) + + my_queue = batch_client.list_tags_for_resource(resourceArn=queue_arn) + my_queue.should.have.key("tags").equals({"k1": "v1", "k3": "v3"}) diff --git a/tests/test_batch/test_batch_task_definition.py b/tests/test_batch/test_batch_task_definition.py index 929054840..2225227bd 100644 --- a/tests/test_batch/test_batch_task_definition.py +++ b/tests/test_batch/test_batch_task_definition.py @@ -1,19 +1,15 @@ -from . import _get_clients, _setup +from . import _get_clients import random import pytest import sure # noqa # pylint: disable=unused-import -from moto import mock_batch, mock_iam, mock_ec2, mock_ecs +from moto import mock_batch from uuid import uuid4 -@mock_ec2 -@mock_ecs -@mock_iam @mock_batch @pytest.mark.parametrize("use_resource_reqs", [True, False]) def test_register_task_definition(use_resource_reqs): - ec2_client, iam_client, _, _, batch_client = _get_clients() - _setup(ec2_client, iam_client) + _, _, _, _, batch_client = _get_clients() resp = register_job_def(batch_client, use_resource_reqs=use_resource_reqs) @@ -26,34 +22,86 @@ def test_register_task_definition(use_resource_reqs): ) -@mock_ec2 -@mock_ecs -@mock_iam @mock_batch -def test_register_task_definition_with_tags(): - ec2_client, iam_client, _, _, batch_client = _get_clients() - _setup(ec2_client, iam_client) +@pytest.mark.parametrize("propagate_tags", [None, True, False]) +def test_register_task_definition_with_tags(propagate_tags): + _, _, _, _, batch_client = _get_clients() - resp = register_job_def_with_tags(batch_client) + job_def_name = str(uuid4())[0:8] + register_job_def_with_tags(batch_client, job_def_name, propagate_tags) - resp.should.contain("jobDefinitionArn") - resp.should.contain("jobDefinitionName") - resp.should.contain("revision") + resp = batch_client.describe_job_definitions(jobDefinitionName=job_def_name) + job_def = resp["jobDefinitions"][0] + if propagate_tags is None: + job_def.shouldnt.have.key("propagateTags") + else: + job_def.should.have.key("propagateTags").equals(propagate_tags) - assert resp["jobDefinitionArn"].endswith( - "{0}:{1}".format(resp["jobDefinitionName"], resp["revision"]) + +@mock_batch +@pytest.mark.parametrize("platform_capability", ["EC2", "FARGATE"]) +def test_register_task_definition_with_platform_capability(platform_capability): + _, _, _, _, batch_client = _get_clients() + + def_name = str(uuid4())[0:6] + batch_client.register_job_definition( + jobDefinitionName=def_name, + type="container", + containerProperties={ + "image": "busybox", + "vcpus": 1, + "memory": 4, + "command": ["exit", "0"], + }, + platformCapabilities=[platform_capability], + ) + + resp = batch_client.describe_job_definitions(jobDefinitionName=def_name) + resp["jobDefinitions"][0].should.have.key("platformCapabilities").equals( + [platform_capability] + ) + + +@mock_batch +def test_register_task_definition_with_retry_strategies(): + _, _, _, _, batch_client = _get_clients() + + def_name = str(uuid4())[0:6] + batch_client.register_job_definition( + jobDefinitionName=def_name, + type="container", + containerProperties={ + "image": "busybox", + "vcpus": 1, + "memory": 4, + "command": ["exit", "0"], + }, + retryStrategy={ + "attempts": 4, + "evaluateOnExit": [ + {"onStatusReason": "osr", "action": "RETRY"}, + {"onStatusReason": "osr2", "action": "Exit"}, + ], + }, + ) + + resp = batch_client.describe_job_definitions(jobDefinitionName=def_name) + resp["jobDefinitions"][0].should.have.key("retryStrategy").equals( + { + "attempts": 4, + "evaluateOnExit": [ + {"onStatusReason": "osr", "action": "retry"}, + {"onStatusReason": "osr2", "action": "exit"}, + ], + } ) -@mock_ec2 -@mock_ecs -@mock_iam @mock_batch @pytest.mark.parametrize("use_resource_reqs", [True, False]) def test_reregister_task_definition(use_resource_reqs): # Reregistering task with the same name bumps the revision number - ec2_client, iam_client, _, _, batch_client = _get_clients() - _setup(ec2_client, iam_client) + _, _, _, _, batch_client = _get_clients() job_def_name = str(uuid4())[0:6] resp1 = register_job_def( @@ -94,14 +142,63 @@ def test_reregister_task_definition(use_resource_reqs): resp4["jobDefinitionArn"].should_not.equal(resp3["jobDefinitionArn"]) -@mock_ec2 -@mock_ecs -@mock_iam +@mock_batch +def test_reregister_task_definition_should_not_reuse_parameters_from_inactive_definition(): + # Reregistering task with the same name bumps the revision number + _, _, _, _, batch_client = _get_clients() + + job_def_name = str(uuid4())[0:6] + # Register job definition with parameters + resp = batch_client.register_job_definition( + jobDefinitionName=job_def_name, + type="container", + containerProperties={ + "image": "busybox", + "vcpus": 1, + "memory": 48, + "command": ["sleep", "0"], + }, + parameters={"param1": "val1"}, + ) + job_def_arn = resp["jobDefinitionArn"] + + definitions = batch_client.describe_job_definitions(jobDefinitionName=job_def_name)[ + "jobDefinitions" + ] + definitions.should.have.length_of(1) + + definitions[0].should.have.key("parameters").equals({"param1": "val1"}) + + # Deactivate the definition + batch_client.deregister_job_definition(jobDefinition=job_def_arn) + + # Second job definition does not provide any parameters + batch_client.register_job_definition( + jobDefinitionName=job_def_name, + type="container", + containerProperties={ + "image": "busybox", + "vcpus": 1, + "memory": 96, + "command": ["sleep", "0"], + }, + ) + + definitions = batch_client.describe_job_definitions(jobDefinitionName=job_def_name)[ + "jobDefinitions" + ] + definitions.should.have.length_of(2) + + # Only the inactive definition should have the parameters + actual = [(d["revision"], d["status"], d.get("parameters")) for d in definitions] + actual.should.contain((1, "INACTIVE", {"param1": "val1"})) + actual.should.contain((2, "ACTIVE", {})) + + @mock_batch @pytest.mark.parametrize("use_resource_reqs", [True, False]) def test_delete_task_definition(use_resource_reqs): - ec2_client, iam_client, _, _, batch_client = _get_clients() - _setup(ec2_client, iam_client) + _, _, _, _, batch_client = _get_clients() resp = register_job_def( batch_client, definition_name=str(uuid4()), use_resource_reqs=use_resource_reqs @@ -111,17 +208,21 @@ def test_delete_task_definition(use_resource_reqs): batch_client.deregister_job_definition(jobDefinition=resp["jobDefinitionArn"]) all_defs = batch_client.describe_job_definitions()["jobDefinitions"] - [jobdef["jobDefinitionName"] for jobdef in all_defs].shouldnt.contain(name) + [jobdef["jobDefinitionName"] for jobdef in all_defs].should.contain(name) + + definitions = batch_client.describe_job_definitions(jobDefinitionName=name)[ + "jobDefinitions" + ] + definitions.should.have.length_of(1) + + definitions[0].should.have.key("revision").equals(1) + definitions[0].should.have.key("status").equals("INACTIVE") -@mock_ec2 -@mock_ecs -@mock_iam @mock_batch @pytest.mark.parametrize("use_resource_reqs", [True, False]) def test_delete_task_definition_by_name(use_resource_reqs): - ec2_client, iam_client, _, _, batch_client = _get_clients() - _setup(ec2_client, iam_client) + _, _, _, _, batch_client = _get_clients() resp = register_job_def( batch_client, definition_name=str(uuid4()), use_resource_reqs=use_resource_reqs @@ -131,17 +232,31 @@ def test_delete_task_definition_by_name(use_resource_reqs): batch_client.deregister_job_definition(jobDefinition=f"{name}:{resp['revision']}") all_defs = batch_client.describe_job_definitions()["jobDefinitions"] - [jobdef["jobDefinitionName"] for jobdef in all_defs].shouldnt.contain(name) + # We should still see our job definition as INACTIVE, as it is kept for 180 days + [jobdef["jobDefinitionName"] for jobdef in all_defs].should.contain(name) + + # Registering the job definition again should up the revision number + register_job_def( + batch_client, definition_name=name, use_resource_reqs=use_resource_reqs + ) + + definitions = batch_client.describe_job_definitions(jobDefinitionName=name)[ + "jobDefinitions" + ] + definitions.should.have.length_of(2) + + revision_status = [ + {"revision": d["revision"], "status": d["status"]} for d in definitions + ] + + revision_status.should.contain({"revision": 1, "status": "INACTIVE"}) + revision_status.should.contain({"revision": 2, "status": "ACTIVE"}) -@mock_ec2 -@mock_ecs -@mock_iam @mock_batch @pytest.mark.parametrize("use_resource_reqs", [True, False]) def test_describe_task_definition(use_resource_reqs): - ec2_client, iam_client, _, _, batch_client = _get_clients() - _setup(ec2_client, iam_client) + _, _, _, _, batch_client = _get_clients() sleep_def_name = f"sleep10_{str(uuid4())[0:6]}" other_name = str(uuid4())[0:6] @@ -183,6 +298,8 @@ def test_describe_task_definition(use_resource_reqs): for job_definition in resp["jobDefinitions"]: job_definition["status"].should.equal("ACTIVE") + job_definition.shouldnt.have.key("platformCapabilities") + job_definition.shouldnt.have.key("retryStrategy") def register_job_def(batch_client, definition_name="sleep10", use_resource_reqs=True): @@ -212,7 +329,10 @@ def register_job_def(batch_client, definition_name="sleep10", use_resource_reqs= ) -def register_job_def_with_tags(batch_client, definition_name="sleep10"): +def register_job_def_with_tags( + batch_client, definition_name="sleep10", propagate_tags=False +): + kwargs = {} if propagate_tags is None else {"propagateTags": propagate_tags} return batch_client.register_job_definition( jobDefinitionName=definition_name, type="container", @@ -223,4 +343,5 @@ def register_job_def_with_tags(batch_client, definition_name="sleep10"): "command": ["sleep", "10"], }, tags={"foo": "123", "bar": "456",}, + **kwargs, )