Batch: Using enum for job status (#6789)
This commit is contained in:
parent
a8ab8011ed
commit
121ad974b8
@ -24,6 +24,7 @@ from .utils import (
|
||||
make_arn_for_job,
|
||||
make_arn_for_task_def,
|
||||
lowercase_first_key,
|
||||
JobStatus,
|
||||
)
|
||||
from moto.ec2.exceptions import InvalidSubnetIdError
|
||||
from moto.ec2.models.instance_types import INSTANCE_TYPES as EC2_INSTANCE_TYPES
|
||||
@ -487,12 +488,7 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
|
||||
ManagedState.__init__(
|
||||
self,
|
||||
"batch::job",
|
||||
[
|
||||
("SUBMITTED", "PENDING"),
|
||||
("PENDING", "RUNNABLE"),
|
||||
("RUNNABLE", "STARTING"),
|
||||
("STARTING", "RUNNING"),
|
||||
],
|
||||
JobStatus.status_transitions(),
|
||||
)
|
||||
|
||||
self.job_name = name
|
||||
@ -539,8 +535,9 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
|
||||
}
|
||||
if self.job_stopped_reason is not None:
|
||||
result["statusReason"] = self.job_stopped_reason
|
||||
if result["status"] not in ["SUBMITTED", "PENDING", "RUNNABLE", "STARTING"]:
|
||||
result["startedAt"] = datetime2int_milliseconds(self.job_started_at)
|
||||
if self.status is not None:
|
||||
if JobStatus.is_job_already_started(self.status):
|
||||
result["startedAt"] = datetime2int_milliseconds(self.job_started_at)
|
||||
if self.job_stopped:
|
||||
result["stoppedAt"] = datetime2int_milliseconds(self.job_stopped_at)
|
||||
if self.exit_code is not None:
|
||||
@ -638,7 +635,7 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
|
||||
containers: List[docker.models.containers.Container] = []
|
||||
|
||||
self.advance()
|
||||
while self.status == "SUBMITTED":
|
||||
while self.status == JobStatus.SUBMITTED:
|
||||
# Wait until we've moved onto state 'PENDING'
|
||||
sleep(0.5)
|
||||
|
||||
@ -730,7 +727,7 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
|
||||
)
|
||||
|
||||
self.advance()
|
||||
while self.status == "PENDING":
|
||||
while self.status == JobStatus.PENDING:
|
||||
# Wait until the state is no longer pending, but 'RUNNABLE'
|
||||
sleep(0.5)
|
||||
# TODO setup ecs container instance
|
||||
@ -765,12 +762,12 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
|
||||
|
||||
log_config = docker.types.LogConfig(type=docker.types.LogConfig.types.JSON)
|
||||
self.advance()
|
||||
while self.status == "RUNNABLE":
|
||||
while self.status == JobStatus.RUNNABLE:
|
||||
# Wait until the state is no longer runnable, but 'STARTING'
|
||||
sleep(0.5)
|
||||
|
||||
self.advance()
|
||||
while self.status == "STARTING":
|
||||
while self.status == JobStatus.STARTING:
|
||||
# Wait until the state is no longer runnable, but 'RUNNING'
|
||||
sleep(0.5)
|
||||
|
||||
@ -898,7 +895,7 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
|
||||
# The describe-method needs them immediately when status is set
|
||||
self.job_stopped = True
|
||||
self.job_stopped_at = datetime.datetime.now()
|
||||
self.status = "SUCCEEDED" if success else "FAILED"
|
||||
self.status = JobStatus.SUCCEEDED if success else JobStatus.FAILED
|
||||
self._stop_attempt()
|
||||
|
||||
def _start_attempt(self) -> None:
|
||||
@ -934,9 +931,9 @@ class Job(threading.Thread, BaseModel, DockerModel, ManagedState):
|
||||
for dependent_id in dependent_ids:
|
||||
if dependent_id in self.all_jobs:
|
||||
dependent_job = self.all_jobs[dependent_id]
|
||||
if dependent_job.status == "SUCCEEDED":
|
||||
if dependent_job.status == JobStatus.SUCCEEDED:
|
||||
successful_dependencies.add(dependent_id)
|
||||
if dependent_job.status == "FAILED":
|
||||
if dependent_job.status == JobStatus.FAILED:
|
||||
logger.error(
|
||||
f"Terminating job {self.name} due to failed dependency {dependent_job.name}"
|
||||
)
|
||||
@ -1038,7 +1035,7 @@ class BatchBackend(BaseBackend):
|
||||
|
||||
def reset(self) -> None:
|
||||
for job in self._jobs.values():
|
||||
if job.status not in ("FAILED", "SUCCEEDED"):
|
||||
if job.status not in (JobStatus.FAILED, JobStatus.SUCCEEDED):
|
||||
job.stop = True
|
||||
# Try to join
|
||||
job.join(0.2)
|
||||
@ -1776,15 +1773,7 @@ class BatchBackend(BaseBackend):
|
||||
if job_queue is None:
|
||||
raise ClientException(f"Job queue {job_queue_name} does not exist")
|
||||
|
||||
if job_status is not None and job_status not in (
|
||||
"SUBMITTED",
|
||||
"PENDING",
|
||||
"RUNNABLE",
|
||||
"STARTING",
|
||||
"RUNNING",
|
||||
"SUCCEEDED",
|
||||
"FAILED",
|
||||
):
|
||||
if job_status is not None and job_status not in JobStatus.job_statuses():
|
||||
raise ClientException(
|
||||
"Job status is not one of SUBMITTED | PENDING | RUNNABLE | STARTING | RUNNING | SUCCEEDED | FAILED"
|
||||
)
|
||||
@ -1821,7 +1810,9 @@ class BatchBackend(BaseBackend):
|
||||
|
||||
job = self.get_job_by_id(job_id)
|
||||
if job is not None:
|
||||
if job.status in ["SUBMITTED", "PENDING", "RUNNABLE"]:
|
||||
if job.status is None:
|
||||
return
|
||||
if JobStatus.is_job_before_starting(job.status):
|
||||
job.terminate(reason)
|
||||
# No-Op for jobs that have already started - user has to explicitly terminate those
|
||||
|
||||
|
@ -1,4 +1,7 @@
|
||||
from typing import Any, Dict
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from .exceptions import ValidationError
|
||||
|
||||
|
||||
def make_arn_for_compute_env(account_id: str, name: str, region_name: str) -> str:
|
||||
@ -34,3 +37,55 @@ def lowercase_first_key(some_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
new_dict[new_key] = value
|
||||
|
||||
return new_dict
|
||||
|
||||
|
||||
def validate_job_status(target_job_status: str, valid_job_statuses: List[str]) -> None:
|
||||
if target_job_status not in valid_job_statuses:
|
||||
raise ValidationError(
|
||||
(
|
||||
"1 validation error detected: Value at 'current_status' failed "
|
||||
"to satisfy constraint: Member must satisfy enum value set: {valid_statues}"
|
||||
).format(valid_statues=valid_job_statuses)
|
||||
)
|
||||
|
||||
|
||||
class JobStatus(str, Enum):
|
||||
SUBMITTED = "SUBMITTED"
|
||||
PENDING = "PENDING"
|
||||
RUNNABLE = "RUNNABLE"
|
||||
STARTING = "STARTING"
|
||||
RUNNING = "RUNNING"
|
||||
SUCCEEDED = "SUCCEEDED"
|
||||
FAILED = "FAILED"
|
||||
|
||||
@classmethod
|
||||
def job_statuses(self) -> List[str]:
|
||||
return sorted([item.value for item in JobStatus])
|
||||
|
||||
@classmethod
|
||||
def is_job_already_started(self, current_status: str) -> bool:
|
||||
validate_job_status(current_status, JobStatus.job_statuses())
|
||||
return current_status not in [
|
||||
JobStatus.SUBMITTED,
|
||||
JobStatus.PENDING,
|
||||
JobStatus.RUNNABLE,
|
||||
JobStatus.STARTING,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def is_job_before_starting(self, current_status: str) -> bool:
|
||||
validate_job_status(current_status, JobStatus.job_statuses())
|
||||
return current_status in [
|
||||
JobStatus.SUBMITTED,
|
||||
JobStatus.PENDING,
|
||||
JobStatus.RUNNABLE,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def status_transitions(self) -> List[Tuple[Optional[str], str]]:
|
||||
return [
|
||||
(JobStatus.SUBMITTED.value, JobStatus.PENDING.value),
|
||||
(JobStatus.PENDING.value, JobStatus.RUNNABLE.value),
|
||||
(JobStatus.RUNNABLE.value, JobStatus.STARTING),
|
||||
(JobStatus.STARTING.value, JobStatus.RUNNING.value),
|
||||
]
|
||||
|
70
tests/test_batch/test_utils.py
Normal file
70
tests/test_batch/test_utils.py
Normal file
@ -0,0 +1,70 @@
|
||||
import pytest
|
||||
|
||||
from moto.batch.exceptions import ValidationError
|
||||
|
||||
from moto.batch.utils import JobStatus
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"job_already_started, job_status",
|
||||
[
|
||||
(None, "InvalidJobStatus"),
|
||||
(False, "SUBMITTED"),
|
||||
(False, "PENDING"),
|
||||
(False, "RUNNABLE"),
|
||||
(False, "STARTING"),
|
||||
(True, "RUNNING"),
|
||||
(True, "SUCCEEDED"),
|
||||
(True, "FAILDED"),
|
||||
],
|
||||
)
|
||||
def test_JobStatus_is_job_already_sarted(job_already_started, job_status):
|
||||
if job_status not in JobStatus.job_statuses():
|
||||
with pytest.raises(ValidationError) as e:
|
||||
_ = JobStatus.is_job_already_started("InvalidJobStatus")
|
||||
assert (
|
||||
e.value.message
|
||||
== "1 validation error detected: Value at 'current_status' failed to satisfy constraint: Member must satisfy enum value set: ['FAILED', 'PENDING', 'RUNNABLE', 'RUNNING', 'STARTING', 'SUBMITTED', 'SUCCEEDED']"
|
||||
)
|
||||
return
|
||||
|
||||
assert JobStatus.is_job_already_started(job_status) is job_already_started
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"job_before_starting, job_status",
|
||||
[
|
||||
(None, "InvalidJobStatus"),
|
||||
(True, "SUBMITTED"),
|
||||
(True, "PENDING"),
|
||||
(True, "RUNNABLE"),
|
||||
(False, "STARTING"),
|
||||
(False, "RUNNING"),
|
||||
(False, "SUCCEEDED"),
|
||||
(False, "FAILDED"),
|
||||
],
|
||||
)
|
||||
def test_JobStatus_is_job_before_starting(job_before_starting, job_status):
|
||||
if job_status not in JobStatus.job_statuses():
|
||||
with pytest.raises(ValidationError) as e:
|
||||
_ = JobStatus.is_job_before_starting("InvalidJobStatus")
|
||||
assert (
|
||||
e.value.message
|
||||
== "1 validation error detected: Value at 'current_status' failed to satisfy constraint: Member must satisfy enum value set: ['FAILED', 'PENDING', 'RUNNABLE', 'RUNNING', 'STARTING', 'SUBMITTED', 'SUCCEEDED']"
|
||||
)
|
||||
return
|
||||
|
||||
assert JobStatus.is_job_before_starting(job_status) is job_before_starting
|
||||
|
||||
|
||||
def test_JobStatus_status_transitions():
|
||||
for before_status, after_status in JobStatus.status_transitions():
|
||||
if before_status == JobStatus.SUBMITTED:
|
||||
assert after_status == JobStatus.PENDING
|
||||
elif before_status == JobStatus.PENDING:
|
||||
assert after_status == JobStatus.RUNNABLE
|
||||
elif before_status == JobStatus.RUNNABLE:
|
||||
assert after_status == JobStatus.STARTING
|
||||
else:
|
||||
assert before_status == JobStatus.STARTING
|
||||
assert after_status == JobStatus.RUNNING
|
Loading…
Reference in New Issue
Block a user