From e703ee9a76235ffda514123cf5f1832ed7704610 Mon Sep 17 00:00:00 2001 From: Killian O'Daly Date: Wed, 27 Apr 2022 12:56:08 +0100 Subject: [PATCH] Increase Tag support for Sagemaker. (#5052) --- IMPLEMENTATION_COVERAGE.md | 8 +- docs/docs/services/sagemaker.rst | 6 +- moto/sagemaker/models.py | 188 ++----- moto/sagemaker/responses.py | 45 +- .../test_sagemaker/test_sagemaker_endpoint.py | 263 ++++++---- .../test_sagemaker_experiment.py | 74 ++- tests/test_sagemaker/test_sagemaker_models.py | 142 ++--- .../test_sagemaker_notebooks.py | 229 +++++---- .../test_sagemaker_processing.py | 483 ++++++++++-------- tests/test_sagemaker/test_sagemaker_search.py | 100 ++-- .../test_sagemaker/test_sagemaker_training.py | 44 ++ tests/test_sagemaker/test_sagemaker_trial.py | 32 ++ .../test_sagemaker_trial_component.py | 29 ++ 13 files changed, 907 insertions(+), 736 deletions(-) diff --git a/IMPLEMENTATION_COVERAGE.md b/IMPLEMENTATION_COVERAGE.md index efb05c941..5bb83a9ad 100644 --- a/IMPLEMENTATION_COVERAGE.md +++ b/IMPLEMENTATION_COVERAGE.md @@ -4904,10 +4904,10 @@ ## sagemaker
-15% implemented +16% implemented - [ ] add_association -- [ ] add_tags +- [X] add_tags - [X] associate_trial_component - [ ] batch_describe_model_package - [ ] create_action @@ -4988,7 +4988,7 @@ - [ ] delete_pipeline - [ ] delete_project - [ ] delete_studio_lifecycle_config -- [ ] delete_tags +- [X] delete_tags - [X] delete_trial - [X] delete_trial_component - [ ] delete_user_profile @@ -5100,7 +5100,7 @@ - [ ] list_projects - [ ] list_studio_lifecycle_configs - [ ] list_subscribed_workteams -- [ ] list_tags +- [X] list_tags - [X] list_training_jobs - [ ] list_training_jobs_for_hyper_parameter_tuning_job - [ ] list_transform_jobs diff --git a/docs/docs/services/sagemaker.rst b/docs/docs/services/sagemaker.rst index b7d2ff3a2..681aeb497 100644 --- a/docs/docs/services/sagemaker.rst +++ b/docs/docs/services/sagemaker.rst @@ -26,7 +26,7 @@ sagemaker |start-h3| Implemented features for this service |end-h3| - [ ] add_association -- [ ] add_tags +- [X] add_tags - [X] associate_trial_component - [ ] batch_describe_model_package - [ ] create_action @@ -107,7 +107,7 @@ sagemaker - [ ] delete_pipeline - [ ] delete_project - [ ] delete_studio_lifecycle_config -- [ ] delete_tags +- [X] delete_tags - [X] delete_trial - [X] delete_trial_component - [ ] delete_user_profile @@ -219,7 +219,7 @@ sagemaker - [ ] list_projects - [ ] list_studio_lifecycle_configs - [ ] list_subscribed_workteams -- [ ] list_tags +- [X] list_tags - [X] list_training_jobs - [ ] list_training_jobs_for_hyper_parameter_tuning_job - [ ] list_transform_jobs diff --git a/moto/sagemaker/models.py b/moto/sagemaker/models.py index fb2307615..8adb8779f 100644 --- a/moto/sagemaker/models.py +++ b/moto/sagemaker/models.py @@ -2,7 +2,6 @@ import json import os from datetime import datetime from moto.core import ACCOUNT_ID, BaseBackend, BaseModel, CloudFormationModel -from moto.core.exceptions import RESTError from moto.core.utils import BackendDict from moto.sagemaker import validators from moto.utilities.paginator import paginate @@ -36,6 +35,13 @@ PAGINATION_MODEL = { "unique_attribute": "trial_component_arn", "fail_on_invalid_token": True, }, + "list_tags": { + "input_token": "NextToken", + "limit_key": "MaxResults", + "limit_default": 50, + "unique_attribute": "Key", + "fail_on_invalid_token": True, + }, } @@ -76,6 +82,7 @@ class FakeProcessingJob(BaseObject): processing_output_config, region_name, role_arn, + tags, stopping_condition, ): self.processing_job_name = processing_job_name @@ -87,7 +94,7 @@ class FakeProcessingJob(BaseObject): self.creation_time = now_string self.last_modified_time = now_string self.processing_end_time = now_string - + self.tags = tags or [] self.role_arn = role_arn self.app_specification = app_specification self.experiment_config = experiment_config @@ -152,7 +159,7 @@ class FakeTrainingJob(BaseObject): self.resource_config = resource_config self.vpc_config = vpc_config self.stopping_condition = stopping_condition - self.tags = tags + self.tags = tags or [] self.enable_network_isolation = enable_network_isolation self.enable_inter_container_traffic_encryption = ( enable_inter_container_traffic_encryption @@ -1095,51 +1102,38 @@ class SageMakerModelBackend(BaseBackend): "LastModifiedTime": experiment_data.last_modified_time, } - def add_tags_to_experiment(self, experiment_arn, tags): - experiment = [ - self.experiments[i] - for i in self.experiments - if self.experiments[i].experiment_arn == experiment_arn - ][0] - experiment.tags.extend(tags) + def _get_resource_from_arn(self, arn): + resources = { + "model": self._models, + "notebook-instance": self.notebook_instances, + "endpoint": self.endpoints, + "endpoint-config": self.endpoint_configs, + "training-job": self.training_jobs, + "experiment": self.experiments, + "experiment-trial": self.trials, + "experiment-trial-component": self.trial_components, + "processing-job": self.processing_jobs, + } + target_resource, target_name = arn.split(":")[-1].split("/") + try: + resource = resources.get(target_resource).get(target_name) + except KeyError: + message = f"Could not find {target_resource} with name {target_name}" + raise ValidationError(message=message) + return resource - def add_tags_to_trial(self, trial_arn, tags): - trial = [ - self.trials[i] for i in self.trials if self.trials[i].trial_arn == trial_arn - ][0] - trial.tags.extend(tags) + def add_tags(self, arn, tags): + resource = self._get_resource_from_arn(arn) + resource.tags.extend(tags) - def add_tags_to_trial_component(self, trial_component_arn, tags): - trial_component = [ - self.trial_components[i] - for i in self.trial_components - if self.trial_components[i].trial_component_arn == trial_component_arn - ][0] - trial_component.tags.extend(tags) + @paginate(pagination_model=PAGINATION_MODEL) + def list_tags(self, arn): + resource = self._get_resource_from_arn(arn) + return resource.tags - def delete_tags_from_experiment(self, experiment_arn, tag_keys): - experiment = [ - self.experiments[i] - for i in self.experiments - if self.experiments[i].experiment_arn == experiment_arn - ][0] - experiment.tags = [tag for tag in experiment.tags if tag["Key"] not in tag_keys] - - def delete_tags_from_trial(self, trial_arn, tag_keys): - trial = [ - self.trials[i] for i in self.trials if self.trials[i].trial_arn == trial_arn - ][0] - trial.tags = [tag for tag in trial.tags if tag["Key"] not in tag_keys] - - def delete_tags_from_trial_component(self, trial_component_arn, tag_keys): - trial_component = [ - self.trial_components[i] - for i in self.trial_components - if self.trial_components[i].trial_component_arn == trial_component_arn - ][0] - trial_component.tags = [ - tag for tag in trial_component.tags if tag["Key"] not in tag_keys - ] + def delete_tags(self, arn, tag_keys): + resource = self._get_resource_from_arn(arn) + resource.tags = [tag for tag in resource.tags if tag["Key"] not in tag_keys] @paginate(pagination_model=PAGINATION_MODEL) def list_experiments(self): @@ -1281,24 +1275,6 @@ class SageMakerModelBackend(BaseBackend): ) raise ValidationError(message=message) - def get_experiment_by_arn(self, arn): - experiments = [ - experiment - for experiment in self.experiments.values() - if experiment.experiment_arn == arn - ] - if len(experiments) == 0: - message = "RecordNotFound" - raise ValidationError(message=message) - return experiments[0] - - def get_experiment_tags(self, arn): - try: - experiment = self.get_experiment_by_arn(arn) - return experiment.tags or [] - except RESTError: - return [] - def create_trial(self, trial_name, experiment_name): trial = FakeTrial( region_name=self.region_name, @@ -1328,20 +1304,6 @@ class SageMakerModelBackend(BaseBackend): ) raise ValidationError(message=message) - def get_trial_by_arn(self, arn): - trials = [trial for trial in self.trials.values() if trial.trial_arn == arn] - if len(trials) == 0: - message = "RecordNotFound" - raise ValidationError(message=message) - return trials[0] - - def get_trial_tags(self, arn): - try: - trial = self.get_trial_by_arn(arn) - return trial.tags or [] - except RESTError: - return [] - @paginate(pagination_model=PAGINATION_MODEL) def list_trials(self, experiment_name=None, trial_component_name=None): trials_fetched = list(self.trials.values()) @@ -1382,24 +1344,6 @@ class SageMakerModelBackend(BaseBackend): ) raise ValidationError(message=message) - def get_trial_component_by_arn(self, arn): - trial_components = [ - trial_component - for trial_component in self.trial_components.values() - if trial_component.trial_component_arn == arn - ] - if len(trial_components) == 0: - message = "RecordNotFound" - raise ValidationError(message=message) - return trial_components[0] - - def get_trial_component_tags(self, arn): - try: - trial_component = self.get_trial_component_by_arn(arn) - return trial_component.tags or [] - except RESTError: - return [] - def describe_trial_component(self, trial_component_name): try: return self.trial_components[trial_component_name].response_object @@ -1518,16 +1462,6 @@ class SageMakerModelBackend(BaseBackend): except KeyError: raise ValidationError(message="RecordNotFound") - def get_notebook_instance_by_arn(self, arn): - instances = [ - notebook_instance - for notebook_instance in self.notebook_instances.values() - if notebook_instance.arn == arn - ] - if len(instances) == 0: - raise ValidationError(message="RecordNotFound") - return instances[0] - def start_notebook_instance(self, notebook_instance_name): notebook_instance = self.get_notebook_instance(notebook_instance_name) notebook_instance.start() @@ -1545,13 +1479,6 @@ class SageMakerModelBackend(BaseBackend): raise ValidationError(message=message) del self.notebook_instances[notebook_instance_name] - def get_notebook_instance_tags(self, arn): - try: - notebook_instance = self.get_notebook_instance_by_arn(arn) - return notebook_instance.tags or [] - except RESTError: - return [] - def create_notebook_instance_lifecycle_config( self, notebook_instance_lifecycle_config_name, on_create, on_start ): @@ -1694,24 +1621,6 @@ class SageMakerModelBackend(BaseBackend): ) raise ValidationError(message=message) - def get_endpoint_by_arn(self, arn): - endpoints = [ - endpoint - for endpoint in self.endpoints.values() - if endpoint.endpoint_arn == arn - ] - if len(endpoints) == 0: - message = "RecordNotFound" - raise ValidationError(message=message) - return endpoints[0] - - def get_endpoint_tags(self, arn): - try: - endpoint = self.get_endpoint_by_arn(arn) - return endpoint.tags or [] - except RESTError: - return [] - def create_processing_job( self, app_specification, @@ -1721,6 +1630,7 @@ class SageMakerModelBackend(BaseBackend): processing_job_name, processing_output_config, role_arn, + tags, stopping_condition, ): processing_job = FakeProcessingJob( @@ -1733,6 +1643,7 @@ class SageMakerModelBackend(BaseBackend): region_name=self.region_name, role_arn=role_arn, stopping_condition=stopping_condition, + tags=tags, ) self.processing_jobs[processing_job_name] = processing_job return processing_job @@ -1894,23 +1805,6 @@ class SageMakerModelBackend(BaseBackend): ) raise ValidationError(message=message) - def get_training_job_by_arn(self, arn): - training_jobs = [ - training_job - for training_job in self.training_jobs.values() - if training_job.training_job_arn == arn - ] - if len(training_jobs) == 0: - raise ValidationError(message="RecordNotFound") - return training_jobs[0] - - def get_training_job_tags(self, arn): - try: - training_job = self.get_training_job_by_arn(arn) - return training_job.tags or [] - except RESTError: - return [] - def _update_training_job_details(self, training_job_name, details_json): self.training_jobs[training_job_name].update(details_json) diff --git a/moto/sagemaker/responses.py b/moto/sagemaker/responses.py index 63bc213be..78f268e53 100644 --- a/moto/sagemaker/responses.py +++ b/moto/sagemaker/responses.py @@ -1,7 +1,6 @@ import json from moto.sagemaker.exceptions import AWSValidationException -from moto.core.exceptions import AWSError from moto.core.responses import BaseResponse from moto.core.utils import amzn_request_id from .models import sagemaker_backends @@ -117,48 +116,29 @@ class SageMakerResponse(BaseResponse): @amzn_request_id def list_tags(self): arn = self._get_param("ResourceArn") - try: - if ":notebook-instance/" in arn: - tags = self.sagemaker_backend.get_notebook_instance_tags(arn) - elif ":endpoint/" in arn: - tags = self.sagemaker_backend.get_endpoint_tags(arn) - elif ":training-job/" in arn: - tags = self.sagemaker_backend.get_training_job_tags(arn) - elif ":experiment/" in arn: - tags = self.sagemaker_backend.get_experiment_tags(arn) - elif ":experiment-trial/" in arn: - tags = self.sagemaker_backend.get_trial_tags(arn) - elif ":experiment-trial-component/" in arn: - tags = self.sagemaker_backend.get_trial_component_tags(arn) - else: - tags = [] - except AWSError: - tags = [] - response = {"Tags": tags} + max_results = self._get_param("MaxResults") + next_token = self._get_param("NextToken") + paged_results, next_token = self.sagemaker_backend.list_tags( + arn=arn, MaxResults=max_results, NextToken=next_token + ) + response = {"Tags": paged_results} + if next_token: + response["NextToken"] = next_token return 200, {}, json.dumps(response) @amzn_request_id def add_tags(self): arn = self._get_param("ResourceArn") tags = self._get_param("Tags") - if ":experiment/" in arn: - self.sagemaker_backend.add_tags_to_experiment(arn, tags) - elif ":experiment-trial/" in arn: - self.sagemaker_backend.add_tags_to_trial(arn, tags) - elif ":experiment-trial-component/" in arn: - self.sagemaker_backend.add_tags_to_trial_component(arn, tags) - return 200, {}, json.dumps({}) + tags = self.sagemaker_backend.add_tags(arn, tags) + response = {"Tags": tags} + return 200, {}, json.dumps(response) @amzn_request_id def delete_tags(self): arn = self._get_param("ResourceArn") tag_keys = self._get_param("TagKeys") - if ":experiment/" in arn: - self.sagemaker_backend.delete_tags_from_experiment(arn, tag_keys) - elif ":experiment-trial/" in arn: - self.sagemaker_backend.delete_tags_from_trial(arn, tag_keys) - elif ":experiment-trial-component/" in arn: - self.sagemaker_backend.delete_tags_from_trial_component(arn, tag_keys) + self.sagemaker_backend.delete_tags(arn, tag_keys) return 200, {}, json.dumps({}) @amzn_request_id @@ -222,6 +202,7 @@ class SageMakerResponse(BaseResponse): processing_output_config=self._get_param("ProcessingOutputConfig"), role_arn=self._get_param("RoleArn"), stopping_condition=self._get_param("StoppingCondition"), + tags=self._get_param("Tags"), ) response = { "ProcessingJobArn": processing_job.processing_job_arn, diff --git a/tests/test_sagemaker/test_sagemaker_endpoint.py b/tests/test_sagemaker/test_sagemaker_endpoint.py index 549959bb5..4a1b30218 100644 --- a/tests/test_sagemaker/test_sagemaker_endpoint.py +++ b/tests/test_sagemaker/test_sagemaker_endpoint.py @@ -1,4 +1,6 @@ import datetime +import uuid + import boto3 from botocore.exceptions import ClientError import sure # noqa # pylint: disable=unused-import @@ -8,115 +10,114 @@ from moto.sts.models import ACCOUNT_ID import pytest TEST_REGION_NAME = "us-east-1" -FAKE_ROLE_ARN = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID) +TEST_ROLE_ARN = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID) GENERIC_TAGS_PARAM = [ {"Key": "newkey1", "Value": "newval1"}, {"Key": "newkey2", "Value": "newval2"}, ] +TEST_MODEL_NAME = "MyModel" +TEST_ENDPOINT_NAME = "MyEndpoint" +TEST_ENDPOINT_CONFIG_NAME = "MyEndpointConfig" +TEST_PRODUCTION_VARIANTS = [ + { + "VariantName": "MyProductionVariant", + "ModelName": TEST_MODEL_NAME, + "InitialInstanceCount": 1, + "InstanceType": "ml.t2.medium", + }, +] + + +@pytest.fixture +def sagemaker_client(): + return boto3.client("sagemaker", region_name=TEST_REGION_NAME) @mock_sagemaker -def test_create_endpoint_config(): - sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) - - model_name = "MyModel" - production_variants = [ - { - "VariantName": "MyProductionVariant", - "ModelName": model_name, - "InitialInstanceCount": 1, - "InstanceType": "ml.t2.medium", - }, - ] - - endpoint_config_name = "MyEndpointConfig" +def test_create_endpoint_config(sagemaker_client): with pytest.raises(ClientError) as e: - sagemaker.create_endpoint_config( - EndpointConfigName=endpoint_config_name, - ProductionVariants=production_variants, + sagemaker_client.create_endpoint_config( + EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME, + ProductionVariants=TEST_PRODUCTION_VARIANTS, ) assert e.value.response["Error"]["Message"].startswith("Could not find model") - _create_model(sagemaker, model_name) - resp = sagemaker.create_endpoint_config( - EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants + _create_model(sagemaker_client, TEST_MODEL_NAME) + resp = sagemaker_client.create_endpoint_config( + EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME, + ProductionVariants=TEST_PRODUCTION_VARIANTS, ) resp["EndpointConfigArn"].should.match( - r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name) + r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format( + TEST_ENDPOINT_CONFIG_NAME + ) ) - resp = sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name) - resp["EndpointConfigArn"].should.match( - r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name) + resp = sagemaker_client.describe_endpoint_config( + EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME ) - resp["EndpointConfigName"].should.equal(endpoint_config_name) - resp["ProductionVariants"].should.equal(production_variants) + resp["EndpointConfigArn"].should.match( + r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format( + TEST_ENDPOINT_CONFIG_NAME + ) + ) + resp["EndpointConfigName"].should.equal(TEST_ENDPOINT_CONFIG_NAME) + resp["ProductionVariants"].should.equal(TEST_PRODUCTION_VARIANTS) @mock_sagemaker -def test_delete_endpoint_config(): - sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) - - model_name = "MyModel" - _create_model(sagemaker, model_name) - - endpoint_config_name = "MyEndpointConfig" - production_variants = [ - { - "VariantName": "MyProductionVariant", - "ModelName": model_name, - "InitialInstanceCount": 1, - "InstanceType": "ml.t2.medium", - }, - ] - - resp = sagemaker.create_endpoint_config( - EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants +def test_delete_endpoint_config(sagemaker_client): + _create_model(sagemaker_client, TEST_MODEL_NAME) + resp = sagemaker_client.create_endpoint_config( + EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME, + ProductionVariants=TEST_PRODUCTION_VARIANTS, ) resp["EndpointConfigArn"].should.match( - r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name) + r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format( + TEST_ENDPOINT_CONFIG_NAME + ) ) - resp = sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name) + resp = sagemaker_client.describe_endpoint_config( + EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME + ) resp["EndpointConfigArn"].should.match( - r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name) + r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format( + TEST_ENDPOINT_CONFIG_NAME + ) ) - resp = sagemaker.delete_endpoint_config(EndpointConfigName=endpoint_config_name) + sagemaker_client.delete_endpoint_config( + EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME + ) with pytest.raises(ClientError) as e: - sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name) + sagemaker_client.describe_endpoint_config( + EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME + ) assert e.value.response["Error"]["Message"].startswith( "Could not find endpoint configuration" ) with pytest.raises(ClientError) as e: - sagemaker.delete_endpoint_config(EndpointConfigName=endpoint_config_name) + sagemaker_client.delete_endpoint_config( + EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME + ) assert e.value.response["Error"]["Message"].startswith( "Could not find endpoint configuration" ) @mock_sagemaker -def test_create_endpoint_invalid_instance_type(): - sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) - - model_name = "MyModel" - _create_model(sagemaker, model_name) +def test_create_endpoint_invalid_instance_type(sagemaker_client): + _create_model(sagemaker_client, TEST_MODEL_NAME) instance_type = "InvalidInstanceType" - production_variants = [ - { - "VariantName": "MyProductionVariant", - "ModelName": model_name, - "InitialInstanceCount": 1, - "InstanceType": instance_type, - }, - ] + production_variants = TEST_PRODUCTION_VARIANTS + production_variants[0]["InstanceType"] = instance_type - endpoint_config_name = "MyEndpointConfig" with pytest.raises(ClientError) as e: - sagemaker.create_endpoint_config( - EndpointConfigName=endpoint_config_name, + sagemaker_client.create_endpoint_config( + EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME, ProductionVariants=production_variants, ) assert e.value.response["Error"]["Code"] == "ValidationException" @@ -127,71 +128,131 @@ def test_create_endpoint_invalid_instance_type(): @mock_sagemaker -def test_create_endpoint(): - sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) - - endpoint_name = "MyEndpoint" +def test_create_endpoint(sagemaker_client): with pytest.raises(ClientError) as e: - sagemaker.create_endpoint( - EndpointName=endpoint_name, EndpointConfigName="NonexistentEndpointConfig" + sagemaker_client.create_endpoint( + EndpointName=TEST_ENDPOINT_NAME, + EndpointConfigName="NonexistentEndpointConfig", ) assert e.value.response["Error"]["Message"].startswith( "Could not find endpoint configuration" ) - model_name = "MyModel" - _create_model(sagemaker, model_name) + _create_model(sagemaker_client, TEST_MODEL_NAME) - endpoint_config_name = "MyEndpointConfig" - _create_endpoint_config(sagemaker, endpoint_config_name, model_name) + _create_endpoint_config( + sagemaker_client, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME + ) - resp = sagemaker.create_endpoint( - EndpointName=endpoint_name, - EndpointConfigName=endpoint_config_name, + resp = sagemaker_client.create_endpoint( + EndpointName=TEST_ENDPOINT_NAME, + EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME, Tags=GENERIC_TAGS_PARAM, ) resp["EndpointArn"].should.match( - r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(endpoint_name) + r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(TEST_ENDPOINT_NAME) ) - resp = sagemaker.describe_endpoint(EndpointName=endpoint_name) + resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME) resp["EndpointArn"].should.match( - r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(endpoint_name) + r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(TEST_ENDPOINT_NAME) ) - resp["EndpointName"].should.equal(endpoint_name) - resp["EndpointConfigName"].should.equal(endpoint_config_name) + resp["EndpointName"].should.equal(TEST_ENDPOINT_NAME) + resp["EndpointConfigName"].should.equal(TEST_ENDPOINT_CONFIG_NAME) resp["EndpointStatus"].should.equal("InService") assert isinstance(resp["CreationTime"], datetime.datetime) assert isinstance(resp["LastModifiedTime"], datetime.datetime) resp["ProductionVariants"][0]["VariantName"].should.equal("MyProductionVariant") - resp = sagemaker.list_tags(ResourceArn=resp["EndpointArn"]) + resp = sagemaker_client.list_tags(ResourceArn=resp["EndpointArn"]) assert resp["Tags"] == GENERIC_TAGS_PARAM @mock_sagemaker -def test_delete_endpoint(): - sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) +def test_delete_endpoint(sagemaker_client): + _set_up_sagemaker_resources( + sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME + ) - model_name = "MyModel" - _create_model(sagemaker, model_name) - - endpoint_config_name = "MyEndpointConfig" - _create_endpoint_config(sagemaker, endpoint_config_name, model_name) - - endpoint_name = "MyEndpoint" - _create_endpoint(sagemaker, endpoint_name, endpoint_config_name) - - sagemaker.delete_endpoint(EndpointName=endpoint_name) + sagemaker_client.delete_endpoint(EndpointName=TEST_ENDPOINT_NAME) with pytest.raises(ClientError) as e: - sagemaker.describe_endpoint(EndpointName=endpoint_name) + sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME) assert e.value.response["Error"]["Message"].startswith("Could not find endpoint") with pytest.raises(ClientError) as e: - sagemaker.delete_endpoint(EndpointName=endpoint_name) + sagemaker_client.delete_endpoint(EndpointName=TEST_ENDPOINT_NAME) assert e.value.response["Error"]["Message"].startswith("Could not find endpoint") +@mock_sagemaker +def test_add_tags_endpoint(sagemaker_client): + _set_up_sagemaker_resources( + sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME + ) + + resource_arn = f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:endpoint/{TEST_ENDPOINT_NAME}" + response = sagemaker_client.add_tags( + ResourceArn=resource_arn, Tags=GENERIC_TAGS_PARAM + ) + assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 + + response = sagemaker_client.list_tags(ResourceArn=resource_arn) + assert response["Tags"] == GENERIC_TAGS_PARAM + + +@mock_sagemaker +def test_delete_tags_endpoint(sagemaker_client): + _set_up_sagemaker_resources( + sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME + ) + + resource_arn = f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:endpoint/{TEST_ENDPOINT_NAME}" + response = sagemaker_client.add_tags( + ResourceArn=resource_arn, Tags=GENERIC_TAGS_PARAM + ) + assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 + + tag_keys = [tag["Key"] for tag in GENERIC_TAGS_PARAM] + response = sagemaker_client.delete_tags(ResourceArn=resource_arn, TagKeys=tag_keys) + assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 + + response = sagemaker_client.list_tags(ResourceArn=resource_arn) + assert response["Tags"] == [] + + +@mock_sagemaker +def test_list_tags_endpoint(sagemaker_client): + _set_up_sagemaker_resources( + sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME + ) + + tags = [] + for _ in range(80): + tags.append({"Key": str(uuid.uuid4()), "Value": "myValue"}) + + resource_arn = f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:endpoint/{TEST_ENDPOINT_NAME}" + response = sagemaker_client.add_tags(ResourceArn=resource_arn, Tags=tags) + assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 + + response = sagemaker_client.list_tags(ResourceArn=resource_arn) + assert len(response["Tags"]) == 50 + assert response["Tags"] == tags[:50] + + response = sagemaker_client.list_tags( + ResourceArn=resource_arn, NextToken=response["NextToken"] + ) + assert len(response["Tags"]) == 30 + assert response["Tags"] == tags[50:] + + +def _set_up_sagemaker_resources( + boto_client, endpoint_name, endpoint_config_name, model_name +): + _create_model(boto_client, model_name) + _create_endpoint_config(boto_client, endpoint_config_name, model_name) + _create_endpoint(boto_client, endpoint_name, endpoint_config_name) + + def _create_model(boto_client, model_name): resp = boto_client.create_model( ModelName=model_name, @@ -199,7 +260,7 @@ def _create_model(boto_client, model_name): "Image": "382416733822.dkr.ecr.us-east-1.amazonaws.com/factorization-machines:1", "ModelDataUrl": "s3://MyBucket/model.tar.gz", }, - ExecutionRoleArn=FAKE_ROLE_ARN, + ExecutionRoleArn=TEST_ROLE_ARN, ) assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 diff --git a/tests/test_sagemaker/test_sagemaker_experiment.py b/tests/test_sagemaker/test_sagemaker_experiment.py index 810a82ee3..6e9d92b0e 100644 --- a/tests/test_sagemaker/test_sagemaker_experiment.py +++ b/tests/test_sagemaker/test_sagemaker_experiment.py @@ -1,54 +1,56 @@ import boto3 +import pytest from moto import mock_sagemaker from moto.sts.models import ACCOUNT_ID TEST_REGION_NAME = "us-east-1" +TEST_EXPERIMENT_NAME = "MyExperimentName" + + +@pytest.fixture +def sagemaker_client(): + return boto3.client("sagemaker", region_name=TEST_REGION_NAME) @mock_sagemaker -def test_create_experiment(): - client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) - - experiment_name = "some-experiment-name" - - resp = client.create_experiment(ExperimentName=experiment_name) +def test_create_experiment(sagemaker_client): + resp = sagemaker_client.create_experiment(ExperimentName=TEST_EXPERIMENT_NAME) assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 - resp = client.list_experiments() + resp = sagemaker_client.list_experiments() assert len(resp["ExperimentSummaries"]) == 1 - assert resp["ExperimentSummaries"][0]["ExperimentName"] == experiment_name + assert resp["ExperimentSummaries"][0]["ExperimentName"] == TEST_EXPERIMENT_NAME assert ( resp["ExperimentSummaries"][0]["ExperimentArn"] - == f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment/{experiment_name}" + == f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment/{TEST_EXPERIMENT_NAME}" ) @mock_sagemaker -def test_list_experiments(): - client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) +def test_list_experiments(sagemaker_client): experiment_names = [f"some-experiment-name-{i}" for i in range(10)] for experiment_name in experiment_names: - resp = client.create_experiment(ExperimentName=experiment_name) + resp = sagemaker_client.create_experiment(ExperimentName=experiment_name) assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 - resp = client.list_experiments(MaxResults=1) + resp = sagemaker_client.list_experiments(MaxResults=1) assert len(resp["ExperimentSummaries"]) == 1 next_token = resp["NextToken"] - resp = client.list_experiments(MaxResults=2, NextToken=next_token) + resp = sagemaker_client.list_experiments(MaxResults=2, NextToken=next_token) assert len(resp["ExperimentSummaries"]) == 2 next_token = resp["NextToken"] - resp = client.list_experiments(NextToken=next_token) + resp = sagemaker_client.list_experiments(NextToken=next_token) assert len(resp["ExperimentSummaries"]) == 7 @@ -56,65 +58,53 @@ def test_list_experiments(): @mock_sagemaker -def test_delete_experiment(): - client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) +def test_delete_experiment(sagemaker_client): + sagemaker_client.create_experiment(ExperimentName=TEST_EXPERIMENT_NAME) - experiment_name = "some-experiment-name" - - resp = client.create_experiment(ExperimentName=experiment_name) - - resp = client.delete_experiment(ExperimentName=experiment_name) + resp = sagemaker_client.delete_experiment(ExperimentName=TEST_EXPERIMENT_NAME) assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 - resp = client.list_experiments() + resp = sagemaker_client.list_experiments() assert len(resp["ExperimentSummaries"]) == 0 @mock_sagemaker -def test_add_tags_to_experiment(): - client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) +def test_add_tags_to_experiment(sagemaker_client): + sagemaker_client.create_experiment(ExperimentName=TEST_EXPERIMENT_NAME) - experiment_name = "some-experiment-name" - - resp = client.create_experiment(ExperimentName=experiment_name) - - resp = client.describe_experiment(ExperimentName=experiment_name) + resp = sagemaker_client.describe_experiment(ExperimentName=TEST_EXPERIMENT_NAME) arn = resp["ExperimentArn"] tags = [{"Key": "name", "Value": "value"}] - client.add_tags(ResourceArn=arn, Tags=tags) + sagemaker_client.add_tags(ResourceArn=arn, Tags=tags) assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 - resp = client.list_tags(ResourceArn=arn) + resp = sagemaker_client.list_tags(ResourceArn=arn) assert resp["Tags"] == tags @mock_sagemaker -def test_delete_tags_to_experiment(): - client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) +def test_delete_tags_to_experiment(sagemaker_client): + sagemaker_client.create_experiment(ExperimentName=TEST_EXPERIMENT_NAME) - experiment_name = "some-experiment-name" - - resp = client.create_experiment(ExperimentName=experiment_name) - - resp = client.describe_experiment(ExperimentName=experiment_name) + resp = sagemaker_client.describe_experiment(ExperimentName=TEST_EXPERIMENT_NAME) arn = resp["ExperimentArn"] tags = [{"Key": "name", "Value": "value"}] - client.add_tags(ResourceArn=arn, Tags=tags) + sagemaker_client.add_tags(ResourceArn=arn, Tags=tags) - client.delete_tags(ResourceArn=arn, TagKeys=[i["Key"] for i in tags]) + sagemaker_client.delete_tags(ResourceArn=arn, TagKeys=[i["Key"] for i in tags]) assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 - resp = client.list_tags(ResourceArn=arn) + resp = sagemaker_client.list_tags(ResourceArn=arn) assert resp["Tags"] == [] diff --git a/tests/test_sagemaker/test_sagemaker_models.py b/tests/test_sagemaker/test_sagemaker_models.py index 89d9f4273..25668e01f 100644 --- a/tests/test_sagemaker/test_sagemaker_models.py +++ b/tests/test_sagemaker/test_sagemaker_models.py @@ -7,118 +7,142 @@ import sure # noqa # pylint: disable=unused-import from moto.sagemaker.models import VpcConfig +TEST_REGION_NAME = "us-east-1" +TEST_ARN = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar" +TEST_MODEL_NAME = "MyModelName" + + +@pytest.fixture +def sagemaker_client(): + return boto3.client("sagemaker", region_name=TEST_REGION_NAME) + class MySageMakerModel(object): - def __init__(self, name, arn, container=None, vpc_config=None): - self.name = name - self.arn = arn + def __init__(self, name=None, arn=None, container=None, vpc_config=None): + self.name = name or TEST_MODEL_NAME + self.arn = arn or TEST_ARN self.container = container or {} self.vpc_config = vpc_config or {"sg-groups": ["sg-123"], "subnets": ["123"]} - def save(self): - client = boto3.client("sagemaker", region_name="us-east-1") + def save(self, sagemaker_client): vpc_config = VpcConfig( self.vpc_config.get("sg-groups"), self.vpc_config.get("subnets") ) - client.create_model( + resp = sagemaker_client.create_model( ModelName=self.name, ExecutionRoleArn=self.arn, VpcConfig=vpc_config.response_object, ) + return resp @mock_sagemaker -def test_describe_model(): - client = boto3.client("sagemaker", region_name="us-east-1") - test_model = MySageMakerModel( - name="blah", - arn="arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar", - vpc_config={"sg-groups": ["sg-123"], "subnets": ["123"]}, - ) - test_model.save() - model = client.describe_model(ModelName="blah") - assert model.get("ModelName").should.equal("blah") +def test_describe_model(sagemaker_client): + test_model = MySageMakerModel() + test_model.save(sagemaker_client) + model = sagemaker_client.describe_model(ModelName=TEST_MODEL_NAME) + assert model.get("ModelName").should.equal(TEST_MODEL_NAME) @mock_sagemaker -def test_describe_model_not_found(): - client = boto3.client("sagemaker", region_name="us-east-1") +def test_describe_model_not_found(sagemaker_client): with pytest.raises(ClientError) as err: - client.describe_model(ModelName="unknown") + sagemaker_client.describe_model(ModelName="unknown") assert err.value.response["Error"]["Message"].should.contain("Could not find model") @mock_sagemaker -def test_create_model(): - client = boto3.client("sagemaker", region_name="us-east-1") +def test_create_model(sagemaker_client): vpc_config = VpcConfig(["sg-foobar"], ["subnet-xxx"]) - exec_role_arn = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar" - name = "blah" - model = client.create_model( - ModelName=name, - ExecutionRoleArn=exec_role_arn, + model = sagemaker_client.create_model( + ModelName=TEST_MODEL_NAME, + ExecutionRoleArn=TEST_ARN, VpcConfig=vpc_config.response_object, ) - - model["ModelArn"].should.match(r"^arn:aws:sagemaker:.*:.*:model/{}$".format(name)) + model["ModelArn"].should.match( + r"^arn:aws:sagemaker:.*:.*:model/{}$".format(TEST_MODEL_NAME) + ) @mock_sagemaker -def test_delete_model(): - client = boto3.client("sagemaker", region_name="us-east-1") - name = "blah" - arn = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar" - test_model = MySageMakerModel(name=name, arn=arn) - test_model.save() +def test_delete_model(sagemaker_client): + test_model = MySageMakerModel() + test_model.save(sagemaker_client) - assert len(client.list_models()["Models"]).should.equal(1) - client.delete_model(ModelName=name) - assert len(client.list_models()["Models"]).should.equal(0) + assert len(sagemaker_client.list_models()["Models"]).should.equal(1) + sagemaker_client.delete_model(ModelName=TEST_MODEL_NAME) + assert len(sagemaker_client.list_models()["Models"]).should.equal(0) @mock_sagemaker -def test_delete_model_not_found(): +def test_delete_model_not_found(sagemaker_client): with pytest.raises(ClientError) as err: - boto3.client("sagemaker", region_name="us-east-1").delete_model( - ModelName="blah" - ) + sagemaker_client.delete_model(ModelName="blah") assert err.value.response["Error"]["Code"].should.equal("404") @mock_sagemaker -def test_list_models(): - client = boto3.client("sagemaker", region_name="us-east-1") - name = "blah" - arn = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar" - test_model = MySageMakerModel(name=name, arn=arn) - test_model.save() - models = client.list_models() +def test_list_models(sagemaker_client): + test_model = MySageMakerModel() + test_model.save(sagemaker_client) + models = sagemaker_client.list_models() assert len(models["Models"]).should.equal(1) - assert models["Models"][0]["ModelName"].should.equal(name) + assert models["Models"][0]["ModelName"].should.equal(TEST_MODEL_NAME) assert models["Models"][0]["ModelArn"].should.match( - r"^arn:aws:sagemaker:.*:.*:model/{}$".format(name) + r"^arn:aws:sagemaker:.*:.*:model/{}$".format(TEST_MODEL_NAME) ) @mock_sagemaker -def test_list_models_multiple(): - client = boto3.client("sagemaker", region_name="us-east-1") - +def test_list_models_multiple(sagemaker_client): name_model_1 = "blah" arn_model_1 = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar" test_model_1 = MySageMakerModel(name=name_model_1, arn=arn_model_1) - test_model_1.save() + test_model_1.save(sagemaker_client) name_model_2 = "blah2" arn_model_2 = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar2" test_model_2 = MySageMakerModel(name=name_model_2, arn=arn_model_2) - test_model_2.save() - models = client.list_models() + test_model_2.save(sagemaker_client) + models = sagemaker_client.list_models() assert len(models["Models"]).should.equal(2) @mock_sagemaker -def test_list_models_none(): - client = boto3.client("sagemaker", region_name="us-east-1") - models = client.list_models() +def test_list_models_none(sagemaker_client): + models = sagemaker_client.list_models() assert len(models["Models"]).should.equal(0) + + +@mock_sagemaker +def test_add_tags_to_model(sagemaker_client): + model = MySageMakerModel().save(sagemaker_client) + resource_arn = model["ModelArn"] + + tags = [ + {"Key": "myKey", "Value": "myValue"}, + ] + response = sagemaker_client.add_tags(ResourceArn=resource_arn, Tags=tags) + assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 + + response = sagemaker_client.list_tags(ResourceArn=resource_arn) + assert response["Tags"] == tags + + +@mock_sagemaker +def test_delete_tags_from_model(sagemaker_client): + model = MySageMakerModel().save(sagemaker_client) + resource_arn = model["ModelArn"] + + tags = [ + {"Key": "myKey", "Value": "myValue"}, + ] + response = sagemaker_client.add_tags(ResourceArn=resource_arn, Tags=tags) + assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 + + tag_keys = [tag["Key"] for tag in tags] + response = sagemaker_client.delete_tags(ResourceArn=resource_arn, TagKeys=tag_keys) + assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 + + response = sagemaker_client.list_tags(ResourceArn=resource_arn) + assert response["Tags"] == [] diff --git a/tests/test_sagemaker/test_sagemaker_notebooks.py b/tests/test_sagemaker/test_sagemaker_notebooks.py index 6b46307f8..d2289631f 100644 --- a/tests/test_sagemaker/test_sagemaker_notebooks.py +++ b/tests/test_sagemaker/test_sagemaker_notebooks.py @@ -22,34 +22,42 @@ FAKE_ADDL_CODE_REPOS = [ "https://github.com/user/repo2", "https://github.com/user/repo2", ] +FAKE_NAME_PARAM = "MyNotebookInstance" +FAKE_INSTANCE_TYPE_PARAM = "ml.t2.medium" + + +@pytest.fixture +def sagemaker_client(): + return boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + +def _get_notebook_instance_arn(notebook_name): + return f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:notebook-instance/{notebook_name}" + + +def _get_notebook_instance_lifecycle_arn(lifecycle_name): + return f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:notebook-instance-lifecycle-configuration/{lifecycle_name}" @mock_sagemaker -def test_create_notebook_instance_minimal_params(): - - sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) - - NAME_PARAM = "MyNotebookInstance" - INSTANCE_TYPE_PARAM = "ml.t2.medium" - +def test_create_notebook_instance_minimal_params(sagemaker_client): args = { - "NotebookInstanceName": NAME_PARAM, - "InstanceType": INSTANCE_TYPE_PARAM, + "NotebookInstanceName": FAKE_NAME_PARAM, + "InstanceType": FAKE_INSTANCE_TYPE_PARAM, "RoleArn": FAKE_ROLE_ARN, } - resp = sagemaker.create_notebook_instance(**args) - assert resp["NotebookInstanceArn"].startswith("arn:aws:sagemaker") - assert resp["NotebookInstanceArn"].endswith(args["NotebookInstanceName"]) + resp = sagemaker_client.create_notebook_instance(**args) + expected_notebook_arn = _get_notebook_instance_arn(FAKE_NAME_PARAM) + assert resp["NotebookInstanceArn"] == expected_notebook_arn - resp = sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM) - assert resp["NotebookInstanceArn"].startswith("arn:aws:sagemaker") - assert resp["NotebookInstanceArn"].endswith(args["NotebookInstanceName"]) - assert resp["NotebookInstanceName"] == NAME_PARAM - assert resp["NotebookInstanceStatus"] == "InService" - assert resp["Url"] == "{}.notebook.{}.sagemaker.aws".format( - NAME_PARAM, TEST_REGION_NAME + resp = sagemaker_client.describe_notebook_instance( + NotebookInstanceName=FAKE_NAME_PARAM ) - assert resp["InstanceType"] == INSTANCE_TYPE_PARAM + assert resp["NotebookInstanceArn"] == expected_notebook_arn + assert resp["NotebookInstanceName"] == FAKE_NAME_PARAM + assert resp["NotebookInstanceStatus"] == "InService" + assert resp["Url"] == f"{FAKE_NAME_PARAM}.notebook.{TEST_REGION_NAME}.sagemaker.aws" + assert resp["InstanceType"] == FAKE_INSTANCE_TYPE_PARAM assert resp["RoleArn"] == FAKE_ROLE_ARN assert isinstance(resp["LastModifiedTime"], datetime.datetime) assert isinstance(resp["CreationTime"], datetime.datetime) @@ -61,69 +69,60 @@ def test_create_notebook_instance_minimal_params(): @mock_sagemaker -def test_create_notebook_instance_params(): - - sagemaker = boto3.client("sagemaker", region_name="us-east-1") - - NAME_PARAM = "MyNotebookInstance" - INSTANCE_TYPE_PARAM = "ml.t2.medium" - DIRECT_INTERNET_ACCESS_PARAM = "Enabled" - VOLUME_SIZE_IN_GB_PARAM = 7 - ACCELERATOR_TYPES_PARAM = ["ml.eia1.medium", "ml.eia2.medium"] - ROOT_ACCESS_PARAM = "Disabled" +def test_create_notebook_instance_params(sagemaker_client): + fake_direct_internet_access_param = "Enabled" + volume_size_in_gb_param = 7 + accelerator_types_param = ["ml.eia1.medium", "ml.eia2.medium"] + root_access_param = "Disabled" args = { - "NotebookInstanceName": NAME_PARAM, - "InstanceType": INSTANCE_TYPE_PARAM, + "NotebookInstanceName": FAKE_NAME_PARAM, + "InstanceType": FAKE_INSTANCE_TYPE_PARAM, "SubnetId": FAKE_SUBNET_ID, "SecurityGroupIds": FAKE_SECURITY_GROUP_IDS, "RoleArn": FAKE_ROLE_ARN, "KmsKeyId": FAKE_KMS_KEY_ID, "Tags": GENERIC_TAGS_PARAM, "LifecycleConfigName": FAKE_LIFECYCLE_CONFIG_NAME, - "DirectInternetAccess": DIRECT_INTERNET_ACCESS_PARAM, - "VolumeSizeInGB": VOLUME_SIZE_IN_GB_PARAM, - "AcceleratorTypes": ACCELERATOR_TYPES_PARAM, + "DirectInternetAccess": fake_direct_internet_access_param, + "VolumeSizeInGB": volume_size_in_gb_param, + "AcceleratorTypes": accelerator_types_param, "DefaultCodeRepository": FAKE_DEFAULT_CODE_REPO, "AdditionalCodeRepositories": FAKE_ADDL_CODE_REPOS, - "RootAccess": ROOT_ACCESS_PARAM, + "RootAccess": root_access_param, } - resp = sagemaker.create_notebook_instance(**args) - assert resp["NotebookInstanceArn"].startswith("arn:aws:sagemaker") - assert resp["NotebookInstanceArn"].endswith(args["NotebookInstanceName"]) + resp = sagemaker_client.create_notebook_instance(**args) + expected_notebook_arn = _get_notebook_instance_arn(FAKE_NAME_PARAM) + assert resp["NotebookInstanceArn"] == expected_notebook_arn - resp = sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM) - assert resp["NotebookInstanceArn"].startswith("arn:aws:sagemaker") - assert resp["NotebookInstanceArn"].endswith(args["NotebookInstanceName"]) - assert resp["NotebookInstanceName"] == NAME_PARAM - assert resp["NotebookInstanceStatus"] == "InService" - assert resp["Url"] == "{}.notebook.{}.sagemaker.aws".format( - NAME_PARAM, TEST_REGION_NAME + resp = sagemaker_client.describe_notebook_instance( + NotebookInstanceName=FAKE_NAME_PARAM ) - assert resp["InstanceType"] == INSTANCE_TYPE_PARAM + assert resp["NotebookInstanceArn"] == expected_notebook_arn + assert resp["NotebookInstanceName"] == FAKE_NAME_PARAM + assert resp["NotebookInstanceStatus"] == "InService" + assert resp["Url"] == f"{FAKE_NAME_PARAM}.notebook.{TEST_REGION_NAME}.sagemaker.aws" + assert resp["InstanceType"] == FAKE_INSTANCE_TYPE_PARAM assert resp["RoleArn"] == FAKE_ROLE_ARN assert isinstance(resp["LastModifiedTime"], datetime.datetime) assert isinstance(resp["CreationTime"], datetime.datetime) assert resp["DirectInternetAccess"] == "Enabled" - assert resp["VolumeSizeInGB"] == VOLUME_SIZE_IN_GB_PARAM + assert resp["VolumeSizeInGB"] == volume_size_in_gb_param # assert resp["RootAccess"] == True # ToDo: Not sure if this defaults... assert resp["SubnetId"] == FAKE_SUBNET_ID assert resp["SecurityGroups"] == FAKE_SECURITY_GROUP_IDS assert resp["KmsKeyId"] == FAKE_KMS_KEY_ID assert resp["NotebookInstanceLifecycleConfigName"] == FAKE_LIFECYCLE_CONFIG_NAME - assert resp["AcceleratorTypes"] == ACCELERATOR_TYPES_PARAM + assert resp["AcceleratorTypes"] == accelerator_types_param assert resp["DefaultCodeRepository"] == FAKE_DEFAULT_CODE_REPO assert resp["AdditionalCodeRepositories"] == FAKE_ADDL_CODE_REPOS - resp = sagemaker.list_tags(ResourceArn=resp["NotebookInstanceArn"]) + resp = sagemaker_client.list_tags(ResourceArn=resp["NotebookInstanceArn"]) assert resp["Tags"] == GENERIC_TAGS_PARAM @mock_sagemaker -def test_create_notebook_instance_invalid_instance_type(): - - sagemaker = boto3.client("sagemaker", region_name="us-east-1") - +def test_create_notebook_instance_invalid_instance_type(sagemaker_client): instance_type = "undefined_instance_type" args = { "NotebookInstanceName": "MyNotebookInstance", @@ -131,7 +130,7 @@ def test_create_notebook_instance_invalid_instance_type(): "RoleArn": FAKE_ROLE_ARN, } with pytest.raises(ClientError) as ex: - sagemaker.create_notebook_instance(**args) + sagemaker_client.create_notebook_instance(**args) assert ex.value.response["Error"]["Code"] == "ValidationException" expected_message = "Value '{}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: [".format( instance_type @@ -141,78 +140,79 @@ def test_create_notebook_instance_invalid_instance_type(): @mock_sagemaker -def test_notebook_instance_lifecycle(): - sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) - - NAME_PARAM = "MyNotebookInstance" - INSTANCE_TYPE_PARAM = "ml.t2.medium" - +def test_notebook_instance_lifecycle(sagemaker_client): args = { - "NotebookInstanceName": NAME_PARAM, - "InstanceType": INSTANCE_TYPE_PARAM, + "NotebookInstanceName": FAKE_NAME_PARAM, + "InstanceType": FAKE_INSTANCE_TYPE_PARAM, "RoleArn": FAKE_ROLE_ARN, } - resp = sagemaker.create_notebook_instance(**args) - assert resp["NotebookInstanceArn"].startswith("arn:aws:sagemaker") - assert resp["NotebookInstanceArn"].endswith(args["NotebookInstanceName"]) + resp = sagemaker_client.create_notebook_instance(**args) + expected_notebook_arn = _get_notebook_instance_arn(FAKE_NAME_PARAM) + assert resp["NotebookInstanceArn"] == expected_notebook_arn - resp = sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM) + resp = sagemaker_client.describe_notebook_instance( + NotebookInstanceName=FAKE_NAME_PARAM + ) notebook_instance_arn = resp["NotebookInstanceArn"] with pytest.raises(ClientError) as ex: - sagemaker.delete_notebook_instance(NotebookInstanceName=NAME_PARAM) + sagemaker_client.delete_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM) assert ex.value.response["Error"]["Code"] == "ValidationException" expected_message = "Status (InService) not in ([Stopped, Failed]). Unable to transition to (Deleting) for Notebook Instance ({})".format( notebook_instance_arn ) assert expected_message in ex.value.response["Error"]["Message"] - sagemaker.stop_notebook_instance(NotebookInstanceName=NAME_PARAM) + sagemaker_client.stop_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM) - resp = sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM) + resp = sagemaker_client.describe_notebook_instance( + NotebookInstanceName=FAKE_NAME_PARAM + ) assert resp["NotebookInstanceStatus"] == "Stopped" - sagemaker.start_notebook_instance(NotebookInstanceName=NAME_PARAM) + sagemaker_client.start_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM) - resp = sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM) + resp = sagemaker_client.describe_notebook_instance( + NotebookInstanceName=FAKE_NAME_PARAM + ) assert resp["NotebookInstanceStatus"] == "InService" - sagemaker.stop_notebook_instance(NotebookInstanceName=NAME_PARAM) + sagemaker_client.stop_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM) - resp = sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM) + resp = sagemaker_client.describe_notebook_instance( + NotebookInstanceName=FAKE_NAME_PARAM + ) assert resp["NotebookInstanceStatus"] == "Stopped" - sagemaker.delete_notebook_instance(NotebookInstanceName=NAME_PARAM) + sagemaker_client.delete_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM) with pytest.raises(ClientError) as ex: - sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM) + sagemaker_client.describe_notebook_instance( + NotebookInstanceName=FAKE_NAME_PARAM + ) assert ex.value.response["Error"]["Message"] == "RecordNotFound" @mock_sagemaker -def test_describe_nonexistent_model(): - sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) - +def test_describe_nonexistent_model(sagemaker_client): with pytest.raises(ClientError) as e: - sagemaker.describe_model(ModelName="Nonexistent") + sagemaker_client.describe_model(ModelName="Nonexistent") assert e.value.response["Error"]["Message"].startswith("Could not find model") @mock_sagemaker -def test_notebook_instance_lifecycle_config(): - sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) - +def test_notebook_instance_lifecycle_config(sagemaker_client): name = "MyLifeCycleConfig" on_create = [{"Content": "Create Script Line 1"}] on_start = [{"Content": "Start Script Line 1"}] - resp = sagemaker.create_notebook_instance_lifecycle_config( + resp = sagemaker_client.create_notebook_instance_lifecycle_config( NotebookInstanceLifecycleConfigName=name, OnCreate=on_create, OnStart=on_start ) - assert resp["NotebookInstanceLifecycleConfigArn"].startswith("arn:aws:sagemaker") - assert resp["NotebookInstanceLifecycleConfigArn"].endswith(name) + expected_arn = _get_notebook_instance_lifecycle_arn(name) + assert resp["NotebookInstanceLifecycleConfigArn"] == expected_arn with pytest.raises(ClientError) as e: - resp = sagemaker.create_notebook_instance_lifecycle_config( + sagemaker_client.create_notebook_instance_lifecycle_config( NotebookInstanceLifecycleConfigName=name, OnCreate=on_create, OnStart=on_start, @@ -221,23 +221,22 @@ def test_notebook_instance_lifecycle_config(): "Notebook Instance Lifecycle Config already exists.)" ) - resp = sagemaker.describe_notebook_instance_lifecycle_config( + resp = sagemaker_client.describe_notebook_instance_lifecycle_config( NotebookInstanceLifecycleConfigName=name ) assert resp["NotebookInstanceLifecycleConfigName"] == name - assert resp["NotebookInstanceLifecycleConfigArn"].startswith("arn:aws:sagemaker") - assert resp["NotebookInstanceLifecycleConfigArn"].endswith(name) + assert resp["NotebookInstanceLifecycleConfigArn"] == expected_arn assert resp["OnStart"] == on_start assert resp["OnCreate"] == on_create assert isinstance(resp["LastModifiedTime"], datetime.datetime) assert isinstance(resp["CreationTime"], datetime.datetime) - sagemaker.delete_notebook_instance_lifecycle_config( + sagemaker_client.delete_notebook_instance_lifecycle_config( NotebookInstanceLifecycleConfigName=name ) with pytest.raises(ClientError) as e: - sagemaker.describe_notebook_instance_lifecycle_config( + sagemaker_client.describe_notebook_instance_lifecycle_config( NotebookInstanceLifecycleConfigName=name ) assert e.value.response["Error"]["Message"].endswith( @@ -245,9 +244,53 @@ def test_notebook_instance_lifecycle_config(): ) with pytest.raises(ClientError) as e: - sagemaker.delete_notebook_instance_lifecycle_config( + sagemaker_client.delete_notebook_instance_lifecycle_config( NotebookInstanceLifecycleConfigName=name ) assert e.value.response["Error"]["Message"].endswith( "Notebook Instance Lifecycle Config does not exist.)" ) + + +@mock_sagemaker +def test_add_tags_to_notebook(sagemaker_client): + args = { + "NotebookInstanceName": FAKE_NAME_PARAM, + "InstanceType": FAKE_INSTANCE_TYPE_PARAM, + "RoleArn": FAKE_ROLE_ARN, + } + resp = sagemaker_client.create_notebook_instance(**args) + resource_arn = resp["NotebookInstanceArn"] + + tags = [ + {"Key": "myKey", "Value": "myValue"}, + ] + response = sagemaker_client.add_tags(ResourceArn=resource_arn, Tags=tags) + assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 + + response = sagemaker_client.list_tags(ResourceArn=resource_arn) + assert response["Tags"] == tags + + +@mock_sagemaker +def test_delete_tags_from_notebook(sagemaker_client): + args = { + "NotebookInstanceName": FAKE_NAME_PARAM, + "InstanceType": FAKE_INSTANCE_TYPE_PARAM, + "RoleArn": FAKE_ROLE_ARN, + } + resp = sagemaker_client.create_notebook_instance(**args) + resource_arn = resp["NotebookInstanceArn"] + + tags = [ + {"Key": "myKey", "Value": "myValue"}, + ] + response = sagemaker_client.add_tags(ResourceArn=resource_arn, Tags=tags) + assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 + + tag_keys = [tag["Key"] for tag in tags] + response = sagemaker_client.delete_tags(ResourceArn=resource_arn, TagKeys=tag_keys) + assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 + + response = sagemaker_client.list_tags(ResourceArn=resource_arn) + assert response["Tags"] == [] diff --git a/tests/test_sagemaker/test_sagemaker_processing.py b/tests/test_sagemaker/test_sagemaker_processing.py index 5266cc4a2..f0c0a200f 100644 --- a/tests/test_sagemaker/test_sagemaker_processing.py +++ b/tests/test_sagemaker/test_sagemaker_processing.py @@ -6,10 +6,17 @@ import pytest from moto import mock_sagemaker from moto.sts.models import ACCOUNT_ID -FAKE_ROLE_ARN = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID) +FAKE_ROLE_ARN = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole" +FAKE_PROCESSING_JOB_NAME = "MyProcessingJob" +FAKE_CONTAINER = "382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1" TEST_REGION_NAME = "us-east-1" +@pytest.fixture +def sagemaker_client(): + return boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + class MyProcessingJobModel(object): def __init__( self, @@ -81,9 +88,7 @@ class MyProcessingJobModel(object): "MaxRuntimeInSeconds": 3600, } - def save(self): - sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) - + def save(self, sagemaker_client): params = { "AppSpecification": self.app_specification, "NetworkConfig": self.network_config, @@ -95,13 +100,265 @@ class MyProcessingJobModel(object): "StoppingCondition": self.stopping_condition, } - return sagemaker.create_processing_job(**params) + return sagemaker_client.create_processing_job(**params) @mock_sagemaker -def test_create_processing_job(): - sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) +def test_create_processing_job(sagemaker_client): + bucket = "my-bucket" + prefix = "my-prefix" + app_specification = { + "ImageUri": FAKE_CONTAINER, + "ContainerEntrypoint": ["python3", "app.py"], + } + processing_resources = { + "ClusterConfig": { + "InstanceCount": 2, + "InstanceType": "ml.m5.xlarge", + "VolumeSizeInGB": 20, + }, + } + stopping_condition = {"MaxRuntimeInSeconds": 60 * 60} + job = MyProcessingJobModel( + processing_job_name=FAKE_PROCESSING_JOB_NAME, + role_arn=FAKE_ROLE_ARN, + container=FAKE_CONTAINER, + bucket=bucket, + prefix=prefix, + app_specification=app_specification, + processing_resources=processing_resources, + stopping_condition=stopping_condition, + ) + resp = job.save(sagemaker_client) + resp["ProcessingJobArn"].should.match( + r"^arn:aws:sagemaker:.*:.*:processing-job/{}$".format(FAKE_PROCESSING_JOB_NAME) + ) + + resp = sagemaker_client.describe_processing_job( + ProcessingJobName=FAKE_PROCESSING_JOB_NAME + ) + resp["ProcessingJobName"].should.equal(FAKE_PROCESSING_JOB_NAME) + resp["ProcessingJobArn"].should.match( + r"^arn:aws:sagemaker:.*:.*:processing-job/{}$".format(FAKE_PROCESSING_JOB_NAME) + ) + assert "python3" in resp["AppSpecification"]["ContainerEntrypoint"] + assert "app.py" in resp["AppSpecification"]["ContainerEntrypoint"] + assert resp["RoleArn"] == FAKE_ROLE_ARN + assert resp["ProcessingJobStatus"] == "Completed" + assert isinstance(resp["CreationTime"], datetime.datetime) + assert isinstance(resp["LastModifiedTime"], datetime.datetime) + + +@mock_sagemaker +def test_list_processing_jobs(sagemaker_client): + test_processing_job = MyProcessingJobModel( + processing_job_name=FAKE_PROCESSING_JOB_NAME, role_arn=FAKE_ROLE_ARN + ) + test_processing_job.save(sagemaker_client) + processing_jobs = sagemaker_client.list_processing_jobs() + assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(1) + assert processing_jobs["ProcessingJobSummaries"][0][ + "ProcessingJobName" + ].should.equal(FAKE_PROCESSING_JOB_NAME) + + assert processing_jobs["ProcessingJobSummaries"][0][ + "ProcessingJobArn" + ].should.match( + r"^arn:aws:sagemaker:.*:.*:processing-job/{}$".format(FAKE_PROCESSING_JOB_NAME) + ) + assert processing_jobs.get("NextToken") is None + + +@mock_sagemaker +def test_list_processing_jobs_multiple(sagemaker_client): + name_job_1 = "blah" + arn_job_1 = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar" + test_processing_job_1 = MyProcessingJobModel( + processing_job_name=name_job_1, role_arn=arn_job_1 + ) + test_processing_job_1.save(sagemaker_client) + + name_job_2 = "blah2" + arn_job_2 = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar2" + test_processing_job_2 = MyProcessingJobModel( + processing_job_name=name_job_2, role_arn=arn_job_2 + ) + test_processing_job_2.save(sagemaker_client) + processing_jobs_limit = sagemaker_client.list_processing_jobs(MaxResults=1) + assert len(processing_jobs_limit["ProcessingJobSummaries"]).should.equal(1) + + processing_jobs = sagemaker_client.list_processing_jobs() + assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(2) + assert processing_jobs.get("NextToken").should.be.none + + +@mock_sagemaker +def test_list_processing_jobs_none(sagemaker_client): + processing_jobs = sagemaker_client.list_processing_jobs() + assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(0) + + +@mock_sagemaker +def test_list_processing_jobs_should_validate_input(sagemaker_client): + junk_status_equals = "blah" + with pytest.raises(ClientError) as ex: + sagemaker_client.list_processing_jobs(StatusEquals=junk_status_equals) + expected_error = f"1 validation errors detected: Value '{junk_status_equals}' at 'statusEquals' failed to satisfy constraint: Member must satisfy enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', 'Failed']" + assert ex.value.response["Error"]["Code"] == "ValidationException" + assert ex.value.response["Error"]["Message"] == expected_error + + junk_next_token = "asdf" + with pytest.raises(ClientError) as ex: + sagemaker_client.list_processing_jobs(NextToken=junk_next_token) + assert ex.value.response["Error"]["Code"] == "ValidationException" + assert ( + ex.value.response["Error"]["Message"] + == 'Invalid pagination token because "{0}".' + ) + + +@mock_sagemaker +def test_list_processing_jobs_with_name_filters(sagemaker_client): + for i in range(5): + name = "xgboost-{}".format(i) + arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i) + MyProcessingJobModel(processing_job_name=name, role_arn=arn).save( + sagemaker_client + ) + + for i in range(5): + name = "vgg-{}".format(i) + arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{}".format(i) + MyProcessingJobModel(processing_job_name=name, role_arn=arn).save( + sagemaker_client + ) + + xgboost_processing_jobs = sagemaker_client.list_processing_jobs( + NameContains="xgboost" + ) + assert len(xgboost_processing_jobs["ProcessingJobSummaries"]).should.equal(5) + + processing_jobs_with_2 = sagemaker_client.list_processing_jobs(NameContains="2") + assert len(processing_jobs_with_2["ProcessingJobSummaries"]).should.equal(2) + + +@mock_sagemaker +def test_list_processing_jobs_paginated(sagemaker_client): + for i in range(5): + name = "xgboost-{}".format(i) + arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i) + MyProcessingJobModel(processing_job_name=name, role_arn=arn).save( + sagemaker_client + ) + + xgboost_processing_job_1 = sagemaker_client.list_processing_jobs( + NameContains="xgboost", MaxResults=1 + ) + assert len(xgboost_processing_job_1["ProcessingJobSummaries"]).should.equal(1) + assert xgboost_processing_job_1["ProcessingJobSummaries"][0][ + "ProcessingJobName" + ].should.equal("xgboost-0") + assert xgboost_processing_job_1.get("NextToken").should_not.be.none + + xgboost_processing_job_next = sagemaker_client.list_processing_jobs( + NameContains="xgboost", + MaxResults=1, + NextToken=xgboost_processing_job_1.get("NextToken"), + ) + assert len(xgboost_processing_job_next["ProcessingJobSummaries"]).should.equal(1) + assert xgboost_processing_job_next["ProcessingJobSummaries"][0][ + "ProcessingJobName" + ].should.equal("xgboost-1") + assert xgboost_processing_job_next.get("NextToken").should_not.be.none + + +@mock_sagemaker +def test_list_processing_jobs_paginated_with_target_in_middle(sagemaker_client): + for i in range(5): + name = "xgboost-{}".format(i) + arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i) + MyProcessingJobModel(processing_job_name=name, role_arn=arn).save( + sagemaker_client + ) + + for i in range(5): + name = "vgg-{}".format(i) + arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{}".format(i) + MyProcessingJobModel(processing_job_name=name, role_arn=arn).save( + sagemaker_client + ) + + vgg_processing_job_1 = sagemaker_client.list_processing_jobs( + NameContains="vgg", MaxResults=1 + ) + assert len(vgg_processing_job_1["ProcessingJobSummaries"]).should.equal(0) + assert vgg_processing_job_1.get("NextToken").should_not.be.none + + vgg_processing_job_6 = sagemaker_client.list_processing_jobs( + NameContains="vgg", MaxResults=6 + ) + + assert len(vgg_processing_job_6["ProcessingJobSummaries"]).should.equal(1) + assert vgg_processing_job_6["ProcessingJobSummaries"][0][ + "ProcessingJobName" + ].should.equal("vgg-0") + assert vgg_processing_job_6.get("NextToken").should_not.be.none + + vgg_processing_job_10 = sagemaker_client.list_processing_jobs( + NameContains="vgg", MaxResults=10 + ) + + assert len(vgg_processing_job_10["ProcessingJobSummaries"]).should.equal(5) + assert vgg_processing_job_10["ProcessingJobSummaries"][-1][ + "ProcessingJobName" + ].should.equal("vgg-4") + assert vgg_processing_job_10.get("NextToken").should.be.none + + +@mock_sagemaker +def test_list_processing_jobs_paginated_with_fragmented_targets(sagemaker_client): + for i in range(5): + name = "xgboost-{}".format(i) + arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i) + MyProcessingJobModel(processing_job_name=name, role_arn=arn).save( + sagemaker_client + ) + + for i in range(5): + name = "vgg-{}".format(i) + arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{}".format(i) + MyProcessingJobModel(processing_job_name=name, role_arn=arn).save( + sagemaker_client + ) + + processing_jobs_with_2 = sagemaker_client.list_processing_jobs( + NameContains="2", MaxResults=8 + ) + assert len(processing_jobs_with_2["ProcessingJobSummaries"]).should.equal(2) + assert processing_jobs_with_2.get("NextToken").should_not.be.none + + processing_jobs_with_2_next = sagemaker_client.list_processing_jobs( + NameContains="2", + MaxResults=1, + NextToken=processing_jobs_with_2.get("NextToken"), + ) + assert len(processing_jobs_with_2_next["ProcessingJobSummaries"]).should.equal(0) + assert processing_jobs_with_2_next.get("NextToken").should_not.be.none + + processing_jobs_with_2_next_next = sagemaker_client.list_processing_jobs( + NameContains="2", + MaxResults=1, + NextToken=processing_jobs_with_2_next.get("NextToken"), + ) + assert len(processing_jobs_with_2_next_next["ProcessingJobSummaries"]).should.equal( + 0 + ) + assert processing_jobs_with_2_next_next.get("NextToken").should.be.none + + +@mock_sagemaker +def test_add_and_delete_tags_in_training_job(sagemaker_client): processing_job_name = "MyProcessingJob" role_arn = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID) container = "382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1" @@ -130,205 +387,21 @@ def test_create_processing_job(): processing_resources=processing_resources, stopping_condition=stopping_condition, ) - resp = job.save() - resp["ProcessingJobArn"].should.match( - r"^arn:aws:sagemaker:.*:.*:processing-job/{}$".format(processing_job_name) - ) + resp = job.save(sagemaker_client) + resource_arn = resp["ProcessingJobArn"] - resp = sagemaker.describe_processing_job(ProcessingJobName=processing_job_name) - resp["ProcessingJobName"].should.equal(processing_job_name) - resp["ProcessingJobArn"].should.match( - r"^arn:aws:sagemaker:.*:.*:processing-job/{}$".format(processing_job_name) - ) - assert "python3" in resp["AppSpecification"]["ContainerEntrypoint"] - assert "app.py" in resp["AppSpecification"]["ContainerEntrypoint"] - assert resp["RoleArn"] == role_arn - assert resp["ProcessingJobStatus"] == "Completed" - assert isinstance(resp["CreationTime"], datetime.datetime) - assert isinstance(resp["LastModifiedTime"], datetime.datetime) + tags = [ + {"Key": "myKey", "Value": "myValue"}, + ] + response = sagemaker_client.add_tags(ResourceArn=resource_arn, Tags=tags) + assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 + response = sagemaker_client.list_tags(ResourceArn=resource_arn) + assert response["Tags"] == tags -@mock_sagemaker -def test_list_processing_jobs(): - client = boto3.client("sagemaker", region_name="us-east-1") - name = "blah" - arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar" - test_processing_job = MyProcessingJobModel(processing_job_name=name, role_arn=arn) - test_processing_job.save() - processing_jobs = client.list_processing_jobs() - assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(1) - assert processing_jobs["ProcessingJobSummaries"][0][ - "ProcessingJobName" - ].should.equal(name) + tag_keys = [tag["Key"] for tag in tags] + response = sagemaker_client.delete_tags(ResourceArn=resource_arn, TagKeys=tag_keys) + assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 - assert processing_jobs["ProcessingJobSummaries"][0][ - "ProcessingJobArn" - ].should.match(r"^arn:aws:sagemaker:.*:.*:processing-job/{}$".format(name)) - assert processing_jobs.get("NextToken") is None - - -@mock_sagemaker -def test_list_processing_jobs_multiple(): - client = boto3.client("sagemaker", region_name="us-east-1") - name_job_1 = "blah" - arn_job_1 = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar" - test_processing_job_1 = MyProcessingJobModel( - processing_job_name=name_job_1, role_arn=arn_job_1 - ) - test_processing_job_1.save() - - name_job_2 = "blah2" - arn_job_2 = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar2" - test_processing_job_2 = MyProcessingJobModel( - processing_job_name=name_job_2, role_arn=arn_job_2 - ) - test_processing_job_2.save() - processing_jobs_limit = client.list_processing_jobs(MaxResults=1) - assert len(processing_jobs_limit["ProcessingJobSummaries"]).should.equal(1) - - processing_jobs = client.list_processing_jobs() - assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(2) - assert processing_jobs.get("NextToken").should.be.none - - -@mock_sagemaker -def test_list_processing_jobs_none(): - client = boto3.client("sagemaker", region_name="us-east-1") - processing_jobs = client.list_processing_jobs() - assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(0) - - -@mock_sagemaker -def test_list_processing_jobs_should_validate_input(): - client = boto3.client("sagemaker", region_name="us-east-1") - junk_status_equals = "blah" - with pytest.raises(ClientError) as ex: - client.list_processing_jobs(StatusEquals=junk_status_equals) - expected_error = f"1 validation errors detected: Value '{junk_status_equals}' at 'statusEquals' failed to satisfy constraint: Member must satisfy enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', 'Failed']" - assert ex.value.response["Error"]["Code"] == "ValidationException" - assert ex.value.response["Error"]["Message"] == expected_error - - junk_next_token = "asdf" - with pytest.raises(ClientError) as ex: - client.list_processing_jobs(NextToken=junk_next_token) - assert ex.value.response["Error"]["Code"] == "ValidationException" - assert ( - ex.value.response["Error"]["Message"] - == 'Invalid pagination token because "{0}".' - ) - - -@mock_sagemaker -def test_list_processing_jobs_with_name_filters(): - client = boto3.client("sagemaker", region_name="us-east-1") - for i in range(5): - name = "xgboost-{}".format(i) - arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i) - MyProcessingJobModel(processing_job_name=name, role_arn=arn).save() - for i in range(5): - name = "vgg-{}".format(i) - arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{}".format(i) - MyProcessingJobModel(processing_job_name=name, role_arn=arn).save() - xgboost_processing_jobs = client.list_processing_jobs(NameContains="xgboost") - assert len(xgboost_processing_jobs["ProcessingJobSummaries"]).should.equal(5) - - processing_jobs_with_2 = client.list_processing_jobs(NameContains="2") - assert len(processing_jobs_with_2["ProcessingJobSummaries"]).should.equal(2) - - -@mock_sagemaker -def test_list_processing_jobs_paginated(): - client = boto3.client("sagemaker", region_name="us-east-1") - for i in range(5): - name = "xgboost-{}".format(i) - arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i) - MyProcessingJobModel(processing_job_name=name, role_arn=arn).save() - xgboost_processing_job_1 = client.list_processing_jobs( - NameContains="xgboost", MaxResults=1 - ) - assert len(xgboost_processing_job_1["ProcessingJobSummaries"]).should.equal(1) - assert xgboost_processing_job_1["ProcessingJobSummaries"][0][ - "ProcessingJobName" - ].should.equal("xgboost-0") - assert xgboost_processing_job_1.get("NextToken").should_not.be.none - - xgboost_processing_job_next = client.list_processing_jobs( - NameContains="xgboost", - MaxResults=1, - NextToken=xgboost_processing_job_1.get("NextToken"), - ) - assert len(xgboost_processing_job_next["ProcessingJobSummaries"]).should.equal(1) - assert xgboost_processing_job_next["ProcessingJobSummaries"][0][ - "ProcessingJobName" - ].should.equal("xgboost-1") - assert xgboost_processing_job_next.get("NextToken").should_not.be.none - - -@mock_sagemaker -def test_list_processing_jobs_paginated_with_target_in_middle(): - client = boto3.client("sagemaker", region_name="us-east-1") - for i in range(5): - name = "xgboost-{}".format(i) - arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i) - MyProcessingJobModel(processing_job_name=name, role_arn=arn).save() - for i in range(5): - name = "vgg-{}".format(i) - arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{}".format(i) - MyProcessingJobModel(processing_job_name=name, role_arn=arn).save() - - vgg_processing_job_1 = client.list_processing_jobs(NameContains="vgg", MaxResults=1) - assert len(vgg_processing_job_1["ProcessingJobSummaries"]).should.equal(0) - assert vgg_processing_job_1.get("NextToken").should_not.be.none - - vgg_processing_job_6 = client.list_processing_jobs(NameContains="vgg", MaxResults=6) - - assert len(vgg_processing_job_6["ProcessingJobSummaries"]).should.equal(1) - assert vgg_processing_job_6["ProcessingJobSummaries"][0][ - "ProcessingJobName" - ].should.equal("vgg-0") - assert vgg_processing_job_6.get("NextToken").should_not.be.none - - vgg_processing_job_10 = client.list_processing_jobs( - NameContains="vgg", MaxResults=10 - ) - - assert len(vgg_processing_job_10["ProcessingJobSummaries"]).should.equal(5) - assert vgg_processing_job_10["ProcessingJobSummaries"][-1][ - "ProcessingJobName" - ].should.equal("vgg-4") - assert vgg_processing_job_10.get("NextToken").should.be.none - - -@mock_sagemaker -def test_list_processing_jobs_paginated_with_fragmented_targets(): - client = boto3.client("sagemaker", region_name="us-east-1") - for i in range(5): - name = "xgboost-{}".format(i) - arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i) - MyProcessingJobModel(processing_job_name=name, role_arn=arn).save() - for i in range(5): - name = "vgg-{}".format(i) - arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{}".format(i) - MyProcessingJobModel(processing_job_name=name, role_arn=arn).save() - - processing_jobs_with_2 = client.list_processing_jobs(NameContains="2", MaxResults=8) - assert len(processing_jobs_with_2["ProcessingJobSummaries"]).should.equal(2) - assert processing_jobs_with_2.get("NextToken").should_not.be.none - - processing_jobs_with_2_next = client.list_processing_jobs( - NameContains="2", - MaxResults=1, - NextToken=processing_jobs_with_2.get("NextToken"), - ) - assert len(processing_jobs_with_2_next["ProcessingJobSummaries"]).should.equal(0) - assert processing_jobs_with_2_next.get("NextToken").should_not.be.none - - processing_jobs_with_2_next_next = client.list_processing_jobs( - NameContains="2", - MaxResults=1, - NextToken=processing_jobs_with_2_next.get("NextToken"), - ) - assert len(processing_jobs_with_2_next_next["ProcessingJobSummaries"]).should.equal( - 0 - ) - assert processing_jobs_with_2_next_next.get("NextToken").should.be.none + response = sagemaker_client.list_tags(ResourceArn=resource_arn) + assert response["Tags"] == [] diff --git a/tests/test_sagemaker/test_sagemaker_search.py b/tests/test_sagemaker/test_sagemaker_search.py index a5d3628f2..a139080d7 100644 --- a/tests/test_sagemaker/test_sagemaker_search.py +++ b/tests/test_sagemaker/test_sagemaker_search.py @@ -8,39 +8,35 @@ from moto import mock_sagemaker TEST_REGION_NAME = "us-east-1" +@pytest.fixture +def sagemaker_client(): + return boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + @mock_sagemaker -def test_search(): - client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) - - experiment_name = "some-experiment-name" - - resp = client.create_experiment(ExperimentName=experiment_name) - - trial_name = "some-trial-name" - - resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name) - - trial_component_name = "some-trial-component-name" - another_trial_component_name = "another-trial-component-name" - - resp = client.create_trial_component(TrialComponentName=trial_component_name) - resp = client.create_trial_component( - TrialComponentName=another_trial_component_name +def test_search(sagemaker_client): + experiment_name = "experiment_name" + trial_component_name = "trial_component_name" + trial_name = "trial_name" + _set_up_trial_component( + sagemaker_client, + experiment_name=experiment_name, + trial_component_name=trial_component_name, + trial_name=trial_name, ) - resp = client.search(Resource="ExperimentTrialComponent") - + resp = sagemaker_client.search(Resource="ExperimentTrialComponent") assert len(resp["Results"]) == 2 - resp = client.describe_trial_component(TrialComponentName=trial_component_name) - + resp = sagemaker_client.describe_trial_component( + TrialComponentName=trial_component_name + ) trial_component_arn = resp["TrialComponentArn"] tags = [{"Key": "key-name", "Value": "some-value"}] + sagemaker_client.add_tags(ResourceArn=trial_component_arn, Tags=tags) - client.add_tags(ResourceArn=trial_component_arn, Tags=tags) - - resp = client.search( + resp = sagemaker_client.search( Resource="ExperimentTrialComponent", SearchExpression={ "Filters": [ @@ -55,49 +51,38 @@ def test_search(): == trial_component_name ) - resp = client.search(Resource="Experiment") + resp = sagemaker_client.search(Resource="Experiment") assert len(resp["Results"]) == 1 assert resp["Results"][0]["Experiment"]["ExperimentName"] == experiment_name - resp = client.search(Resource="ExperimentTrial") + resp = sagemaker_client.search(Resource="ExperimentTrial") assert len(resp["Results"]) == 1 assert resp["Results"][0]["Trial"]["TrialName"] == trial_name @mock_sagemaker -def test_search_trial_component_with_experiment_name(): - client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) - - experiment_name = "some-experiment-name" - - resp = client.create_experiment(ExperimentName=experiment_name) - - trial_name = "some-trial-name" - - resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name) - - trial_component_name = "some-trial-component-name" - another_trial_component_name = "another-trial-component-name" - - resp = client.create_trial_component(TrialComponentName=trial_component_name) - resp = client.create_trial_component( - TrialComponentName=another_trial_component_name +def test_search_trial_component_with_experiment_name(sagemaker_client): + experiment_name = "experiment_name" + trial_component_name = "trial_component_name" + _set_up_trial_component( + sagemaker_client, + experiment_name=experiment_name, + trial_component_name=trial_component_name, ) - resp = client.search(Resource="ExperimentTrialComponent") - + resp = sagemaker_client.search(Resource="ExperimentTrialComponent") assert len(resp["Results"]) == 2 - resp = client.describe_trial_component(TrialComponentName=trial_component_name) - + resp = sagemaker_client.describe_trial_component( + TrialComponentName=trial_component_name + ) trial_component_arn = resp["TrialComponentArn"] tags = [{"Key": "key-name", "Value": "some-value"}] - - client.add_tags(ResourceArn=trial_component_arn, Tags=tags) + sagemaker_client.add_tags(ResourceArn=trial_component_arn, Tags=tags) with pytest.raises(ClientError) as ex: - client.search( + sagemaker_client.search( Resource="ExperimentTrialComponent", SearchExpression={ "Filters": [ @@ -115,3 +100,18 @@ def test_search_trial_component_with_experiment_name(): "Unknown property name: ExperimentName" ) ex.value.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + + +def _set_up_trial_component( + sagemaker_client, + experiment_name="some-experiment-name", + trial_component_name="some-trial-component-name", + trial_name="some-trial-name", + another_trial_component_name="another-trial-component-name", +): + sagemaker_client.create_experiment(ExperimentName=experiment_name) + sagemaker_client.create_trial(ExperimentName=experiment_name, TrialName=trial_name) + sagemaker_client.create_trial_component(TrialComponentName=trial_component_name) + sagemaker_client.create_trial_component( + TrialComponentName=another_trial_component_name + ) diff --git a/tests/test_sagemaker/test_sagemaker_training.py b/tests/test_sagemaker/test_sagemaker_training.py index 76fa3170c..bd96d67b1 100644 --- a/tests/test_sagemaker/test_sagemaker_training.py +++ b/tests/test_sagemaker/test_sagemaker_training.py @@ -397,3 +397,47 @@ def test_list_training_jobs_paginated_with_fragmented_targets(): ) assert len(training_jobs_with_2_next_next["TrainingJobSummaries"]).should.equal(0) assert training_jobs_with_2_next_next.get("NextToken").should.be.none + + +@mock_sagemaker +def test_add_tags_to_training_job(): + client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + name = "blah" + resource_arn = f"arn:aws:sagemaker:us-east-1:000000000000:training-job/{name}" + test_training_job = MyTrainingJobModel( + training_job_name=name, role_arn=resource_arn + ) + test_training_job.save() + + tags = [ + {"Key": "myKey", "Value": "myValue"}, + ] + response = client.add_tags(ResourceArn=resource_arn, Tags=tags) + assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 + + response = client.list_tags(ResourceArn=resource_arn) + assert response["Tags"] == tags + + +@mock_sagemaker +def test_delete_tags_from_training_job(): + client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + name = "blah" + resource_arn = f"arn:aws:sagemaker:us-east-1:000000000000:training-job/{name}" + test_training_job = MyTrainingJobModel( + training_job_name=name, role_arn=resource_arn + ) + test_training_job.save() + + tags = [ + {"Key": "myKey", "Value": "myValue"}, + ] + response = client.add_tags(ResourceArn=resource_arn, Tags=tags) + assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 + + tag_keys = [tag["Key"] for tag in tags] + response = client.delete_tags(ResourceArn=resource_arn, TagKeys=tag_keys) + assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 + + response = client.list_tags(ResourceArn=resource_arn) + assert response["Tags"] == [] diff --git a/tests/test_sagemaker/test_sagemaker_trial.py b/tests/test_sagemaker/test_sagemaker_trial.py index cf02b26bd..c1054c560 100644 --- a/tests/test_sagemaker/test_sagemaker_trial.py +++ b/tests/test_sagemaker/test_sagemaker_trial.py @@ -1,3 +1,5 @@ +import uuid + import boto3 from moto import mock_sagemaker @@ -158,3 +160,33 @@ def test_delete_tags_to_trial(): resp = client.list_tags(ResourceArn=arn) assert resp["Tags"] == [] + + +@mock_sagemaker +def test_list_trial_tags(): + client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + experiment_name = "some-experiment-name" + client.create_experiment(ExperimentName=experiment_name) + + trial_name = "some-trial-name" + client.create_trial(ExperimentName=experiment_name, TrialName=trial_name) + resp = client.describe_trial(TrialName=trial_name) + resource_arn = resp["TrialArn"] + + tags = [] + for _ in range(80): + tags.append({"Key": str(uuid.uuid4()), "Value": "myValue"}) + + response = client.add_tags(ResourceArn=resource_arn, Tags=tags) + assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 + + response = client.list_tags(ResourceArn=resource_arn) + assert len(response["Tags"]) == 50 + assert response["Tags"] == tags[:50] + + response = client.list_tags( + ResourceArn=resource_arn, NextToken=response["NextToken"] + ) + assert len(response["Tags"]) == 30 + assert response["Tags"] == tags[50:] diff --git a/tests/test_sagemaker/test_sagemaker_trial_component.py b/tests/test_sagemaker/test_sagemaker_trial_component.py index e120f7ccd..3bdd1f0fc 100644 --- a/tests/test_sagemaker/test_sagemaker_trial_component.py +++ b/tests/test_sagemaker/test_sagemaker_trial_component.py @@ -1,3 +1,5 @@ +import uuid + import boto3 import pytest @@ -123,6 +125,33 @@ def test_delete_tags_to_trial_component(): assert resp["Tags"] == [] +@mock_sagemaker +def test_list_trial_component_tags(): + client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + trial_component_name = "some-trial-component-name" + client.create_trial_component(TrialComponentName=trial_component_name) + resp = client.describe_trial_component(TrialComponentName=trial_component_name) + resource_arn = resp["TrialComponentArn"] + + tags = [] + for _ in range(80): + tags.append({"Key": str(uuid.uuid4()), "Value": "myValue"}) + + response = client.add_tags(ResourceArn=resource_arn, Tags=tags) + assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 + + response = client.list_tags(ResourceArn=resource_arn) + assert len(response["Tags"]) == 50 + assert response["Tags"] == tags[:50] + + response = client.list_tags( + ResourceArn=resource_arn, NextToken=response["NextToken"] + ) + assert len(response["Tags"]) == 30 + assert response["Tags"] == tags[50:] + + @mock_sagemaker def test_associate_trial_component(): client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)