diff --git a/moto/batch/models.py b/moto/batch/models.py index 6eb02c39c..2129320e7 100644 --- a/moto/batch/models.py +++ b/moto/batch/models.py @@ -10,7 +10,7 @@ from moto.ec2 import ec2_backends from moto.ecs import ecs_backends from .exceptions import InvalidParameterValueException, InternalFailure, ClientException -from .utils import make_arn_for_compute_env, make_arn_for_job_queue +from .utils import make_arn_for_compute_env, make_arn_for_job_queue, make_arn_for_task_def from moto.ec2.exceptions import InvalidSubnetIdError from moto.ec2.models import INSTANCE_TYPES as EC2_INSTANCE_TYPES from moto.iam.exceptions import IAMNotFoundException @@ -78,6 +78,52 @@ class JobQueue(BaseModel): return result +class JobDefinition(BaseModel): + def __init__(self, name, parameters, _type, container_properties, region_name, revision=0, retry_strategy=0): + self.name = name + self.retries = retry_strategy + self.type = _type + self.revision = revision + self._region = region_name + self.container_properties = container_properties + self.arn = None + + self.parameters = {} + if parameters is not None: + if not isinstance(parameters, dict): + raise ClientException('parameters must be a string to string map') + self.parameters = parameters + + if _type not in ('container',): + raise ClientException('type must be one of "container"') + + self._update_arn() + + # For future use when containers arnt the only thing in batch + if _type != 'container': + raise NotImplementedError() + + self._validate_container_properties() + + def _update_arn(self): + self.revision += 1 + self.arn = make_arn_for_task_def(DEFAULT_ACCOUNT_ID, self.name, self.revision, self._region) + + def _validate_container_properties(self): + if 'image' not in self.container_properties: + raise ClientException('containerProperties must contain image') + + if 'memory' not in self.container_properties: + raise ClientException('containerProperties must contain memory') + if self.container_properties['memory'] < 4: + raise ClientException('container memory limit must be greater than 4') + + if 'vcpus' not in self.container_properties: + raise ClientException('containerProperties must contain vcpus') + if self.container_properties['vcpus'] < 1: + raise ClientException('container vcpus limit must be greater than 0') + + class BatchBackend(BaseBackend): def __init__(self, region_name=None): super(BatchBackend, self).__init__() @@ -85,6 +131,7 @@ class BatchBackend(BaseBackend): self._compute_environments = {} self._job_queues = {} + self._job_definitions = {} @property def iam_backend(self): @@ -161,6 +208,29 @@ class BatchBackend(BaseBackend): env = self.get_job_queue_by_name(identifier) return env + def get_job_definition_by_arn(self, arn): + return self._job_definitions.get(arn) + + def get_job_definition_by_name(self, name): + for comp_env in self._job_definitions.values(): + if comp_env.name == name: + return comp_env + return None + + def get_job_definition(self, identifier): + """ + Get job queue by name or ARN + :param identifier: Name or ARN + :type identifier: str + + :return: Job Queue or None + :rtype: JobQueue or None + """ + env = self.get_job_definition_by_arn(identifier) + if env is None: + env = self.get_job_definition_by_name(identifier) + return env + def describe_compute_environments(self, environments=None, max_results=None, next_token=None): envs = set() if environments is not None: @@ -512,6 +582,24 @@ class BatchBackend(BaseBackend): if job_queue is not None: del self._job_queues[job_queue.arn] + def register_job_definition(self, def_name, parameters, _type, retry_strategy, container_properties): + if def_name is None: + raise ClientException('jobDefinitionName must be provided') + + if self.get_job_definition_by_name(def_name) is not None: + raise ClientException('A job definition called {0} already exists'.format(def_name)) + + if retry_strategy is not None: + try: + retry_strategy = retry_strategy['attempts'] + except Exception: + raise ClientException('retryStrategy is malformed') + + job_def = JobDefinition(def_name, parameters, _type, container_properties, region_name=self.region_name, retry_strategy=retry_strategy) + self._job_definitions[job_def.arn] = job_def + + return def_name, job_def.arn, job_def.revision + available_regions = boto3.session.Session().get_available_regions("batch") batch_backends = {region: BatchBackend(region_name=region) for region in available_regions} diff --git a/moto/batch/responses.py b/moto/batch/responses.py index 7c870382e..dec740221 100644 --- a/moto/batch/responses.py +++ b/moto/batch/responses.py @@ -178,3 +178,30 @@ class BatchResponse(BaseResponse): self.batch_backend.delete_job_queue(queue_name) return '' + + # RegisterJobDefinition + def registerjobdefinition(self): + container_properties = self._get_param('containerProperties') + def_name = self._get_param('jobDefinitionName') + parameters = self._get_param('parameters') + retry_strategy = self._get_param('retryStrategy') + _type = self._get_param('type') + + try: + name, arn, revision = self.batch_backend.register_job_definition( + def_name=def_name, + parameters=parameters, + _type=_type, + retry_strategy=retry_strategy, + container_properties=container_properties + ) + except AWSError as err: + return err.response() + + result = { + 'jobDefinitionArn': arn, + 'jobDefinitionName': name, + 'revision': revision + } + + return json.dumps(result) diff --git a/moto/batch/urls.py b/moto/batch/urls.py index bc186bd29..cd5ccb00c 100644 --- a/moto/batch/urls.py +++ b/moto/batch/urls.py @@ -13,5 +13,6 @@ url_paths = { '{0}/v1/createjobqueue': BatchResponse.dispatch, '{0}/v1/describejobqueues': BatchResponse.dispatch, '{0}/v1/updatejobqueue': BatchResponse.dispatch, - '{0}/v1/deletejobqueue': BatchResponse.dispatch + '{0}/v1/deletejobqueue': BatchResponse.dispatch, + '{0}/v1/registerjobdefinition': BatchResponse.dispatch } diff --git a/moto/batch/utils.py b/moto/batch/utils.py index 68c6a3581..6cdd381f7 100644 --- a/moto/batch/utils.py +++ b/moto/batch/utils.py @@ -7,3 +7,7 @@ def make_arn_for_compute_env(account_id, name, region_name): def make_arn_for_job_queue(account_id, name, region_name): return "arn:aws:batch:{0}:{1}:job-queue/{2}".format(region_name, account_id, name) + + +def make_arn_for_task_def(account_id, name, revision, region_name): + return "arn:aws:batch:{0}:{1}:job-definition/{2}:{3}".format(region_name, account_id, name, revision) diff --git a/tests/test_batch/test_batch.py b/tests/test_batch/test_batch.py index e7c4cf629..6eba45d27 100644 --- a/tests/test_batch/test_batch.py +++ b/tests/test_batch/test_batch.py @@ -421,3 +421,29 @@ def test_update_job_queue(): resp = batch_client.describe_job_queues() resp.should.contain('jobQueues') len(resp['jobQueues']).should.equal(0) + + +@mock_ec2 +@mock_ecs +@mock_iam +@mock_batch +def test_register_task_definition(): + ec2_client, iam_client, ecs_client, batch_client = _get_clients() + vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) + + resp = batch_client.register_job_definition( + jobDefinitionName='sleep10', + type='container', + containerProperties={ + 'image': 'busybox', + 'vcpus': 1, + 'memory': 128, + 'command': ['sleep', '10'] + } + ) + + resp.should.contain('jobDefinitionArn') + resp.should.contain('jobDefinitionName') + resp.should.contain('revision') + + assert resp['jobDefinitionArn'].endswith('{0}:{1}'.format(resp['jobDefinitionName'], resp['revision']))