diff --git a/moto/sagemaker/exceptions.py b/moto/sagemaker/exceptions.py index 0331fee89..43ccbf554 100644 --- a/moto/sagemaker/exceptions.py +++ b/moto/sagemaker/exceptions.py @@ -1,5 +1,5 @@ from __future__ import unicode_literals -from moto.core.exceptions import RESTError, JsonRESTError +from moto.core.exceptions import RESTError, JsonRESTError, AWSError ERROR_WITH_MODEL_NAME = """{% extends 'single_error' %} {% block extra %}{{ model }}{% endblock %} @@ -32,3 +32,7 @@ class MissingModel(ModelError): class ValidationError(JsonRESTError): def __init__(self, message, **kwargs): super(ValidationError, self).__init__("ValidationException", message, **kwargs) + + +class AWSValidationException(AWSError): + TYPE = "ValidationException" diff --git a/moto/sagemaker/models.py b/moto/sagemaker/models.py index b1c51abec..b6d513b25 100644 --- a/moto/sagemaker/models.py +++ b/moto/sagemaker/models.py @@ -3,11 +3,10 @@ from __future__ import unicode_literals import os from boto3 import Session from datetime import datetime - from moto.core import ACCOUNT_ID, BaseBackend, BaseModel, CloudFormationModel from moto.core.exceptions import RESTError from moto.sagemaker import validators -from .exceptions import MissingModel, ValidationError +from .exceptions import MissingModel, ValidationError, AWSValidationException class BaseObject(BaseModel): @@ -1243,6 +1242,90 @@ class SageMakerModelBackend(BaseBackend): except RESTError: return [] + def list_training_jobs( + self, + next_token, + max_results, + creation_time_after, + creation_time_before, + last_modified_time_after, + last_modified_time_before, + name_contains, + status_equals, + sort_by, + sort_order, + ): + if next_token: + try: + starting_index = int(next_token) + if starting_index > len(self.training_jobs): + raise ValueError # invalid next_token + except ValueError: + raise AWSValidationException('Invalid pagination token because "{0}".') + else: + starting_index = 0 + + if max_results: + end_index = max_results + starting_index + training_jobs_fetched = list(self.training_jobs.values())[ + starting_index:end_index + ] + if end_index >= len(self.training_jobs): + next_index = None + else: + next_index = end_index + else: + training_jobs_fetched = list(self.training_jobs.values()) + next_index = None + + if name_contains is not None: + training_jobs_fetched = filter( + lambda x: name_contains in x.training_job_name, training_jobs_fetched + ) + + if creation_time_after is not None: + training_jobs_fetched = filter( + lambda x: x.creation_time > creation_time_after, training_jobs_fetched + ) + + if creation_time_before is not None: + training_jobs_fetched = filter( + lambda x: x.creation_time < creation_time_before, training_jobs_fetched + ) + + if last_modified_time_after is not None: + training_jobs_fetched = filter( + lambda x: x.last_modified_time > last_modified_time_after, + training_jobs_fetched, + ) + + if last_modified_time_before is not None: + training_jobs_fetched = filter( + lambda x: x.last_modified_time < last_modified_time_before, + training_jobs_fetched, + ) + if status_equals is not None: + training_jobs_fetched = filter( + lambda x: x.training_job_status == status_equals, training_jobs_fetched + ) + + training_job_summaries = [ + { + "TrainingJobName": training_job_data.training_job_name, + "TrainingJobArn": training_job_data.training_job_arn, + "CreationTime": training_job_data.creation_time, + "TrainingEndTime": training_job_data.training_end_time, + "LastModifiedTime": training_job_data.last_modified_time, + "TrainingJobStatus": training_job_data.training_job_status, + } + for training_job_data in training_jobs_fetched + ] + + return { + "TrainingJobSummaries": training_job_summaries, + "NextToken": str(next_index) if next_index is not None else None, + } + sagemaker_backends = {} for region in Session().get_available_regions("sagemaker"): diff --git a/moto/sagemaker/responses.py b/moto/sagemaker/responses.py index 27b60662d..ebeee0279 100644 --- a/moto/sagemaker/responses.py +++ b/moto/sagemaker/responses.py @@ -1,6 +1,7 @@ from __future__ import unicode_literals import json +from moto.sagemaker.exceptions import AWSValidationException from moto.core.exceptions import AWSError from moto.core.responses import BaseResponse @@ -8,6 +9,10 @@ from moto.core.utils import amzn_request_id from .models import sagemaker_backends +def format_enum_error(value, attribute, allowed): + return f"Value '{value}' at '{attribute}' failed to satisfy constraint: Member must satisfy enum value set: {allowed}" + + class SageMakerResponse(BaseResponse): @property def sagemaker_backend(self): @@ -274,3 +279,65 @@ class SageMakerResponse(BaseResponse): ) ) return 200, {}, json.dumps("{}") + + @amzn_request_id + def list_training_jobs(self): + max_results_range = range(1, 101) + allowed_sort_by = ["Name", "CreationTime", "Status"] + allowed_sort_order = ["Ascending", "Descending"] + allowed_status_equals = [ + "Completed", + "Stopped", + "InProgress", + "Stopping", + "Failed", + ] + + try: + max_results = self._get_int_param("MaxResults") + sort_by = self._get_param("SortBy", "CreationTime") + sort_order = self._get_param("SortOrder", "Ascending") + status_equals = self._get_param("StatusEquals") + next_token = self._get_param("NextToken") + errors = [] + if max_results and max_results not in max_results_range: + errors.append( + "Value '%s' at 'maxResults' failed to satisfy constraint: Member must have value less than or equal to %s".format( + max_results, max_results_range[-1] + ) + ) + + if sort_by not in allowed_sort_by: + errors.append(format_enum_error(sort_by, "sortBy", allowed_sort_by)) + if sort_order not in allowed_sort_order: + errors.append( + format_enum_error(sort_order, "sortOrder", allowed_sort_order) + ) + + if status_equals and status_equals not in allowed_status_equals: + errors.append( + format_enum_error( + status_equals, "statusEquals", allowed_status_equals + ) + ) + + if errors != []: + raise AWSValidationException( + f"{len(errors)} validation errors detected: {';'.join(errors)}" + ) + + response = self.sagemaker_backend.list_training_jobs( + next_token=next_token, + max_results=max_results, + creation_time_after=self._get_param("CreationTimeAfter"), + creation_time_before=self._get_param("CreationTimeBefore"), + last_modified_time_after=self._get_param("LastModifiedTimeAfter"), + last_modified_time_before=self._get_param("LastModifiedTimeBefore"), + name_contains=self._get_param("NameContains"), + status_equals=status_equals, + sort_by=sort_by, + sort_order=sort_order, + ) + return 200, {}, json.dumps(response) + except AWSError as err: + return err.response() diff --git a/requirements-dev.txt b/requirements-dev.txt index 171617185..cdab55017 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -9,4 +9,4 @@ click inflection==0.3.1 lxml packaging - +prompt_toolkit diff --git a/tests/test_sagemaker/test_sagemaker_models.py b/tests/test_sagemaker/test_sagemaker_models.py index 975811be4..32239d557 100644 --- a/tests/test_sagemaker/test_sagemaker_models.py +++ b/tests/test_sagemaker/test_sagemaker_models.py @@ -15,10 +15,8 @@ class MySageMakerModel(object): def __init__(self, name, arn, container=None, vpc_config=None): self.name = name self.arn = arn - self.container = container if container else {} - self.vpc_config = ( - vpc_config if vpc_config else {"sg-groups": ["sg-123"], "subnets": ["123"]} - ) + self.container = container or {} + self.vpc_config = vpc_config or {"sg-groups": ["sg-123"], "subnets": ["123"]} def save(self): client = boto3.client("sagemaker", region_name="us-east-1") diff --git a/tests/test_sagemaker/test_sagemaker_training.py b/tests/test_sagemaker/test_sagemaker_training.py index c7b631ae3..94ab8eead 100644 --- a/tests/test_sagemaker/test_sagemaker_training.py +++ b/tests/test_sagemaker/test_sagemaker_training.py @@ -1,8 +1,15 @@ # -*- coding: utf-8 -*- from __future__ import unicode_literals +from moto.core.exceptions import JsonRESTError +from re import M +from moto.core import responses +from os import O_DSYNC, scandir +import pytest import boto3 +from botocore.exceptions import ClientError import datetime +from botocore.configloader import raw_config_parse import sure # noqa from moto import mock_sagemaker @@ -12,34 +19,44 @@ FAKE_ROLE_ARN = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID) TEST_REGION_NAME = "us-east-1" -@mock_sagemaker -def test_create_training_job(): - sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) - - training_job_name = "MyTrainingJob" - container = "382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1" - bucket = "my-bucket" - prefix = "sagemaker/DEMO-breast-cancer-prediction/" - - params = { - "RoleArn": FAKE_ROLE_ARN, - "TrainingJobName": training_job_name, - "AlgorithmSpecification": { - "TrainingImage": container, +class MyTrainingJobModel(object): + def __init__( + self, + training_job_name, + role_arn, + container=None, + bucket=None, + prefix=None, + algorithm_specification=None, + resource_config=None, + input_data_config=None, + output_data_config=None, + hyper_parameters=None, + stopping_condition=None, + ): + self.training_job_name = training_job_name + self.role_arn = role_arn + self.container = ( + container or "382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1" + ) + self.bucket = bucket or "my-bucket" + self.prefix = prefix or "sagemaker/DEMO-breast-cancer-prediction/" + self.algorithm_specification = algorithm_specification or { + "TrainingImage": self.container, "TrainingInputMode": "File", - }, - "ResourceConfig": { + } + self.resource_config = resource_config or { "InstanceCount": 1, "InstanceType": "ml.c4.2xlarge", "VolumeSizeInGB": 10, - }, - "InputDataConfig": [ + } + self.input_data_config = input_data_config or [ { "ChannelName": "train", "DataSource": { "S3DataSource": { "S3DataType": "S3Prefix", - "S3Uri": "s3://{}/{}/train/".format(bucket, prefix), + "S3Uri": "s3://{}/{}/train/".format(self.bucket, self.prefix), "S3DataDistributionType": "ShardedByS3Key", } }, @@ -51,27 +68,115 @@ def test_create_training_job(): "DataSource": { "S3DataSource": { "S3DataType": "S3Prefix", - "S3Uri": "s3://{}/{}/validation/".format(bucket, prefix), + "S3Uri": "s3://{}/{}/validation/".format( + self.bucket, self.prefix + ), "S3DataDistributionType": "FullyReplicated", } }, "CompressionType": "None", "RecordWrapperType": "None", }, - ], - "OutputDataConfig": {"S3OutputPath": "s3://{}/{}/".format(bucket, prefix)}, - "HyperParameters": { + ] + self.output_data_config = output_data_config or { + "S3OutputPath": "s3://{}/{}/".format(self.bucket, self.prefix) + } + self.hyper_parameters = hyper_parameters or { "feature_dim": "30", "mini_batch_size": "100", "predictor_type": "regressor", "epochs": "10", "num_models": "32", "loss": "absolute_loss", - }, - "StoppingCondition": {"MaxRuntimeInSeconds": 60 * 60}, - } + } - resp = sagemaker.create_training_job(**params) + self.stopping_condition = stopping_condition or {"MaxRuntimeInSeconds": 60 * 60} + + def save(self): + sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + params = { + "RoleArn": self.role_arn, + "TrainingJobName": self.training_job_name, + "AlgorithmSpecification": self.algorithm_specification, + "ResourceConfig": self.resource_config, + "InputDataConfig": self.input_data_config, + "OutputDataConfig": self.output_data_config, + "HyperParameters": self.hyper_parameters, + "StoppingCondition": self.stopping_condition, + } + return sagemaker.create_training_job(**params) + + +@mock_sagemaker +def test_create_training_job(): + sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + training_job_name = "MyTrainingJob" + role_arn = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID) + container = "382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1" + bucket = "my-bucket" + prefix = "sagemaker/DEMO-breast-cancer-prediction/" + algorithm_specification = { + "TrainingImage": container, + "TrainingInputMode": "File", + } + resource_config = { + "InstanceCount": 1, + "InstanceType": "ml.c4.2xlarge", + "VolumeSizeInGB": 10, + } + input_data_config = [ + { + "ChannelName": "train", + "DataSource": { + "S3DataSource": { + "S3DataType": "S3Prefix", + "S3Uri": "s3://{}/{}/train/".format(bucket, prefix), + "S3DataDistributionType": "ShardedByS3Key", + } + }, + "CompressionType": "None", + "RecordWrapperType": "None", + }, + { + "ChannelName": "validation", + "DataSource": { + "S3DataSource": { + "S3DataType": "S3Prefix", + "S3Uri": "s3://{}/{}/validation/".format(bucket, prefix), + "S3DataDistributionType": "FullyReplicated", + } + }, + "CompressionType": "None", + "RecordWrapperType": "None", + }, + ] + output_data_config = {"S3OutputPath": "s3://{}/{}/".format(bucket, prefix)} + hyper_parameters = { + "feature_dim": "30", + "mini_batch_size": "100", + "predictor_type": "regressor", + "epochs": "10", + "num_models": "32", + "loss": "absolute_loss", + } + stopping_condition = {"MaxRuntimeInSeconds": 60 * 60} + + job = MyTrainingJobModel( + training_job_name, + role_arn, + container=container, + bucket=bucket, + prefix=prefix, + algorithm_specification=algorithm_specification, + resource_config=resource_config, + input_data_config=input_data_config, + output_data_config=output_data_config, + hyper_parameters=hyper_parameters, + stopping_condition=stopping_condition, + ) + resp = job.save() resp["TrainingJobArn"].should.match( r"^arn:aws:sagemaker:.*:.*:training-job/{}$".format(training_job_name) ) @@ -82,29 +187,29 @@ def test_create_training_job(): r"^arn:aws:sagemaker:.*:.*:training-job/{}$".format(training_job_name) ) assert resp["ModelArtifacts"]["S3ModelArtifacts"].startswith( - params["OutputDataConfig"]["S3OutputPath"] + output_data_config["S3OutputPath"] ) assert training_job_name in (resp["ModelArtifacts"]["S3ModelArtifacts"]) assert resp["ModelArtifacts"]["S3ModelArtifacts"].endswith("output/model.tar.gz") assert resp["TrainingJobStatus"] == "Completed" assert resp["SecondaryStatus"] == "Completed" - assert resp["HyperParameters"] == params["HyperParameters"] + assert resp["HyperParameters"] == hyper_parameters assert ( resp["AlgorithmSpecification"]["TrainingImage"] - == params["AlgorithmSpecification"]["TrainingImage"] + == algorithm_specification["TrainingImage"] ) assert ( resp["AlgorithmSpecification"]["TrainingInputMode"] - == params["AlgorithmSpecification"]["TrainingInputMode"] + == algorithm_specification["TrainingInputMode"] ) assert "MetricDefinitions" in resp["AlgorithmSpecification"] assert "Name" in resp["AlgorithmSpecification"]["MetricDefinitions"][0] assert "Regex" in resp["AlgorithmSpecification"]["MetricDefinitions"][0] - assert resp["RoleArn"] == FAKE_ROLE_ARN - assert resp["InputDataConfig"] == params["InputDataConfig"] - assert resp["OutputDataConfig"] == params["OutputDataConfig"] - assert resp["ResourceConfig"] == params["ResourceConfig"] - assert resp["StoppingCondition"] == params["StoppingCondition"] + assert resp["RoleArn"] == role_arn + assert resp["InputDataConfig"] == input_data_config + assert resp["OutputDataConfig"] == output_data_config + assert resp["ResourceConfig"] == resource_config + assert resp["StoppingCondition"] == stopping_condition assert isinstance(resp["CreationTime"], datetime.datetime) assert isinstance(resp["TrainingStartTime"], datetime.datetime) assert isinstance(resp["TrainingEndTime"], datetime.datetime) @@ -120,3 +225,183 @@ def test_create_training_job(): assert "Timestamp" in resp["FinalMetricDataList"][0] pass + + +@mock_sagemaker +def test_list_training_jobs(): + client = boto3.client("sagemaker", region_name="us-east-1") + name = "blah" + arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar" + test_training_job = MyTrainingJobModel(training_job_name=name, role_arn=arn) + test_training_job.save() + training_jobs = client.list_training_jobs() + assert len(training_jobs["TrainingJobSummaries"]).should.equal(1) + assert training_jobs["TrainingJobSummaries"][0]["TrainingJobName"].should.equal( + name + ) + + assert training_jobs["TrainingJobSummaries"][0]["TrainingJobArn"].should.match( + r"^arn:aws:sagemaker:.*:.*:training-job/{}$".format(name) + ) + assert training_jobs.get("NextToken") is None + + +@mock_sagemaker +def test_list_training_jobs_multiple(): + client = boto3.client("sagemaker", region_name="us-east-1") + name_job_1 = "blah" + arn_job_1 = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar" + test_training_job_1 = MyTrainingJobModel( + training_job_name=name_job_1, role_arn=arn_job_1 + ) + test_training_job_1.save() + + name_job_2 = "blah2" + arn_job_2 = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar2" + test_training_job_2 = MyTrainingJobModel( + training_job_name=name_job_2, role_arn=arn_job_2 + ) + test_training_job_2.save() + training_jobs_limit = client.list_training_jobs(MaxResults=1) + assert len(training_jobs_limit["TrainingJobSummaries"]).should.equal(1) + + training_jobs = client.list_training_jobs() + assert len(training_jobs["TrainingJobSummaries"]).should.equal(2) + assert training_jobs.get("NextToken").should.be.none + + +@mock_sagemaker +def test_list_training_jobs_none(): + client = boto3.client("sagemaker", region_name="us-east-1") + training_jobs = client.list_training_jobs() + assert len(training_jobs["TrainingJobSummaries"]).should.equal(0) + + +@mock_sagemaker +def test_list_training_jobs_should_validate_input(): + client = boto3.client("sagemaker", region_name="us-east-1") + junk_status_equals = "blah" + with pytest.raises(ClientError) as ex: + client.list_training_jobs(StatusEquals=junk_status_equals) + expected_error = f"1 validation errors detected: Value '{junk_status_equals}' at 'statusEquals' failed to satisfy constraint: Member must satisfy enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', 'Failed']" + assert ex.value.response["Error"]["Code"] == "ValidationException" + assert ex.value.response["Error"]["Message"] == expected_error + + junk_next_token = "asdf" + with pytest.raises(ClientError) as ex: + client.list_training_jobs(NextToken=junk_next_token) + assert ex.value.response["Error"]["Code"] == "ValidationException" + assert ( + ex.value.response["Error"]["Message"] + == 'Invalid pagination token because "{0}".' + ) + + +@mock_sagemaker +def test_list_training_jobs_with_name_filters(): + client = boto3.client("sagemaker", region_name="us-east-1") + for i in range(5): + name = "xgboost-{}".format(i) + arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i) + MyTrainingJobModel(training_job_name=name, role_arn=arn).save() + for i in range(5): + name = "vgg-{}".format(i) + arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{}".format(i) + MyTrainingJobModel(training_job_name=name, role_arn=arn).save() + xgboost_training_jobs = client.list_training_jobs(NameContains="xgboost") + assert len(xgboost_training_jobs["TrainingJobSummaries"]).should.equal(5) + + training_jobs_with_2 = client.list_training_jobs(NameContains="2") + assert len(training_jobs_with_2["TrainingJobSummaries"]).should.equal(2) + + +@mock_sagemaker +def test_list_training_jobs_paginated(): + client = boto3.client("sagemaker", region_name="us-east-1") + for i in range(5): + name = "xgboost-{}".format(i) + arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i) + MyTrainingJobModel(training_job_name=name, role_arn=arn).save() + xgboost_training_job_1 = client.list_training_jobs( + NameContains="xgboost", MaxResults=1 + ) + assert len(xgboost_training_job_1["TrainingJobSummaries"]).should.equal(1) + assert xgboost_training_job_1["TrainingJobSummaries"][0][ + "TrainingJobName" + ].should.equal("xgboost-0") + assert xgboost_training_job_1.get("NextToken").should_not.be.none + + xgboost_training_job_next = client.list_training_jobs( + NameContains="xgboost", + MaxResults=1, + NextToken=xgboost_training_job_1.get("NextToken"), + ) + assert len(xgboost_training_job_next["TrainingJobSummaries"]).should.equal(1) + assert xgboost_training_job_next["TrainingJobSummaries"][0][ + "TrainingJobName" + ].should.equal("xgboost-1") + assert xgboost_training_job_next.get("NextToken").should_not.be.none + + +@mock_sagemaker +def test_list_training_jobs_paginated_with_target_in_middle(): + client = boto3.client("sagemaker", region_name="us-east-1") + for i in range(5): + name = "xgboost-{}".format(i) + arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i) + MyTrainingJobModel(training_job_name=name, role_arn=arn).save() + for i in range(5): + name = "vgg-{}".format(i) + arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{}".format(i) + MyTrainingJobModel(training_job_name=name, role_arn=arn).save() + + vgg_training_job_1 = client.list_training_jobs(NameContains="vgg", MaxResults=1) + assert len(vgg_training_job_1["TrainingJobSummaries"]).should.equal(0) + assert vgg_training_job_1.get("NextToken").should_not.be.none + + vgg_training_job_6 = client.list_training_jobs(NameContains="vgg", MaxResults=6) + + assert len(vgg_training_job_6["TrainingJobSummaries"]).should.equal(1) + assert vgg_training_job_6["TrainingJobSummaries"][0][ + "TrainingJobName" + ].should.equal("vgg-0") + assert vgg_training_job_6.get("NextToken").should_not.be.none + + vgg_training_job_10 = client.list_training_jobs(NameContains="vgg", MaxResults=10) + + assert len(vgg_training_job_10["TrainingJobSummaries"]).should.equal(5) + assert vgg_training_job_10["TrainingJobSummaries"][-1][ + "TrainingJobName" + ].should.equal("vgg-4") + assert vgg_training_job_10.get("NextToken").should.be.none + + +@mock_sagemaker +def test_list_training_jobs_paginated_with_fragmented_targets(): + client = boto3.client("sagemaker", region_name="us-east-1") + for i in range(5): + name = "xgboost-{}".format(i) + arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i) + MyTrainingJobModel(training_job_name=name, role_arn=arn).save() + for i in range(5): + name = "vgg-{}".format(i) + arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{}".format(i) + MyTrainingJobModel(training_job_name=name, role_arn=arn).save() + + training_jobs_with_2 = client.list_training_jobs(NameContains="2", MaxResults=8) + assert len(training_jobs_with_2["TrainingJobSummaries"]).should.equal(2) + assert training_jobs_with_2.get("NextToken").should_not.be.none + + training_jobs_with_2_next = client.list_training_jobs( + NameContains="2", MaxResults=1, NextToken=training_jobs_with_2.get("NextToken"), + ) + assert len(training_jobs_with_2_next["TrainingJobSummaries"]).should.equal(0) + assert training_jobs_with_2_next.get("NextToken").should_not.be.none + + training_jobs_with_2_next_next = client.list_training_jobs( + NameContains="2", + MaxResults=1, + NextToken=training_jobs_with_2_next.get("NextToken"), + ) + assert len(training_jobs_with_2_next_next["TrainingJobSummaries"]).should.equal(0) + assert training_jobs_with_2_next_next.get("NextToken").should.be.none