Improve sagemaker (#4517)

This commit is contained in:
Bogdan Girman 2021-11-01 23:30:07 +01:00 committed by GitHub
parent 6264fb292c
commit 7764a94491
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 236 additions and 13 deletions

View File

@ -35,3 +35,8 @@ class ValidationError(JsonRESTError):
class AWSValidationException(AWSError):
TYPE = "ValidationException"
class ResourceNotFound(JsonRESTError):
def __init__(self, message, **kwargs):
super(ResourceNotFound, self).__init__(__class__.__name__, message, **kwargs)

View File

@ -5,7 +5,12 @@ from datetime import datetime
from moto.core import ACCOUNT_ID, BaseBackend, BaseModel, CloudFormationModel
from moto.core.exceptions import RESTError
from moto.sagemaker import validators
from .exceptions import MissingModel, ValidationError, AWSValidationException
from .exceptions import (
MissingModel,
ValidationError,
AWSValidationException,
ResourceNotFound,
)
class BaseObject(BaseModel):
@ -1102,8 +1107,13 @@ class SageMakerModelBackend(BaseBackend):
if f["Name"] == "ExperimentName":
experiment_name = f["Value"]
if getattr(item, "experiment_name") != experiment_name:
return False
if hasattr(item, "experiment_name"):
if getattr(item, "experiment_name") != experiment_name:
return False
else:
raise ValidationError(
message="Unknown property name: ExperimentName"
)
if f["Name"] == "TrialName":
raise AWSValidationException(
@ -1209,6 +1219,7 @@ class SageMakerModelBackend(BaseBackend):
trial_name=trial_name,
experiment_name=experiment_name,
tags=[],
trial_components=[],
)
self.trials[trial_name] = trial
return trial.response_create
@ -1245,11 +1256,22 @@ class SageMakerModelBackend(BaseBackend):
except RESTError:
return []
def list_trials(self, experiment_name=None):
def list_trials(self, experiment_name=None, trial_component_name=None):
next_index = None
trials_fetched = list(self.trials.values())
def evaluate_filter_expression(trial_data):
if experiment_name is not None:
if trial_data.experiment_name != experiment_name:
return False
if trial_component_name is not None:
if trial_component_name not in trial_data.trial_components:
return False
return True
trial_summaries = [
{
"TrialName": trial_data.trial_name,
@ -1258,7 +1280,7 @@ class SageMakerModelBackend(BaseBackend):
"LastModifiedTime": trial_data.last_modified_time,
}
for trial_data in trials_fetched
if experiment_name is None or trial_data.experiment_name == experiment_name
if evaluate_filter_expression(trial_data)
]
return {
@ -1342,12 +1364,42 @@ class SageMakerModelBackend(BaseBackend):
trial_name = params["TrialName"]
trial_component_name = params["TrialComponentName"]
self.trial_components[trial_component_name].trial_name = trial_name
if trial_name in self.trials.keys():
self.trials[trial_name].trial_components.extend([trial_component_name])
else:
raise ResourceNotFound(
message=f"Trial 'arn:aws:sagemaker:{self.region_name}:{ACCOUNT_ID}:experiment-trial/{trial_name}' does not exist."
)
if trial_component_name in self.trial_components.keys():
self.trial_components[trial_component_name].trial_name = trial_name
return {
"TrialComponentArn": self.trial_components[
trial_component_name
].trial_component_arn,
"TrialArn": self.trials[trial_name].trial_arn,
}
def disassociate_trial_component(self, params):
trial_component_name = params["TrialComponentName"]
trial_name = params["TrialName"]
self.trial_components[trial_component_name].trial_name = None
if trial_component_name in self.trial_components.keys():
self.trial_components[trial_component_name].trial_name = None
if trial_name in self.trials.keys():
self.trials[trial_name].trial_components = list(
filter(
lambda x: x != trial_component_name,
self.trials[trial_name].trial_components,
)
)
return {
"TrialComponentArn": f"arn:aws:sagemaker:{self.region_name}:{ACCOUNT_ID}:experiment-trial-component/{trial_component_name}",
"TrialArn": f"arn:aws:sagemaker:{self.region_name}:{ACCOUNT_ID}:experiment-trial/{trial_name}",
}
def create_notebook_instance(
self,
@ -1804,11 +1856,12 @@ class FakeExperiment(BaseObject):
class FakeTrial(BaseObject):
def __init__(
self, region_name, trial_name, experiment_name, tags,
self, region_name, trial_name, experiment_name, tags, trial_components,
):
self.trial_name = trial_name
self.trial_arn = FakeTrial.arn_formatter(trial_name, region_name)
self.tags = tags
self.trial_components = trial_components
self.experiment_name = experiment_name
self.creation_time = self.last_modified_time = datetime.now().strftime(
"%Y-%m-%d %H:%M:%S"

View File

@ -346,6 +346,7 @@ class SageMakerResponse(BaseResponse):
def list_trials(self):
response = self.sagemaker_backend.list_trials(
experiment_name=self._get_param("ExperimentName"),
trial_component_name=self._get_param("TrialComponentName"),
)
return 200, {}, json.dumps(response)
@ -404,14 +405,14 @@ class SageMakerResponse(BaseResponse):
@amzn_request_id
def associate_trial_component(self, *args, **kwargs):
self.sagemaker_backend.associate_trial_component(self.request_params)
response = {}
response = self.sagemaker_backend.associate_trial_component(self.request_params)
return 200, {}, json.dumps(response)
@amzn_request_id
def disassociate_trial_component(self, *args, **kwargs):
self.sagemaker_backend.disassociate_trial_component(self.request_params)
response = {}
response = self.sagemaker_backend.disassociate_trial_component(
self.request_params
)
return 200, {}, json.dumps(response)
@amzn_request_id

View File

@ -1,4 +1,7 @@
import boto3
import pytest
from botocore.exceptions import ClientError
from moto import mock_sagemaker
@ -59,3 +62,56 @@ def test_search():
resp = client.search(Resource="ExperimentTrial")
assert len(resp["Results"]) == 1
assert resp["Results"][0]["Trial"]["TrialName"] == trial_name
@mock_sagemaker
def test_search_trial_component_with_experiment_name():
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
experiment_name = "some-experiment-name"
resp = client.create_experiment(ExperimentName=experiment_name)
trial_name = "some-trial-name"
resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name)
trial_component_name = "some-trial-component-name"
another_trial_component_name = "another-trial-component-name"
resp = client.create_trial_component(TrialComponentName=trial_component_name)
resp = client.create_trial_component(
TrialComponentName=another_trial_component_name
)
resp = client.search(Resource="ExperimentTrialComponent")
assert len(resp["Results"]) == 2
resp = client.describe_trial_component(TrialComponentName=trial_component_name)
trial_component_arn = resp["TrialComponentArn"]
tags = [{"Key": "key-name", "Value": "some-value"}]
client.add_tags(ResourceArn=trial_component_arn, Tags=tags)
with pytest.raises(ClientError) as ex:
client.search(
Resource="ExperimentTrialComponent",
SearchExpression={
"Filters": [
{
"Name": "ExperimentName",
"Operator": "Equals",
"Value": experiment_name,
}
]
},
)
ex.value.response["Error"]["Code"].should.equal("ValidationException")
ex.value.response["Error"]["Message"].should.equal(
"Unknown property name: ExperimentName"
)
ex.value.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400)

View File

@ -30,6 +30,28 @@ def test_create_trial():
)
@mock_sagemaker
def test_list_trials_by_trial_component_name():
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
experiment_name = "some-experiment-name"
resp = client.create_experiment(ExperimentName=experiment_name)
trial_name = "some-trial-name"
resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name)
trial_component_name = "some-trial-component-name"
resp = client.create_trial_component(TrialComponentName=trial_component_name)
resp = client.list_trials(TrialComponentName=trial_component_name)
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
assert len(resp["TrialSummaries"]) == 0
@mock_sagemaker
def test_delete_trial():
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)

View File

@ -1,4 +1,7 @@
import boto3
import pytest
from botocore.exceptions import ClientError
from moto import mock_sagemaker
from moto.sts.models import ACCOUNT_ID
@ -108,11 +111,19 @@ def test_associate_trial_component():
resp = client.create_trial_component(TrialComponentName=trial_component_name)
client.associate_trial_component(
resp = client.associate_trial_component(
TrialComponentName=trial_component_name, TrialName=trial_name
)
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
assert (
resp["TrialComponentArn"]
== f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial-component/{trial_component_name}"
)
assert (
resp["TrialArn"]
== f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial/{trial_name}"
)
resp = client.list_trial_components(TrialName=trial_name)
@ -120,3 +131,78 @@ def test_associate_trial_component():
assert (
resp["TrialComponentSummaries"][0]["TrialComponentName"] == trial_component_name
)
resp = client.list_trials(TrialComponentName=trial_component_name)
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
assert resp["TrialSummaries"][0]["TrialName"] == trial_name
with pytest.raises(ClientError) as ex:
resp = client.associate_trial_component(
TrialComponentName="does-not-exist", TrialName="does-not-exist"
)
ex.value.response["Error"]["Code"].should.equal("ResourceNotFound")
ex.value.response["Error"]["Message"].should.equal(
f"Trial 'arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial/does-not-exist' does not exist."
)
ex.value.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400)
@mock_sagemaker
def test_disassociate_trial_component():
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
experiment_name = "some-experiment-name"
resp = client.create_experiment(ExperimentName=experiment_name)
trial_name = "some-trial-name"
resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name)
trial_component_name = "some-trial-component-name"
resp = client.create_trial_component(TrialComponentName=trial_component_name)
client.associate_trial_component(
TrialComponentName=trial_component_name, TrialName=trial_name
)
resp = client.disassociate_trial_component(
TrialComponentName=trial_component_name, TrialName=trial_name
)
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
assert (
resp["TrialComponentArn"]
== f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial-component/{trial_component_name}"
)
assert (
resp["TrialArn"]
== f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial/{trial_name}"
)
resp = client.list_trial_components(TrialName=trial_name)
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
assert len(resp["TrialComponentSummaries"]) == 0
resp = client.list_trials(TrialComponentName=trial_component_name)
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
assert len(resp["TrialSummaries"]) == 0
resp = client.disassociate_trial_component(
TrialComponentName="does-not-exist", TrialName="does-not-exist"
)
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
assert (
resp["TrialComponentArn"]
== f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial-component/does-not-exist"
)
assert (
resp["TrialArn"]
== f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial/does-not-exist"
)