Batch: add SchedulingPolicy methods (#5877)

This commit is contained in:
Bert Blommers 2023-01-26 14:06:50 -01:00 committed by GitHub
parent 8dcf2d33ed
commit 2f8a356b3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 243 additions and 17 deletions

View File

@ -491,23 +491,23 @@
## batch ## batch
<details> <details>
<summary>79% implemented</summary> <summary>100% implemented</summary>
- [X] cancel_job - [X] cancel_job
- [X] create_compute_environment - [X] create_compute_environment
- [X] create_job_queue - [X] create_job_queue
- [ ] create_scheduling_policy - [X] create_scheduling_policy
- [X] delete_compute_environment - [X] delete_compute_environment
- [X] delete_job_queue - [X] delete_job_queue
- [ ] delete_scheduling_policy - [X] delete_scheduling_policy
- [X] deregister_job_definition - [X] deregister_job_definition
- [X] describe_compute_environments - [X] describe_compute_environments
- [X] describe_job_definitions - [X] describe_job_definitions
- [X] describe_job_queues - [X] describe_job_queues
- [X] describe_jobs - [X] describe_jobs
- [ ] describe_scheduling_policies - [X] describe_scheduling_policies
- [X] list_jobs - [X] list_jobs
- [ ] list_scheduling_policies - [X] list_scheduling_policies
- [X] list_tags_for_resource - [X] list_tags_for_resource
- [X] register_job_definition - [X] register_job_definition
- [X] submit_job - [X] submit_job
@ -516,7 +516,7 @@
- [X] untag_resource - [X] untag_resource
- [X] update_compute_environment - [X] update_compute_environment
- [X] update_job_queue - [X] update_job_queue
- [ ] update_scheduling_policy - [X] update_scheduling_policy
</details> </details>
## budgets ## budgets

View File

@ -30,10 +30,10 @@ batch
- [X] cancel_job - [X] cancel_job
- [X] create_compute_environment - [X] create_compute_environment
- [X] create_job_queue - [X] create_job_queue
- [ ] create_scheduling_policy - [X] create_scheduling_policy
- [X] delete_compute_environment - [X] delete_compute_environment
- [X] delete_job_queue - [X] delete_job_queue
- [ ] delete_scheduling_policy - [X] delete_scheduling_policy
- [X] deregister_job_definition - [X] deregister_job_definition
- [X] describe_compute_environments - [X] describe_compute_environments
@ -51,13 +51,17 @@ batch
- [X] describe_jobs - [X] describe_jobs
- [ ] describe_scheduling_policies - [X] describe_scheduling_policies
- [X] list_jobs - [X] list_jobs
Pagination is not yet implemented Pagination is not yet implemented
- [ ] list_scheduling_policies - [X] list_scheduling_policies
Pagination is not yet implemented
- [X] list_tags_for_resource - [X] list_tags_for_resource
- [X] register_job_definition - [X] register_job_definition
- [X] submit_job - [X] submit_job
@ -70,5 +74,5 @@ batch
- [X] untag_resource - [X] untag_resource
- [X] update_compute_environment - [X] update_compute_environment
- [X] update_job_queue - [X] update_job_queue
- [ ] update_scheduling_policy - [X] update_scheduling_policy

View File

@ -132,6 +132,7 @@ class JobQueue(CloudFormationModel):
state: str, state: str,
environments: List[ComputeEnvironment], environments: List[ComputeEnvironment],
env_order_json: List[Dict[str, Any]], env_order_json: List[Dict[str, Any]],
schedule_policy: Optional[str],
backend: "BatchBackend", backend: "BatchBackend",
tags: Optional[Dict[str, str]] = None, tags: Optional[Dict[str, str]] = None,
): ):
@ -152,6 +153,7 @@ class JobQueue(CloudFormationModel):
self.state = state self.state = state
self.environments = environments self.environments = environments
self.env_order_json = env_order_json self.env_order_json = env_order_json
self.schedule_policy = schedule_policy
self.arn = make_arn_for_job_queue(backend.account_id, name, backend.region_name) self.arn = make_arn_for_job_queue(backend.account_id, name, backend.region_name)
self.status = "VALID" self.status = "VALID"
self.backend = backend self.backend = backend
@ -167,6 +169,7 @@ class JobQueue(CloudFormationModel):
"jobQueueArn": self.arn, "jobQueueArn": self.arn,
"jobQueueName": self.name, "jobQueueName": self.name,
"priority": self.priority, "priority": self.priority,
"schedulingPolicyArn": self.schedule_policy,
"state": self.state, "state": self.state,
"status": self.status, "status": self.status,
"tags": self.backend.list_tags_for_resource(self.arn), "tags": self.backend.list_tags_for_resource(self.arn),
@ -209,6 +212,7 @@ class JobQueue(CloudFormationModel):
priority=properties["Priority"], priority=properties["Priority"],
state=properties.get("State", "ENABLED"), state=properties.get("State", "ENABLED"),
compute_env_order=compute_envs, compute_env_order=compute_envs,
schedule_policy={},
) )
arn = queue[1] arn = queue[1]
@ -914,6 +918,35 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
return True return True
class SchedulingPolicy(BaseModel):
def __init__(
self,
account_id: str,
region: str,
name: str,
fairshare_policy: Dict[str, Any],
backend: "BatchBackend",
tags: Dict[str, str],
):
self.name = name
self.arn = f"arn:aws:batch:{region}:{account_id}:scheduling-policy/{name}"
self.fairshare_policy = {
"computeReservation": fairshare_policy.get("computeReservation") or 0,
"shareDecaySeconds": fairshare_policy.get("shareDecaySeconds") or 0,
"shareDistribution": fairshare_policy.get("shareDistribution") or [],
}
self.backend = backend
if tags:
backend.tag_resource(self.arn, tags)
def to_dict(self, create: bool = False) -> Dict[str, Any]:
resp: Dict[str, Any] = {"name": self.name, "arn": self.arn}
if not create:
resp["fairsharePolicy"] = self.fairshare_policy
resp["tags"] = self.backend.list_tags_for_resource(self.arn)
return resp
class BatchBackend(BaseBackend): class BatchBackend(BaseBackend):
""" """
Batch-jobs are executed inside a Docker-container. Everytime the `submit_job`-method is called, a new Docker container is started. Batch-jobs are executed inside a Docker-container. Everytime the `submit_job`-method is called, a new Docker container is started.
@ -931,6 +964,7 @@ class BatchBackend(BaseBackend):
self._job_queues: Dict[str, JobQueue] = {} self._job_queues: Dict[str, JobQueue] = {}
self._job_definitions: Dict[str, JobDefinition] = {} self._job_definitions: Dict[str, JobDefinition] = {}
self._jobs: Dict[str, Job] = {} self._jobs: Dict[str, Job] = {}
self._scheduling_policies: Dict[str, SchedulingPolicy] = {}
state_manager.register_default_transition( state_manager.register_default_transition(
"batch::job", transition={"progression": "manual", "times": 1} "batch::job", transition={"progression": "manual", "times": 1}
@ -1401,6 +1435,7 @@ class BatchBackend(BaseBackend):
self, self,
queue_name: str, queue_name: str,
priority: str, priority: str,
schedule_policy: Optional[str],
state: str, state: str,
compute_env_order: List[Dict[str, str]], compute_env_order: List[Dict[str, str]],
tags: Optional[Dict[str, str]] = None, tags: Optional[Dict[str, str]] = None,
@ -1444,6 +1479,7 @@ class BatchBackend(BaseBackend):
state, state,
env_objects, env_objects,
compute_env_order, compute_env_order,
schedule_policy=schedule_policy,
backend=self, backend=self,
tags=tags, tags=tags,
) )
@ -1477,6 +1513,7 @@ class BatchBackend(BaseBackend):
priority: Optional[str], priority: Optional[str],
state: Optional[str], state: Optional[str],
compute_env_order: Optional[List[Dict[str, Any]]], compute_env_order: Optional[List[Dict[str, Any]]],
schedule_policy: Optional[str],
) -> Tuple[str, str]: ) -> Tuple[str, str]:
if queue_name is None: if queue_name is None:
raise ClientException("jobQueueName must be provided") raise ClientException("jobQueueName must be provided")
@ -1519,6 +1556,8 @@ class BatchBackend(BaseBackend):
if priority is not None: if priority is not None:
job_queue.priority = priority job_queue.priority = priority
if schedule_policy is not None:
job_queue.schedule_policy = schedule_policy
return queue_name, job_queue.arn return queue_name, job_queue.arn
@ -1768,5 +1807,31 @@ class BatchBackend(BaseBackend):
def untag_resource(self, resource_arn: str, tag_keys: List[str]) -> None: def untag_resource(self, resource_arn: str, tag_keys: List[str]) -> None:
self.tagger.untag_resource_using_names(resource_arn, tag_keys) self.tagger.untag_resource_using_names(resource_arn, tag_keys)
def create_scheduling_policy(
self, name: str, fairshare_policy: Dict[str, Any], tags: Dict[str, str]
) -> SchedulingPolicy:
policy = SchedulingPolicy(
self.account_id, self.region_name, name, fairshare_policy, self, tags
)
self._scheduling_policies[policy.arn] = policy
return self._scheduling_policies[policy.arn]
def describe_scheduling_policies(self, arns: List[str]) -> List[SchedulingPolicy]:
return [pol for arn, pol in self._scheduling_policies.items() if arn in arns]
def list_scheduling_policies(self) -> List[str]:
"""
Pagination is not yet implemented
"""
return list(self._scheduling_policies.keys())
def delete_scheduling_policy(self, arn: str) -> None:
self._scheduling_policies.pop(arn, None)
def update_scheduling_policy(
self, arn: str, fairshare_policy: Dict[str, Any]
) -> None:
self._scheduling_policies[arn].fairshare_policy = fairshare_policy
batch_backends = BackendDict(BatchBackend, "batch") batch_backends = BackendDict(BatchBackend, "batch")

View File

@ -86,6 +86,7 @@ class BatchResponse(BaseResponse):
def createjobqueue(self) -> str: def createjobqueue(self) -> str:
compute_env_order = self._get_param("computeEnvironmentOrder") compute_env_order = self._get_param("computeEnvironmentOrder")
queue_name = self._get_param("jobQueueName") queue_name = self._get_param("jobQueueName")
schedule_policy = self._get_param("schedulingPolicyArn")
priority = self._get_param("priority") priority = self._get_param("priority")
state = self._get_param("state") state = self._get_param("state")
tags = self._get_param("tags") tags = self._get_param("tags")
@ -93,6 +94,7 @@ class BatchResponse(BaseResponse):
name, arn = self.batch_backend.create_job_queue( name, arn = self.batch_backend.create_job_queue(
queue_name=queue_name, queue_name=queue_name,
priority=priority, priority=priority,
schedule_policy=schedule_policy,
state=state, state=state,
compute_env_order=compute_env_order, compute_env_order=compute_env_order,
tags=tags, tags=tags,
@ -117,6 +119,7 @@ class BatchResponse(BaseResponse):
def updatejobqueue(self) -> str: def updatejobqueue(self) -> str:
compute_env_order = self._get_param("computeEnvironmentOrder") compute_env_order = self._get_param("computeEnvironmentOrder")
queue_name = self._get_param("jobQueue") queue_name = self._get_param("jobQueue")
schedule_policy = self._get_param("schedulingPolicyArn")
priority = self._get_param("priority") priority = self._get_param("priority")
state = self._get_param("state") state = self._get_param("state")
@ -125,6 +128,7 @@ class BatchResponse(BaseResponse):
priority=priority, priority=priority,
state=state, state=state,
compute_env_order=compute_env_order, compute_env_order=compute_env_order,
schedule_policy=schedule_policy,
) )
result = {"jobQueueArn": arn, "jobQueueName": name} result = {"jobQueueArn": arn, "jobQueueName": name}
@ -271,3 +275,41 @@ class BatchResponse(BaseResponse):
tag_keys = self.querystring.get("tagKeys") tag_keys = self.querystring.get("tagKeys")
self.batch_backend.untag_resource(resource_arn, tag_keys) # type: ignore[arg-type] self.batch_backend.untag_resource(resource_arn, tag_keys) # type: ignore[arg-type]
return "" return ""
@amzn_request_id
def createschedulingpolicy(self) -> str:
body = json.loads(self.body)
name = body.get("name")
fairshare_policy = body.get("fairsharePolicy") or {}
tags = body.get("tags")
policy = self.batch_backend.create_scheduling_policy(
name, fairshare_policy, tags
)
return json.dumps(policy.to_dict(create=True))
@amzn_request_id
def describeschedulingpolicies(self) -> str:
body = json.loads(self.body)
arns = body.get("arns") or []
policies = self.batch_backend.describe_scheduling_policies(arns)
return json.dumps({"schedulingPolicies": [pol.to_dict() for pol in policies]})
@amzn_request_id
def listschedulingpolicies(self) -> str:
arns = self.batch_backend.list_scheduling_policies()
return json.dumps({"schedulingPolicies": [{"arn": arn} for arn in arns]})
@amzn_request_id
def deleteschedulingpolicy(self) -> str:
body = json.loads(self.body)
arn = body["arn"]
self.batch_backend.delete_scheduling_policy(arn)
return ""
@amzn_request_id
def updateschedulingpolicy(self) -> str:
body = json.loads(self.body)
arn = body.get("arn")
fairshare_policy = body.get("fairsharePolicy") or {}
self.batch_backend.update_scheduling_policy(arn, fairshare_policy)
return ""

View File

@ -14,6 +14,11 @@ url_paths = {
"{0}/v1/registerjobdefinition": BatchResponse.dispatch, "{0}/v1/registerjobdefinition": BatchResponse.dispatch,
"{0}/v1/deregisterjobdefinition": BatchResponse.dispatch, "{0}/v1/deregisterjobdefinition": BatchResponse.dispatch,
"{0}/v1/describejobdefinitions": BatchResponse.dispatch, "{0}/v1/describejobdefinitions": BatchResponse.dispatch,
"{0}/v1/createschedulingpolicy": BatchResponse.dispatch,
"{0}/v1/describeschedulingpolicies": BatchResponse.dispatch,
"{0}/v1/listschedulingpolicies": BatchResponse.dispatch,
"{0}/v1/deleteschedulingpolicy": BatchResponse.dispatch,
"{0}/v1/updateschedulingpolicy": BatchResponse.dispatch,
"{0}/v1/submitjob": BatchResponse.dispatch, "{0}/v1/submitjob": BatchResponse.dispatch,
"{0}/v1/describejobs": BatchResponse.dispatch, "{0}/v1/describejobs": BatchResponse.dispatch,
"{0}/v1/listjobs": BatchResponse.dispatch, "{0}/v1/listjobs": BatchResponse.dispatch,

View File

@ -53,12 +53,9 @@ autoscaling:
- TestAccAutoScalingLaunchConfiguration_encryptedRootBlockDevice - TestAccAutoScalingLaunchConfiguration_encryptedRootBlockDevice
batch: batch:
- TestAccBatchJobDefinition - TestAccBatchJobDefinition
- TestAccBatchJobQueue_basic - TestAccBatchJobQueue_
- TestAccBatchJobQueue_tags - TestAccBatchJobQueueDataSource_
- TestAccBatchJobQueue_disappears - TestAccBatchSchedulingPolicy
- TestAccBatchJobQueue_priority
- TestAccBatchJobQueue_state
- TestAccBatchJobQueue_ComputeEnvironments_externalOrderUpdate
ce: ce:
- TestAccCECostCategory - TestAccCECostCategory
cloudformation: cloudformation:

View File

@ -30,6 +30,7 @@ def test_create_job_queue():
state="ENABLED", state="ENABLED",
priority=123, priority=123,
computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}], computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}],
schedulingPolicyArn="policy_arn",
) )
resp.should.contain("jobQueueArn") resp.should.contain("jobQueueArn")
resp.should.contain("jobQueueName") resp.should.contain("jobQueueName")
@ -39,6 +40,7 @@ def test_create_job_queue():
our_queues = [q for q in all_queues if q["jobQueueName"] == jq_name] our_queues = [q for q in all_queues if q["jobQueueName"] == jq_name]
our_queues.should.have.length_of(1) our_queues.should.have.length_of(1)
our_queues[0]["jobQueueArn"].should.equal(queue_arn) our_queues[0]["jobQueueArn"].should.equal(queue_arn)
our_queues[0]["schedulingPolicyArn"].should.equal("policy_arn")
@mock_ec2 @mock_ec2

View File

@ -0,0 +1,97 @@
import boto3
from moto import mock_batch
from tests import DEFAULT_ACCOUNT_ID
@mock_batch
def test_create_scheduling_policy():
client = boto3.client("batch", "us-east-2")
resp = client.create_scheduling_policy(name="test")
resp.should.have.key("name").equals("test")
resp.should.have.key("arn").equals(
f"arn:aws:batch:us-east-2:{DEFAULT_ACCOUNT_ID}:scheduling-policy/test"
)
@mock_batch
def test_describe_default_scheduling_policy():
client = boto3.client("batch", "us-east-2")
arn = client.create_scheduling_policy(name="test")["arn"]
resp = client.describe_scheduling_policies(arns=[arn])
resp.should.have.key("schedulingPolicies").length_of(1)
policy = resp["schedulingPolicies"][0]
policy["arn"].should.equal(arn)
policy["fairsharePolicy"].should.equal(
{"computeReservation": 0, "shareDecaySeconds": 0, "shareDistribution": []}
)
policy["tags"].should.equal({})
@mock_batch
def test_describe_scheduling_policy():
client = boto3.client("batch", "us-east-2")
arn = client.create_scheduling_policy(
name="test",
fairsharePolicy={
"shareDecaySeconds": 1,
"computeReservation": 2,
"shareDistribution": [{"shareIdentifier": "A", "weightFactor": 0.1}],
},
)["arn"]
resp = client.list_scheduling_policies()
resp.should.have.key("schedulingPolicies")
arns = [a["arn"] for a in resp["schedulingPolicies"]]
arns.should.contain(arn)
resp = client.describe_scheduling_policies(arns=[arn])
resp.should.have.key("schedulingPolicies").length_of(1)
policy = resp["schedulingPolicies"][0]
policy["arn"].should.equal(arn)
policy["fairsharePolicy"].should.equal(
{
"computeReservation": 2,
"shareDecaySeconds": 1,
"shareDistribution": [{"shareIdentifier": "A", "weightFactor": 0.1}],
}
)
policy["tags"].should.equal({})
@mock_batch
def test_delete_scheduling_policy():
client = boto3.client("batch", "us-east-2")
arn = client.create_scheduling_policy(name="test")["arn"]
client.delete_scheduling_policy(arn=arn)
resp = client.describe_scheduling_policies(arns=[arn])
resp.should.have.key("schedulingPolicies").length_of(0)
@mock_batch
def test_update_scheduling_policy():
client = boto3.client("batch", "us-east-2")
arn = client.create_scheduling_policy(name="test")["arn"]
client.update_scheduling_policy(
arn=arn,
fairsharePolicy={
"computeReservation": 5,
"shareDecaySeconds": 10,
"shareDistribution": [],
},
)
resp = client.describe_scheduling_policies(arns=[arn])
resp.should.have.key("schedulingPolicies").length_of(1)
policy = resp["schedulingPolicies"][0]
policy["arn"].should.equal(arn)
policy["fairsharePolicy"].should.equal(
{"computeReservation": 5, "shareDecaySeconds": 10, "shareDistribution": []}
)

View File

@ -0,0 +1,14 @@
import boto3
from moto import mock_batch
@mock_batch
def test_create_with_tags():
client = boto3.client("batch", "us-east-2")
arn = client.create_scheduling_policy(name="test", tags={"key": "val"})["arn"]
resp = client.describe_scheduling_policies(arns=[arn])
policy = resp["schedulingPolicies"][0]
policy["tags"].should.equal({"key": "val"})