diff --git a/moto/batch/exceptions.py b/moto/batch/exceptions.py index e598ee7af..cd6031a95 100644 --- a/moto/batch/exceptions.py +++ b/moto/batch/exceptions.py @@ -1,3 +1,32 @@ from __future__ import unicode_literals -from moto.core.exceptions import RESTError +import json + +class AWSError(Exception): + CODE = None + STATUS = 400 + + def __init__(self, message, code=None, status=None): + self.message = message + self.code = code if code is not None else self.CODE + self.status = status if status is not None else self.STATUS + + def response(self): + return json.dumps({'__type': self.code, 'message': self.message}), dict(status=self.status) + + +class InvalidRequestException(AWSError): + CODE = 'InvalidRequestException' + + +class InvalidParameterValueException(AWSError): + CODE = 'InvalidParameterValue' + + +class ValidationError(AWSError): + CODE = 'ValidationError' + + +class InternalFailure(AWSError): + CODE = 'InternalFailure' + STATUS = 500 diff --git a/moto/batch/models.py b/moto/batch/models.py index a54b30c32..c7def48d1 100644 --- a/moto/batch/models.py +++ b/moto/batch/models.py @@ -1,6 +1,28 @@ from __future__ import unicode_literals import boto3 +import re from moto.core import BaseBackend, BaseModel +from moto.iam import iam_backends +from moto.ec2 import ec2_backends + +from .exceptions import InvalidParameterValueException, InternalFailure +from .utils import make_arn_for_compute_env +from moto.ec2.exceptions import InvalidSubnetIdError +from moto.iam.exceptions import IAMNotFoundException + + +DEFAULT_ACCOUNT_ID = 123456789012 +COMPUTE_ENVIRONMENT_NAME_REGEX = re.compile(r'^[A-Za-z0-9_]{1,128}$') + + +class ComputeEnvironment(BaseModel): + def __init__(self, compute_environment_name, _type, state, compute_resources, service_role, region_name): + self.compute_environment_name = compute_environment_name + self.type = _type + self.state = state + self.compute_resources = compute_resources + self.service_role = service_role + self.arn = make_arn_for_compute_env(DEFAULT_ACCOUNT_ID, compute_environment_name, region_name) class BatchBackend(BaseBackend): @@ -8,16 +30,125 @@ class BatchBackend(BaseBackend): super(BatchBackend, self).__init__() self.region_name = region_name + self._compute_environments = {} + + @property + def iam_backend(self): + """ + :return: IAM Backend + :rtype: moto.iam.models.IAMBackend + """ + return iam_backends['global'] + + @property + def ec2_backend(self): + """ + :return: EC2 Backend + :rtype: moto.ec2.models.EC2Backend + """ + return ec2_backends[self.region_name] + def reset(self): region_name = self.region_name self.__dict__ = {} self.__init__(region_name) - def create_compute_environment(self, compute_environment_name, type, state, compute_resources, service_role): - # implement here - return compute_environment_name, compute_environment_arn - # add methods from here + def get_compute_environment(self, arn): + return self._compute_environments.get(arn) + + def get_compute_environment_by_name(self, name): + for comp_env in self._compute_environments.values(): + if comp_env.name == name: + return comp_env + return None + + def create_compute_environment(self, compute_environment_name, _type, state, compute_resources, service_role): + # Validate + if COMPUTE_ENVIRONMENT_NAME_REGEX.match(compute_environment_name) is None: + raise InvalidParameterValueException('Compute environment name does not match ^[A-Za-z0-9_]{1,128}$') + + if self.get_compute_environment_by_name(compute_environment_name) is not None: + raise InvalidParameterValueException('A compute environment already exists with the name {0}'.format(compute_environment_name)) + + # Look for IAM role + try: + self.iam_backend.get_role_by_arn(service_role) + except IAMNotFoundException: + raise InvalidParameterValueException('Could not find IAM role {0}'.format(service_role)) + + if _type not in ('MANAGED', 'UNMANAGED'): + raise InvalidParameterValueException('type {0} must be one of MANAGED | UNMANAGED'.format(service_role)) + + if state is not None and state not in ('ENABLED', 'DISABLED'): + raise InvalidParameterValueException('state {0} must be one of ENABLED | DISABLED'.format(state)) + + if compute_resources is None and _type == 'MANAGED': + raise InvalidParameterValueException('computeResources must be specified when creating a MANAGED environment'.format(state)) + elif compute_resources is not None: + self._validate_compute_resources(compute_resources) + + # By here, all values except SPOT ones have been validated + new_comp_env = ComputeEnvironment( + compute_environment_name, _type, state, + compute_resources, service_role, + region_name=self.region_name + ) + self._compute_environments[new_comp_env.arn] = new_comp_env + + # TODO scale out if MANAGED and we have compute instance types + + return compute_environment_name, new_comp_env.arn + + def _validate_compute_resources(self, cr): + if 'instanceRole' not in cr: + raise InvalidParameterValueException('computeResources must contain instanceRole') + elif self.iam_backend.get_role_by_arn(cr['instanceRole']) is None: + raise InvalidParameterValueException('could not find instanceRole {0}'.format(cr['instanceRole'])) + + # TODO move the not in checks to a loop, or create a json schema validator class + if 'maxvCpus' not in cr: + raise InvalidParameterValueException('computeResources must contain maxVCpus') + if 'minvCpus' not in cr: + raise InvalidParameterValueException('computeResources must contain minVCpus') + if cr['maxvCpus'] < 0: + raise InvalidParameterValueException('maxVCpus must be positive') + if cr['minvCpus'] < 0: + raise InvalidParameterValueException('minVCpus must be positive') + if cr['maxvCpus'] < cr['minvCpus']: + raise InvalidParameterValueException('maxVCpus must be greater than minvCpus') + + # TODO check instance types when that logic exists + if 'instanceTypes' not in cr: + raise InvalidParameterValueException('computeResources must contain instanceTypes') + if len(cr['instanceTypes']) == 0: + raise InvalidParameterValueException('At least 1 instance type must be provided') + + if 'securityGroupIds' not in cr: + raise InvalidParameterValueException('computeResources must contain securityGroupIds') + for sec_id in cr['securityGroupIds']: + if self.ec2_backend.get_security_group_from_id(sec_id) is None: + raise InvalidParameterValueException('security group {0} does not exist'.format(sec_id)) + if len(cr['securityGroupIds']) == 0: + raise InvalidParameterValueException('At least 1 security group must be provided') + + if 'subnets' not in cr: + raise InvalidParameterValueException('computeResources must contain subnets') + for subnet_id in cr['subnets']: + try: + self.ec2_backend.get_subnet(subnet_id) + except InvalidSubnetIdError: + raise InvalidParameterValueException('subnet {0} does not exist'.format(subnet_id)) + if len(cr['subnets']) == 0: + raise InvalidParameterValueException('At least 1 subnet must be provided') + + if 'type' not in cr: + raise InvalidParameterValueException('computeResources must contain type') + if cr['type'] not in ('EC2', 'SPOT'): + raise InvalidParameterValueException('computeResources.type must be either EC2 | SPOT') + + if cr['type'] == 'SPOT': + raise InternalFailure('SPOT NOT SUPPORTED YET') available_regions = boto3.session.Session().get_available_regions("batch") -batch_backends = {region: BatchBackend for region in available_regions} +batch_backends = {region: BatchBackend(region_name=region) for region in available_regions} diff --git a/moto/batch/responses.py b/moto/batch/responses.py index d91af8a77..0368906f0 100644 --- a/moto/batch/responses.py +++ b/moto/batch/responses.py @@ -1,14 +1,58 @@ from __future__ import unicode_literals from moto.core.responses import BaseResponse from .models import batch_backends +from six.moves.urllib.parse import urlsplit + +from .exceptions import AWSError + +import json class BatchResponse(BaseResponse): + def _error(self, code, message): + return json.dumps({'__type': code, 'message': message}), dict(status=400) + @property def batch_backend(self): return batch_backends[self.region] - # add methods from here + @property + def json(self): + if not hasattr(self, '_json'): + self._json = json.loads(self.body) + return self._json + def _get_param(self, param_name, if_none=None): + val = self.json.get(param_name) + if val is not None: + return val + return if_none -# add teampltes from here + def _get_action(self): + # Return element after the /v1/* + return urlsplit(self.uri).path.lstrip('/').split('/')[1] + + # CreateComputeEnvironment + def createcomputeenvironment(self): + compute_env_name = self._get_param('computeEnvironmentName') + compute_resource = self._get_param('computeResources') + service_role = self._get_param('serviceRole') + state = self._get_param('state') + _type = self._get_param('type') + + try: + name, arn = self.batch_backend.create_compute_environment( + compute_environment_name=compute_env_name, + _type=_type, state=state, + compute_resources=compute_resource, + service_role=service_role + ) + except AWSError as err: + return err.response() + + result = { + 'computeEnvironmentArn': arn, + 'computeEnvironmentName': name + } + + return json.dumps(result) diff --git a/moto/batch/urls.py b/moto/batch/urls.py index 27cd9fc51..93f8a2f23 100644 --- a/moto/batch/urls.py +++ b/moto/batch/urls.py @@ -6,5 +6,5 @@ url_bases = [ ] url_paths = { - '{0}/$': BatchResponse.dispatch, + '{0}/v1/createcomputeenvironment': BatchResponse.dispatch, } diff --git a/moto/batch/utils.py b/moto/batch/utils.py index 33e474d61..d323a9bf7 100644 --- a/moto/batch/utils.py +++ b/moto/batch/utils.py @@ -1,6 +1,5 @@ from __future__ import unicode_literals -import uuid -def make_arn_for_topic(account_id, name, region_name): - return "arn:aws:sns:{0}:{1}:{2}".format(region_name, account_id, name) +def make_arn_for_compute_env(account_id, name, region_name): + return "arn:aws:batch:{0}:{1}:compute-environment/{2}".format(region_name, account_id, name) diff --git a/moto/iam/models.py b/moto/iam/models.py index a7e584284..34efb1a22 100644 --- a/moto/iam/models.py +++ b/moto/iam/models.py @@ -534,6 +534,12 @@ class IAMBackend(BaseBackend): return role raise IAMNotFoundException("Role {0} not found".format(role_name)) + def get_role_by_arn(self, arn): + for role in self.get_roles(): + if role.arn == arn: + return role + raise IAMNotFoundException("Role {0} not found".format(arn)) + def delete_role(self, role_name): for role in self.get_roles(): if role.name == role_name: diff --git a/tests/test_batch/test_batch.py b/tests/test_batch/test_batch.py index eafd32eae..3aae48e1e 100644 --- a/tests/test_batch/test_batch.py +++ b/tests/test_batch/test_batch.py @@ -2,10 +2,87 @@ from __future__ import unicode_literals import boto3 import sure # noqa -from moto import mock_batch +from moto import mock_batch, mock_iam, mock_ec2 +DEFAULT_REGION = 'eu-central-1' + + +def _get_clients(): + return boto3.client('ec2', region_name=DEFAULT_REGION), \ + boto3.client('iam', region_name=DEFAULT_REGION), \ + boto3.client('batch', region_name=DEFAULT_REGION) + + +def _setup(ec2_client, iam_client): + """ + Do prerequisite setup + :return: VPC ID, Subnet ID, Security group ID, IAM Role ARN + :rtype: tuple + """ + resp = ec2_client.create_vpc(CidrBlock='172.30.0.0/24') + vpc_id = resp['Vpc']['VpcId'] + resp = ec2_client.create_subnet( + AvailabilityZone='eu-central-1a', + CidrBlock='172.30.0.0/25', + VpcId=vpc_id + ) + subnet_id = resp['Subnet']['SubnetId'] + resp = ec2_client.create_security_group( + Description='test_sg_desc', + GroupName='test_sg', + VpcId=vpc_id + ) + sg_id = resp['GroupId'] + + resp = iam_client.create_role( + RoleName='TestRole', + AssumeRolePolicyDocument='some_policy' + ) + iam_arn = resp['Role']['Arn'] + + return vpc_id, subnet_id, sg_id, iam_arn + + +# Yes, yes it talks to all the things +@mock_ec2 +@mock_iam @mock_batch -def test_list(): - # do test - pass \ No newline at end of file +def test_create_compute_environment(): + ec2_client, iam_client, batch_client = _get_clients() + vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) + + compute_name = 'test_compute_env' + resp = batch_client.create_compute_environment( + computeEnvironmentName=compute_name, + type='MANAGED', + state='ENABLED', + computeResources={ + 'type': 'EC2', + 'minvCpus': 123, + 'maxvCpus': 123, + 'desiredvCpus': 123, + 'instanceTypes': [ + 'some_instance_type', + ], + 'imageId': 'some_image_id', + 'subnets': [ + subnet_id, + ], + 'securityGroupIds': [ + sg_id, + ], + 'ec2KeyPair': 'string', + 'instanceRole': iam_arn, + 'tags': { + 'string': 'string' + }, + 'bidPercentage': 123, + 'spotIamFleetRole': 'string' + }, + serviceRole=iam_arn + ) + resp.should.contain('computeEnvironmentArn') + resp['computeEnvironmentName'].should.equal(compute_name) + +# TODO create 1000s of tests to test complex option combinations of create environment