diff --git a/moto/sagemaker/models.py b/moto/sagemaker/models.py index 3da7da937..a999630e5 100644 --- a/moto/sagemaker/models.py +++ b/moto/sagemaker/models.py @@ -39,6 +39,62 @@ class BaseObject(BaseModel): return self.gen_response_object() +class FakeProcessingJob(BaseObject): + def __init__( + self, + app_specification, + experiment_config, + network_config, + processing_inputs, + processing_job_name, + processing_output_config, + processing_resources, + region_name, + role_arn, + stopping_condition, + ): + self.processing_job_name = processing_job_name + self.processing_job_arn = FakeProcessingJob.arn_formatter( + processing_job_name, region_name + ) + + now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + self.creation_time = now_string + self.last_modified_time = now_string + self.processing_end_time = now_string + + self.role_arn = role_arn + self.app_specification = app_specification + self.experiment_config = experiment_config + self.network_config = network_config + self.processing_inputs = processing_inputs + self.processing_job_status = "Completed" + self.processing_output_config = processing_output_config + self.stopping_condition = stopping_condition + + @property + def response_object(self): + response_object = self.gen_response_object() + return { + k: v for k, v in response_object.items() if v is not None and v != [None] + } + + @property + def response_create(self): + return {"ProcessingJobArn": self.processing_job_arn} + + @staticmethod + def arn_formatter(endpoint_name, region_name): + return ( + "arn:aws:sagemaker:" + + region_name + + ":" + + str(ACCOUNT_ID) + + ":processing-job/" + + endpoint_name + ) + + class FakeTrainingJob(BaseObject): def __init__( self, @@ -898,6 +954,7 @@ class SageMakerModelBackend(BaseBackend): self.endpoint_configs = {} self.endpoints = {} self.experiments = {} + self.processing_jobs = {} self.trials = {} self.trial_components = {} self.training_jobs = {} @@ -1672,6 +1729,129 @@ class SageMakerModelBackend(BaseBackend): except RESTError: return [] + def create_processing_job( + self, + app_specification, + experiment_config, + network_config, + processing_inputs, + processing_job_name, + processing_output_config, + processing_resources, + role_arn, + stopping_condition, + ): + processing_job = FakeProcessingJob( + app_specification=app_specification, + experiment_config=experiment_config, + network_config=network_config, + processing_inputs=processing_inputs, + processing_job_name=processing_job_name, + processing_output_config=processing_output_config, + processing_resources=processing_resources, + region_name=self.region_name, + role_arn=role_arn, + stopping_condition=stopping_condition, + ) + self.processing_jobs[processing_job_name] = processing_job + return processing_job + + def describe_processing_job(self, processing_job_name): + try: + return self.processing_jobs[processing_job_name].response_object + except KeyError: + message = "Could not find processing job '{}'.".format( + FakeProcessingJob.arn_formatter(processing_job_name, self.region_name) + ) + raise ValidationError(message=message) + + def list_processing_jobs( + self, + next_token, + max_results, + creation_time_after, + creation_time_before, + last_modified_time_after, + last_modified_time_before, + name_contains, + status_equals, + sort_by, + sort_order, + ): + if next_token: + try: + starting_index = int(next_token) + if starting_index > len(self.processing_jobs): + raise ValueError # invalid next_token + except ValueError: + raise AWSValidationException('Invalid pagination token because "{0}".') + else: + starting_index = 0 + + if max_results: + end_index = max_results + starting_index + processing_jobs_fetched = list(self.processing_jobs.values())[ + starting_index:end_index + ] + if end_index >= len(self.processing_jobs): + next_index = None + else: + next_index = end_index + else: + processing_jobs_fetched = list(self.processing_jobs.values()) + next_index = None + + if name_contains is not None: + processing_jobs_fetched = filter( + lambda x: name_contains in x.processing_job_name, + processing_jobs_fetched, + ) + + if creation_time_after is not None: + processing_jobs_fetched = filter( + lambda x: x.creation_time > creation_time_after, processing_jobs_fetched + ) + + if creation_time_before is not None: + processing_jobs_fetched = filter( + lambda x: x.creation_time < creation_time_before, + processing_jobs_fetched, + ) + + if last_modified_time_after is not None: + processing_jobs_fetched = filter( + lambda x: x.last_modified_time > last_modified_time_after, + processing_jobs_fetched, + ) + + if last_modified_time_before is not None: + processing_jobs_fetched = filter( + lambda x: x.last_modified_time < last_modified_time_before, + processing_jobs_fetched, + ) + if status_equals is not None: + processing_jobs_fetched = filter( + lambda x: x.training_job_status == status_equals, + processing_jobs_fetched, + ) + + processing_job_summaries = [ + { + "ProcessingJobName": processing_job_data.processing_job_name, + "ProcessingJobArn": processing_job_data.processing_job_arn, + "CreationTime": processing_job_data.creation_time, + "ProcessingEndTime": processing_job_data.processing_end_time, + "LastModifiedTime": processing_job_data.last_modified_time, + "ProcessingJobStatus": processing_job_data.processing_job_status, + } + for processing_job_data in processing_jobs_fetched + ] + + return { + "ProcessingJobSummaries": processing_job_summaries, + "NextToken": str(next_index) if next_index is not None else None, + } + def create_training_job( self, training_job_name, diff --git a/moto/sagemaker/responses.py b/moto/sagemaker/responses.py index 8ba597618..46d92b077 100644 --- a/moto/sagemaker/responses.py +++ b/moto/sagemaker/responses.py @@ -225,6 +225,33 @@ class SageMakerResponse(BaseResponse): self.sagemaker_backend.delete_endpoint(endpoint_name) return 200, {}, json.dumps("{}") + @amzn_request_id + def create_processing_job(self): + try: + processing_job = self.sagemaker_backend.create_processing_job( + app_specification=self._get_param("AppSpecification"), + experiment_config=self._get_param("ExperimentConfig"), + network_config=self._get_param("NetworkConfig"), + processing_inputs=self._get_param("ProcessingInputs"), + processing_job_name=self._get_param("ProcessingJobName"), + processing_output_config=self._get_param("ProcessingOutputConfig"), + processing_resources=self._get_param("ProcessingResources"), + role_arn=self._get_param("RoleArn"), + stopping_condition=self._get_param("StoppingCondition"), + ) + response = { + "ProcessingJobArn": processing_job.processing_job_arn, + } + return 200, {}, json.dumps(response) + except AWSError as err: + return err.response() + + @amzn_request_id + def describe_processing_job(self): + processing_job_name = self._get_param("ProcessingJobName") + response = self.sagemaker_backend.describe_processing_job(processing_job_name) + return json.dumps(response) + @amzn_request_id def create_training_job(self): try: @@ -420,6 +447,68 @@ class SageMakerResponse(BaseResponse): response = self.sagemaker_backend.list_associations(self.request_params) return 200, {}, json.dumps(response) + @amzn_request_id + def list_processing_jobs(self): + max_results_range = range(1, 101) + allowed_sort_by = ["Name", "CreationTime", "Status"] + allowed_sort_order = ["Ascending", "Descending"] + allowed_status_equals = [ + "Completed", + "Stopped", + "InProgress", + "Stopping", + "Failed", + ] + + try: + max_results = self._get_int_param("MaxResults") + sort_by = self._get_param("SortBy", "CreationTime") + sort_order = self._get_param("SortOrder", "Ascending") + status_equals = self._get_param("StatusEquals") + next_token = self._get_param("NextToken") + errors = [] + if max_results and max_results not in max_results_range: + errors.append( + "Value '{0}' at 'maxResults' failed to satisfy constraint: Member must have value less than or equal to {1}".format( + max_results, max_results_range[-1] + ) + ) + + if sort_by not in allowed_sort_by: + errors.append(format_enum_error(sort_by, "sortBy", allowed_sort_by)) + if sort_order not in allowed_sort_order: + errors.append( + format_enum_error(sort_order, "sortOrder", allowed_sort_order) + ) + + if status_equals and status_equals not in allowed_status_equals: + errors.append( + format_enum_error( + status_equals, "statusEquals", allowed_status_equals + ) + ) + + if errors != []: + raise AWSValidationException( + f"{len(errors)} validation errors detected: {';'.join(errors)}" + ) + + response = self.sagemaker_backend.list_processing_jobs( + next_token=next_token, + max_results=max_results, + creation_time_after=self._get_param("CreationTimeAfter"), + creation_time_before=self._get_param("CreationTimeBefore"), + last_modified_time_after=self._get_param("LastModifiedTimeAfter"), + last_modified_time_before=self._get_param("LastModifiedTimeBefore"), + name_contains=self._get_param("NameContains"), + status_equals=status_equals, + sort_by=sort_by, + sort_order=sort_order, + ) + return 200, {}, json.dumps(response) + except AWSError as err: + return err.response() + @amzn_request_id def list_training_jobs(self): max_results_range = range(1, 101) diff --git a/tests/test_sagemaker/test_sagemaker_processing.py b/tests/test_sagemaker/test_sagemaker_processing.py new file mode 100644 index 000000000..ce4d2974d --- /dev/null +++ b/tests/test_sagemaker/test_sagemaker_processing.py @@ -0,0 +1,334 @@ +import boto3 +from botocore.exceptions import ClientError +import datetime +import pytest + +from moto import mock_sagemaker +from moto.sts.models import ACCOUNT_ID + +FAKE_ROLE_ARN = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID) +TEST_REGION_NAME = "us-east-1" + + +class MyProcessingJobModel(object): + def __init__( + self, + processing_job_name, + role_arn, + container=None, + bucket=None, + prefix=None, + app_specification=None, + network_config=None, + processing_inputs=None, + processing_output_config=None, + processing_resources=None, + stopping_condition=None, + ): + self.processing_job_name = processing_job_name + self.role_arn = role_arn + self.container = ( + container + or "683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn:0.23-1-cpu-py3" + ) + self.bucket = bucket or "my-bucket" + self.prefix = prefix or "sagemaker" + self.app_specification = app_specification or { + "ImageUri": self.container, + "ContainerEntrypoint": ["python3",], + } + self.network_config = network_config or { + "EnableInterContainerTrafficEncryption": False, + "EnableNetworkIsolation": False, + } + self.processing_inputs = processing_inputs or [ + { + "InputName": "input", + "AppManaged": False, + "S3Input": { + "S3Uri": "s3://{}/{}/processing/".format(self.bucket, self.prefix), + "LocalPath": "/opt/ml/processing/input", + "S3DataType": "S3Prefix", + "S3InputMode": "File", + "S3DataDistributionType": "FullyReplicated", + "S3CompressionType": "None", + }, + } + ] + self.processing_output_config = processing_output_config or { + "Outputs": [ + { + "OutputName": "output", + "S3Output": { + "S3Uri": "s3://{}/{}/processing/".format( + self.bucket, self.prefix + ), + "LocalPath": "/opt/ml/processing/output", + "S3UploadMode": "EndOfJob", + }, + "AppManaged": False, + } + ] + } + self.processing_resources = processing_resources or { + "ClusterConfig": { + "InstanceCount": 1, + "InstanceType": "ml.m5.large", + "VolumeSizeInGB": 10, + }, + } + self.stopping_condition = stopping_condition or { + "MaxRuntimeInSeconds": 3600, + } + + def save(self): + sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + params = { + "AppSpecification": self.app_specification, + "NetworkConfig": self.network_config, + "ProcessingInputs": self.processing_inputs, + "ProcessingJobName": self.processing_job_name, + "ProcessingOutputConfig": self.processing_output_config, + "ProcessingResources": self.processing_resources, + "RoleArn": self.role_arn, + "StoppingCondition": self.stopping_condition, + } + + return sagemaker.create_processing_job(**params) + + +@mock_sagemaker +def test_create_processing_job(): + sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + processing_job_name = "MyProcessingJob" + role_arn = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID) + container = "382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1" + bucket = "my-bucket" + prefix = "my-prefix" + app_specification = { + "ImageUri": container, + "ContainerEntrypoint": ["python3", "app.py"], + } + processing_resources = { + "ClusterConfig": { + "InstanceCount": 2, + "InstanceType": "ml.m5.xlarge", + "VolumeSizeInGB": 20, + }, + } + stopping_condition = {"MaxRuntimeInSeconds": 60 * 60} + + job = MyProcessingJobModel( + processing_job_name, + role_arn, + container=container, + bucket=bucket, + prefix=prefix, + app_specification=app_specification, + processing_resources=processing_resources, + stopping_condition=stopping_condition, + ) + resp = job.save() + resp["ProcessingJobArn"].should.match( + r"^arn:aws:sagemaker:.*:.*:processing-job/{}$".format(processing_job_name) + ) + + resp = sagemaker.describe_processing_job(ProcessingJobName=processing_job_name) + resp["ProcessingJobName"].should.equal(processing_job_name) + resp["ProcessingJobArn"].should.match( + r"^arn:aws:sagemaker:.*:.*:processing-job/{}$".format(processing_job_name) + ) + assert "python3" in resp["AppSpecification"]["ContainerEntrypoint"] + assert "app.py" in resp["AppSpecification"]["ContainerEntrypoint"] + assert resp["RoleArn"] == role_arn + assert resp["ProcessingJobStatus"] == "Completed" + assert isinstance(resp["CreationTime"], datetime.datetime) + assert isinstance(resp["LastModifiedTime"], datetime.datetime) + + +@mock_sagemaker +def test_list_processing_jobs(): + client = boto3.client("sagemaker", region_name="us-east-1") + name = "blah" + arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar" + test_processing_job = MyProcessingJobModel(processing_job_name=name, role_arn=arn) + test_processing_job.save() + processing_jobs = client.list_processing_jobs() + assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(1) + assert processing_jobs["ProcessingJobSummaries"][0][ + "ProcessingJobName" + ].should.equal(name) + + assert processing_jobs["ProcessingJobSummaries"][0][ + "ProcessingJobArn" + ].should.match(r"^arn:aws:sagemaker:.*:.*:processing-job/{}$".format(name)) + assert processing_jobs.get("NextToken") is None + + +@mock_sagemaker +def test_list_processing_jobs_multiple(): + client = boto3.client("sagemaker", region_name="us-east-1") + name_job_1 = "blah" + arn_job_1 = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar" + test_processing_job_1 = MyProcessingJobModel( + processing_job_name=name_job_1, role_arn=arn_job_1 + ) + test_processing_job_1.save() + + name_job_2 = "blah2" + arn_job_2 = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar2" + test_processing_job_2 = MyProcessingJobModel( + processing_job_name=name_job_2, role_arn=arn_job_2 + ) + test_processing_job_2.save() + processing_jobs_limit = client.list_processing_jobs(MaxResults=1) + assert len(processing_jobs_limit["ProcessingJobSummaries"]).should.equal(1) + + processing_jobs = client.list_processing_jobs() + assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(2) + assert processing_jobs.get("NextToken").should.be.none + + +@mock_sagemaker +def test_list_processing_jobs_none(): + client = boto3.client("sagemaker", region_name="us-east-1") + processing_jobs = client.list_processing_jobs() + assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(0) + + +@mock_sagemaker +def test_list_processing_jobs_should_validate_input(): + client = boto3.client("sagemaker", region_name="us-east-1") + junk_status_equals = "blah" + with pytest.raises(ClientError) as ex: + client.list_processing_jobs(StatusEquals=junk_status_equals) + expected_error = f"1 validation errors detected: Value '{junk_status_equals}' at 'statusEquals' failed to satisfy constraint: Member must satisfy enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', 'Failed']" + assert ex.value.response["Error"]["Code"] == "ValidationException" + assert ex.value.response["Error"]["Message"] == expected_error + + junk_next_token = "asdf" + with pytest.raises(ClientError) as ex: + client.list_processing_jobs(NextToken=junk_next_token) + assert ex.value.response["Error"]["Code"] == "ValidationException" + assert ( + ex.value.response["Error"]["Message"] + == 'Invalid pagination token because "{0}".' + ) + + +@mock_sagemaker +def test_list_processing_jobs_with_name_filters(): + client = boto3.client("sagemaker", region_name="us-east-1") + for i in range(5): + name = "xgboost-{}".format(i) + arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i) + MyProcessingJobModel(processing_job_name=name, role_arn=arn).save() + for i in range(5): + name = "vgg-{}".format(i) + arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{}".format(i) + MyProcessingJobModel(processing_job_name=name, role_arn=arn).save() + xgboost_processing_jobs = client.list_processing_jobs(NameContains="xgboost") + assert len(xgboost_processing_jobs["ProcessingJobSummaries"]).should.equal(5) + + processing_jobs_with_2 = client.list_processing_jobs(NameContains="2") + assert len(processing_jobs_with_2["ProcessingJobSummaries"]).should.equal(2) + + +@mock_sagemaker +def test_list_processing_jobs_paginated(): + client = boto3.client("sagemaker", region_name="us-east-1") + for i in range(5): + name = "xgboost-{}".format(i) + arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i) + MyProcessingJobModel(processing_job_name=name, role_arn=arn).save() + xgboost_processing_job_1 = client.list_processing_jobs( + NameContains="xgboost", MaxResults=1 + ) + assert len(xgboost_processing_job_1["ProcessingJobSummaries"]).should.equal(1) + assert xgboost_processing_job_1["ProcessingJobSummaries"][0][ + "ProcessingJobName" + ].should.equal("xgboost-0") + assert xgboost_processing_job_1.get("NextToken").should_not.be.none + + xgboost_processing_job_next = client.list_processing_jobs( + NameContains="xgboost", + MaxResults=1, + NextToken=xgboost_processing_job_1.get("NextToken"), + ) + assert len(xgboost_processing_job_next["ProcessingJobSummaries"]).should.equal(1) + assert xgboost_processing_job_next["ProcessingJobSummaries"][0][ + "ProcessingJobName" + ].should.equal("xgboost-1") + assert xgboost_processing_job_next.get("NextToken").should_not.be.none + + +@mock_sagemaker +def test_list_processing_jobs_paginated_with_target_in_middle(): + client = boto3.client("sagemaker", region_name="us-east-1") + for i in range(5): + name = "xgboost-{}".format(i) + arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i) + MyProcessingJobModel(processing_job_name=name, role_arn=arn).save() + for i in range(5): + name = "vgg-{}".format(i) + arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{}".format(i) + MyProcessingJobModel(processing_job_name=name, role_arn=arn).save() + + vgg_processing_job_1 = client.list_processing_jobs(NameContains="vgg", MaxResults=1) + assert len(vgg_processing_job_1["ProcessingJobSummaries"]).should.equal(0) + assert vgg_processing_job_1.get("NextToken").should_not.be.none + + vgg_processing_job_6 = client.list_processing_jobs(NameContains="vgg", MaxResults=6) + + assert len(vgg_processing_job_6["ProcessingJobSummaries"]).should.equal(1) + assert vgg_processing_job_6["ProcessingJobSummaries"][0][ + "ProcessingJobName" + ].should.equal("vgg-0") + assert vgg_processing_job_6.get("NextToken").should_not.be.none + + vgg_processing_job_10 = client.list_processing_jobs( + NameContains="vgg", MaxResults=10 + ) + + assert len(vgg_processing_job_10["ProcessingJobSummaries"]).should.equal(5) + assert vgg_processing_job_10["ProcessingJobSummaries"][-1][ + "ProcessingJobName" + ].should.equal("vgg-4") + assert vgg_processing_job_10.get("NextToken").should.be.none + + +@mock_sagemaker +def test_list_processing_jobs_paginated_with_fragmented_targets(): + client = boto3.client("sagemaker", region_name="us-east-1") + for i in range(5): + name = "xgboost-{}".format(i) + arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i) + MyProcessingJobModel(processing_job_name=name, role_arn=arn).save() + for i in range(5): + name = "vgg-{}".format(i) + arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{}".format(i) + MyProcessingJobModel(processing_job_name=name, role_arn=arn).save() + + processing_jobs_with_2 = client.list_processing_jobs(NameContains="2", MaxResults=8) + assert len(processing_jobs_with_2["ProcessingJobSummaries"]).should.equal(2) + assert processing_jobs_with_2.get("NextToken").should_not.be.none + + processing_jobs_with_2_next = client.list_processing_jobs( + NameContains="2", + MaxResults=1, + NextToken=processing_jobs_with_2.get("NextToken"), + ) + assert len(processing_jobs_with_2_next["ProcessingJobSummaries"]).should.equal(0) + assert processing_jobs_with_2_next.get("NextToken").should_not.be.none + + processing_jobs_with_2_next_next = client.list_processing_jobs( + NameContains="2", + MaxResults=1, + NextToken=processing_jobs_with_2_next.get("NextToken"), + ) + assert len(processing_jobs_with_2_next_next["ProcessingJobSummaries"]).should.equal( + 0 + ) + assert processing_jobs_with_2_next_next.get("NextToken").should.be.none