Increase Tag support for Sagemaker. (#5052)
This commit is contained in:
parent
d12254309d
commit
e703ee9a76
@ -4904,10 +4904,10 @@
|
||||
|
||||
## sagemaker
|
||||
<details>
|
||||
<summary>15% implemented</summary>
|
||||
<summary>16% implemented</summary>
|
||||
|
||||
- [ ] add_association
|
||||
- [ ] add_tags
|
||||
- [X] add_tags
|
||||
- [X] associate_trial_component
|
||||
- [ ] batch_describe_model_package
|
||||
- [ ] create_action
|
||||
@ -4988,7 +4988,7 @@
|
||||
- [ ] delete_pipeline
|
||||
- [ ] delete_project
|
||||
- [ ] delete_studio_lifecycle_config
|
||||
- [ ] delete_tags
|
||||
- [X] delete_tags
|
||||
- [X] delete_trial
|
||||
- [X] delete_trial_component
|
||||
- [ ] delete_user_profile
|
||||
@ -5100,7 +5100,7 @@
|
||||
- [ ] list_projects
|
||||
- [ ] list_studio_lifecycle_configs
|
||||
- [ ] list_subscribed_workteams
|
||||
- [ ] list_tags
|
||||
- [X] list_tags
|
||||
- [X] list_training_jobs
|
||||
- [ ] list_training_jobs_for_hyper_parameter_tuning_job
|
||||
- [ ] list_transform_jobs
|
||||
|
@ -26,7 +26,7 @@ sagemaker
|
||||
|start-h3| Implemented features for this service |end-h3|
|
||||
|
||||
- [ ] add_association
|
||||
- [ ] add_tags
|
||||
- [X] add_tags
|
||||
- [X] associate_trial_component
|
||||
- [ ] batch_describe_model_package
|
||||
- [ ] create_action
|
||||
@ -107,7 +107,7 @@ sagemaker
|
||||
- [ ] delete_pipeline
|
||||
- [ ] delete_project
|
||||
- [ ] delete_studio_lifecycle_config
|
||||
- [ ] delete_tags
|
||||
- [X] delete_tags
|
||||
- [X] delete_trial
|
||||
- [X] delete_trial_component
|
||||
- [ ] delete_user_profile
|
||||
@ -219,7 +219,7 @@ sagemaker
|
||||
- [ ] list_projects
|
||||
- [ ] list_studio_lifecycle_configs
|
||||
- [ ] list_subscribed_workteams
|
||||
- [ ] list_tags
|
||||
- [X] list_tags
|
||||
- [X] list_training_jobs
|
||||
- [ ] list_training_jobs_for_hyper_parameter_tuning_job
|
||||
- [ ] list_transform_jobs
|
||||
|
@ -2,7 +2,6 @@ import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from moto.core import ACCOUNT_ID, BaseBackend, BaseModel, CloudFormationModel
|
||||
from moto.core.exceptions import RESTError
|
||||
from moto.core.utils import BackendDict
|
||||
from moto.sagemaker import validators
|
||||
from moto.utilities.paginator import paginate
|
||||
@ -36,6 +35,13 @@ PAGINATION_MODEL = {
|
||||
"unique_attribute": "trial_component_arn",
|
||||
"fail_on_invalid_token": True,
|
||||
},
|
||||
"list_tags": {
|
||||
"input_token": "NextToken",
|
||||
"limit_key": "MaxResults",
|
||||
"limit_default": 50,
|
||||
"unique_attribute": "Key",
|
||||
"fail_on_invalid_token": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@ -76,6 +82,7 @@ class FakeProcessingJob(BaseObject):
|
||||
processing_output_config,
|
||||
region_name,
|
||||
role_arn,
|
||||
tags,
|
||||
stopping_condition,
|
||||
):
|
||||
self.processing_job_name = processing_job_name
|
||||
@ -87,7 +94,7 @@ class FakeProcessingJob(BaseObject):
|
||||
self.creation_time = now_string
|
||||
self.last_modified_time = now_string
|
||||
self.processing_end_time = now_string
|
||||
|
||||
self.tags = tags or []
|
||||
self.role_arn = role_arn
|
||||
self.app_specification = app_specification
|
||||
self.experiment_config = experiment_config
|
||||
@ -152,7 +159,7 @@ class FakeTrainingJob(BaseObject):
|
||||
self.resource_config = resource_config
|
||||
self.vpc_config = vpc_config
|
||||
self.stopping_condition = stopping_condition
|
||||
self.tags = tags
|
||||
self.tags = tags or []
|
||||
self.enable_network_isolation = enable_network_isolation
|
||||
self.enable_inter_container_traffic_encryption = (
|
||||
enable_inter_container_traffic_encryption
|
||||
@ -1095,51 +1102,38 @@ class SageMakerModelBackend(BaseBackend):
|
||||
"LastModifiedTime": experiment_data.last_modified_time,
|
||||
}
|
||||
|
||||
def add_tags_to_experiment(self, experiment_arn, tags):
|
||||
experiment = [
|
||||
self.experiments[i]
|
||||
for i in self.experiments
|
||||
if self.experiments[i].experiment_arn == experiment_arn
|
||||
][0]
|
||||
experiment.tags.extend(tags)
|
||||
def _get_resource_from_arn(self, arn):
|
||||
resources = {
|
||||
"model": self._models,
|
||||
"notebook-instance": self.notebook_instances,
|
||||
"endpoint": self.endpoints,
|
||||
"endpoint-config": self.endpoint_configs,
|
||||
"training-job": self.training_jobs,
|
||||
"experiment": self.experiments,
|
||||
"experiment-trial": self.trials,
|
||||
"experiment-trial-component": self.trial_components,
|
||||
"processing-job": self.processing_jobs,
|
||||
}
|
||||
target_resource, target_name = arn.split(":")[-1].split("/")
|
||||
try:
|
||||
resource = resources.get(target_resource).get(target_name)
|
||||
except KeyError:
|
||||
message = f"Could not find {target_resource} with name {target_name}"
|
||||
raise ValidationError(message=message)
|
||||
return resource
|
||||
|
||||
def add_tags_to_trial(self, trial_arn, tags):
|
||||
trial = [
|
||||
self.trials[i] for i in self.trials if self.trials[i].trial_arn == trial_arn
|
||||
][0]
|
||||
trial.tags.extend(tags)
|
||||
def add_tags(self, arn, tags):
|
||||
resource = self._get_resource_from_arn(arn)
|
||||
resource.tags.extend(tags)
|
||||
|
||||
def add_tags_to_trial_component(self, trial_component_arn, tags):
|
||||
trial_component = [
|
||||
self.trial_components[i]
|
||||
for i in self.trial_components
|
||||
if self.trial_components[i].trial_component_arn == trial_component_arn
|
||||
][0]
|
||||
trial_component.tags.extend(tags)
|
||||
@paginate(pagination_model=PAGINATION_MODEL)
|
||||
def list_tags(self, arn):
|
||||
resource = self._get_resource_from_arn(arn)
|
||||
return resource.tags
|
||||
|
||||
def delete_tags_from_experiment(self, experiment_arn, tag_keys):
|
||||
experiment = [
|
||||
self.experiments[i]
|
||||
for i in self.experiments
|
||||
if self.experiments[i].experiment_arn == experiment_arn
|
||||
][0]
|
||||
experiment.tags = [tag for tag in experiment.tags if tag["Key"] not in tag_keys]
|
||||
|
||||
def delete_tags_from_trial(self, trial_arn, tag_keys):
|
||||
trial = [
|
||||
self.trials[i] for i in self.trials if self.trials[i].trial_arn == trial_arn
|
||||
][0]
|
||||
trial.tags = [tag for tag in trial.tags if tag["Key"] not in tag_keys]
|
||||
|
||||
def delete_tags_from_trial_component(self, trial_component_arn, tag_keys):
|
||||
trial_component = [
|
||||
self.trial_components[i]
|
||||
for i in self.trial_components
|
||||
if self.trial_components[i].trial_component_arn == trial_component_arn
|
||||
][0]
|
||||
trial_component.tags = [
|
||||
tag for tag in trial_component.tags if tag["Key"] not in tag_keys
|
||||
]
|
||||
def delete_tags(self, arn, tag_keys):
|
||||
resource = self._get_resource_from_arn(arn)
|
||||
resource.tags = [tag for tag in resource.tags if tag["Key"] not in tag_keys]
|
||||
|
||||
@paginate(pagination_model=PAGINATION_MODEL)
|
||||
def list_experiments(self):
|
||||
@ -1281,24 +1275,6 @@ class SageMakerModelBackend(BaseBackend):
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
|
||||
def get_experiment_by_arn(self, arn):
|
||||
experiments = [
|
||||
experiment
|
||||
for experiment in self.experiments.values()
|
||||
if experiment.experiment_arn == arn
|
||||
]
|
||||
if len(experiments) == 0:
|
||||
message = "RecordNotFound"
|
||||
raise ValidationError(message=message)
|
||||
return experiments[0]
|
||||
|
||||
def get_experiment_tags(self, arn):
|
||||
try:
|
||||
experiment = self.get_experiment_by_arn(arn)
|
||||
return experiment.tags or []
|
||||
except RESTError:
|
||||
return []
|
||||
|
||||
def create_trial(self, trial_name, experiment_name):
|
||||
trial = FakeTrial(
|
||||
region_name=self.region_name,
|
||||
@ -1328,20 +1304,6 @@ class SageMakerModelBackend(BaseBackend):
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
|
||||
def get_trial_by_arn(self, arn):
|
||||
trials = [trial for trial in self.trials.values() if trial.trial_arn == arn]
|
||||
if len(trials) == 0:
|
||||
message = "RecordNotFound"
|
||||
raise ValidationError(message=message)
|
||||
return trials[0]
|
||||
|
||||
def get_trial_tags(self, arn):
|
||||
try:
|
||||
trial = self.get_trial_by_arn(arn)
|
||||
return trial.tags or []
|
||||
except RESTError:
|
||||
return []
|
||||
|
||||
@paginate(pagination_model=PAGINATION_MODEL)
|
||||
def list_trials(self, experiment_name=None, trial_component_name=None):
|
||||
trials_fetched = list(self.trials.values())
|
||||
@ -1382,24 +1344,6 @@ class SageMakerModelBackend(BaseBackend):
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
|
||||
def get_trial_component_by_arn(self, arn):
|
||||
trial_components = [
|
||||
trial_component
|
||||
for trial_component in self.trial_components.values()
|
||||
if trial_component.trial_component_arn == arn
|
||||
]
|
||||
if len(trial_components) == 0:
|
||||
message = "RecordNotFound"
|
||||
raise ValidationError(message=message)
|
||||
return trial_components[0]
|
||||
|
||||
def get_trial_component_tags(self, arn):
|
||||
try:
|
||||
trial_component = self.get_trial_component_by_arn(arn)
|
||||
return trial_component.tags or []
|
||||
except RESTError:
|
||||
return []
|
||||
|
||||
def describe_trial_component(self, trial_component_name):
|
||||
try:
|
||||
return self.trial_components[trial_component_name].response_object
|
||||
@ -1518,16 +1462,6 @@ class SageMakerModelBackend(BaseBackend):
|
||||
except KeyError:
|
||||
raise ValidationError(message="RecordNotFound")
|
||||
|
||||
def get_notebook_instance_by_arn(self, arn):
|
||||
instances = [
|
||||
notebook_instance
|
||||
for notebook_instance in self.notebook_instances.values()
|
||||
if notebook_instance.arn == arn
|
||||
]
|
||||
if len(instances) == 0:
|
||||
raise ValidationError(message="RecordNotFound")
|
||||
return instances[0]
|
||||
|
||||
def start_notebook_instance(self, notebook_instance_name):
|
||||
notebook_instance = self.get_notebook_instance(notebook_instance_name)
|
||||
notebook_instance.start()
|
||||
@ -1545,13 +1479,6 @@ class SageMakerModelBackend(BaseBackend):
|
||||
raise ValidationError(message=message)
|
||||
del self.notebook_instances[notebook_instance_name]
|
||||
|
||||
def get_notebook_instance_tags(self, arn):
|
||||
try:
|
||||
notebook_instance = self.get_notebook_instance_by_arn(arn)
|
||||
return notebook_instance.tags or []
|
||||
except RESTError:
|
||||
return []
|
||||
|
||||
def create_notebook_instance_lifecycle_config(
|
||||
self, notebook_instance_lifecycle_config_name, on_create, on_start
|
||||
):
|
||||
@ -1694,24 +1621,6 @@ class SageMakerModelBackend(BaseBackend):
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
|
||||
def get_endpoint_by_arn(self, arn):
|
||||
endpoints = [
|
||||
endpoint
|
||||
for endpoint in self.endpoints.values()
|
||||
if endpoint.endpoint_arn == arn
|
||||
]
|
||||
if len(endpoints) == 0:
|
||||
message = "RecordNotFound"
|
||||
raise ValidationError(message=message)
|
||||
return endpoints[0]
|
||||
|
||||
def get_endpoint_tags(self, arn):
|
||||
try:
|
||||
endpoint = self.get_endpoint_by_arn(arn)
|
||||
return endpoint.tags or []
|
||||
except RESTError:
|
||||
return []
|
||||
|
||||
def create_processing_job(
|
||||
self,
|
||||
app_specification,
|
||||
@ -1721,6 +1630,7 @@ class SageMakerModelBackend(BaseBackend):
|
||||
processing_job_name,
|
||||
processing_output_config,
|
||||
role_arn,
|
||||
tags,
|
||||
stopping_condition,
|
||||
):
|
||||
processing_job = FakeProcessingJob(
|
||||
@ -1733,6 +1643,7 @@ class SageMakerModelBackend(BaseBackend):
|
||||
region_name=self.region_name,
|
||||
role_arn=role_arn,
|
||||
stopping_condition=stopping_condition,
|
||||
tags=tags,
|
||||
)
|
||||
self.processing_jobs[processing_job_name] = processing_job
|
||||
return processing_job
|
||||
@ -1894,23 +1805,6 @@ class SageMakerModelBackend(BaseBackend):
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
|
||||
def get_training_job_by_arn(self, arn):
|
||||
training_jobs = [
|
||||
training_job
|
||||
for training_job in self.training_jobs.values()
|
||||
if training_job.training_job_arn == arn
|
||||
]
|
||||
if len(training_jobs) == 0:
|
||||
raise ValidationError(message="RecordNotFound")
|
||||
return training_jobs[0]
|
||||
|
||||
def get_training_job_tags(self, arn):
|
||||
try:
|
||||
training_job = self.get_training_job_by_arn(arn)
|
||||
return training_job.tags or []
|
||||
except RESTError:
|
||||
return []
|
||||
|
||||
def _update_training_job_details(self, training_job_name, details_json):
|
||||
self.training_jobs[training_job_name].update(details_json)
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
import json
|
||||
from moto.sagemaker.exceptions import AWSValidationException
|
||||
|
||||
from moto.core.exceptions import AWSError
|
||||
from moto.core.responses import BaseResponse
|
||||
from moto.core.utils import amzn_request_id
|
||||
from .models import sagemaker_backends
|
||||
@ -117,48 +116,29 @@ class SageMakerResponse(BaseResponse):
|
||||
@amzn_request_id
|
||||
def list_tags(self):
|
||||
arn = self._get_param("ResourceArn")
|
||||
try:
|
||||
if ":notebook-instance/" in arn:
|
||||
tags = self.sagemaker_backend.get_notebook_instance_tags(arn)
|
||||
elif ":endpoint/" in arn:
|
||||
tags = self.sagemaker_backend.get_endpoint_tags(arn)
|
||||
elif ":training-job/" in arn:
|
||||
tags = self.sagemaker_backend.get_training_job_tags(arn)
|
||||
elif ":experiment/" in arn:
|
||||
tags = self.sagemaker_backend.get_experiment_tags(arn)
|
||||
elif ":experiment-trial/" in arn:
|
||||
tags = self.sagemaker_backend.get_trial_tags(arn)
|
||||
elif ":experiment-trial-component/" in arn:
|
||||
tags = self.sagemaker_backend.get_trial_component_tags(arn)
|
||||
else:
|
||||
tags = []
|
||||
except AWSError:
|
||||
tags = []
|
||||
response = {"Tags": tags}
|
||||
max_results = self._get_param("MaxResults")
|
||||
next_token = self._get_param("NextToken")
|
||||
paged_results, next_token = self.sagemaker_backend.list_tags(
|
||||
arn=arn, MaxResults=max_results, NextToken=next_token
|
||||
)
|
||||
response = {"Tags": paged_results}
|
||||
if next_token:
|
||||
response["NextToken"] = next_token
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def add_tags(self):
|
||||
arn = self._get_param("ResourceArn")
|
||||
tags = self._get_param("Tags")
|
||||
if ":experiment/" in arn:
|
||||
self.sagemaker_backend.add_tags_to_experiment(arn, tags)
|
||||
elif ":experiment-trial/" in arn:
|
||||
self.sagemaker_backend.add_tags_to_trial(arn, tags)
|
||||
elif ":experiment-trial-component/" in arn:
|
||||
self.sagemaker_backend.add_tags_to_trial_component(arn, tags)
|
||||
return 200, {}, json.dumps({})
|
||||
tags = self.sagemaker_backend.add_tags(arn, tags)
|
||||
response = {"Tags": tags}
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def delete_tags(self):
|
||||
arn = self._get_param("ResourceArn")
|
||||
tag_keys = self._get_param("TagKeys")
|
||||
if ":experiment/" in arn:
|
||||
self.sagemaker_backend.delete_tags_from_experiment(arn, tag_keys)
|
||||
elif ":experiment-trial/" in arn:
|
||||
self.sagemaker_backend.delete_tags_from_trial(arn, tag_keys)
|
||||
elif ":experiment-trial-component/" in arn:
|
||||
self.sagemaker_backend.delete_tags_from_trial_component(arn, tag_keys)
|
||||
self.sagemaker_backend.delete_tags(arn, tag_keys)
|
||||
return 200, {}, json.dumps({})
|
||||
|
||||
@amzn_request_id
|
||||
@ -222,6 +202,7 @@ class SageMakerResponse(BaseResponse):
|
||||
processing_output_config=self._get_param("ProcessingOutputConfig"),
|
||||
role_arn=self._get_param("RoleArn"),
|
||||
stopping_condition=self._get_param("StoppingCondition"),
|
||||
tags=self._get_param("Tags"),
|
||||
)
|
||||
response = {
|
||||
"ProcessingJobArn": processing_job.processing_job_arn,
|
||||
|
@ -1,4 +1,6 @@
|
||||
import datetime
|
||||
import uuid
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
import sure # noqa # pylint: disable=unused-import
|
||||
@ -8,115 +10,114 @@ from moto.sts.models import ACCOUNT_ID
|
||||
import pytest
|
||||
|
||||
TEST_REGION_NAME = "us-east-1"
|
||||
FAKE_ROLE_ARN = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)
|
||||
TEST_ROLE_ARN = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)
|
||||
GENERIC_TAGS_PARAM = [
|
||||
{"Key": "newkey1", "Value": "newval1"},
|
||||
{"Key": "newkey2", "Value": "newval2"},
|
||||
]
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_endpoint_config():
|
||||
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
model_name = "MyModel"
|
||||
production_variants = [
|
||||
TEST_MODEL_NAME = "MyModel"
|
||||
TEST_ENDPOINT_NAME = "MyEndpoint"
|
||||
TEST_ENDPOINT_CONFIG_NAME = "MyEndpointConfig"
|
||||
TEST_PRODUCTION_VARIANTS = [
|
||||
{
|
||||
"VariantName": "MyProductionVariant",
|
||||
"ModelName": model_name,
|
||||
"ModelName": TEST_MODEL_NAME,
|
||||
"InitialInstanceCount": 1,
|
||||
"InstanceType": "ml.t2.medium",
|
||||
},
|
||||
]
|
||||
|
||||
endpoint_config_name = "MyEndpointConfig"
|
||||
|
||||
@pytest.fixture
|
||||
def sagemaker_client():
|
||||
return boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_endpoint_config(sagemaker_client):
|
||||
with pytest.raises(ClientError) as e:
|
||||
sagemaker.create_endpoint_config(
|
||||
EndpointConfigName=endpoint_config_name,
|
||||
ProductionVariants=production_variants,
|
||||
sagemaker_client.create_endpoint_config(
|
||||
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME,
|
||||
ProductionVariants=TEST_PRODUCTION_VARIANTS,
|
||||
)
|
||||
assert e.value.response["Error"]["Message"].startswith("Could not find model")
|
||||
|
||||
_create_model(sagemaker, model_name)
|
||||
resp = sagemaker.create_endpoint_config(
|
||||
EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants
|
||||
_create_model(sagemaker_client, TEST_MODEL_NAME)
|
||||
resp = sagemaker_client.create_endpoint_config(
|
||||
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME,
|
||||
ProductionVariants=TEST_PRODUCTION_VARIANTS,
|
||||
)
|
||||
resp["EndpointConfigArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
|
||||
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(
|
||||
TEST_ENDPOINT_CONFIG_NAME
|
||||
)
|
||||
)
|
||||
|
||||
resp = sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
|
||||
resp["EndpointConfigArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
|
||||
resp = sagemaker_client.describe_endpoint_config(
|
||||
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME
|
||||
)
|
||||
resp["EndpointConfigName"].should.equal(endpoint_config_name)
|
||||
resp["ProductionVariants"].should.equal(production_variants)
|
||||
resp["EndpointConfigArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(
|
||||
TEST_ENDPOINT_CONFIG_NAME
|
||||
)
|
||||
)
|
||||
resp["EndpointConfigName"].should.equal(TEST_ENDPOINT_CONFIG_NAME)
|
||||
resp["ProductionVariants"].should.equal(TEST_PRODUCTION_VARIANTS)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_delete_endpoint_config():
|
||||
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
model_name = "MyModel"
|
||||
_create_model(sagemaker, model_name)
|
||||
|
||||
endpoint_config_name = "MyEndpointConfig"
|
||||
production_variants = [
|
||||
{
|
||||
"VariantName": "MyProductionVariant",
|
||||
"ModelName": model_name,
|
||||
"InitialInstanceCount": 1,
|
||||
"InstanceType": "ml.t2.medium",
|
||||
},
|
||||
]
|
||||
|
||||
resp = sagemaker.create_endpoint_config(
|
||||
EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants
|
||||
def test_delete_endpoint_config(sagemaker_client):
|
||||
_create_model(sagemaker_client, TEST_MODEL_NAME)
|
||||
resp = sagemaker_client.create_endpoint_config(
|
||||
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME,
|
||||
ProductionVariants=TEST_PRODUCTION_VARIANTS,
|
||||
)
|
||||
resp["EndpointConfigArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
|
||||
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(
|
||||
TEST_ENDPOINT_CONFIG_NAME
|
||||
)
|
||||
)
|
||||
|
||||
resp = sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
|
||||
resp = sagemaker_client.describe_endpoint_config(
|
||||
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME
|
||||
)
|
||||
resp["EndpointConfigArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
|
||||
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(
|
||||
TEST_ENDPOINT_CONFIG_NAME
|
||||
)
|
||||
)
|
||||
|
||||
resp = sagemaker.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
|
||||
sagemaker_client.delete_endpoint_config(
|
||||
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME
|
||||
)
|
||||
with pytest.raises(ClientError) as e:
|
||||
sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
|
||||
sagemaker_client.describe_endpoint_config(
|
||||
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME
|
||||
)
|
||||
assert e.value.response["Error"]["Message"].startswith(
|
||||
"Could not find endpoint configuration"
|
||||
)
|
||||
|
||||
with pytest.raises(ClientError) as e:
|
||||
sagemaker.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
|
||||
sagemaker_client.delete_endpoint_config(
|
||||
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME
|
||||
)
|
||||
assert e.value.response["Error"]["Message"].startswith(
|
||||
"Could not find endpoint configuration"
|
||||
)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_endpoint_invalid_instance_type():
|
||||
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
model_name = "MyModel"
|
||||
_create_model(sagemaker, model_name)
|
||||
def test_create_endpoint_invalid_instance_type(sagemaker_client):
|
||||
_create_model(sagemaker_client, TEST_MODEL_NAME)
|
||||
|
||||
instance_type = "InvalidInstanceType"
|
||||
production_variants = [
|
||||
{
|
||||
"VariantName": "MyProductionVariant",
|
||||
"ModelName": model_name,
|
||||
"InitialInstanceCount": 1,
|
||||
"InstanceType": instance_type,
|
||||
},
|
||||
]
|
||||
production_variants = TEST_PRODUCTION_VARIANTS
|
||||
production_variants[0]["InstanceType"] = instance_type
|
||||
|
||||
endpoint_config_name = "MyEndpointConfig"
|
||||
with pytest.raises(ClientError) as e:
|
||||
sagemaker.create_endpoint_config(
|
||||
EndpointConfigName=endpoint_config_name,
|
||||
sagemaker_client.create_endpoint_config(
|
||||
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME,
|
||||
ProductionVariants=production_variants,
|
||||
)
|
||||
assert e.value.response["Error"]["Code"] == "ValidationException"
|
||||
@ -127,71 +128,131 @@ def test_create_endpoint_invalid_instance_type():
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_endpoint():
|
||||
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
endpoint_name = "MyEndpoint"
|
||||
def test_create_endpoint(sagemaker_client):
|
||||
with pytest.raises(ClientError) as e:
|
||||
sagemaker.create_endpoint(
|
||||
EndpointName=endpoint_name, EndpointConfigName="NonexistentEndpointConfig"
|
||||
sagemaker_client.create_endpoint(
|
||||
EndpointName=TEST_ENDPOINT_NAME,
|
||||
EndpointConfigName="NonexistentEndpointConfig",
|
||||
)
|
||||
assert e.value.response["Error"]["Message"].startswith(
|
||||
"Could not find endpoint configuration"
|
||||
)
|
||||
|
||||
model_name = "MyModel"
|
||||
_create_model(sagemaker, model_name)
|
||||
_create_model(sagemaker_client, TEST_MODEL_NAME)
|
||||
|
||||
endpoint_config_name = "MyEndpointConfig"
|
||||
_create_endpoint_config(sagemaker, endpoint_config_name, model_name)
|
||||
_create_endpoint_config(
|
||||
sagemaker_client, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
|
||||
)
|
||||
|
||||
resp = sagemaker.create_endpoint(
|
||||
EndpointName=endpoint_name,
|
||||
EndpointConfigName=endpoint_config_name,
|
||||
resp = sagemaker_client.create_endpoint(
|
||||
EndpointName=TEST_ENDPOINT_NAME,
|
||||
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME,
|
||||
Tags=GENERIC_TAGS_PARAM,
|
||||
)
|
||||
resp["EndpointArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(endpoint_name)
|
||||
r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(TEST_ENDPOINT_NAME)
|
||||
)
|
||||
|
||||
resp = sagemaker.describe_endpoint(EndpointName=endpoint_name)
|
||||
resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME)
|
||||
resp["EndpointArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(endpoint_name)
|
||||
r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(TEST_ENDPOINT_NAME)
|
||||
)
|
||||
resp["EndpointName"].should.equal(endpoint_name)
|
||||
resp["EndpointConfigName"].should.equal(endpoint_config_name)
|
||||
resp["EndpointName"].should.equal(TEST_ENDPOINT_NAME)
|
||||
resp["EndpointConfigName"].should.equal(TEST_ENDPOINT_CONFIG_NAME)
|
||||
resp["EndpointStatus"].should.equal("InService")
|
||||
assert isinstance(resp["CreationTime"], datetime.datetime)
|
||||
assert isinstance(resp["LastModifiedTime"], datetime.datetime)
|
||||
resp["ProductionVariants"][0]["VariantName"].should.equal("MyProductionVariant")
|
||||
|
||||
resp = sagemaker.list_tags(ResourceArn=resp["EndpointArn"])
|
||||
resp = sagemaker_client.list_tags(ResourceArn=resp["EndpointArn"])
|
||||
assert resp["Tags"] == GENERIC_TAGS_PARAM
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_delete_endpoint():
|
||||
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
def test_delete_endpoint(sagemaker_client):
|
||||
_set_up_sagemaker_resources(
|
||||
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
|
||||
)
|
||||
|
||||
model_name = "MyModel"
|
||||
_create_model(sagemaker, model_name)
|
||||
|
||||
endpoint_config_name = "MyEndpointConfig"
|
||||
_create_endpoint_config(sagemaker, endpoint_config_name, model_name)
|
||||
|
||||
endpoint_name = "MyEndpoint"
|
||||
_create_endpoint(sagemaker, endpoint_name, endpoint_config_name)
|
||||
|
||||
sagemaker.delete_endpoint(EndpointName=endpoint_name)
|
||||
sagemaker_client.delete_endpoint(EndpointName=TEST_ENDPOINT_NAME)
|
||||
with pytest.raises(ClientError) as e:
|
||||
sagemaker.describe_endpoint(EndpointName=endpoint_name)
|
||||
sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME)
|
||||
assert e.value.response["Error"]["Message"].startswith("Could not find endpoint")
|
||||
|
||||
with pytest.raises(ClientError) as e:
|
||||
sagemaker.delete_endpoint(EndpointName=endpoint_name)
|
||||
sagemaker_client.delete_endpoint(EndpointName=TEST_ENDPOINT_NAME)
|
||||
assert e.value.response["Error"]["Message"].startswith("Could not find endpoint")
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_add_tags_endpoint(sagemaker_client):
|
||||
_set_up_sagemaker_resources(
|
||||
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
|
||||
)
|
||||
|
||||
resource_arn = f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:endpoint/{TEST_ENDPOINT_NAME}"
|
||||
response = sagemaker_client.add_tags(
|
||||
ResourceArn=resource_arn, Tags=GENERIC_TAGS_PARAM
|
||||
)
|
||||
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
response = sagemaker_client.list_tags(ResourceArn=resource_arn)
|
||||
assert response["Tags"] == GENERIC_TAGS_PARAM
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_delete_tags_endpoint(sagemaker_client):
|
||||
_set_up_sagemaker_resources(
|
||||
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
|
||||
)
|
||||
|
||||
resource_arn = f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:endpoint/{TEST_ENDPOINT_NAME}"
|
||||
response = sagemaker_client.add_tags(
|
||||
ResourceArn=resource_arn, Tags=GENERIC_TAGS_PARAM
|
||||
)
|
||||
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
tag_keys = [tag["Key"] for tag in GENERIC_TAGS_PARAM]
|
||||
response = sagemaker_client.delete_tags(ResourceArn=resource_arn, TagKeys=tag_keys)
|
||||
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
response = sagemaker_client.list_tags(ResourceArn=resource_arn)
|
||||
assert response["Tags"] == []
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_tags_endpoint(sagemaker_client):
|
||||
_set_up_sagemaker_resources(
|
||||
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
|
||||
)
|
||||
|
||||
tags = []
|
||||
for _ in range(80):
|
||||
tags.append({"Key": str(uuid.uuid4()), "Value": "myValue"})
|
||||
|
||||
resource_arn = f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:endpoint/{TEST_ENDPOINT_NAME}"
|
||||
response = sagemaker_client.add_tags(ResourceArn=resource_arn, Tags=tags)
|
||||
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
response = sagemaker_client.list_tags(ResourceArn=resource_arn)
|
||||
assert len(response["Tags"]) == 50
|
||||
assert response["Tags"] == tags[:50]
|
||||
|
||||
response = sagemaker_client.list_tags(
|
||||
ResourceArn=resource_arn, NextToken=response["NextToken"]
|
||||
)
|
||||
assert len(response["Tags"]) == 30
|
||||
assert response["Tags"] == tags[50:]
|
||||
|
||||
|
||||
def _set_up_sagemaker_resources(
|
||||
boto_client, endpoint_name, endpoint_config_name, model_name
|
||||
):
|
||||
_create_model(boto_client, model_name)
|
||||
_create_endpoint_config(boto_client, endpoint_config_name, model_name)
|
||||
_create_endpoint(boto_client, endpoint_name, endpoint_config_name)
|
||||
|
||||
|
||||
def _create_model(boto_client, model_name):
|
||||
resp = boto_client.create_model(
|
||||
ModelName=model_name,
|
||||
@ -199,7 +260,7 @@ def _create_model(boto_client, model_name):
|
||||
"Image": "382416733822.dkr.ecr.us-east-1.amazonaws.com/factorization-machines:1",
|
||||
"ModelDataUrl": "s3://MyBucket/model.tar.gz",
|
||||
},
|
||||
ExecutionRoleArn=FAKE_ROLE_ARN,
|
||||
ExecutionRoleArn=TEST_ROLE_ARN,
|
||||
)
|
||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
|
@ -1,54 +1,56 @@
|
||||
import boto3
|
||||
import pytest
|
||||
|
||||
from moto import mock_sagemaker
|
||||
from moto.sts.models import ACCOUNT_ID
|
||||
|
||||
TEST_REGION_NAME = "us-east-1"
|
||||
TEST_EXPERIMENT_NAME = "MyExperimentName"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sagemaker_client():
|
||||
return boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_experiment():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
experiment_name = "some-experiment-name"
|
||||
|
||||
resp = client.create_experiment(ExperimentName=experiment_name)
|
||||
def test_create_experiment(sagemaker_client):
|
||||
resp = sagemaker_client.create_experiment(ExperimentName=TEST_EXPERIMENT_NAME)
|
||||
|
||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
resp = client.list_experiments()
|
||||
resp = sagemaker_client.list_experiments()
|
||||
|
||||
assert len(resp["ExperimentSummaries"]) == 1
|
||||
assert resp["ExperimentSummaries"][0]["ExperimentName"] == experiment_name
|
||||
assert resp["ExperimentSummaries"][0]["ExperimentName"] == TEST_EXPERIMENT_NAME
|
||||
assert (
|
||||
resp["ExperimentSummaries"][0]["ExperimentArn"]
|
||||
== f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment/{experiment_name}"
|
||||
== f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment/{TEST_EXPERIMENT_NAME}"
|
||||
)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_experiments():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
def test_list_experiments(sagemaker_client):
|
||||
|
||||
experiment_names = [f"some-experiment-name-{i}" for i in range(10)]
|
||||
|
||||
for experiment_name in experiment_names:
|
||||
resp = client.create_experiment(ExperimentName=experiment_name)
|
||||
resp = sagemaker_client.create_experiment(ExperimentName=experiment_name)
|
||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
resp = client.list_experiments(MaxResults=1)
|
||||
resp = sagemaker_client.list_experiments(MaxResults=1)
|
||||
|
||||
assert len(resp["ExperimentSummaries"]) == 1
|
||||
|
||||
next_token = resp["NextToken"]
|
||||
|
||||
resp = client.list_experiments(MaxResults=2, NextToken=next_token)
|
||||
resp = sagemaker_client.list_experiments(MaxResults=2, NextToken=next_token)
|
||||
|
||||
assert len(resp["ExperimentSummaries"]) == 2
|
||||
|
||||
next_token = resp["NextToken"]
|
||||
|
||||
resp = client.list_experiments(NextToken=next_token)
|
||||
resp = sagemaker_client.list_experiments(NextToken=next_token)
|
||||
|
||||
assert len(resp["ExperimentSummaries"]) == 7
|
||||
|
||||
@ -56,65 +58,53 @@ def test_list_experiments():
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_delete_experiment():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
def test_delete_experiment(sagemaker_client):
|
||||
sagemaker_client.create_experiment(ExperimentName=TEST_EXPERIMENT_NAME)
|
||||
|
||||
experiment_name = "some-experiment-name"
|
||||
|
||||
resp = client.create_experiment(ExperimentName=experiment_name)
|
||||
|
||||
resp = client.delete_experiment(ExperimentName=experiment_name)
|
||||
resp = sagemaker_client.delete_experiment(ExperimentName=TEST_EXPERIMENT_NAME)
|
||||
|
||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
resp = client.list_experiments()
|
||||
resp = sagemaker_client.list_experiments()
|
||||
|
||||
assert len(resp["ExperimentSummaries"]) == 0
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_add_tags_to_experiment():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
def test_add_tags_to_experiment(sagemaker_client):
|
||||
sagemaker_client.create_experiment(ExperimentName=TEST_EXPERIMENT_NAME)
|
||||
|
||||
experiment_name = "some-experiment-name"
|
||||
|
||||
resp = client.create_experiment(ExperimentName=experiment_name)
|
||||
|
||||
resp = client.describe_experiment(ExperimentName=experiment_name)
|
||||
resp = sagemaker_client.describe_experiment(ExperimentName=TEST_EXPERIMENT_NAME)
|
||||
|
||||
arn = resp["ExperimentArn"]
|
||||
|
||||
tags = [{"Key": "name", "Value": "value"}]
|
||||
|
||||
client.add_tags(ResourceArn=arn, Tags=tags)
|
||||
sagemaker_client.add_tags(ResourceArn=arn, Tags=tags)
|
||||
|
||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
resp = client.list_tags(ResourceArn=arn)
|
||||
resp = sagemaker_client.list_tags(ResourceArn=arn)
|
||||
|
||||
assert resp["Tags"] == tags
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_delete_tags_to_experiment():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
def test_delete_tags_to_experiment(sagemaker_client):
|
||||
sagemaker_client.create_experiment(ExperimentName=TEST_EXPERIMENT_NAME)
|
||||
|
||||
experiment_name = "some-experiment-name"
|
||||
|
||||
resp = client.create_experiment(ExperimentName=experiment_name)
|
||||
|
||||
resp = client.describe_experiment(ExperimentName=experiment_name)
|
||||
resp = sagemaker_client.describe_experiment(ExperimentName=TEST_EXPERIMENT_NAME)
|
||||
|
||||
arn = resp["ExperimentArn"]
|
||||
|
||||
tags = [{"Key": "name", "Value": "value"}]
|
||||
|
||||
client.add_tags(ResourceArn=arn, Tags=tags)
|
||||
sagemaker_client.add_tags(ResourceArn=arn, Tags=tags)
|
||||
|
||||
client.delete_tags(ResourceArn=arn, TagKeys=[i["Key"] for i in tags])
|
||||
sagemaker_client.delete_tags(ResourceArn=arn, TagKeys=[i["Key"] for i in tags])
|
||||
|
||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
resp = client.list_tags(ResourceArn=arn)
|
||||
resp = sagemaker_client.list_tags(ResourceArn=arn)
|
||||
|
||||
assert resp["Tags"] == []
|
||||
|
@ -7,118 +7,142 @@ import sure # noqa # pylint: disable=unused-import
|
||||
|
||||
from moto.sagemaker.models import VpcConfig
|
||||
|
||||
TEST_REGION_NAME = "us-east-1"
|
||||
TEST_ARN = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar"
|
||||
TEST_MODEL_NAME = "MyModelName"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sagemaker_client():
|
||||
return boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
|
||||
class MySageMakerModel(object):
|
||||
def __init__(self, name, arn, container=None, vpc_config=None):
|
||||
self.name = name
|
||||
self.arn = arn
|
||||
def __init__(self, name=None, arn=None, container=None, vpc_config=None):
|
||||
self.name = name or TEST_MODEL_NAME
|
||||
self.arn = arn or TEST_ARN
|
||||
self.container = container or {}
|
||||
self.vpc_config = vpc_config or {"sg-groups": ["sg-123"], "subnets": ["123"]}
|
||||
|
||||
def save(self):
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
def save(self, sagemaker_client):
|
||||
vpc_config = VpcConfig(
|
||||
self.vpc_config.get("sg-groups"), self.vpc_config.get("subnets")
|
||||
)
|
||||
client.create_model(
|
||||
resp = sagemaker_client.create_model(
|
||||
ModelName=self.name,
|
||||
ExecutionRoleArn=self.arn,
|
||||
VpcConfig=vpc_config.response_object,
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_describe_model():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
test_model = MySageMakerModel(
|
||||
name="blah",
|
||||
arn="arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar",
|
||||
vpc_config={"sg-groups": ["sg-123"], "subnets": ["123"]},
|
||||
)
|
||||
test_model.save()
|
||||
model = client.describe_model(ModelName="blah")
|
||||
assert model.get("ModelName").should.equal("blah")
|
||||
def test_describe_model(sagemaker_client):
|
||||
test_model = MySageMakerModel()
|
||||
test_model.save(sagemaker_client)
|
||||
model = sagemaker_client.describe_model(ModelName=TEST_MODEL_NAME)
|
||||
assert model.get("ModelName").should.equal(TEST_MODEL_NAME)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_describe_model_not_found():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
def test_describe_model_not_found(sagemaker_client):
|
||||
with pytest.raises(ClientError) as err:
|
||||
client.describe_model(ModelName="unknown")
|
||||
sagemaker_client.describe_model(ModelName="unknown")
|
||||
assert err.value.response["Error"]["Message"].should.contain("Could not find model")
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_model():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
def test_create_model(sagemaker_client):
|
||||
vpc_config = VpcConfig(["sg-foobar"], ["subnet-xxx"])
|
||||
exec_role_arn = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar"
|
||||
name = "blah"
|
||||
model = client.create_model(
|
||||
ModelName=name,
|
||||
ExecutionRoleArn=exec_role_arn,
|
||||
model = sagemaker_client.create_model(
|
||||
ModelName=TEST_MODEL_NAME,
|
||||
ExecutionRoleArn=TEST_ARN,
|
||||
VpcConfig=vpc_config.response_object,
|
||||
)
|
||||
|
||||
model["ModelArn"].should.match(r"^arn:aws:sagemaker:.*:.*:model/{}$".format(name))
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_delete_model():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
name = "blah"
|
||||
arn = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar"
|
||||
test_model = MySageMakerModel(name=name, arn=arn)
|
||||
test_model.save()
|
||||
|
||||
assert len(client.list_models()["Models"]).should.equal(1)
|
||||
client.delete_model(ModelName=name)
|
||||
assert len(client.list_models()["Models"]).should.equal(0)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_delete_model_not_found():
|
||||
with pytest.raises(ClientError) as err:
|
||||
boto3.client("sagemaker", region_name="us-east-1").delete_model(
|
||||
ModelName="blah"
|
||||
model["ModelArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:model/{}$".format(TEST_MODEL_NAME)
|
||||
)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_delete_model(sagemaker_client):
|
||||
test_model = MySageMakerModel()
|
||||
test_model.save(sagemaker_client)
|
||||
|
||||
assert len(sagemaker_client.list_models()["Models"]).should.equal(1)
|
||||
sagemaker_client.delete_model(ModelName=TEST_MODEL_NAME)
|
||||
assert len(sagemaker_client.list_models()["Models"]).should.equal(0)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_delete_model_not_found(sagemaker_client):
|
||||
with pytest.raises(ClientError) as err:
|
||||
sagemaker_client.delete_model(ModelName="blah")
|
||||
assert err.value.response["Error"]["Code"].should.equal("404")
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_models():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
name = "blah"
|
||||
arn = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar"
|
||||
test_model = MySageMakerModel(name=name, arn=arn)
|
||||
test_model.save()
|
||||
models = client.list_models()
|
||||
def test_list_models(sagemaker_client):
|
||||
test_model = MySageMakerModel()
|
||||
test_model.save(sagemaker_client)
|
||||
models = sagemaker_client.list_models()
|
||||
assert len(models["Models"]).should.equal(1)
|
||||
assert models["Models"][0]["ModelName"].should.equal(name)
|
||||
assert models["Models"][0]["ModelName"].should.equal(TEST_MODEL_NAME)
|
||||
assert models["Models"][0]["ModelArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:model/{}$".format(name)
|
||||
r"^arn:aws:sagemaker:.*:.*:model/{}$".format(TEST_MODEL_NAME)
|
||||
)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_models_multiple():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
|
||||
def test_list_models_multiple(sagemaker_client):
|
||||
name_model_1 = "blah"
|
||||
arn_model_1 = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar"
|
||||
test_model_1 = MySageMakerModel(name=name_model_1, arn=arn_model_1)
|
||||
test_model_1.save()
|
||||
test_model_1.save(sagemaker_client)
|
||||
|
||||
name_model_2 = "blah2"
|
||||
arn_model_2 = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar2"
|
||||
test_model_2 = MySageMakerModel(name=name_model_2, arn=arn_model_2)
|
||||
test_model_2.save()
|
||||
models = client.list_models()
|
||||
test_model_2.save(sagemaker_client)
|
||||
models = sagemaker_client.list_models()
|
||||
assert len(models["Models"]).should.equal(2)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_models_none():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
models = client.list_models()
|
||||
def test_list_models_none(sagemaker_client):
|
||||
models = sagemaker_client.list_models()
|
||||
assert len(models["Models"]).should.equal(0)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_add_tags_to_model(sagemaker_client):
|
||||
model = MySageMakerModel().save(sagemaker_client)
|
||||
resource_arn = model["ModelArn"]
|
||||
|
||||
tags = [
|
||||
{"Key": "myKey", "Value": "myValue"},
|
||||
]
|
||||
response = sagemaker_client.add_tags(ResourceArn=resource_arn, Tags=tags)
|
||||
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
response = sagemaker_client.list_tags(ResourceArn=resource_arn)
|
||||
assert response["Tags"] == tags
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_delete_tags_from_model(sagemaker_client):
|
||||
model = MySageMakerModel().save(sagemaker_client)
|
||||
resource_arn = model["ModelArn"]
|
||||
|
||||
tags = [
|
||||
{"Key": "myKey", "Value": "myValue"},
|
||||
]
|
||||
response = sagemaker_client.add_tags(ResourceArn=resource_arn, Tags=tags)
|
||||
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
tag_keys = [tag["Key"] for tag in tags]
|
||||
response = sagemaker_client.delete_tags(ResourceArn=resource_arn, TagKeys=tag_keys)
|
||||
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
response = sagemaker_client.list_tags(ResourceArn=resource_arn)
|
||||
assert response["Tags"] == []
|
||||
|
@ -22,34 +22,42 @@ FAKE_ADDL_CODE_REPOS = [
|
||||
"https://github.com/user/repo2",
|
||||
"https://github.com/user/repo2",
|
||||
]
|
||||
FAKE_NAME_PARAM = "MyNotebookInstance"
|
||||
FAKE_INSTANCE_TYPE_PARAM = "ml.t2.medium"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sagemaker_client():
|
||||
return boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
|
||||
def _get_notebook_instance_arn(notebook_name):
|
||||
return f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:notebook-instance/{notebook_name}"
|
||||
|
||||
|
||||
def _get_notebook_instance_lifecycle_arn(lifecycle_name):
|
||||
return f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:notebook-instance-lifecycle-configuration/{lifecycle_name}"
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_notebook_instance_minimal_params():
|
||||
|
||||
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
NAME_PARAM = "MyNotebookInstance"
|
||||
INSTANCE_TYPE_PARAM = "ml.t2.medium"
|
||||
|
||||
def test_create_notebook_instance_minimal_params(sagemaker_client):
|
||||
args = {
|
||||
"NotebookInstanceName": NAME_PARAM,
|
||||
"InstanceType": INSTANCE_TYPE_PARAM,
|
||||
"NotebookInstanceName": FAKE_NAME_PARAM,
|
||||
"InstanceType": FAKE_INSTANCE_TYPE_PARAM,
|
||||
"RoleArn": FAKE_ROLE_ARN,
|
||||
}
|
||||
resp = sagemaker.create_notebook_instance(**args)
|
||||
assert resp["NotebookInstanceArn"].startswith("arn:aws:sagemaker")
|
||||
assert resp["NotebookInstanceArn"].endswith(args["NotebookInstanceName"])
|
||||
resp = sagemaker_client.create_notebook_instance(**args)
|
||||
expected_notebook_arn = _get_notebook_instance_arn(FAKE_NAME_PARAM)
|
||||
assert resp["NotebookInstanceArn"] == expected_notebook_arn
|
||||
|
||||
resp = sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM)
|
||||
assert resp["NotebookInstanceArn"].startswith("arn:aws:sagemaker")
|
||||
assert resp["NotebookInstanceArn"].endswith(args["NotebookInstanceName"])
|
||||
assert resp["NotebookInstanceName"] == NAME_PARAM
|
||||
assert resp["NotebookInstanceStatus"] == "InService"
|
||||
assert resp["Url"] == "{}.notebook.{}.sagemaker.aws".format(
|
||||
NAME_PARAM, TEST_REGION_NAME
|
||||
resp = sagemaker_client.describe_notebook_instance(
|
||||
NotebookInstanceName=FAKE_NAME_PARAM
|
||||
)
|
||||
assert resp["InstanceType"] == INSTANCE_TYPE_PARAM
|
||||
assert resp["NotebookInstanceArn"] == expected_notebook_arn
|
||||
assert resp["NotebookInstanceName"] == FAKE_NAME_PARAM
|
||||
assert resp["NotebookInstanceStatus"] == "InService"
|
||||
assert resp["Url"] == f"{FAKE_NAME_PARAM}.notebook.{TEST_REGION_NAME}.sagemaker.aws"
|
||||
assert resp["InstanceType"] == FAKE_INSTANCE_TYPE_PARAM
|
||||
assert resp["RoleArn"] == FAKE_ROLE_ARN
|
||||
assert isinstance(resp["LastModifiedTime"], datetime.datetime)
|
||||
assert isinstance(resp["CreationTime"], datetime.datetime)
|
||||
@ -61,69 +69,60 @@ def test_create_notebook_instance_minimal_params():
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_notebook_instance_params():
|
||||
|
||||
sagemaker = boto3.client("sagemaker", region_name="us-east-1")
|
||||
|
||||
NAME_PARAM = "MyNotebookInstance"
|
||||
INSTANCE_TYPE_PARAM = "ml.t2.medium"
|
||||
DIRECT_INTERNET_ACCESS_PARAM = "Enabled"
|
||||
VOLUME_SIZE_IN_GB_PARAM = 7
|
||||
ACCELERATOR_TYPES_PARAM = ["ml.eia1.medium", "ml.eia2.medium"]
|
||||
ROOT_ACCESS_PARAM = "Disabled"
|
||||
def test_create_notebook_instance_params(sagemaker_client):
|
||||
fake_direct_internet_access_param = "Enabled"
|
||||
volume_size_in_gb_param = 7
|
||||
accelerator_types_param = ["ml.eia1.medium", "ml.eia2.medium"]
|
||||
root_access_param = "Disabled"
|
||||
|
||||
args = {
|
||||
"NotebookInstanceName": NAME_PARAM,
|
||||
"InstanceType": INSTANCE_TYPE_PARAM,
|
||||
"NotebookInstanceName": FAKE_NAME_PARAM,
|
||||
"InstanceType": FAKE_INSTANCE_TYPE_PARAM,
|
||||
"SubnetId": FAKE_SUBNET_ID,
|
||||
"SecurityGroupIds": FAKE_SECURITY_GROUP_IDS,
|
||||
"RoleArn": FAKE_ROLE_ARN,
|
||||
"KmsKeyId": FAKE_KMS_KEY_ID,
|
||||
"Tags": GENERIC_TAGS_PARAM,
|
||||
"LifecycleConfigName": FAKE_LIFECYCLE_CONFIG_NAME,
|
||||
"DirectInternetAccess": DIRECT_INTERNET_ACCESS_PARAM,
|
||||
"VolumeSizeInGB": VOLUME_SIZE_IN_GB_PARAM,
|
||||
"AcceleratorTypes": ACCELERATOR_TYPES_PARAM,
|
||||
"DirectInternetAccess": fake_direct_internet_access_param,
|
||||
"VolumeSizeInGB": volume_size_in_gb_param,
|
||||
"AcceleratorTypes": accelerator_types_param,
|
||||
"DefaultCodeRepository": FAKE_DEFAULT_CODE_REPO,
|
||||
"AdditionalCodeRepositories": FAKE_ADDL_CODE_REPOS,
|
||||
"RootAccess": ROOT_ACCESS_PARAM,
|
||||
"RootAccess": root_access_param,
|
||||
}
|
||||
resp = sagemaker.create_notebook_instance(**args)
|
||||
assert resp["NotebookInstanceArn"].startswith("arn:aws:sagemaker")
|
||||
assert resp["NotebookInstanceArn"].endswith(args["NotebookInstanceName"])
|
||||
resp = sagemaker_client.create_notebook_instance(**args)
|
||||
expected_notebook_arn = _get_notebook_instance_arn(FAKE_NAME_PARAM)
|
||||
assert resp["NotebookInstanceArn"] == expected_notebook_arn
|
||||
|
||||
resp = sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM)
|
||||
assert resp["NotebookInstanceArn"].startswith("arn:aws:sagemaker")
|
||||
assert resp["NotebookInstanceArn"].endswith(args["NotebookInstanceName"])
|
||||
assert resp["NotebookInstanceName"] == NAME_PARAM
|
||||
assert resp["NotebookInstanceStatus"] == "InService"
|
||||
assert resp["Url"] == "{}.notebook.{}.sagemaker.aws".format(
|
||||
NAME_PARAM, TEST_REGION_NAME
|
||||
resp = sagemaker_client.describe_notebook_instance(
|
||||
NotebookInstanceName=FAKE_NAME_PARAM
|
||||
)
|
||||
assert resp["InstanceType"] == INSTANCE_TYPE_PARAM
|
||||
assert resp["NotebookInstanceArn"] == expected_notebook_arn
|
||||
assert resp["NotebookInstanceName"] == FAKE_NAME_PARAM
|
||||
assert resp["NotebookInstanceStatus"] == "InService"
|
||||
assert resp["Url"] == f"{FAKE_NAME_PARAM}.notebook.{TEST_REGION_NAME}.sagemaker.aws"
|
||||
assert resp["InstanceType"] == FAKE_INSTANCE_TYPE_PARAM
|
||||
assert resp["RoleArn"] == FAKE_ROLE_ARN
|
||||
assert isinstance(resp["LastModifiedTime"], datetime.datetime)
|
||||
assert isinstance(resp["CreationTime"], datetime.datetime)
|
||||
assert resp["DirectInternetAccess"] == "Enabled"
|
||||
assert resp["VolumeSizeInGB"] == VOLUME_SIZE_IN_GB_PARAM
|
||||
assert resp["VolumeSizeInGB"] == volume_size_in_gb_param
|
||||
# assert resp["RootAccess"] == True # ToDo: Not sure if this defaults...
|
||||
assert resp["SubnetId"] == FAKE_SUBNET_ID
|
||||
assert resp["SecurityGroups"] == FAKE_SECURITY_GROUP_IDS
|
||||
assert resp["KmsKeyId"] == FAKE_KMS_KEY_ID
|
||||
assert resp["NotebookInstanceLifecycleConfigName"] == FAKE_LIFECYCLE_CONFIG_NAME
|
||||
assert resp["AcceleratorTypes"] == ACCELERATOR_TYPES_PARAM
|
||||
assert resp["AcceleratorTypes"] == accelerator_types_param
|
||||
assert resp["DefaultCodeRepository"] == FAKE_DEFAULT_CODE_REPO
|
||||
assert resp["AdditionalCodeRepositories"] == FAKE_ADDL_CODE_REPOS
|
||||
|
||||
resp = sagemaker.list_tags(ResourceArn=resp["NotebookInstanceArn"])
|
||||
resp = sagemaker_client.list_tags(ResourceArn=resp["NotebookInstanceArn"])
|
||||
assert resp["Tags"] == GENERIC_TAGS_PARAM
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_notebook_instance_invalid_instance_type():
|
||||
|
||||
sagemaker = boto3.client("sagemaker", region_name="us-east-1")
|
||||
|
||||
def test_create_notebook_instance_invalid_instance_type(sagemaker_client):
|
||||
instance_type = "undefined_instance_type"
|
||||
args = {
|
||||
"NotebookInstanceName": "MyNotebookInstance",
|
||||
@ -131,7 +130,7 @@ def test_create_notebook_instance_invalid_instance_type():
|
||||
"RoleArn": FAKE_ROLE_ARN,
|
||||
}
|
||||
with pytest.raises(ClientError) as ex:
|
||||
sagemaker.create_notebook_instance(**args)
|
||||
sagemaker_client.create_notebook_instance(**args)
|
||||
assert ex.value.response["Error"]["Code"] == "ValidationException"
|
||||
expected_message = "Value '{}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: [".format(
|
||||
instance_type
|
||||
@ -141,78 +140,79 @@ def test_create_notebook_instance_invalid_instance_type():
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_notebook_instance_lifecycle():
|
||||
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
NAME_PARAM = "MyNotebookInstance"
|
||||
INSTANCE_TYPE_PARAM = "ml.t2.medium"
|
||||
|
||||
def test_notebook_instance_lifecycle(sagemaker_client):
|
||||
args = {
|
||||
"NotebookInstanceName": NAME_PARAM,
|
||||
"InstanceType": INSTANCE_TYPE_PARAM,
|
||||
"NotebookInstanceName": FAKE_NAME_PARAM,
|
||||
"InstanceType": FAKE_INSTANCE_TYPE_PARAM,
|
||||
"RoleArn": FAKE_ROLE_ARN,
|
||||
}
|
||||
resp = sagemaker.create_notebook_instance(**args)
|
||||
assert resp["NotebookInstanceArn"].startswith("arn:aws:sagemaker")
|
||||
assert resp["NotebookInstanceArn"].endswith(args["NotebookInstanceName"])
|
||||
resp = sagemaker_client.create_notebook_instance(**args)
|
||||
expected_notebook_arn = _get_notebook_instance_arn(FAKE_NAME_PARAM)
|
||||
assert resp["NotebookInstanceArn"] == expected_notebook_arn
|
||||
|
||||
resp = sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM)
|
||||
resp = sagemaker_client.describe_notebook_instance(
|
||||
NotebookInstanceName=FAKE_NAME_PARAM
|
||||
)
|
||||
notebook_instance_arn = resp["NotebookInstanceArn"]
|
||||
|
||||
with pytest.raises(ClientError) as ex:
|
||||
sagemaker.delete_notebook_instance(NotebookInstanceName=NAME_PARAM)
|
||||
sagemaker_client.delete_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
|
||||
assert ex.value.response["Error"]["Code"] == "ValidationException"
|
||||
expected_message = "Status (InService) not in ([Stopped, Failed]). Unable to transition to (Deleting) for Notebook Instance ({})".format(
|
||||
notebook_instance_arn
|
||||
)
|
||||
assert expected_message in ex.value.response["Error"]["Message"]
|
||||
|
||||
sagemaker.stop_notebook_instance(NotebookInstanceName=NAME_PARAM)
|
||||
sagemaker_client.stop_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
|
||||
|
||||
resp = sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM)
|
||||
resp = sagemaker_client.describe_notebook_instance(
|
||||
NotebookInstanceName=FAKE_NAME_PARAM
|
||||
)
|
||||
assert resp["NotebookInstanceStatus"] == "Stopped"
|
||||
|
||||
sagemaker.start_notebook_instance(NotebookInstanceName=NAME_PARAM)
|
||||
sagemaker_client.start_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
|
||||
|
||||
resp = sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM)
|
||||
resp = sagemaker_client.describe_notebook_instance(
|
||||
NotebookInstanceName=FAKE_NAME_PARAM
|
||||
)
|
||||
assert resp["NotebookInstanceStatus"] == "InService"
|
||||
|
||||
sagemaker.stop_notebook_instance(NotebookInstanceName=NAME_PARAM)
|
||||
sagemaker_client.stop_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
|
||||
|
||||
resp = sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM)
|
||||
resp = sagemaker_client.describe_notebook_instance(
|
||||
NotebookInstanceName=FAKE_NAME_PARAM
|
||||
)
|
||||
assert resp["NotebookInstanceStatus"] == "Stopped"
|
||||
|
||||
sagemaker.delete_notebook_instance(NotebookInstanceName=NAME_PARAM)
|
||||
sagemaker_client.delete_notebook_instance(NotebookInstanceName=FAKE_NAME_PARAM)
|
||||
|
||||
with pytest.raises(ClientError) as ex:
|
||||
sagemaker.describe_notebook_instance(NotebookInstanceName=NAME_PARAM)
|
||||
sagemaker_client.describe_notebook_instance(
|
||||
NotebookInstanceName=FAKE_NAME_PARAM
|
||||
)
|
||||
assert ex.value.response["Error"]["Message"] == "RecordNotFound"
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_describe_nonexistent_model():
|
||||
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
def test_describe_nonexistent_model(sagemaker_client):
|
||||
with pytest.raises(ClientError) as e:
|
||||
sagemaker.describe_model(ModelName="Nonexistent")
|
||||
sagemaker_client.describe_model(ModelName="Nonexistent")
|
||||
assert e.value.response["Error"]["Message"].startswith("Could not find model")
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_notebook_instance_lifecycle_config():
|
||||
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
def test_notebook_instance_lifecycle_config(sagemaker_client):
|
||||
name = "MyLifeCycleConfig"
|
||||
on_create = [{"Content": "Create Script Line 1"}]
|
||||
on_start = [{"Content": "Start Script Line 1"}]
|
||||
resp = sagemaker.create_notebook_instance_lifecycle_config(
|
||||
resp = sagemaker_client.create_notebook_instance_lifecycle_config(
|
||||
NotebookInstanceLifecycleConfigName=name, OnCreate=on_create, OnStart=on_start
|
||||
)
|
||||
assert resp["NotebookInstanceLifecycleConfigArn"].startswith("arn:aws:sagemaker")
|
||||
assert resp["NotebookInstanceLifecycleConfigArn"].endswith(name)
|
||||
expected_arn = _get_notebook_instance_lifecycle_arn(name)
|
||||
assert resp["NotebookInstanceLifecycleConfigArn"] == expected_arn
|
||||
|
||||
with pytest.raises(ClientError) as e:
|
||||
resp = sagemaker.create_notebook_instance_lifecycle_config(
|
||||
sagemaker_client.create_notebook_instance_lifecycle_config(
|
||||
NotebookInstanceLifecycleConfigName=name,
|
||||
OnCreate=on_create,
|
||||
OnStart=on_start,
|
||||
@ -221,23 +221,22 @@ def test_notebook_instance_lifecycle_config():
|
||||
"Notebook Instance Lifecycle Config already exists.)"
|
||||
)
|
||||
|
||||
resp = sagemaker.describe_notebook_instance_lifecycle_config(
|
||||
resp = sagemaker_client.describe_notebook_instance_lifecycle_config(
|
||||
NotebookInstanceLifecycleConfigName=name
|
||||
)
|
||||
assert resp["NotebookInstanceLifecycleConfigName"] == name
|
||||
assert resp["NotebookInstanceLifecycleConfigArn"].startswith("arn:aws:sagemaker")
|
||||
assert resp["NotebookInstanceLifecycleConfigArn"].endswith(name)
|
||||
assert resp["NotebookInstanceLifecycleConfigArn"] == expected_arn
|
||||
assert resp["OnStart"] == on_start
|
||||
assert resp["OnCreate"] == on_create
|
||||
assert isinstance(resp["LastModifiedTime"], datetime.datetime)
|
||||
assert isinstance(resp["CreationTime"], datetime.datetime)
|
||||
|
||||
sagemaker.delete_notebook_instance_lifecycle_config(
|
||||
sagemaker_client.delete_notebook_instance_lifecycle_config(
|
||||
NotebookInstanceLifecycleConfigName=name
|
||||
)
|
||||
|
||||
with pytest.raises(ClientError) as e:
|
||||
sagemaker.describe_notebook_instance_lifecycle_config(
|
||||
sagemaker_client.describe_notebook_instance_lifecycle_config(
|
||||
NotebookInstanceLifecycleConfigName=name
|
||||
)
|
||||
assert e.value.response["Error"]["Message"].endswith(
|
||||
@ -245,9 +244,53 @@ def test_notebook_instance_lifecycle_config():
|
||||
)
|
||||
|
||||
with pytest.raises(ClientError) as e:
|
||||
sagemaker.delete_notebook_instance_lifecycle_config(
|
||||
sagemaker_client.delete_notebook_instance_lifecycle_config(
|
||||
NotebookInstanceLifecycleConfigName=name
|
||||
)
|
||||
assert e.value.response["Error"]["Message"].endswith(
|
||||
"Notebook Instance Lifecycle Config does not exist.)"
|
||||
)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_add_tags_to_notebook(sagemaker_client):
|
||||
args = {
|
||||
"NotebookInstanceName": FAKE_NAME_PARAM,
|
||||
"InstanceType": FAKE_INSTANCE_TYPE_PARAM,
|
||||
"RoleArn": FAKE_ROLE_ARN,
|
||||
}
|
||||
resp = sagemaker_client.create_notebook_instance(**args)
|
||||
resource_arn = resp["NotebookInstanceArn"]
|
||||
|
||||
tags = [
|
||||
{"Key": "myKey", "Value": "myValue"},
|
||||
]
|
||||
response = sagemaker_client.add_tags(ResourceArn=resource_arn, Tags=tags)
|
||||
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
response = sagemaker_client.list_tags(ResourceArn=resource_arn)
|
||||
assert response["Tags"] == tags
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_delete_tags_from_notebook(sagemaker_client):
|
||||
args = {
|
||||
"NotebookInstanceName": FAKE_NAME_PARAM,
|
||||
"InstanceType": FAKE_INSTANCE_TYPE_PARAM,
|
||||
"RoleArn": FAKE_ROLE_ARN,
|
||||
}
|
||||
resp = sagemaker_client.create_notebook_instance(**args)
|
||||
resource_arn = resp["NotebookInstanceArn"]
|
||||
|
||||
tags = [
|
||||
{"Key": "myKey", "Value": "myValue"},
|
||||
]
|
||||
response = sagemaker_client.add_tags(ResourceArn=resource_arn, Tags=tags)
|
||||
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
tag_keys = [tag["Key"] for tag in tags]
|
||||
response = sagemaker_client.delete_tags(ResourceArn=resource_arn, TagKeys=tag_keys)
|
||||
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
response = sagemaker_client.list_tags(ResourceArn=resource_arn)
|
||||
assert response["Tags"] == []
|
||||
|
@ -6,10 +6,17 @@ import pytest
|
||||
from moto import mock_sagemaker
|
||||
from moto.sts.models import ACCOUNT_ID
|
||||
|
||||
FAKE_ROLE_ARN = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)
|
||||
FAKE_ROLE_ARN = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole"
|
||||
FAKE_PROCESSING_JOB_NAME = "MyProcessingJob"
|
||||
FAKE_CONTAINER = "382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1"
|
||||
TEST_REGION_NAME = "us-east-1"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sagemaker_client():
|
||||
return boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
|
||||
class MyProcessingJobModel(object):
|
||||
def __init__(
|
||||
self,
|
||||
@ -81,9 +88,7 @@ class MyProcessingJobModel(object):
|
||||
"MaxRuntimeInSeconds": 3600,
|
||||
}
|
||||
|
||||
def save(self):
|
||||
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
def save(self, sagemaker_client):
|
||||
params = {
|
||||
"AppSpecification": self.app_specification,
|
||||
"NetworkConfig": self.network_config,
|
||||
@ -95,13 +100,265 @@ class MyProcessingJobModel(object):
|
||||
"StoppingCondition": self.stopping_condition,
|
||||
}
|
||||
|
||||
return sagemaker.create_processing_job(**params)
|
||||
return sagemaker_client.create_processing_job(**params)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_processing_job():
|
||||
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
def test_create_processing_job(sagemaker_client):
|
||||
bucket = "my-bucket"
|
||||
prefix = "my-prefix"
|
||||
app_specification = {
|
||||
"ImageUri": FAKE_CONTAINER,
|
||||
"ContainerEntrypoint": ["python3", "app.py"],
|
||||
}
|
||||
processing_resources = {
|
||||
"ClusterConfig": {
|
||||
"InstanceCount": 2,
|
||||
"InstanceType": "ml.m5.xlarge",
|
||||
"VolumeSizeInGB": 20,
|
||||
},
|
||||
}
|
||||
stopping_condition = {"MaxRuntimeInSeconds": 60 * 60}
|
||||
|
||||
job = MyProcessingJobModel(
|
||||
processing_job_name=FAKE_PROCESSING_JOB_NAME,
|
||||
role_arn=FAKE_ROLE_ARN,
|
||||
container=FAKE_CONTAINER,
|
||||
bucket=bucket,
|
||||
prefix=prefix,
|
||||
app_specification=app_specification,
|
||||
processing_resources=processing_resources,
|
||||
stopping_condition=stopping_condition,
|
||||
)
|
||||
resp = job.save(sagemaker_client)
|
||||
resp["ProcessingJobArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:processing-job/{}$".format(FAKE_PROCESSING_JOB_NAME)
|
||||
)
|
||||
|
||||
resp = sagemaker_client.describe_processing_job(
|
||||
ProcessingJobName=FAKE_PROCESSING_JOB_NAME
|
||||
)
|
||||
resp["ProcessingJobName"].should.equal(FAKE_PROCESSING_JOB_NAME)
|
||||
resp["ProcessingJobArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:processing-job/{}$".format(FAKE_PROCESSING_JOB_NAME)
|
||||
)
|
||||
assert "python3" in resp["AppSpecification"]["ContainerEntrypoint"]
|
||||
assert "app.py" in resp["AppSpecification"]["ContainerEntrypoint"]
|
||||
assert resp["RoleArn"] == FAKE_ROLE_ARN
|
||||
assert resp["ProcessingJobStatus"] == "Completed"
|
||||
assert isinstance(resp["CreationTime"], datetime.datetime)
|
||||
assert isinstance(resp["LastModifiedTime"], datetime.datetime)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_processing_jobs(sagemaker_client):
|
||||
test_processing_job = MyProcessingJobModel(
|
||||
processing_job_name=FAKE_PROCESSING_JOB_NAME, role_arn=FAKE_ROLE_ARN
|
||||
)
|
||||
test_processing_job.save(sagemaker_client)
|
||||
processing_jobs = sagemaker_client.list_processing_jobs()
|
||||
assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(1)
|
||||
assert processing_jobs["ProcessingJobSummaries"][0][
|
||||
"ProcessingJobName"
|
||||
].should.equal(FAKE_PROCESSING_JOB_NAME)
|
||||
|
||||
assert processing_jobs["ProcessingJobSummaries"][0][
|
||||
"ProcessingJobArn"
|
||||
].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:processing-job/{}$".format(FAKE_PROCESSING_JOB_NAME)
|
||||
)
|
||||
assert processing_jobs.get("NextToken") is None
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_processing_jobs_multiple(sagemaker_client):
|
||||
name_job_1 = "blah"
|
||||
arn_job_1 = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar"
|
||||
test_processing_job_1 = MyProcessingJobModel(
|
||||
processing_job_name=name_job_1, role_arn=arn_job_1
|
||||
)
|
||||
test_processing_job_1.save(sagemaker_client)
|
||||
|
||||
name_job_2 = "blah2"
|
||||
arn_job_2 = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar2"
|
||||
test_processing_job_2 = MyProcessingJobModel(
|
||||
processing_job_name=name_job_2, role_arn=arn_job_2
|
||||
)
|
||||
test_processing_job_2.save(sagemaker_client)
|
||||
processing_jobs_limit = sagemaker_client.list_processing_jobs(MaxResults=1)
|
||||
assert len(processing_jobs_limit["ProcessingJobSummaries"]).should.equal(1)
|
||||
|
||||
processing_jobs = sagemaker_client.list_processing_jobs()
|
||||
assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(2)
|
||||
assert processing_jobs.get("NextToken").should.be.none
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_processing_jobs_none(sagemaker_client):
|
||||
processing_jobs = sagemaker_client.list_processing_jobs()
|
||||
assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(0)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_processing_jobs_should_validate_input(sagemaker_client):
|
||||
junk_status_equals = "blah"
|
||||
with pytest.raises(ClientError) as ex:
|
||||
sagemaker_client.list_processing_jobs(StatusEquals=junk_status_equals)
|
||||
expected_error = f"1 validation errors detected: Value '{junk_status_equals}' at 'statusEquals' failed to satisfy constraint: Member must satisfy enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', 'Failed']"
|
||||
assert ex.value.response["Error"]["Code"] == "ValidationException"
|
||||
assert ex.value.response["Error"]["Message"] == expected_error
|
||||
|
||||
junk_next_token = "asdf"
|
||||
with pytest.raises(ClientError) as ex:
|
||||
sagemaker_client.list_processing_jobs(NextToken=junk_next_token)
|
||||
assert ex.value.response["Error"]["Code"] == "ValidationException"
|
||||
assert (
|
||||
ex.value.response["Error"]["Message"]
|
||||
== 'Invalid pagination token because "{0}".'
|
||||
)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_processing_jobs_with_name_filters(sagemaker_client):
|
||||
for i in range(5):
|
||||
name = "xgboost-{}".format(i)
|
||||
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i)
|
||||
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save(
|
||||
sagemaker_client
|
||||
)
|
||||
|
||||
for i in range(5):
|
||||
name = "vgg-{}".format(i)
|
||||
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{}".format(i)
|
||||
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save(
|
||||
sagemaker_client
|
||||
)
|
||||
|
||||
xgboost_processing_jobs = sagemaker_client.list_processing_jobs(
|
||||
NameContains="xgboost"
|
||||
)
|
||||
assert len(xgboost_processing_jobs["ProcessingJobSummaries"]).should.equal(5)
|
||||
|
||||
processing_jobs_with_2 = sagemaker_client.list_processing_jobs(NameContains="2")
|
||||
assert len(processing_jobs_with_2["ProcessingJobSummaries"]).should.equal(2)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_processing_jobs_paginated(sagemaker_client):
|
||||
for i in range(5):
|
||||
name = "xgboost-{}".format(i)
|
||||
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i)
|
||||
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save(
|
||||
sagemaker_client
|
||||
)
|
||||
|
||||
xgboost_processing_job_1 = sagemaker_client.list_processing_jobs(
|
||||
NameContains="xgboost", MaxResults=1
|
||||
)
|
||||
assert len(xgboost_processing_job_1["ProcessingJobSummaries"]).should.equal(1)
|
||||
assert xgboost_processing_job_1["ProcessingJobSummaries"][0][
|
||||
"ProcessingJobName"
|
||||
].should.equal("xgboost-0")
|
||||
assert xgboost_processing_job_1.get("NextToken").should_not.be.none
|
||||
|
||||
xgboost_processing_job_next = sagemaker_client.list_processing_jobs(
|
||||
NameContains="xgboost",
|
||||
MaxResults=1,
|
||||
NextToken=xgboost_processing_job_1.get("NextToken"),
|
||||
)
|
||||
assert len(xgboost_processing_job_next["ProcessingJobSummaries"]).should.equal(1)
|
||||
assert xgboost_processing_job_next["ProcessingJobSummaries"][0][
|
||||
"ProcessingJobName"
|
||||
].should.equal("xgboost-1")
|
||||
assert xgboost_processing_job_next.get("NextToken").should_not.be.none
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_processing_jobs_paginated_with_target_in_middle(sagemaker_client):
|
||||
for i in range(5):
|
||||
name = "xgboost-{}".format(i)
|
||||
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i)
|
||||
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save(
|
||||
sagemaker_client
|
||||
)
|
||||
|
||||
for i in range(5):
|
||||
name = "vgg-{}".format(i)
|
||||
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{}".format(i)
|
||||
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save(
|
||||
sagemaker_client
|
||||
)
|
||||
|
||||
vgg_processing_job_1 = sagemaker_client.list_processing_jobs(
|
||||
NameContains="vgg", MaxResults=1
|
||||
)
|
||||
assert len(vgg_processing_job_1["ProcessingJobSummaries"]).should.equal(0)
|
||||
assert vgg_processing_job_1.get("NextToken").should_not.be.none
|
||||
|
||||
vgg_processing_job_6 = sagemaker_client.list_processing_jobs(
|
||||
NameContains="vgg", MaxResults=6
|
||||
)
|
||||
|
||||
assert len(vgg_processing_job_6["ProcessingJobSummaries"]).should.equal(1)
|
||||
assert vgg_processing_job_6["ProcessingJobSummaries"][0][
|
||||
"ProcessingJobName"
|
||||
].should.equal("vgg-0")
|
||||
assert vgg_processing_job_6.get("NextToken").should_not.be.none
|
||||
|
||||
vgg_processing_job_10 = sagemaker_client.list_processing_jobs(
|
||||
NameContains="vgg", MaxResults=10
|
||||
)
|
||||
|
||||
assert len(vgg_processing_job_10["ProcessingJobSummaries"]).should.equal(5)
|
||||
assert vgg_processing_job_10["ProcessingJobSummaries"][-1][
|
||||
"ProcessingJobName"
|
||||
].should.equal("vgg-4")
|
||||
assert vgg_processing_job_10.get("NextToken").should.be.none
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_processing_jobs_paginated_with_fragmented_targets(sagemaker_client):
|
||||
for i in range(5):
|
||||
name = "xgboost-{}".format(i)
|
||||
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i)
|
||||
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save(
|
||||
sagemaker_client
|
||||
)
|
||||
|
||||
for i in range(5):
|
||||
name = "vgg-{}".format(i)
|
||||
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{}".format(i)
|
||||
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save(
|
||||
sagemaker_client
|
||||
)
|
||||
|
||||
processing_jobs_with_2 = sagemaker_client.list_processing_jobs(
|
||||
NameContains="2", MaxResults=8
|
||||
)
|
||||
assert len(processing_jobs_with_2["ProcessingJobSummaries"]).should.equal(2)
|
||||
assert processing_jobs_with_2.get("NextToken").should_not.be.none
|
||||
|
||||
processing_jobs_with_2_next = sagemaker_client.list_processing_jobs(
|
||||
NameContains="2",
|
||||
MaxResults=1,
|
||||
NextToken=processing_jobs_with_2.get("NextToken"),
|
||||
)
|
||||
assert len(processing_jobs_with_2_next["ProcessingJobSummaries"]).should.equal(0)
|
||||
assert processing_jobs_with_2_next.get("NextToken").should_not.be.none
|
||||
|
||||
processing_jobs_with_2_next_next = sagemaker_client.list_processing_jobs(
|
||||
NameContains="2",
|
||||
MaxResults=1,
|
||||
NextToken=processing_jobs_with_2_next.get("NextToken"),
|
||||
)
|
||||
assert len(processing_jobs_with_2_next_next["ProcessingJobSummaries"]).should.equal(
|
||||
0
|
||||
)
|
||||
assert processing_jobs_with_2_next_next.get("NextToken").should.be.none
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_add_and_delete_tags_in_training_job(sagemaker_client):
|
||||
processing_job_name = "MyProcessingJob"
|
||||
role_arn = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)
|
||||
container = "382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1"
|
||||
@ -130,205 +387,21 @@ def test_create_processing_job():
|
||||
processing_resources=processing_resources,
|
||||
stopping_condition=stopping_condition,
|
||||
)
|
||||
resp = job.save()
|
||||
resp["ProcessingJobArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:processing-job/{}$".format(processing_job_name)
|
||||
)
|
||||
resp = job.save(sagemaker_client)
|
||||
resource_arn = resp["ProcessingJobArn"]
|
||||
|
||||
resp = sagemaker.describe_processing_job(ProcessingJobName=processing_job_name)
|
||||
resp["ProcessingJobName"].should.equal(processing_job_name)
|
||||
resp["ProcessingJobArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:processing-job/{}$".format(processing_job_name)
|
||||
)
|
||||
assert "python3" in resp["AppSpecification"]["ContainerEntrypoint"]
|
||||
assert "app.py" in resp["AppSpecification"]["ContainerEntrypoint"]
|
||||
assert resp["RoleArn"] == role_arn
|
||||
assert resp["ProcessingJobStatus"] == "Completed"
|
||||
assert isinstance(resp["CreationTime"], datetime.datetime)
|
||||
assert isinstance(resp["LastModifiedTime"], datetime.datetime)
|
||||
tags = [
|
||||
{"Key": "myKey", "Value": "myValue"},
|
||||
]
|
||||
response = sagemaker_client.add_tags(ResourceArn=resource_arn, Tags=tags)
|
||||
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
response = sagemaker_client.list_tags(ResourceArn=resource_arn)
|
||||
assert response["Tags"] == tags
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_processing_jobs():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
name = "blah"
|
||||
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar"
|
||||
test_processing_job = MyProcessingJobModel(processing_job_name=name, role_arn=arn)
|
||||
test_processing_job.save()
|
||||
processing_jobs = client.list_processing_jobs()
|
||||
assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(1)
|
||||
assert processing_jobs["ProcessingJobSummaries"][0][
|
||||
"ProcessingJobName"
|
||||
].should.equal(name)
|
||||
tag_keys = [tag["Key"] for tag in tags]
|
||||
response = sagemaker_client.delete_tags(ResourceArn=resource_arn, TagKeys=tag_keys)
|
||||
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
assert processing_jobs["ProcessingJobSummaries"][0][
|
||||
"ProcessingJobArn"
|
||||
].should.match(r"^arn:aws:sagemaker:.*:.*:processing-job/{}$".format(name))
|
||||
assert processing_jobs.get("NextToken") is None
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_processing_jobs_multiple():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
name_job_1 = "blah"
|
||||
arn_job_1 = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar"
|
||||
test_processing_job_1 = MyProcessingJobModel(
|
||||
processing_job_name=name_job_1, role_arn=arn_job_1
|
||||
)
|
||||
test_processing_job_1.save()
|
||||
|
||||
name_job_2 = "blah2"
|
||||
arn_job_2 = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar2"
|
||||
test_processing_job_2 = MyProcessingJobModel(
|
||||
processing_job_name=name_job_2, role_arn=arn_job_2
|
||||
)
|
||||
test_processing_job_2.save()
|
||||
processing_jobs_limit = client.list_processing_jobs(MaxResults=1)
|
||||
assert len(processing_jobs_limit["ProcessingJobSummaries"]).should.equal(1)
|
||||
|
||||
processing_jobs = client.list_processing_jobs()
|
||||
assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(2)
|
||||
assert processing_jobs.get("NextToken").should.be.none
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_processing_jobs_none():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
processing_jobs = client.list_processing_jobs()
|
||||
assert len(processing_jobs["ProcessingJobSummaries"]).should.equal(0)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_processing_jobs_should_validate_input():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
junk_status_equals = "blah"
|
||||
with pytest.raises(ClientError) as ex:
|
||||
client.list_processing_jobs(StatusEquals=junk_status_equals)
|
||||
expected_error = f"1 validation errors detected: Value '{junk_status_equals}' at 'statusEquals' failed to satisfy constraint: Member must satisfy enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', 'Failed']"
|
||||
assert ex.value.response["Error"]["Code"] == "ValidationException"
|
||||
assert ex.value.response["Error"]["Message"] == expected_error
|
||||
|
||||
junk_next_token = "asdf"
|
||||
with pytest.raises(ClientError) as ex:
|
||||
client.list_processing_jobs(NextToken=junk_next_token)
|
||||
assert ex.value.response["Error"]["Code"] == "ValidationException"
|
||||
assert (
|
||||
ex.value.response["Error"]["Message"]
|
||||
== 'Invalid pagination token because "{0}".'
|
||||
)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_processing_jobs_with_name_filters():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
for i in range(5):
|
||||
name = "xgboost-{}".format(i)
|
||||
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i)
|
||||
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save()
|
||||
for i in range(5):
|
||||
name = "vgg-{}".format(i)
|
||||
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{}".format(i)
|
||||
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save()
|
||||
xgboost_processing_jobs = client.list_processing_jobs(NameContains="xgboost")
|
||||
assert len(xgboost_processing_jobs["ProcessingJobSummaries"]).should.equal(5)
|
||||
|
||||
processing_jobs_with_2 = client.list_processing_jobs(NameContains="2")
|
||||
assert len(processing_jobs_with_2["ProcessingJobSummaries"]).should.equal(2)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_processing_jobs_paginated():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
for i in range(5):
|
||||
name = "xgboost-{}".format(i)
|
||||
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i)
|
||||
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save()
|
||||
xgboost_processing_job_1 = client.list_processing_jobs(
|
||||
NameContains="xgboost", MaxResults=1
|
||||
)
|
||||
assert len(xgboost_processing_job_1["ProcessingJobSummaries"]).should.equal(1)
|
||||
assert xgboost_processing_job_1["ProcessingJobSummaries"][0][
|
||||
"ProcessingJobName"
|
||||
].should.equal("xgboost-0")
|
||||
assert xgboost_processing_job_1.get("NextToken").should_not.be.none
|
||||
|
||||
xgboost_processing_job_next = client.list_processing_jobs(
|
||||
NameContains="xgboost",
|
||||
MaxResults=1,
|
||||
NextToken=xgboost_processing_job_1.get("NextToken"),
|
||||
)
|
||||
assert len(xgboost_processing_job_next["ProcessingJobSummaries"]).should.equal(1)
|
||||
assert xgboost_processing_job_next["ProcessingJobSummaries"][0][
|
||||
"ProcessingJobName"
|
||||
].should.equal("xgboost-1")
|
||||
assert xgboost_processing_job_next.get("NextToken").should_not.be.none
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_processing_jobs_paginated_with_target_in_middle():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
for i in range(5):
|
||||
name = "xgboost-{}".format(i)
|
||||
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i)
|
||||
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save()
|
||||
for i in range(5):
|
||||
name = "vgg-{}".format(i)
|
||||
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{}".format(i)
|
||||
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save()
|
||||
|
||||
vgg_processing_job_1 = client.list_processing_jobs(NameContains="vgg", MaxResults=1)
|
||||
assert len(vgg_processing_job_1["ProcessingJobSummaries"]).should.equal(0)
|
||||
assert vgg_processing_job_1.get("NextToken").should_not.be.none
|
||||
|
||||
vgg_processing_job_6 = client.list_processing_jobs(NameContains="vgg", MaxResults=6)
|
||||
|
||||
assert len(vgg_processing_job_6["ProcessingJobSummaries"]).should.equal(1)
|
||||
assert vgg_processing_job_6["ProcessingJobSummaries"][0][
|
||||
"ProcessingJobName"
|
||||
].should.equal("vgg-0")
|
||||
assert vgg_processing_job_6.get("NextToken").should_not.be.none
|
||||
|
||||
vgg_processing_job_10 = client.list_processing_jobs(
|
||||
NameContains="vgg", MaxResults=10
|
||||
)
|
||||
|
||||
assert len(vgg_processing_job_10["ProcessingJobSummaries"]).should.equal(5)
|
||||
assert vgg_processing_job_10["ProcessingJobSummaries"][-1][
|
||||
"ProcessingJobName"
|
||||
].should.equal("vgg-4")
|
||||
assert vgg_processing_job_10.get("NextToken").should.be.none
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_processing_jobs_paginated_with_fragmented_targets():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
for i in range(5):
|
||||
name = "xgboost-{}".format(i)
|
||||
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i)
|
||||
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save()
|
||||
for i in range(5):
|
||||
name = "vgg-{}".format(i)
|
||||
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{}".format(i)
|
||||
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save()
|
||||
|
||||
processing_jobs_with_2 = client.list_processing_jobs(NameContains="2", MaxResults=8)
|
||||
assert len(processing_jobs_with_2["ProcessingJobSummaries"]).should.equal(2)
|
||||
assert processing_jobs_with_2.get("NextToken").should_not.be.none
|
||||
|
||||
processing_jobs_with_2_next = client.list_processing_jobs(
|
||||
NameContains="2",
|
||||
MaxResults=1,
|
||||
NextToken=processing_jobs_with_2.get("NextToken"),
|
||||
)
|
||||
assert len(processing_jobs_with_2_next["ProcessingJobSummaries"]).should.equal(0)
|
||||
assert processing_jobs_with_2_next.get("NextToken").should_not.be.none
|
||||
|
||||
processing_jobs_with_2_next_next = client.list_processing_jobs(
|
||||
NameContains="2",
|
||||
MaxResults=1,
|
||||
NextToken=processing_jobs_with_2_next.get("NextToken"),
|
||||
)
|
||||
assert len(processing_jobs_with_2_next_next["ProcessingJobSummaries"]).should.equal(
|
||||
0
|
||||
)
|
||||
assert processing_jobs_with_2_next_next.get("NextToken").should.be.none
|
||||
response = sagemaker_client.list_tags(ResourceArn=resource_arn)
|
||||
assert response["Tags"] == []
|
||||
|
@ -8,39 +8,35 @@ from moto import mock_sagemaker
|
||||
TEST_REGION_NAME = "us-east-1"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sagemaker_client():
|
||||
return boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_search():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
experiment_name = "some-experiment-name"
|
||||
|
||||
resp = client.create_experiment(ExperimentName=experiment_name)
|
||||
|
||||
trial_name = "some-trial-name"
|
||||
|
||||
resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name)
|
||||
|
||||
trial_component_name = "some-trial-component-name"
|
||||
another_trial_component_name = "another-trial-component-name"
|
||||
|
||||
resp = client.create_trial_component(TrialComponentName=trial_component_name)
|
||||
resp = client.create_trial_component(
|
||||
TrialComponentName=another_trial_component_name
|
||||
def test_search(sagemaker_client):
|
||||
experiment_name = "experiment_name"
|
||||
trial_component_name = "trial_component_name"
|
||||
trial_name = "trial_name"
|
||||
_set_up_trial_component(
|
||||
sagemaker_client,
|
||||
experiment_name=experiment_name,
|
||||
trial_component_name=trial_component_name,
|
||||
trial_name=trial_name,
|
||||
)
|
||||
|
||||
resp = client.search(Resource="ExperimentTrialComponent")
|
||||
|
||||
resp = sagemaker_client.search(Resource="ExperimentTrialComponent")
|
||||
assert len(resp["Results"]) == 2
|
||||
|
||||
resp = client.describe_trial_component(TrialComponentName=trial_component_name)
|
||||
|
||||
resp = sagemaker_client.describe_trial_component(
|
||||
TrialComponentName=trial_component_name
|
||||
)
|
||||
trial_component_arn = resp["TrialComponentArn"]
|
||||
|
||||
tags = [{"Key": "key-name", "Value": "some-value"}]
|
||||
sagemaker_client.add_tags(ResourceArn=trial_component_arn, Tags=tags)
|
||||
|
||||
client.add_tags(ResourceArn=trial_component_arn, Tags=tags)
|
||||
|
||||
resp = client.search(
|
||||
resp = sagemaker_client.search(
|
||||
Resource="ExperimentTrialComponent",
|
||||
SearchExpression={
|
||||
"Filters": [
|
||||
@ -55,49 +51,38 @@ def test_search():
|
||||
== trial_component_name
|
||||
)
|
||||
|
||||
resp = client.search(Resource="Experiment")
|
||||
resp = sagemaker_client.search(Resource="Experiment")
|
||||
assert len(resp["Results"]) == 1
|
||||
assert resp["Results"][0]["Experiment"]["ExperimentName"] == experiment_name
|
||||
|
||||
resp = client.search(Resource="ExperimentTrial")
|
||||
resp = sagemaker_client.search(Resource="ExperimentTrial")
|
||||
assert len(resp["Results"]) == 1
|
||||
assert resp["Results"][0]["Trial"]["TrialName"] == trial_name
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_search_trial_component_with_experiment_name():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
experiment_name = "some-experiment-name"
|
||||
|
||||
resp = client.create_experiment(ExperimentName=experiment_name)
|
||||
|
||||
trial_name = "some-trial-name"
|
||||
|
||||
resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name)
|
||||
|
||||
trial_component_name = "some-trial-component-name"
|
||||
another_trial_component_name = "another-trial-component-name"
|
||||
|
||||
resp = client.create_trial_component(TrialComponentName=trial_component_name)
|
||||
resp = client.create_trial_component(
|
||||
TrialComponentName=another_trial_component_name
|
||||
def test_search_trial_component_with_experiment_name(sagemaker_client):
|
||||
experiment_name = "experiment_name"
|
||||
trial_component_name = "trial_component_name"
|
||||
_set_up_trial_component(
|
||||
sagemaker_client,
|
||||
experiment_name=experiment_name,
|
||||
trial_component_name=trial_component_name,
|
||||
)
|
||||
|
||||
resp = client.search(Resource="ExperimentTrialComponent")
|
||||
|
||||
resp = sagemaker_client.search(Resource="ExperimentTrialComponent")
|
||||
assert len(resp["Results"]) == 2
|
||||
|
||||
resp = client.describe_trial_component(TrialComponentName=trial_component_name)
|
||||
|
||||
resp = sagemaker_client.describe_trial_component(
|
||||
TrialComponentName=trial_component_name
|
||||
)
|
||||
trial_component_arn = resp["TrialComponentArn"]
|
||||
|
||||
tags = [{"Key": "key-name", "Value": "some-value"}]
|
||||
|
||||
client.add_tags(ResourceArn=trial_component_arn, Tags=tags)
|
||||
sagemaker_client.add_tags(ResourceArn=trial_component_arn, Tags=tags)
|
||||
|
||||
with pytest.raises(ClientError) as ex:
|
||||
client.search(
|
||||
sagemaker_client.search(
|
||||
Resource="ExperimentTrialComponent",
|
||||
SearchExpression={
|
||||
"Filters": [
|
||||
@ -115,3 +100,18 @@ def test_search_trial_component_with_experiment_name():
|
||||
"Unknown property name: ExperimentName"
|
||||
)
|
||||
ex.value.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400)
|
||||
|
||||
|
||||
def _set_up_trial_component(
|
||||
sagemaker_client,
|
||||
experiment_name="some-experiment-name",
|
||||
trial_component_name="some-trial-component-name",
|
||||
trial_name="some-trial-name",
|
||||
another_trial_component_name="another-trial-component-name",
|
||||
):
|
||||
sagemaker_client.create_experiment(ExperimentName=experiment_name)
|
||||
sagemaker_client.create_trial(ExperimentName=experiment_name, TrialName=trial_name)
|
||||
sagemaker_client.create_trial_component(TrialComponentName=trial_component_name)
|
||||
sagemaker_client.create_trial_component(
|
||||
TrialComponentName=another_trial_component_name
|
||||
)
|
||||
|
@ -397,3 +397,47 @@ def test_list_training_jobs_paginated_with_fragmented_targets():
|
||||
)
|
||||
assert len(training_jobs_with_2_next_next["TrainingJobSummaries"]).should.equal(0)
|
||||
assert training_jobs_with_2_next_next.get("NextToken").should.be.none
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_add_tags_to_training_job():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
name = "blah"
|
||||
resource_arn = f"arn:aws:sagemaker:us-east-1:000000000000:training-job/{name}"
|
||||
test_training_job = MyTrainingJobModel(
|
||||
training_job_name=name, role_arn=resource_arn
|
||||
)
|
||||
test_training_job.save()
|
||||
|
||||
tags = [
|
||||
{"Key": "myKey", "Value": "myValue"},
|
||||
]
|
||||
response = client.add_tags(ResourceArn=resource_arn, Tags=tags)
|
||||
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
response = client.list_tags(ResourceArn=resource_arn)
|
||||
assert response["Tags"] == tags
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_delete_tags_from_training_job():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
name = "blah"
|
||||
resource_arn = f"arn:aws:sagemaker:us-east-1:000000000000:training-job/{name}"
|
||||
test_training_job = MyTrainingJobModel(
|
||||
training_job_name=name, role_arn=resource_arn
|
||||
)
|
||||
test_training_job.save()
|
||||
|
||||
tags = [
|
||||
{"Key": "myKey", "Value": "myValue"},
|
||||
]
|
||||
response = client.add_tags(ResourceArn=resource_arn, Tags=tags)
|
||||
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
tag_keys = [tag["Key"] for tag in tags]
|
||||
response = client.delete_tags(ResourceArn=resource_arn, TagKeys=tag_keys)
|
||||
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
response = client.list_tags(ResourceArn=resource_arn)
|
||||
assert response["Tags"] == []
|
||||
|
@ -1,3 +1,5 @@
|
||||
import uuid
|
||||
|
||||
import boto3
|
||||
|
||||
from moto import mock_sagemaker
|
||||
@ -158,3 +160,33 @@ def test_delete_tags_to_trial():
|
||||
resp = client.list_tags(ResourceArn=arn)
|
||||
|
||||
assert resp["Tags"] == []
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_trial_tags():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
experiment_name = "some-experiment-name"
|
||||
client.create_experiment(ExperimentName=experiment_name)
|
||||
|
||||
trial_name = "some-trial-name"
|
||||
client.create_trial(ExperimentName=experiment_name, TrialName=trial_name)
|
||||
resp = client.describe_trial(TrialName=trial_name)
|
||||
resource_arn = resp["TrialArn"]
|
||||
|
||||
tags = []
|
||||
for _ in range(80):
|
||||
tags.append({"Key": str(uuid.uuid4()), "Value": "myValue"})
|
||||
|
||||
response = client.add_tags(ResourceArn=resource_arn, Tags=tags)
|
||||
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
response = client.list_tags(ResourceArn=resource_arn)
|
||||
assert len(response["Tags"]) == 50
|
||||
assert response["Tags"] == tags[:50]
|
||||
|
||||
response = client.list_tags(
|
||||
ResourceArn=resource_arn, NextToken=response["NextToken"]
|
||||
)
|
||||
assert len(response["Tags"]) == 30
|
||||
assert response["Tags"] == tags[50:]
|
||||
|
@ -1,3 +1,5 @@
|
||||
import uuid
|
||||
|
||||
import boto3
|
||||
import pytest
|
||||
|
||||
@ -123,6 +125,33 @@ def test_delete_tags_to_trial_component():
|
||||
assert resp["Tags"] == []
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_trial_component_tags():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
trial_component_name = "some-trial-component-name"
|
||||
client.create_trial_component(TrialComponentName=trial_component_name)
|
||||
resp = client.describe_trial_component(TrialComponentName=trial_component_name)
|
||||
resource_arn = resp["TrialComponentArn"]
|
||||
|
||||
tags = []
|
||||
for _ in range(80):
|
||||
tags.append({"Key": str(uuid.uuid4()), "Value": "myValue"})
|
||||
|
||||
response = client.add_tags(ResourceArn=resource_arn, Tags=tags)
|
||||
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
response = client.list_tags(ResourceArn=resource_arn)
|
||||
assert len(response["Tags"]) == 50
|
||||
assert response["Tags"] == tags[:50]
|
||||
|
||||
response = client.list_tags(
|
||||
ResourceArn=resource_arn, NextToken=response["NextToken"]
|
||||
)
|
||||
assert len(response["Tags"]) == 30
|
||||
assert response["Tags"] == tags[50:]
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_associate_trial_component():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
Loading…
x
Reference in New Issue
Block a user