diff --git a/moto/sagemaker/models.py b/moto/sagemaker/models.py index 0f2778c9c..9325c3255 100644 --- a/moto/sagemaker/models.py +++ b/moto/sagemaker/models.py @@ -165,8 +165,8 @@ class FakeTrainingJob(BaseObject): self.debug_rule_configurations = debug_rule_configurations self.tensor_board_output_config = tensor_board_output_config self.experiment_config = experiment_config - self.training_job_arn = arn_formatter( - "training-job", training_job_name, account_id, region_name + self.training_job_arn = FakeTrainingJob.arn_formatter( + training_job_name, account_id, region_name ) self.creation_time = self.last_modified_time = datetime.now().strftime( "%Y-%m-%d %H:%M:%S" @@ -219,6 +219,10 @@ class FakeTrainingJob(BaseObject): def response_create(self): return {"TrainingJobArn": self.training_job_arn} + @staticmethod + def arn_formatter(name, account_id, region_name): + return arn_formatter("training-job", name, account_id, region_name) + class FakeEndpoint(BaseObject, CloudFormationModel): def __init__( @@ -1865,22 +1869,12 @@ class SageMakerModelBackend(BaseBackend): return self.training_jobs[training_job_name].response_object except KeyError: message = "Could not find training job '{}'.".format( - FakeTrainingJob.arn_formatter(training_job_name, self.region_name) + FakeTrainingJob.arn_formatter( + training_job_name, self.account_id, self.region_name + ) ) raise ValidationError(message=message) - def delete_training_job(self, training_job_name): - try: - del self.training_jobs[training_job_name] - except KeyError: - message = "Could not find endpoint configuration '{}'.".format( - FakeTrainingJob.arn_formatter(training_job_name, self.region_name) - ) - raise ValidationError(message=message) - - def _update_training_job_details(self, training_job_name, details_json): - self.training_jobs[training_job_name].update(details_json) - def list_training_jobs( self, next_token, diff --git a/moto/sagemaker/responses.py b/moto/sagemaker/responses.py index a2bcbdff5..bf97830e0 100644 --- a/moto/sagemaker/responses.py +++ b/moto/sagemaker/responses.py @@ -256,12 +256,6 @@ class SageMakerResponse(BaseResponse): response = self.sagemaker_backend.describe_training_job(training_job_name) return json.dumps(response) - @amzn_request_id - def delete_training_job(self): - training_job_name = self._get_param("TrainingJobName") - self.sagemaker_backend.delete_training_job(training_job_name) - return 200, {}, json.dumps("{}") - @amzn_request_id def create_notebook_instance_lifecycle_config(self): lifecycle_configuration = ( diff --git a/tests/test_sagemaker/test_sagemaker_training.py b/tests/test_sagemaker/test_sagemaker_training.py index 93bc8a708..cdf656289 100644 --- a/tests/test_sagemaker/test_sagemaker_training.py +++ b/tests/test_sagemaker/test_sagemaker_training.py @@ -441,3 +441,15 @@ def test_delete_tags_from_training_job(): response = client.list_tags(ResourceArn=resource_arn) assert response["Tags"] == [] + + +@mock_sagemaker +def test_describe_unknown_training_job(): + client = boto3.client("sagemaker", region_name="us-east-1") + with pytest.raises(ClientError) as exc: + client.describe_training_job(TrainingJobName="unknown") + err = exc.value.response["Error"] + err["Code"].should.equal("ValidationException") + err["Message"].should.equal( + f"Could not find training job 'arn:aws:sagemaker:us-east-1:{ACCOUNT_ID}:training-job/unknown'." + )