diff --git a/moto/batch/models.py b/moto/batch/models.py index 05137296b..be8fca9d1 100644 --- a/moto/batch/models.py +++ b/moto/batch/models.py @@ -1,13 +1,22 @@ from __future__ import unicode_literals import boto3 import re +import requests.adapters from itertools import cycle import six +import datetime +import time import uuid +import logging +import docker +import functools +import threading +import dateutil.parser from moto.core import BaseBackend, BaseModel from moto.iam import iam_backends from moto.ec2 import ec2_backends from moto.ecs import ecs_backends +from moto.logs import logs_backends from .exceptions import InvalidParameterValueException, InternalFailure, ClientException from .utils import make_arn_for_compute_env, make_arn_for_job_queue, make_arn_for_task_def @@ -16,10 +25,16 @@ from moto.ec2.models import INSTANCE_TYPES as EC2_INSTANCE_TYPES from moto.iam.exceptions import IAMNotFoundException +_orig_adapter_send = requests.adapters.HTTPAdapter.send +logger = logging.getLogger(__name__) DEFAULT_ACCOUNT_ID = 123456789012 COMPUTE_ENVIRONMENT_NAME_REGEX = re.compile(r'^[A-Za-z0-9_]{1,128}$') +def datetime2int(date): + return int(time.mktime(date.timetuple())) + + class ComputeEnvironment(BaseModel): def __init__(self, compute_environment_name, _type, state, compute_resources, service_role, region_name): self.name = compute_environment_name @@ -65,6 +80,8 @@ class JobQueue(BaseModel): self.arn = make_arn_for_job_queue(DEFAULT_ACCOUNT_ID, name, region_name) self.status = 'VALID' + self.jobs = [] + def describe(self): result = { 'computeEnvironmentOrder': self.env_order_json, @@ -156,6 +173,162 @@ class JobDefinition(BaseModel): return result +class Job(threading.Thread, BaseModel): + def __init__(self, name, job_def, job_queue, log_backend): + """ + Docker Job + + :param name: Job Name + :param job_def: Job definition + :type: job_def: JobDefinition + :param job_queue: Job Queue + :param log_backend: Log backend + :type log_backend: moto.logs.models.LogsBackend + """ + threading.Thread.__init__(self) + + self.job_name = name + self.job_id = str(uuid.uuid4()) + self.job_definition = job_def + self.job_queue = job_queue + self.job_state = 'SUBMITTED' # One of SUBMITTED | PENDING | RUNNABLE | STARTING | RUNNING | SUCCEEDED | FAILED + self.job_queue.jobs.append(self) + self.job_started_at = datetime.datetime(1970, 1, 1) + self.job_stopped_at = datetime.datetime(1970, 1, 1) + self.job_stopped = False + + self.stop = False + + self.daemon = True + self.name = 'MOTO-BATCH-' + self.job_id + + self.docker_client = docker.from_env() + self._log_backend = log_backend + + # Unfortunately mocking replaces this method w/o fallback enabled, so we + # need to replace it if we detect it's been mocked + if requests.adapters.HTTPAdapter.send != _orig_adapter_send: + _orig_get_adapter = self.docker_client.api.get_adapter + + def replace_adapter_send(*args, **kwargs): + adapter = _orig_get_adapter(*args, **kwargs) + + if isinstance(adapter, requests.adapters.HTTPAdapter): + adapter.send = functools.partial(_orig_adapter_send, adapter) + return adapter + self.docker_client.api.get_adapter = replace_adapter_send + + def describe(self): + result = { + 'jobDefinition': self.job_definition.arn, + 'jobId': self.job_id, + 'jobName': self.job_name, + 'jobQueue': self.job_queue.arn, + 'startedAt': datetime2int(self.job_started_at), + 'status': self.job_state, + 'dependsOn': [] + } + if self.job_stopped: + result['stoppedAt'] = datetime2int(self.job_stopped_at) + return result + + def run(self): + """ + Run the container. + + Logic is as follows: + Generate container info (eventually from task definition) + Start container + Loop whilst not asked to stop and the container is running. + Get all logs from container between the last time I checked and now. + Convert logs into cloudwatch format + Put logs into cloudwatch + + :return: + """ + try: + self.job_state = 'PENDING' + time.sleep(1) + + image = 'alpine:latest' + cmd = '/bin/sh -c "for a in `seq 1 10`; do echo Hello World; sleep 1; done"' + name = '{0}-{1}'.format(self.job_name, self.job_id) + + self.job_state = 'RUNNABLE' + # TODO setup ecs container instance + time.sleep(1) + + self.job_state = 'STARTING' + container = self.docker_client.containers.run( + image, cmd, + detach=True, + name=name + ) + self.job_state = 'RUNNING' + self.job_started_at = datetime.datetime.now() + try: + # Log collection + logs_stdout = [] + logs_stderr = [] + container.reload() + + # Dodgy hack, we can only check docker logs once a second, but we want to loop more + # so we can stop if asked to in a quick manner, should all go away if we go async + # There also be some dodgyness when sending an integer to docker logs and some + # events seem to be duplicated. + now = datetime.datetime.now() + i = 1 + while container.status == 'running' and not self.stop: + time.sleep(0.15) + if i % 10 == 0: + logs_stderr.extend(container.logs(stdout=False, stderr=True, timestamps=True, since=datetime2int(now)).decode().split('\n')) + logs_stdout.extend(container.logs(stdout=True, stderr=False, timestamps=True, since=datetime2int(now)).decode().split('\n')) + now = datetime.datetime.now() + container.reload() + i += 1 + + # Container should be stopped by this point... unless asked to stop + if container.status == 'running': + container.kill() + + self.job_stopped_at = datetime.datetime.now() + # Get final logs + logs_stderr.extend(container.logs(stdout=False, stderr=True, timestamps=True, since=datetime2int(now)).decode().split('\n')) + logs_stdout.extend(container.logs(stdout=True, stderr=False, timestamps=True, since=datetime2int(now)).decode().split('\n')) + + self.job_state = 'SUCCEEDED' if not self.stop else 'FAILED' + + # Process logs + logs_stdout = [x for x in logs_stdout if len(x) > 0] + logs_stderr = [x for x in logs_stderr if len(x) > 0] + logs = [] + for line in logs_stdout + logs_stderr: + date, line = line.split(' ', 1) + date = dateutil.parser.parse(date) + date = int(date.timestamp()) + logs.append({'timestamp': date, 'message': line.strip()}) + + # Send to cloudwatch + log_group = '/aws/batch/job' + stream_name = '{0}/default/{1}'.format(self.job_definition.name, self.job_id) + self._log_backend.ensure_log_group(log_group, None) + self._log_backend.create_log_stream(log_group, stream_name) + self._log_backend.put_log_events(log_group, stream_name, logs, None) + + except Exception as err: + logger.error('Failed to run AWS Batch container {0}. Error {1}'.format(self.name, err)) + self.job_state = 'FAILED' + container.kill() + finally: + container.remove() + except Exception as err: + logger.error('Failed to run AWS Batch container {0}. Error {1}'.format(self.name, err)) + self.job_state = 'FAILED' + + self.job_stopped = True + self.job_stopped_at = datetime.datetime.now() + + class BatchBackend(BaseBackend): def __init__(self, region_name=None): super(BatchBackend, self).__init__() @@ -164,6 +337,7 @@ class BatchBackend(BaseBackend): self._compute_environments = {} self._job_queues = {} self._job_definitions = {} + self._jobs = {} @property def iam_backend(self): @@ -189,8 +363,23 @@ class BatchBackend(BaseBackend): """ return ecs_backends[self.region_name] + @property + def logs_backend(self): + """ + :return: ECS Backend + :rtype: moto.logs.models.LogsBackend + """ + return logs_backends[self.region_name] + def reset(self): region_name = self.region_name + + for job in self._jobs.values(): + if job.job_state not in ('FAILED', 'SUCCEEDED'): + job.stop = True + # Try to join + job.join(0.2) + self.__dict__ = {} self.__init__(region_name) @@ -691,6 +880,42 @@ class BatchBackend(BaseBackend): return [job for job in jobs if job.status == status] return jobs + def submit_job(self, job_name, job_def_id, job_queue, parameters=None, retries=None, depends_on=None, container_overrides=None): + # TODO parameters, retries (which is a dict raw from request), job dependancies and container overrides are ignored for now + + # Look for job definition + job_def = self.get_job_definition_by_arn(job_def_id) + if job_def is None and ':' in job_def_id: + job_def = self.get_job_definition_by_name_revision(*job_def_id.split(':', 1)) + if job_def is None: + raise ClientException('Job definition {0} does not exist'.format(job_def_id)) + + queue = self.get_job_queue(job_queue) + if queue is None: + raise ClientException('Job queue {0} does not exist'.format(job_queue)) + + job = Job(job_name, job_def, queue, log_backend=self.logs_backend) + self._jobs[job.job_id] = job + + # Here comes the fun + job.start() + + return job_name, job.job_id + + def describe_jobs(self, jobs): + job_filter = set() + if jobs is not None: + job_filter = set(jobs) + + result = [] + for key, job in self._jobs.items(): + if len(job_filter) > 0 and key not in job_filter: + continue + + result.append(job.describe()) + + return result + available_regions = boto3.session.Session().get_available_regions("batch") batch_backends = {region: BatchBackend(region_name=region) for region in available_regions} diff --git a/moto/batch/responses.py b/moto/batch/responses.py index 0d3900d1d..2bec7ddf1 100644 --- a/moto/batch/responses.py +++ b/moto/batch/responses.py @@ -226,3 +226,40 @@ class BatchResponse(BaseResponse): result = {'jobDefinitions': [job.describe() for job in job_defs]} return json.dumps(result) + + # SubmitJob + def submitjob(self): + container_overrides = self._get_param('containerOverrides') + depends_on = self._get_param('dependsOn') + job_def = self._get_param('jobDefinition') + job_name = self._get_param('jobName') + job_queue = self._get_param('jobQueue') + parameters = self._get_param('parameters') + retries = self._get_param('retryStrategy') + + try: + name, job_id = self.batch_backend.submit_job( + job_name, job_def, job_queue, + parameters=parameters, + retries=retries, + depends_on=depends_on, + container_overrides=container_overrides + ) + except AWSError as err: + return err.response() + + result = { + 'jobId': job_id, + 'jobName': name, + } + + return json.dumps(result) + + # DescribeJobs + def describejobs(self): + jobs = self._get_param('jobs') + + try: + return json.dumps({'jobs': self.batch_backend.describe_jobs(jobs)}) + except AWSError as err: + return err.response() diff --git a/moto/batch/urls.py b/moto/batch/urls.py index 3265bb535..924e55e6d 100644 --- a/moto/batch/urls.py +++ b/moto/batch/urls.py @@ -16,5 +16,7 @@ url_paths = { '{0}/v1/deletejobqueue': BatchResponse.dispatch, '{0}/v1/registerjobdefinition': BatchResponse.dispatch, '{0}/v1/deregisterjobdefinition': BatchResponse.dispatch, - '{0}/v1/describejobdefinitions': BatchResponse.dispatch + '{0}/v1/describejobdefinitions': BatchResponse.dispatch, + '{0}/v1/submitjob': BatchResponse.dispatch, + '{0}/v1/describejobs': BatchResponse.dispatch } diff --git a/moto/logs/models.py b/moto/logs/models.py index 14f511932..09dcb3645 100644 --- a/moto/logs/models.py +++ b/moto/logs/models.py @@ -22,6 +22,13 @@ class LogEvent: "timestamp": self.timestamp } + def to_response_dict(self): + return { + "ingestionTime": self.ingestionTime, + "message": self.message, + "timestamp": self.timestamp + } + class LogStream: _log_ids = 0 @@ -41,7 +48,14 @@ class LogStream: self.__class__._log_ids += 1 + def _update(self): + self.firstEventTimestamp = min([x.timestamp for x in self.events]) + self.lastEventTimestamp = max([x.timestamp for x in self.events]) + def to_describe_dict(self): + # Compute start and end times + self._update() + return { "arn": self.arn, "creationTime": self.creationTime, @@ -79,7 +93,7 @@ class LogStream: if next_token is None: next_token = 0 - events_page = events[next_token: next_token + limit] + events_page = [event.to_response_dict() for event in events[next_token: next_token + limit]] next_token += limit if next_token >= len(self.events): next_token = None @@ -120,17 +134,17 @@ class LogGroup: del self.streams[log_stream_name] def describe_log_streams(self, descending, limit, log_group_name, log_stream_name_prefix, next_token, order_by): - log_streams = [stream.to_describe_dict() for name, stream in self.streams.items() if name.startswith(log_stream_name_prefix)] + log_streams = [(name, stream.to_describe_dict()) for name, stream in self.streams.items() if name.startswith(log_stream_name_prefix)] - def sorter(stream): - return stream.name if order_by == 'logStreamName' else stream.lastEventTimestamp + def sorter(item): + return item[0] if order_by == 'logStreamName' else item[1]['lastEventTimestamp'] if next_token is None: next_token = 0 log_streams = sorted(log_streams, key=sorter, reverse=descending) new_token = next_token + limit - log_streams_page = log_streams[next_token: new_token] + log_streams_page = [x[1] for x in log_streams[next_token: new_token]] if new_token >= len(log_streams): new_token = None diff --git a/moto/logs/responses.py b/moto/logs/responses.py index 4cb9caa6a..53b2390f4 100644 --- a/moto/logs/responses.py +++ b/moto/logs/responses.py @@ -47,7 +47,7 @@ class LogsResponse(BaseResponse): def describe_log_streams(self): log_group_name = self._get_param('logGroupName') - log_stream_name_prefix = self._get_param('logStreamNamePrefix') + log_stream_name_prefix = self._get_param('logStreamNamePrefix', '') descending = self._get_param('descending', False) limit = self._get_param('limit', 50) assert limit <= 50 @@ -83,7 +83,7 @@ class LogsResponse(BaseResponse): limit = self._get_param('limit', 10000) assert limit <= 10000 next_token = self._get_param('nextToken') - start_from_head = self._get_param('startFromHead') + start_from_head = self._get_param('startFromHead', False) events, next_backward_token, next_foward_token = \ self.logs_backend.get_log_events(log_group_name, log_stream_name, start_time, end_time, limit, next_token, start_from_head) diff --git a/tests/test_batch/test_batch.py b/tests/test_batch/test_batch.py index ebe710760..acbe75e94 100644 --- a/tests/test_batch/test_batch.py +++ b/tests/test_batch/test_batch.py @@ -1,11 +1,24 @@ from __future__ import unicode_literals +import time +import datetime import boto3 from botocore.exceptions import ClientError import sure # noqa -from moto import mock_batch, mock_iam, mock_ec2, mock_ecs +from moto import mock_batch, mock_iam, mock_ec2, mock_ecs, mock_logs +import functools +import nose +def expected_failure(test): + @functools.wraps(test) + def inner(*args, **kwargs): + try: + test(*args, **kwargs) + except Exception as err: + raise nose.SkipTest + return inner + DEFAULT_REGION = 'eu-central-1' @@ -13,6 +26,7 @@ def _get_clients(): return boto3.client('ec2', region_name=DEFAULT_REGION), \ boto3.client('iam', region_name=DEFAULT_REGION), \ boto3.client('ecs', region_name=DEFAULT_REGION), \ + boto3.client('logs', region_name=DEFAULT_REGION), \ boto3.client('batch', region_name=DEFAULT_REGION) @@ -52,7 +66,7 @@ def _setup(ec2_client, iam_client): @mock_iam @mock_batch def test_create_managed_compute_environment(): - ec2_client, iam_client, ecs_client, batch_client = _get_clients() + ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) compute_name = 'test_compute_env' @@ -105,7 +119,7 @@ def test_create_managed_compute_environment(): @mock_iam @mock_batch def test_create_unmanaged_compute_environment(): - ec2_client, iam_client, ecs_client, batch_client = _get_clients() + ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) compute_name = 'test_compute_env' @@ -136,7 +150,7 @@ def test_create_unmanaged_compute_environment(): @mock_iam @mock_batch def test_describe_compute_environment(): - ec2_client, iam_client, ecs_client, batch_client = _get_clients() + ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) compute_name = 'test_compute_env' @@ -163,7 +177,7 @@ def test_describe_compute_environment(): @mock_iam @mock_batch def test_delete_unmanaged_compute_environment(): - ec2_client, iam_client, ecs_client, batch_client = _get_clients() + ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) compute_name = 'test_compute_env' @@ -190,7 +204,7 @@ def test_delete_unmanaged_compute_environment(): @mock_iam @mock_batch def test_delete_managed_compute_environment(): - ec2_client, iam_client, ecs_client, batch_client = _get_clients() + ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) compute_name = 'test_compute_env' @@ -247,7 +261,7 @@ def test_delete_managed_compute_environment(): @mock_iam @mock_batch def test_update_unmanaged_compute_environment_state(): - ec2_client, iam_client, ecs_client, batch_client = _get_clients() + ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) compute_name = 'test_compute_env' @@ -273,7 +287,7 @@ def test_update_unmanaged_compute_environment_state(): @mock_iam @mock_batch def test_create_job_queue(): - ec2_client, iam_client, ecs_client, batch_client = _get_clients() + ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) compute_name = 'test_compute_env' @@ -315,7 +329,7 @@ def test_create_job_queue(): @mock_iam @mock_batch def test_job_queue_bad_arn(): - ec2_client, iam_client, ecs_client, batch_client = _get_clients() + ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) compute_name = 'test_compute_env' @@ -348,7 +362,7 @@ def test_job_queue_bad_arn(): @mock_iam @mock_batch def test_update_job_queue(): - ec2_client, iam_client, ecs_client, batch_client = _get_clients() + ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) compute_name = 'test_compute_env' @@ -389,7 +403,7 @@ def test_update_job_queue(): @mock_iam @mock_batch def test_update_job_queue(): - ec2_client, iam_client, ecs_client, batch_client = _get_clients() + ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) compute_name = 'test_compute_env' @@ -428,7 +442,7 @@ def test_update_job_queue(): @mock_iam @mock_batch def test_register_task_definition(): - ec2_client, iam_client, ecs_client, batch_client = _get_clients() + ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) resp = batch_client.register_job_definition( @@ -455,7 +469,7 @@ def test_register_task_definition(): @mock_batch def test_reregister_task_definition(): # Reregistering task with the same name bumps the revision number - ec2_client, iam_client, ecs_client, batch_client = _get_clients() + ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) resp1 = batch_client.register_job_definition( @@ -496,7 +510,7 @@ def test_reregister_task_definition(): @mock_iam @mock_batch def test_delete_task_definition(): - ec2_client, iam_client, ecs_client, batch_client = _get_clients() + ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) resp = batch_client.register_job_definition( @@ -521,10 +535,10 @@ def test_delete_task_definition(): @mock_iam @mock_batch def test_describe_task_definition(): - ec2_client, iam_client, ecs_client, batch_client = _get_clients() + ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) - resp = batch_client.register_job_definition( + batch_client.register_job_definition( jobDefinitionName='sleep10', type='container', containerProperties={ @@ -534,8 +548,7 @@ def test_describe_task_definition(): 'command': ['sleep', '10'] } ) - arn1 = resp['jobDefinitionArn'] - resp = batch_client.register_job_definition( + batch_client.register_job_definition( jobDefinitionName='sleep10', type='container', containerProperties={ @@ -545,8 +558,7 @@ def test_describe_task_definition(): 'command': ['sleep', '10'] } ) - arn2 = resp['jobDefinitionArn'] - resp = batch_client.register_job_definition( + batch_client.register_job_definition( jobDefinitionName='test1', type='container', containerProperties={ @@ -556,7 +568,6 @@ def test_describe_task_definition(): 'command': ['sleep', '10'] } ) - arn3 = resp['jobDefinitionArn'] resp = batch_client.describe_job_definitions( jobDefinitionName='sleep10' @@ -571,3 +582,76 @@ def test_describe_task_definition(): ) len(resp['jobDefinitions']).should.equal(3) + +# SLOW TEST +@expected_failure +@mock_logs +@mock_ec2 +@mock_ecs +@mock_iam +@mock_batch +def test_submit_job(): + ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() + vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) + + compute_name = 'test_compute_env' + resp = batch_client.create_compute_environment( + computeEnvironmentName=compute_name, + type='UNMANAGED', + state='ENABLED', + serviceRole=iam_arn + ) + arn = resp['computeEnvironmentArn'] + + resp = batch_client.create_job_queue( + jobQueueName='test_job_queue', + state='ENABLED', + priority=123, + computeEnvironmentOrder=[ + { + 'order': 123, + 'computeEnvironment': arn + }, + ] + ) + queue_arn = resp['jobQueueArn'] + + resp = batch_client.register_job_definition( + jobDefinitionName='sleep10', + type='container', + containerProperties={ + 'image': 'busybox', + 'vcpus': 1, + 'memory': 128, + 'command': ['sleep', '10'] + } + ) + job_def_arn = resp['jobDefinitionArn'] + + resp = batch_client.submit_job( + jobName='test1', + jobQueue=queue_arn, + jobDefinition=job_def_arn + ) + job_id = resp['jobId'] + + future = datetime.datetime.now() + datetime.timedelta(seconds=30) + + while datetime.datetime.now() < future: + resp = batch_client.describe_jobs(jobs=[job_id]) + print("{0}:{1} {2}".format(resp['jobs'][0]['jobName'], resp['jobs'][0]['jobId'], resp['jobs'][0]['status'])) + + if resp['jobs'][0]['status'] == 'FAILED': + raise RuntimeError('Batch job failed') + if resp['jobs'][0]['status'] == 'SUCCEEDED': + break + time.sleep(0.5) + else: + raise RuntimeError('Batch job timed out') + + resp = logs_client.describe_log_streams(logGroupName='/aws/batch/job') + len(resp['logStreams']).should.equal(1) + ls_name = resp['logStreams'][0]['logStreamName'] + + resp = logs_client.get_log_events(logGroupName='/aws/batch/job', logStreamName=ls_name) + len(resp['events']).should.be.greater_than(5) \ No newline at end of file