moto/moto/batch/models.py
2019-11-22 09:39:07 -05:00

1278 lines
43 KiB
Python

from __future__ import unicode_literals
import boto3
import re
import requests.adapters
from itertools import cycle
import six
import datetime
import time
import uuid
import logging
import docker
import functools
import threading
import dateutil.parser
from moto.core import BaseBackend, BaseModel
from moto.iam import iam_backends
from moto.ec2 import ec2_backends
from moto.ecs import ecs_backends
from moto.logs import logs_backends
from .exceptions import InvalidParameterValueException, InternalFailure, ClientException
from .utils import (
make_arn_for_compute_env,
make_arn_for_job_queue,
make_arn_for_task_def,
lowercase_first_key,
)
from moto.ec2.exceptions import InvalidSubnetIdError
from moto.ec2.models import INSTANCE_TYPES as EC2_INSTANCE_TYPES
from moto.iam.exceptions import IAMNotFoundException
_orig_adapter_send = requests.adapters.HTTPAdapter.send
logger = logging.getLogger(__name__)
DEFAULT_ACCOUNT_ID = 123456789012
COMPUTE_ENVIRONMENT_NAME_REGEX = re.compile(
r"^[A-Za-z0-9][A-Za-z0-9_-]{1,126}[A-Za-z0-9]$"
)
def datetime2int(date):
return int(time.mktime(date.timetuple()))
class ComputeEnvironment(BaseModel):
def __init__(
self,
compute_environment_name,
_type,
state,
compute_resources,
service_role,
region_name,
):
self.name = compute_environment_name
self.env_type = _type
self.state = state
self.compute_resources = compute_resources
self.service_role = service_role
self.arn = make_arn_for_compute_env(
DEFAULT_ACCOUNT_ID, compute_environment_name, region_name
)
self.instances = []
self.ecs_arn = None
self.ecs_name = None
def add_instance(self, instance):
self.instances.append(instance)
def set_ecs(self, arn, name):
self.ecs_arn = arn
self.ecs_name = name
@property
def physical_resource_id(self):
return self.arn
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
backend = batch_backends[region_name]
properties = cloudformation_json["Properties"]
env = backend.create_compute_environment(
resource_name,
properties["Type"],
properties.get("State", "ENABLED"),
lowercase_first_key(properties["ComputeResources"]),
properties["ServiceRole"],
)
arn = env[1]
return backend.get_compute_environment_by_arn(arn)
class JobQueue(BaseModel):
def __init__(
self, name, priority, state, environments, env_order_json, region_name
):
"""
:param name: Job queue name
:type name: str
:param priority: Job queue priority
:type priority: int
:param state: Either ENABLED or DISABLED
:type state: str
:param environments: Compute Environments
:type environments: list of ComputeEnvironment
:param env_order_json: Compute Environments JSON for use when describing
:type env_order_json: list of dict
:param region_name: Region name
:type region_name: str
"""
self.name = name
self.priority = priority
self.state = state
self.environments = environments
self.env_order_json = env_order_json
self.arn = make_arn_for_job_queue(DEFAULT_ACCOUNT_ID, name, region_name)
self.status = "VALID"
self.jobs = []
def describe(self):
result = {
"computeEnvironmentOrder": self.env_order_json,
"jobQueueArn": self.arn,
"jobQueueName": self.name,
"priority": self.priority,
"state": self.state,
"status": self.status,
}
return result
@property
def physical_resource_id(self):
return self.arn
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
backend = batch_backends[region_name]
properties = cloudformation_json["Properties"]
# Need to deal with difference case from cloudformation compute_resources, e.g. instanceRole vs InstanceRole
# Hacky fix to normalise keys, is making me think I want to start spamming cAsEiNsEnSiTiVe dictionaries
compute_envs = [
lowercase_first_key(dict_item)
for dict_item in properties["ComputeEnvironmentOrder"]
]
queue = backend.create_job_queue(
queue_name=resource_name,
priority=properties["Priority"],
state=properties.get("State", "ENABLED"),
compute_env_order=compute_envs,
)
arn = queue[1]
return backend.get_job_queue_by_arn(arn)
class JobDefinition(BaseModel):
def __init__(
self,
name,
parameters,
_type,
container_properties,
region_name,
revision=0,
retry_strategy=0,
):
self.name = name
self.retries = retry_strategy
self.type = _type
self.revision = revision
self._region = region_name
self.container_properties = container_properties
self.arn = None
self.status = "INACTIVE"
if parameters is None:
parameters = {}
self.parameters = parameters
self._validate()
self._update_arn()
def _update_arn(self):
self.revision += 1
self.arn = make_arn_for_task_def(
DEFAULT_ACCOUNT_ID, self.name, self.revision, self._region
)
def _validate(self):
if self.type not in ("container",):
raise ClientException('type must be one of "container"')
# For future use when containers arnt the only thing in batch
if self.type != "container":
raise NotImplementedError()
if not isinstance(self.parameters, dict):
raise ClientException("parameters must be a string to string map")
if "image" not in self.container_properties:
raise ClientException("containerProperties must contain image")
if "memory" not in self.container_properties:
raise ClientException("containerProperties must contain memory")
if self.container_properties["memory"] < 4:
raise ClientException("container memory limit must be greater than 4")
if "vcpus" not in self.container_properties:
raise ClientException("containerProperties must contain vcpus")
if self.container_properties["vcpus"] < 1:
raise ClientException("container vcpus limit must be greater than 0")
def update(self, parameters, _type, container_properties, retry_strategy):
if parameters is None:
parameters = self.parameters
if _type is None:
_type = self.type
if container_properties is None:
container_properties = self.container_properties
if retry_strategy is None:
retry_strategy = self.retries
return JobDefinition(
self.name,
parameters,
_type,
container_properties,
region_name=self._region,
revision=self.revision,
retry_strategy=retry_strategy,
)
def describe(self):
result = {
"jobDefinitionArn": self.arn,
"jobDefinitionName": self.name,
"parameters": self.parameters,
"revision": self.revision,
"status": self.status,
"type": self.type,
}
if self.container_properties is not None:
result["containerProperties"] = self.container_properties
if self.retries is not None and self.retries > 0:
result["retryStrategy"] = {"attempts": self.retries}
return result
@property
def physical_resource_id(self):
return self.arn
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
backend = batch_backends[region_name]
properties = cloudformation_json["Properties"]
res = backend.register_job_definition(
def_name=resource_name,
parameters=lowercase_first_key(properties.get("Parameters", {})),
_type="container",
retry_strategy=lowercase_first_key(properties["RetryStrategy"]),
container_properties=lowercase_first_key(properties["ContainerProperties"]),
)
arn = res[1]
return backend.get_job_definition_by_arn(arn)
class Job(threading.Thread, BaseModel):
def __init__(self, name, job_def, job_queue, log_backend):
"""
Docker Job
:param name: Job Name
:param job_def: Job definition
:type: job_def: JobDefinition
:param job_queue: Job Queue
:param log_backend: Log backend
:type log_backend: moto.logs.models.LogsBackend
"""
threading.Thread.__init__(self)
self.job_name = name
self.job_id = str(uuid.uuid4())
self.job_definition = job_def
self.job_queue = job_queue
self.job_state = "SUBMITTED" # One of SUBMITTED | PENDING | RUNNABLE | STARTING | RUNNING | SUCCEEDED | FAILED
self.job_queue.jobs.append(self)
self.job_started_at = datetime.datetime(1970, 1, 1)
self.job_stopped_at = datetime.datetime(1970, 1, 1)
self.job_stopped = False
self.job_stopped_reason = None
self.stop = False
self.daemon = True
self.name = "MOTO-BATCH-" + self.job_id
self.docker_client = docker.from_env()
self._log_backend = log_backend
# Unfortunately mocking replaces this method w/o fallback enabled, so we
# need to replace it if we detect it's been mocked
if requests.adapters.HTTPAdapter.send != _orig_adapter_send:
_orig_get_adapter = self.docker_client.api.get_adapter
def replace_adapter_send(*args, **kwargs):
adapter = _orig_get_adapter(*args, **kwargs)
if isinstance(adapter, requests.adapters.HTTPAdapter):
adapter.send = functools.partial(_orig_adapter_send, adapter)
return adapter
self.docker_client.api.get_adapter = replace_adapter_send
def describe(self):
result = {
"jobDefinition": self.job_definition.arn,
"jobId": self.job_id,
"jobName": self.job_name,
"jobQueue": self.job_queue.arn,
"startedAt": datetime2int(self.job_started_at),
"status": self.job_state,
"dependsOn": [],
}
if self.job_stopped:
result["stoppedAt"] = datetime2int(self.job_stopped_at)
result["container"] = {}
result["container"]["command"] = [
'/bin/sh -c "for a in `seq 1 10`; do echo Hello World; sleep 1; done"'
]
result["container"]["privileged"] = False
result["container"]["readonlyRootFilesystem"] = False
result["container"]["ulimits"] = {}
result["container"]["vcpus"] = 1
result["container"]["volumes"] = ""
result["container"]["logStreamName"] = self.log_stream_name
if self.job_stopped_reason is not None:
result["statusReason"] = self.job_stopped_reason
return result
def run(self):
"""
Run the container.
Logic is as follows:
Generate container info (eventually from task definition)
Start container
Loop whilst not asked to stop and the container is running.
Get all logs from container between the last time I checked and now.
Convert logs into cloudwatch format
Put logs into cloudwatch
:return:
"""
try:
self.job_state = "PENDING"
time.sleep(1)
image = "alpine:latest"
cmd = '/bin/sh -c "for a in `seq 1 10`; do echo Hello World; sleep 1; done"'
name = "{0}-{1}".format(self.job_name, self.job_id)
self.job_state = "RUNNABLE"
# TODO setup ecs container instance
time.sleep(1)
self.job_state = "STARTING"
container = self.docker_client.containers.run(
image, cmd, detach=True, name=name
)
self.job_state = "RUNNING"
self.job_started_at = datetime.datetime.now()
try:
# Log collection
logs_stdout = []
logs_stderr = []
container.reload()
# Dodgy hack, we can only check docker logs once a second, but we want to loop more
# so we can stop if asked to in a quick manner, should all go away if we go async
# There also be some dodgyness when sending an integer to docker logs and some
# events seem to be duplicated.
now = datetime.datetime.now()
i = 1
while container.status == "running" and not self.stop:
time.sleep(0.15)
if i % 10 == 0:
logs_stderr.extend(
container.logs(
stdout=False,
stderr=True,
timestamps=True,
since=datetime2int(now),
)
.decode()
.split("\n")
)
logs_stdout.extend(
container.logs(
stdout=True,
stderr=False,
timestamps=True,
since=datetime2int(now),
)
.decode()
.split("\n")
)
now = datetime.datetime.now()
container.reload()
i += 1
# Container should be stopped by this point... unless asked to stop
if container.status == "running":
container.kill()
self.job_stopped_at = datetime.datetime.now()
# Get final logs
logs_stderr.extend(
container.logs(
stdout=False,
stderr=True,
timestamps=True,
since=datetime2int(now),
)
.decode()
.split("\n")
)
logs_stdout.extend(
container.logs(
stdout=True,
stderr=False,
timestamps=True,
since=datetime2int(now),
)
.decode()
.split("\n")
)
self.job_state = "SUCCEEDED" if not self.stop else "FAILED"
# Process logs
logs_stdout = [x for x in logs_stdout if len(x) > 0]
logs_stderr = [x for x in logs_stderr if len(x) > 0]
logs = []
for line in logs_stdout + logs_stderr:
date, line = line.split(" ", 1)
date = dateutil.parser.parse(date)
date = int(date.timestamp())
logs.append({"timestamp": date, "message": line.strip()})
# Send to cloudwatch
log_group = "/aws/batch/job"
stream_name = "{0}/default/{1}".format(
self.job_definition.name, self.job_id
)
self.log_stream_name = stream_name
self._log_backend.ensure_log_group(log_group, None)
self._log_backend.create_log_stream(log_group, stream_name)
self._log_backend.put_log_events(log_group, stream_name, logs, None)
except Exception as err:
logger.error(
"Failed to run AWS Batch container {0}. Error {1}".format(
self.name, err
)
)
self.job_state = "FAILED"
container.kill()
finally:
container.remove()
except Exception as err:
logger.error(
"Failed to run AWS Batch container {0}. Error {1}".format(
self.name, err
)
)
self.job_state = "FAILED"
self.job_stopped = True
self.job_stopped_at = datetime.datetime.now()
def terminate(self, reason):
if not self.stop:
self.stop = True
self.job_stopped_reason = reason
class BatchBackend(BaseBackend):
def __init__(self, region_name=None):
super(BatchBackend, self).__init__()
self.region_name = region_name
self._compute_environments = {}
self._job_queues = {}
self._job_definitions = {}
self._jobs = {}
@property
def iam_backend(self):
"""
:return: IAM Backend
:rtype: moto.iam.models.IAMBackend
"""
return iam_backends["global"]
@property
def ec2_backend(self):
"""
:return: EC2 Backend
:rtype: moto.ec2.models.EC2Backend
"""
return ec2_backends[self.region_name]
@property
def ecs_backend(self):
"""
:return: ECS Backend
:rtype: moto.ecs.models.EC2ContainerServiceBackend
"""
return ecs_backends[self.region_name]
@property
def logs_backend(self):
"""
:return: ECS Backend
:rtype: moto.logs.models.LogsBackend
"""
return logs_backends[self.region_name]
def reset(self):
region_name = self.region_name
for job in self._jobs.values():
if job.job_state not in ("FAILED", "SUCCEEDED"):
job.stop = True
# Try to join
job.join(0.2)
self.__dict__ = {}
self.__init__(region_name)
def get_compute_environment_by_arn(self, arn):
return self._compute_environments.get(arn)
def get_compute_environment_by_name(self, name):
for comp_env in self._compute_environments.values():
if comp_env.name == name:
return comp_env
return None
def get_compute_environment(self, identifier):
"""
Get compute environment by name or ARN
:param identifier: Name or ARN
:type identifier: str
:return: Compute Environment or None
:rtype: ComputeEnvironment or None
"""
env = self.get_compute_environment_by_arn(identifier)
if env is None:
env = self.get_compute_environment_by_name(identifier)
return env
def get_job_queue_by_arn(self, arn):
return self._job_queues.get(arn)
def get_job_queue_by_name(self, name):
for comp_env in self._job_queues.values():
if comp_env.name == name:
return comp_env
return None
def get_job_queue(self, identifier):
"""
Get job queue by name or ARN
:param identifier: Name or ARN
:type identifier: str
:return: Job Queue or None
:rtype: JobQueue or None
"""
env = self.get_job_queue_by_arn(identifier)
if env is None:
env = self.get_job_queue_by_name(identifier)
return env
def get_job_definition_by_arn(self, arn):
return self._job_definitions.get(arn)
def get_job_definition_by_name(self, name):
latest_revision = -1
latest_job = None
for job_def in self._job_definitions.values():
if job_def.name == name and job_def.revision > latest_revision:
latest_job = job_def
latest_revision = job_def.revision
return latest_job
def get_job_definition_by_name_revision(self, name, revision):
for job_def in self._job_definitions.values():
if job_def.name == name and job_def.revision == revision:
return job_def
return None
def get_job_definition(self, identifier):
"""
Get job definitions by name or ARN
:param identifier: Name or ARN
:type identifier: str
:return: Job definition or None
:rtype: JobDefinition or None
"""
job_def = self.get_job_definition_by_arn(identifier)
if job_def is None:
if ":" in identifier:
job_def = self.get_job_definition_by_name_revision(
*identifier.split(":", 1)
)
else:
job_def = self.get_job_definition_by_name(identifier)
return job_def
def get_job_definitions(self, identifier):
"""
Get job definitions by name or ARN
:param identifier: Name or ARN
:type identifier: str
:return: Job definition or None
:rtype: list of JobDefinition
"""
result = []
env = self.get_job_definition_by_arn(identifier)
if env is not None:
result.append(env)
else:
for value in self._job_definitions.values():
if value.name == identifier:
result.append(value)
return result
def get_job_by_id(self, identifier):
"""
Get job by id
:param identifier: Job ID
:type identifier: str
:return: Job
:rtype: Job
"""
try:
return self._jobs[identifier]
except KeyError:
return None
def describe_compute_environments(
self, environments=None, max_results=None, next_token=None
):
envs = set()
if environments is not None:
envs = set(environments)
result = []
for arn, environment in self._compute_environments.items():
# Filter shortcut
if len(envs) > 0 and arn not in envs and environment.name not in envs:
continue
json_part = {
"computeEnvironmentArn": arn,
"computeEnvironmentName": environment.name,
"ecsClusterArn": environment.ecs_arn,
"serviceRole": environment.service_role,
"state": environment.state,
"type": environment.env_type,
"status": "VALID",
}
if environment.env_type == "MANAGED":
json_part["computeResources"] = environment.compute_resources
result.append(json_part)
return result
def create_compute_environment(
self, compute_environment_name, _type, state, compute_resources, service_role
):
# Validate
if COMPUTE_ENVIRONMENT_NAME_REGEX.match(compute_environment_name) is None:
raise InvalidParameterValueException(
"Compute environment name does not match ^[A-Za-z0-9][A-Za-z0-9_-]{1,126}[A-Za-z0-9]$"
)
if self.get_compute_environment_by_name(compute_environment_name) is not None:
raise InvalidParameterValueException(
"A compute environment already exists with the name {0}".format(
compute_environment_name
)
)
# Look for IAM role
try:
self.iam_backend.get_role_by_arn(service_role)
except IAMNotFoundException:
raise InvalidParameterValueException(
"Could not find IAM role {0}".format(service_role)
)
if _type not in ("MANAGED", "UNMANAGED"):
raise InvalidParameterValueException(
"type {0} must be one of MANAGED | UNMANAGED".format(service_role)
)
if state is not None and state not in ("ENABLED", "DISABLED"):
raise InvalidParameterValueException(
"state {0} must be one of ENABLED | DISABLED".format(state)
)
if compute_resources is None and _type == "MANAGED":
raise InvalidParameterValueException(
"computeResources must be specified when creating a MANAGED environment".format(
state
)
)
elif compute_resources is not None:
self._validate_compute_resources(compute_resources)
# By here, all values except SPOT ones have been validated
new_comp_env = ComputeEnvironment(
compute_environment_name,
_type,
state,
compute_resources,
service_role,
region_name=self.region_name,
)
self._compute_environments[new_comp_env.arn] = new_comp_env
# Ok by this point, everything is legit, so if its Managed then start some instances
if _type == "MANAGED":
cpus = int(
compute_resources.get("desiredvCpus", compute_resources["minvCpus"])
)
instance_types = compute_resources["instanceTypes"]
needed_instance_types = self.find_min_instances_to_meet_vcpus(
instance_types, cpus
)
# Create instances
# Will loop over and over so we get decent subnet coverage
subnet_cycle = cycle(compute_resources["subnets"])
for instance_type in needed_instance_types:
reservation = self.ec2_backend.add_instances(
image_id="ami-ecs-optimised", # Todo import AMIs
count=1,
user_data=None,
security_group_names=[],
instance_type=instance_type,
region_name=self.region_name,
subnet_id=six.next(subnet_cycle),
key_name=compute_resources.get("ec2KeyPair", "AWS_OWNED"),
security_group_ids=compute_resources["securityGroupIds"],
)
new_comp_env.add_instance(reservation.instances[0])
# Create ECS cluster
# Should be of format P2OnDemand_Batch_UUID
cluster_name = "OnDemand_Batch_" + str(uuid.uuid4())
ecs_cluster = self.ecs_backend.create_cluster(cluster_name)
new_comp_env.set_ecs(ecs_cluster.arn, cluster_name)
return compute_environment_name, new_comp_env.arn
def _validate_compute_resources(self, cr):
"""
Checks contents of sub dictionary for managed clusters
:param cr: computeResources
:type cr: dict
"""
for param in (
"instanceRole",
"maxvCpus",
"minvCpus",
"instanceTypes",
"securityGroupIds",
"subnets",
"type",
):
if param not in cr:
raise InvalidParameterValueException(
"computeResources must contain {0}".format(param)
)
if self.iam_backend.get_role_by_arn(cr["instanceRole"]) is None:
raise InvalidParameterValueException(
"could not find instanceRole {0}".format(cr["instanceRole"])
)
if cr["maxvCpus"] < 0:
raise InvalidParameterValueException("maxVCpus must be positive")
if cr["minvCpus"] < 0:
raise InvalidParameterValueException("minVCpus must be positive")
if cr["maxvCpus"] < cr["minvCpus"]:
raise InvalidParameterValueException(
"maxVCpus must be greater than minvCpus"
)
if len(cr["instanceTypes"]) == 0:
raise InvalidParameterValueException(
"At least 1 instance type must be provided"
)
for instance_type in cr["instanceTypes"]:
if instance_type == "optimal":
pass # Optimal should pick from latest of current gen
elif instance_type not in EC2_INSTANCE_TYPES:
raise InvalidParameterValueException(
"Instance type {0} does not exist".format(instance_type)
)
for sec_id in cr["securityGroupIds"]:
if self.ec2_backend.get_security_group_from_id(sec_id) is None:
raise InvalidParameterValueException(
"security group {0} does not exist".format(sec_id)
)
if len(cr["securityGroupIds"]) == 0:
raise InvalidParameterValueException(
"At least 1 security group must be provided"
)
for subnet_id in cr["subnets"]:
try:
self.ec2_backend.get_subnet(subnet_id)
except InvalidSubnetIdError:
raise InvalidParameterValueException(
"subnet {0} does not exist".format(subnet_id)
)
if len(cr["subnets"]) == 0:
raise InvalidParameterValueException("At least 1 subnet must be provided")
if cr["type"] not in ("EC2", "SPOT"):
raise InvalidParameterValueException(
"computeResources.type must be either EC2 | SPOT"
)
if cr["type"] == "SPOT":
raise InternalFailure("SPOT NOT SUPPORTED YET")
@staticmethod
def find_min_instances_to_meet_vcpus(instance_types, target):
"""
Finds the minimum needed instances to meed a vcpu target
:param instance_types: Instance types, like ['t2.medium', 't2.small']
:type instance_types: list of str
:param target: VCPU target
:type target: float
:return: List of instance types
:rtype: list of str
"""
# vcpus = [ (vcpus, instance_type), (vcpus, instance_type), ... ]
instance_vcpus = []
instances = []
for instance_type in instance_types:
if instance_type == "optimal":
instance_type = "m4.4xlarge"
instance_vcpus.append(
(EC2_INSTANCE_TYPES[instance_type]["vcpus"], instance_type)
)
instance_vcpus = sorted(instance_vcpus, key=lambda item: item[0], reverse=True)
# Loop through,
# if biggest instance type smaller than target, and len(instance_types)> 1, then use biggest type
# if biggest instance type bigger than target, and len(instance_types)> 1, then remove it and move on
# if biggest instance type bigger than target and len(instan_types) == 1 then add instance and finish
# if biggest instance type smaller than target and len(instan_types) == 1 then loop adding instances until target == 0
# ^^ boils down to keep adding last till target vcpus is negative
# #Algorithm ;-) ... Could probably be done better with some quality lambdas
while target > 0:
current_vcpu, current_instance = instance_vcpus[0]
if len(instance_vcpus) > 1:
if current_vcpu <= target:
target -= current_vcpu
instances.append(current_instance)
else:
# try next biggest instance
instance_vcpus.pop(0)
else:
# Were on the last instance
target -= current_vcpu
instances.append(current_instance)
return instances
def delete_compute_environment(self, compute_environment_name):
if compute_environment_name is None:
raise InvalidParameterValueException("Missing computeEnvironment parameter")
compute_env = self.get_compute_environment(compute_environment_name)
if compute_env is not None:
# Pop ComputeEnvironment
self._compute_environments.pop(compute_env.arn)
# Delete ECS cluster
self.ecs_backend.delete_cluster(compute_env.ecs_name)
if compute_env.env_type == "MANAGED":
# Delete compute environment
instance_ids = [instance.id for instance in compute_env.instances]
self.ec2_backend.terminate_instances(instance_ids)
def update_compute_environment(
self, compute_environment_name, state, compute_resources, service_role
):
# Validate
compute_env = self.get_compute_environment(compute_environment_name)
if compute_env is None:
raise ClientException("Compute environment {0} does not exist")
# Look for IAM role
if service_role is not None:
try:
role = self.iam_backend.get_role_by_arn(service_role)
except IAMNotFoundException:
raise InvalidParameterValueException(
"Could not find IAM role {0}".format(service_role)
)
compute_env.service_role = role
if state is not None:
if state not in ("ENABLED", "DISABLED"):
raise InvalidParameterValueException(
"state {0} must be one of ENABLED | DISABLED".format(state)
)
compute_env.state = state
if compute_resources is not None:
# TODO Implement resizing of instances based on changing vCpus
# compute_resources CAN contain desiredvCpus, maxvCpus, minvCpus, and can contain none of them.
pass
return compute_env.name, compute_env.arn
def create_job_queue(self, queue_name, priority, state, compute_env_order):
"""
Create a job queue
:param queue_name: Queue name
:type queue_name: str
:param priority: Queue priority
:type priority: int
:param state: Queue state
:type state: string
:param compute_env_order: Compute environment list
:type compute_env_order: list of dict
:return: Tuple of Name, ARN
:rtype: tuple of str
"""
for variable, var_name in (
(queue_name, "jobQueueName"),
(priority, "priority"),
(state, "state"),
(compute_env_order, "computeEnvironmentOrder"),
):
if variable is None:
raise ClientException("{0} must be provided".format(var_name))
if state not in ("ENABLED", "DISABLED"):
raise ClientException(
"state {0} must be one of ENABLED | DISABLED".format(state)
)
if self.get_job_queue_by_name(queue_name) is not None:
raise ClientException("Job queue {0} already exists".format(queue_name))
if len(compute_env_order) == 0:
raise ClientException("At least 1 compute environment must be provided")
try:
# orders and extracts computeEnvironment names
ordered_compute_environments = [
item["computeEnvironment"]
for item in sorted(compute_env_order, key=lambda x: x["order"])
]
env_objects = []
# Check each ARN exists, then make a list of compute env's
for arn in ordered_compute_environments:
env = self.get_compute_environment_by_arn(arn)
if env is None:
raise ClientException(
"Compute environment {0} does not exist".format(arn)
)
env_objects.append(env)
except Exception:
raise ClientException("computeEnvironmentOrder is malformed")
# Create new Job Queue
queue = JobQueue(
queue_name,
priority,
state,
env_objects,
compute_env_order,
self.region_name,
)
self._job_queues[queue.arn] = queue
return queue_name, queue.arn
def describe_job_queues(self, job_queues=None, max_results=None, next_token=None):
envs = set()
if job_queues is not None:
envs = set(job_queues)
result = []
for arn, job_queue in self._job_queues.items():
# Filter shortcut
if len(envs) > 0 and arn not in envs and job_queue.name not in envs:
continue
result.append(job_queue.describe())
return result
def update_job_queue(self, queue_name, priority, state, compute_env_order):
"""
Update a job queue
:param queue_name: Queue name
:type queue_name: str
:param priority: Queue priority
:type priority: int
:param state: Queue state
:type state: string
:param compute_env_order: Compute environment list
:type compute_env_order: list of dict
:return: Tuple of Name, ARN
:rtype: tuple of str
"""
if queue_name is None:
raise ClientException("jobQueueName must be provided")
job_queue = self.get_job_queue(queue_name)
if job_queue is None:
raise ClientException("Job queue {0} does not exist".format(queue_name))
if state is not None:
if state not in ("ENABLED", "DISABLED"):
raise ClientException(
"state {0} must be one of ENABLED | DISABLED".format(state)
)
job_queue.state = state
if compute_env_order is not None:
if len(compute_env_order) == 0:
raise ClientException("At least 1 compute environment must be provided")
try:
# orders and extracts computeEnvironment names
ordered_compute_environments = [
item["computeEnvironment"]
for item in sorted(compute_env_order, key=lambda x: x["order"])
]
env_objects = []
# Check each ARN exists, then make a list of compute env's
for arn in ordered_compute_environments:
env = self.get_compute_environment_by_arn(arn)
if env is None:
raise ClientException(
"Compute environment {0} does not exist".format(arn)
)
env_objects.append(env)
except Exception:
raise ClientException("computeEnvironmentOrder is malformed")
job_queue.env_order_json = compute_env_order
job_queue.environments = env_objects
if priority is not None:
job_queue.priority = priority
return queue_name, job_queue.arn
def delete_job_queue(self, queue_name):
job_queue = self.get_job_queue(queue_name)
if job_queue is not None:
del self._job_queues[job_queue.arn]
def register_job_definition(
self, def_name, parameters, _type, retry_strategy, container_properties
):
if def_name is None:
raise ClientException("jobDefinitionName must be provided")
job_def = self.get_job_definition_by_name(def_name)
if retry_strategy is not None:
try:
retry_strategy = retry_strategy["attempts"]
except Exception:
raise ClientException("retryStrategy is malformed")
if job_def is None:
job_def = JobDefinition(
def_name,
parameters,
_type,
container_properties,
region_name=self.region_name,
retry_strategy=retry_strategy,
)
else:
# Make new jobdef
job_def = job_def.update(
parameters, _type, container_properties, retry_strategy
)
self._job_definitions[job_def.arn] = job_def
return def_name, job_def.arn, job_def.revision
def deregister_job_definition(self, def_name):
job_def = self.get_job_definition_by_arn(def_name)
if job_def is None and ":" in def_name:
name, revision = def_name.split(":", 1)
job_def = self.get_job_definition_by_name_revision(name, revision)
if job_def is not None:
del self._job_definitions[job_def.arn]
def describe_job_definitions(
self,
job_def_name=None,
job_def_list=None,
status=None,
max_results=None,
next_token=None,
):
jobs = []
# As a job name can reference multiple revisions, we get a list of them
if job_def_name is not None:
job_def = self.get_job_definitions(job_def_name)
if job_def is not None:
jobs.extend(job_def)
elif job_def_list is not None:
for job in job_def_list:
job_def = self.get_job_definitions(job)
if job_def is not None:
jobs.extend(job_def)
else:
jobs.extend(self._job_definitions.values())
# Got all the job defs were after, filter then by status
if status is not None:
return [job for job in jobs if job.status == status]
return jobs
def submit_job(
self,
job_name,
job_def_id,
job_queue,
parameters=None,
retries=None,
depends_on=None,
container_overrides=None,
):
# TODO parameters, retries (which is a dict raw from request), job dependencies and container overrides are ignored for now
# Look for job definition
job_def = self.get_job_definition(job_def_id)
if job_def is None:
raise ClientException(
"Job definition {0} does not exist".format(job_def_id)
)
queue = self.get_job_queue(job_queue)
if queue is None:
raise ClientException("Job queue {0} does not exist".format(job_queue))
job = Job(job_name, job_def, queue, log_backend=self.logs_backend)
self._jobs[job.job_id] = job
# Here comes the fun
job.start()
return job_name, job.job_id
def describe_jobs(self, jobs):
job_filter = set()
if jobs is not None:
job_filter = set(jobs)
result = []
for key, job in self._jobs.items():
if len(job_filter) > 0 and key not in job_filter:
continue
result.append(job.describe())
return result
def list_jobs(self, job_queue, job_status=None, max_results=None, next_token=None):
jobs = []
job_queue = self.get_job_queue(job_queue)
if job_queue is None:
raise ClientException("Job queue {0} does not exist".format(job_queue))
if job_status is not None and job_status not in (
"SUBMITTED",
"PENDING",
"RUNNABLE",
"STARTING",
"RUNNING",
"SUCCEEDED",
"FAILED",
):
raise ClientException(
"Job status is not one of SUBMITTED | PENDING | RUNNABLE | STARTING | RUNNING | SUCCEEDED | FAILED"
)
for job in job_queue.jobs:
if job_status is not None and job.job_state != job_status:
continue
jobs.append(job)
return jobs
def terminate_job(self, job_id, reason):
if job_id is None:
raise ClientException("Job ID does not exist")
if reason is None:
raise ClientException("Reason does not exist")
job = self.get_job_by_id(job_id)
if job is None:
raise ClientException("Job not found")
job.terminate(reason)
available_regions = boto3.session.Session().get_available_regions("batch")
batch_backends = {
region: BatchBackend(region_name=region) for region in available_regions
}