Sagemaker:describe_training_job() - describe unknown job (#5582)

This commit is contained in:
Bert Blommers 2022-10-19 21:53:02 +00:00 committed by GitHub
parent 62f93c7ed0
commit da9cf7bb3a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 21 additions and 21 deletions

View File

@ -165,8 +165,8 @@ class FakeTrainingJob(BaseObject):
self.debug_rule_configurations = debug_rule_configurations self.debug_rule_configurations = debug_rule_configurations
self.tensor_board_output_config = tensor_board_output_config self.tensor_board_output_config = tensor_board_output_config
self.experiment_config = experiment_config self.experiment_config = experiment_config
self.training_job_arn = arn_formatter( self.training_job_arn = FakeTrainingJob.arn_formatter(
"training-job", training_job_name, account_id, region_name training_job_name, account_id, region_name
) )
self.creation_time = self.last_modified_time = datetime.now().strftime( self.creation_time = self.last_modified_time = datetime.now().strftime(
"%Y-%m-%d %H:%M:%S" "%Y-%m-%d %H:%M:%S"
@ -219,6 +219,10 @@ class FakeTrainingJob(BaseObject):
def response_create(self): def response_create(self):
return {"TrainingJobArn": self.training_job_arn} return {"TrainingJobArn": self.training_job_arn}
@staticmethod
def arn_formatter(name, account_id, region_name):
return arn_formatter("training-job", name, account_id, region_name)
class FakeEndpoint(BaseObject, CloudFormationModel): class FakeEndpoint(BaseObject, CloudFormationModel):
def __init__( def __init__(
@ -1865,22 +1869,12 @@ class SageMakerModelBackend(BaseBackend):
return self.training_jobs[training_job_name].response_object return self.training_jobs[training_job_name].response_object
except KeyError: except KeyError:
message = "Could not find training job '{}'.".format( message = "Could not find training job '{}'.".format(
FakeTrainingJob.arn_formatter(training_job_name, self.region_name) FakeTrainingJob.arn_formatter(
training_job_name, self.account_id, self.region_name
)
) )
raise ValidationError(message=message) raise ValidationError(message=message)
def delete_training_job(self, training_job_name):
try:
del self.training_jobs[training_job_name]
except KeyError:
message = "Could not find endpoint configuration '{}'.".format(
FakeTrainingJob.arn_formatter(training_job_name, self.region_name)
)
raise ValidationError(message=message)
def _update_training_job_details(self, training_job_name, details_json):
self.training_jobs[training_job_name].update(details_json)
def list_training_jobs( def list_training_jobs(
self, self,
next_token, next_token,

View File

@ -256,12 +256,6 @@ class SageMakerResponse(BaseResponse):
response = self.sagemaker_backend.describe_training_job(training_job_name) response = self.sagemaker_backend.describe_training_job(training_job_name)
return json.dumps(response) return json.dumps(response)
@amzn_request_id
def delete_training_job(self):
training_job_name = self._get_param("TrainingJobName")
self.sagemaker_backend.delete_training_job(training_job_name)
return 200, {}, json.dumps("{}")
@amzn_request_id @amzn_request_id
def create_notebook_instance_lifecycle_config(self): def create_notebook_instance_lifecycle_config(self):
lifecycle_configuration = ( lifecycle_configuration = (

View File

@ -441,3 +441,15 @@ def test_delete_tags_from_training_job():
response = client.list_tags(ResourceArn=resource_arn) response = client.list_tags(ResourceArn=resource_arn)
assert response["Tags"] == [] assert response["Tags"] == []
@mock_sagemaker
def test_describe_unknown_training_job():
client = boto3.client("sagemaker", region_name="us-east-1")
with pytest.raises(ClientError) as exc:
client.describe_training_job(TrainingJobName="unknown")
err = exc.value.response["Error"]
err["Code"].should.equal("ValidationException")
err["Message"].should.equal(
f"Could not find training job 'arn:aws:sagemaker:us-east-1:{ACCOUNT_ID}:training-job/unknown'."
)