Improve sagemaker (#4517)
This commit is contained in:
parent
6264fb292c
commit
7764a94491
@ -35,3 +35,8 @@ class ValidationError(JsonRESTError):
|
||||
|
||||
class AWSValidationException(AWSError):
|
||||
TYPE = "ValidationException"
|
||||
|
||||
|
||||
class ResourceNotFound(JsonRESTError):
|
||||
def __init__(self, message, **kwargs):
|
||||
super(ResourceNotFound, self).__init__(__class__.__name__, message, **kwargs)
|
||||
|
@ -5,7 +5,12 @@ from datetime import datetime
|
||||
from moto.core import ACCOUNT_ID, BaseBackend, BaseModel, CloudFormationModel
|
||||
from moto.core.exceptions import RESTError
|
||||
from moto.sagemaker import validators
|
||||
from .exceptions import MissingModel, ValidationError, AWSValidationException
|
||||
from .exceptions import (
|
||||
MissingModel,
|
||||
ValidationError,
|
||||
AWSValidationException,
|
||||
ResourceNotFound,
|
||||
)
|
||||
|
||||
|
||||
class BaseObject(BaseModel):
|
||||
@ -1102,8 +1107,13 @@ class SageMakerModelBackend(BaseBackend):
|
||||
if f["Name"] == "ExperimentName":
|
||||
experiment_name = f["Value"]
|
||||
|
||||
if hasattr(item, "experiment_name"):
|
||||
if getattr(item, "experiment_name") != experiment_name:
|
||||
return False
|
||||
else:
|
||||
raise ValidationError(
|
||||
message="Unknown property name: ExperimentName"
|
||||
)
|
||||
|
||||
if f["Name"] == "TrialName":
|
||||
raise AWSValidationException(
|
||||
@ -1209,6 +1219,7 @@ class SageMakerModelBackend(BaseBackend):
|
||||
trial_name=trial_name,
|
||||
experiment_name=experiment_name,
|
||||
tags=[],
|
||||
trial_components=[],
|
||||
)
|
||||
self.trials[trial_name] = trial
|
||||
return trial.response_create
|
||||
@ -1245,11 +1256,22 @@ class SageMakerModelBackend(BaseBackend):
|
||||
except RESTError:
|
||||
return []
|
||||
|
||||
def list_trials(self, experiment_name=None):
|
||||
def list_trials(self, experiment_name=None, trial_component_name=None):
|
||||
next_index = None
|
||||
|
||||
trials_fetched = list(self.trials.values())
|
||||
|
||||
def evaluate_filter_expression(trial_data):
|
||||
if experiment_name is not None:
|
||||
if trial_data.experiment_name != experiment_name:
|
||||
return False
|
||||
|
||||
if trial_component_name is not None:
|
||||
if trial_component_name not in trial_data.trial_components:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
trial_summaries = [
|
||||
{
|
||||
"TrialName": trial_data.trial_name,
|
||||
@ -1258,7 +1280,7 @@ class SageMakerModelBackend(BaseBackend):
|
||||
"LastModifiedTime": trial_data.last_modified_time,
|
||||
}
|
||||
for trial_data in trials_fetched
|
||||
if experiment_name is None or trial_data.experiment_name == experiment_name
|
||||
if evaluate_filter_expression(trial_data)
|
||||
]
|
||||
|
||||
return {
|
||||
@ -1342,13 +1364,43 @@ class SageMakerModelBackend(BaseBackend):
|
||||
trial_name = params["TrialName"]
|
||||
trial_component_name = params["TrialComponentName"]
|
||||
|
||||
if trial_name in self.trials.keys():
|
||||
self.trials[trial_name].trial_components.extend([trial_component_name])
|
||||
else:
|
||||
raise ResourceNotFound(
|
||||
message=f"Trial 'arn:aws:sagemaker:{self.region_name}:{ACCOUNT_ID}:experiment-trial/{trial_name}' does not exist."
|
||||
)
|
||||
|
||||
if trial_component_name in self.trial_components.keys():
|
||||
self.trial_components[trial_component_name].trial_name = trial_name
|
||||
|
||||
return {
|
||||
"TrialComponentArn": self.trial_components[
|
||||
trial_component_name
|
||||
].trial_component_arn,
|
||||
"TrialArn": self.trials[trial_name].trial_arn,
|
||||
}
|
||||
|
||||
def disassociate_trial_component(self, params):
|
||||
trial_component_name = params["TrialComponentName"]
|
||||
trial_name = params["TrialName"]
|
||||
|
||||
if trial_component_name in self.trial_components.keys():
|
||||
self.trial_components[trial_component_name].trial_name = None
|
||||
|
||||
if trial_name in self.trials.keys():
|
||||
self.trials[trial_name].trial_components = list(
|
||||
filter(
|
||||
lambda x: x != trial_component_name,
|
||||
self.trials[trial_name].trial_components,
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
"TrialComponentArn": f"arn:aws:sagemaker:{self.region_name}:{ACCOUNT_ID}:experiment-trial-component/{trial_component_name}",
|
||||
"TrialArn": f"arn:aws:sagemaker:{self.region_name}:{ACCOUNT_ID}:experiment-trial/{trial_name}",
|
||||
}
|
||||
|
||||
def create_notebook_instance(
|
||||
self,
|
||||
notebook_instance_name,
|
||||
@ -1804,11 +1856,12 @@ class FakeExperiment(BaseObject):
|
||||
|
||||
class FakeTrial(BaseObject):
|
||||
def __init__(
|
||||
self, region_name, trial_name, experiment_name, tags,
|
||||
self, region_name, trial_name, experiment_name, tags, trial_components,
|
||||
):
|
||||
self.trial_name = trial_name
|
||||
self.trial_arn = FakeTrial.arn_formatter(trial_name, region_name)
|
||||
self.tags = tags
|
||||
self.trial_components = trial_components
|
||||
self.experiment_name = experiment_name
|
||||
self.creation_time = self.last_modified_time = datetime.now().strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
|
@ -346,6 +346,7 @@ class SageMakerResponse(BaseResponse):
|
||||
def list_trials(self):
|
||||
response = self.sagemaker_backend.list_trials(
|
||||
experiment_name=self._get_param("ExperimentName"),
|
||||
trial_component_name=self._get_param("TrialComponentName"),
|
||||
)
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@ -404,14 +405,14 @@ class SageMakerResponse(BaseResponse):
|
||||
|
||||
@amzn_request_id
|
||||
def associate_trial_component(self, *args, **kwargs):
|
||||
self.sagemaker_backend.associate_trial_component(self.request_params)
|
||||
response = {}
|
||||
response = self.sagemaker_backend.associate_trial_component(self.request_params)
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def disassociate_trial_component(self, *args, **kwargs):
|
||||
self.sagemaker_backend.disassociate_trial_component(self.request_params)
|
||||
response = {}
|
||||
response = self.sagemaker_backend.disassociate_trial_component(
|
||||
self.request_params
|
||||
)
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
|
@ -1,4 +1,7 @@
|
||||
import boto3
|
||||
import pytest
|
||||
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from moto import mock_sagemaker
|
||||
|
||||
@ -59,3 +62,56 @@ def test_search():
|
||||
resp = client.search(Resource="ExperimentTrial")
|
||||
assert len(resp["Results"]) == 1
|
||||
assert resp["Results"][0]["Trial"]["TrialName"] == trial_name
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_search_trial_component_with_experiment_name():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
experiment_name = "some-experiment-name"
|
||||
|
||||
resp = client.create_experiment(ExperimentName=experiment_name)
|
||||
|
||||
trial_name = "some-trial-name"
|
||||
|
||||
resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name)
|
||||
|
||||
trial_component_name = "some-trial-component-name"
|
||||
another_trial_component_name = "another-trial-component-name"
|
||||
|
||||
resp = client.create_trial_component(TrialComponentName=trial_component_name)
|
||||
resp = client.create_trial_component(
|
||||
TrialComponentName=another_trial_component_name
|
||||
)
|
||||
|
||||
resp = client.search(Resource="ExperimentTrialComponent")
|
||||
|
||||
assert len(resp["Results"]) == 2
|
||||
|
||||
resp = client.describe_trial_component(TrialComponentName=trial_component_name)
|
||||
|
||||
trial_component_arn = resp["TrialComponentArn"]
|
||||
|
||||
tags = [{"Key": "key-name", "Value": "some-value"}]
|
||||
|
||||
client.add_tags(ResourceArn=trial_component_arn, Tags=tags)
|
||||
|
||||
with pytest.raises(ClientError) as ex:
|
||||
client.search(
|
||||
Resource="ExperimentTrialComponent",
|
||||
SearchExpression={
|
||||
"Filters": [
|
||||
{
|
||||
"Name": "ExperimentName",
|
||||
"Operator": "Equals",
|
||||
"Value": experiment_name,
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
ex.value.response["Error"]["Code"].should.equal("ValidationException")
|
||||
ex.value.response["Error"]["Message"].should.equal(
|
||||
"Unknown property name: ExperimentName"
|
||||
)
|
||||
ex.value.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400)
|
||||
|
@ -30,6 +30,28 @@ def test_create_trial():
|
||||
)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_trials_by_trial_component_name():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
experiment_name = "some-experiment-name"
|
||||
|
||||
resp = client.create_experiment(ExperimentName=experiment_name)
|
||||
|
||||
trial_name = "some-trial-name"
|
||||
|
||||
resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name)
|
||||
|
||||
trial_component_name = "some-trial-component-name"
|
||||
|
||||
resp = client.create_trial_component(TrialComponentName=trial_component_name)
|
||||
|
||||
resp = client.list_trials(TrialComponentName=trial_component_name)
|
||||
|
||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
assert len(resp["TrialSummaries"]) == 0
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_delete_trial():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
@ -1,4 +1,7 @@
|
||||
import boto3
|
||||
import pytest
|
||||
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from moto import mock_sagemaker
|
||||
from moto.sts.models import ACCOUNT_ID
|
||||
@ -108,11 +111,19 @@ def test_associate_trial_component():
|
||||
|
||||
resp = client.create_trial_component(TrialComponentName=trial_component_name)
|
||||
|
||||
client.associate_trial_component(
|
||||
resp = client.associate_trial_component(
|
||||
TrialComponentName=trial_component_name, TrialName=trial_name
|
||||
)
|
||||
|
||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
assert (
|
||||
resp["TrialComponentArn"]
|
||||
== f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial-component/{trial_component_name}"
|
||||
)
|
||||
assert (
|
||||
resp["TrialArn"]
|
||||
== f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial/{trial_name}"
|
||||
)
|
||||
|
||||
resp = client.list_trial_components(TrialName=trial_name)
|
||||
|
||||
@ -120,3 +131,78 @@ def test_associate_trial_component():
|
||||
assert (
|
||||
resp["TrialComponentSummaries"][0]["TrialComponentName"] == trial_component_name
|
||||
)
|
||||
|
||||
resp = client.list_trials(TrialComponentName=trial_component_name)
|
||||
|
||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
assert resp["TrialSummaries"][0]["TrialName"] == trial_name
|
||||
|
||||
with pytest.raises(ClientError) as ex:
|
||||
resp = client.associate_trial_component(
|
||||
TrialComponentName="does-not-exist", TrialName="does-not-exist"
|
||||
)
|
||||
|
||||
ex.value.response["Error"]["Code"].should.equal("ResourceNotFound")
|
||||
ex.value.response["Error"]["Message"].should.equal(
|
||||
f"Trial 'arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial/does-not-exist' does not exist."
|
||||
)
|
||||
ex.value.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_disassociate_trial_component():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
experiment_name = "some-experiment-name"
|
||||
|
||||
resp = client.create_experiment(ExperimentName=experiment_name)
|
||||
|
||||
trial_name = "some-trial-name"
|
||||
|
||||
resp = client.create_trial(ExperimentName=experiment_name, TrialName=trial_name)
|
||||
|
||||
trial_component_name = "some-trial-component-name"
|
||||
|
||||
resp = client.create_trial_component(TrialComponentName=trial_component_name)
|
||||
|
||||
client.associate_trial_component(
|
||||
TrialComponentName=trial_component_name, TrialName=trial_name
|
||||
)
|
||||
|
||||
resp = client.disassociate_trial_component(
|
||||
TrialComponentName=trial_component_name, TrialName=trial_name
|
||||
)
|
||||
|
||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
assert (
|
||||
resp["TrialComponentArn"]
|
||||
== f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial-component/{trial_component_name}"
|
||||
)
|
||||
assert (
|
||||
resp["TrialArn"]
|
||||
== f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial/{trial_name}"
|
||||
)
|
||||
|
||||
resp = client.list_trial_components(TrialName=trial_name)
|
||||
|
||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
assert len(resp["TrialComponentSummaries"]) == 0
|
||||
|
||||
resp = client.list_trials(TrialComponentName=trial_component_name)
|
||||
|
||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
assert len(resp["TrialSummaries"]) == 0
|
||||
|
||||
resp = client.disassociate_trial_component(
|
||||
TrialComponentName="does-not-exist", TrialName="does-not-exist"
|
||||
)
|
||||
|
||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
assert (
|
||||
resp["TrialComponentArn"]
|
||||
== f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial-component/does-not-exist"
|
||||
)
|
||||
assert (
|
||||
resp["TrialArn"]
|
||||
== f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:experiment-trial/does-not-exist"
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user