Sagemaker - Add Transform jobs (#6296)
This commit is contained in:
parent
417561cec6
commit
6b2ee153e6
@ -6009,7 +6009,7 @@
|
||||
- [ ] create_space
|
||||
- [ ] create_studio_lifecycle_config
|
||||
- [X] create_training_job
|
||||
- [ ] create_transform_job
|
||||
- [X] create_transform_job
|
||||
- [X] create_trial
|
||||
- [X] create_trial_component
|
||||
- [ ] create_user_profile
|
||||
@ -6113,7 +6113,7 @@
|
||||
- [ ] describe_studio_lifecycle_config
|
||||
- [ ] describe_subscribed_workteam
|
||||
- [X] describe_training_job
|
||||
- [ ] describe_transform_job
|
||||
- [X] describe_transform_job
|
||||
- [X] describe_trial
|
||||
- [X] describe_trial_component
|
||||
- [ ] describe_user_profile
|
||||
@ -6193,7 +6193,7 @@
|
||||
- [X] list_tags
|
||||
- [X] list_training_jobs
|
||||
- [ ] list_training_jobs_for_hyper_parameter_tuning_job
|
||||
- [ ] list_transform_jobs
|
||||
- [X] list_transform_jobs
|
||||
- [X] list_trial_components
|
||||
- [X] list_trials
|
||||
- [ ] list_user_profiles
|
||||
|
@ -3,7 +3,7 @@ import os
|
||||
import random
|
||||
import string
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Iterable
|
||||
from typing import Any, Dict, List, Optional, Iterable, Union
|
||||
|
||||
from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel
|
||||
from moto.sagemaker import validators
|
||||
@ -717,6 +717,80 @@ class FakeEndpointConfig(BaseObject, CloudFormationModel):
|
||||
)
|
||||
|
||||
|
||||
class FakeTransformJob(BaseObject):
|
||||
def __init__(
|
||||
self,
|
||||
account_id: str,
|
||||
region_name: str,
|
||||
transform_job_name: str,
|
||||
model_name: str,
|
||||
max_concurrent_transforms: int,
|
||||
model_client_config: Dict[str, int],
|
||||
max_payload_in_mb: int,
|
||||
batch_strategy: str,
|
||||
environment: Dict[str, str],
|
||||
transform_input: Dict[str, Union[Dict[str, str], str]],
|
||||
transform_output: Dict[str, str],
|
||||
data_capture_config: Dict[str, Union[str, bool]],
|
||||
transform_resources: Dict[str, Union[str, int]],
|
||||
data_processing: Dict[str, str],
|
||||
tags: Dict[str, str],
|
||||
experiment_config: Dict[str, str],
|
||||
):
|
||||
self.transform_job_name = transform_job_name
|
||||
self.model_name = model_name
|
||||
self.max_concurrent_transforms = max_concurrent_transforms
|
||||
self.model_client_config = model_client_config
|
||||
self.max_payload_in_mb = max_payload_in_mb
|
||||
self.batch_strategy = batch_strategy
|
||||
self.environment = environment
|
||||
self.transform_input = transform_input
|
||||
self.transform_output = transform_output
|
||||
self.data_capture_config = data_capture_config
|
||||
self.transform_resources = transform_resources
|
||||
self.data_processing = data_processing
|
||||
self.tags = tags
|
||||
self.experiment_config = experiment_config
|
||||
self.transform_job_arn = FakeTransformJob.arn_formatter(
|
||||
transform_job_name, account_id, region_name
|
||||
)
|
||||
self.transform_job_status = "Completed"
|
||||
self.failure_reason = ""
|
||||
self.labeling_job_arn = ""
|
||||
self.auto_ml_job_arn = ""
|
||||
now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
self.creation_time = now_string
|
||||
self.transform_start_time = now_string
|
||||
self.transform_end_time = now_string
|
||||
self.last_modified_time = now_string
|
||||
|
||||
# Override title case
|
||||
def camelCase(self, key: str) -> str:
|
||||
words = []
|
||||
for word in key.split("_"):
|
||||
if word == "mb":
|
||||
words.append("MB")
|
||||
else:
|
||||
words.append(word.title())
|
||||
return "".join(words)
|
||||
|
||||
@property
|
||||
def response_object(self) -> Dict[str, Any]: # type: ignore[misc]
|
||||
response_object = self.gen_response_object()
|
||||
response = {
|
||||
k: v for k, v in response_object.items() if v is not None and v != [None]
|
||||
}
|
||||
return response
|
||||
|
||||
@property
|
||||
def response_create(self) -> Dict[str, str]:
|
||||
return {"TransformJobArn": self.transform_job_arn}
|
||||
|
||||
@staticmethod
|
||||
def arn_formatter(name: str, account_id: str, region_name: str) -> str:
|
||||
return arn_formatter("transform-job", name, account_id, region_name)
|
||||
|
||||
|
||||
class Model(BaseObject, CloudFormationModel):
|
||||
def __init__(
|
||||
self,
|
||||
@ -1199,6 +1273,7 @@ class SageMakerModelBackend(BaseBackend):
|
||||
self.trials: Dict[str, FakeTrial] = {}
|
||||
self.trial_components: Dict[str, FakeTrialComponent] = {}
|
||||
self.training_jobs: Dict[str, FakeTrainingJob] = {}
|
||||
self.transform_jobs: Dict[str, FakeTransformJob] = {}
|
||||
self.notebook_instance_lifecycle_configurations: Dict[
|
||||
str, FakeSageMakerNotebookInstanceLifecycleConfig
|
||||
] = {}
|
||||
@ -1324,6 +1399,7 @@ class SageMakerModelBackend(BaseBackend):
|
||||
"endpoint": self.endpoints,
|
||||
"endpoint-config": self.endpoint_configs,
|
||||
"training-job": self.training_jobs,
|
||||
"transform-job": self.transform_jobs,
|
||||
"experiment": self.experiments,
|
||||
"experiment-trial": self.trials,
|
||||
"experiment-trial-component": self.trial_components,
|
||||
@ -2271,6 +2347,137 @@ class SageMakerModelBackend(BaseBackend):
|
||||
"NextToken": str(next_index) if next_index is not None else None,
|
||||
}
|
||||
|
||||
def create_transform_job(
|
||||
self,
|
||||
transform_job_name: str,
|
||||
model_name: str,
|
||||
max_concurrent_transforms: int,
|
||||
model_client_config: Dict[str, int],
|
||||
max_payload_in_mb: int,
|
||||
batch_strategy: str,
|
||||
environment: Dict[str, str],
|
||||
transform_input: Dict[str, Union[Dict[str, str], str]],
|
||||
transform_output: Dict[str, str],
|
||||
data_capture_config: Dict[str, Union[str, bool]],
|
||||
transform_resources: Dict[str, Union[str, int]],
|
||||
data_processing: Dict[str, str],
|
||||
tags: Dict[str, str],
|
||||
experiment_config: Dict[str, str],
|
||||
) -> FakeTransformJob:
|
||||
transform_job = FakeTransformJob(
|
||||
account_id=self.account_id,
|
||||
region_name=self.region_name,
|
||||
transform_job_name=transform_job_name,
|
||||
model_name=model_name,
|
||||
max_concurrent_transforms=max_concurrent_transforms,
|
||||
model_client_config=model_client_config,
|
||||
max_payload_in_mb=max_payload_in_mb,
|
||||
batch_strategy=batch_strategy,
|
||||
environment=environment,
|
||||
transform_input=transform_input,
|
||||
transform_output=transform_output,
|
||||
data_capture_config=data_capture_config,
|
||||
transform_resources=transform_resources,
|
||||
data_processing=data_processing,
|
||||
tags=tags,
|
||||
experiment_config=experiment_config,
|
||||
)
|
||||
self.transform_jobs[transform_job_name] = transform_job
|
||||
return transform_job
|
||||
|
||||
def list_transform_jobs(
|
||||
self,
|
||||
next_token: str,
|
||||
max_results: int,
|
||||
creation_time_after: str,
|
||||
creation_time_before: str,
|
||||
last_modified_time_after: str,
|
||||
last_modified_time_before: str,
|
||||
name_contains: str,
|
||||
status_equals: str,
|
||||
) -> Dict[str, Any]:
|
||||
if next_token:
|
||||
try:
|
||||
starting_index = int(next_token)
|
||||
if starting_index > len(self.transform_jobs):
|
||||
raise ValueError # invalid next_token
|
||||
except ValueError:
|
||||
raise AWSValidationException('Invalid pagination token because "{0}".')
|
||||
else:
|
||||
starting_index = 0
|
||||
|
||||
if max_results:
|
||||
end_index = max_results + starting_index
|
||||
transform_jobs_fetched: Iterable[FakeTransformJob] = list(
|
||||
self.transform_jobs.values()
|
||||
)[starting_index:end_index]
|
||||
if end_index >= len(self.transform_jobs):
|
||||
next_index = None
|
||||
else:
|
||||
next_index = end_index
|
||||
else:
|
||||
transform_jobs_fetched = list(self.transform_jobs.values())
|
||||
next_index = None
|
||||
|
||||
if name_contains is not None:
|
||||
transform_jobs_fetched = filter(
|
||||
lambda x: name_contains in x.transform_job_name, transform_jobs_fetched
|
||||
)
|
||||
|
||||
if creation_time_after is not None:
|
||||
transform_jobs_fetched = filter(
|
||||
lambda x: x.creation_time > creation_time_after, transform_jobs_fetched
|
||||
)
|
||||
|
||||
if creation_time_before is not None:
|
||||
transform_jobs_fetched = filter(
|
||||
lambda x: x.creation_time < creation_time_before, transform_jobs_fetched
|
||||
)
|
||||
|
||||
if last_modified_time_after is not None:
|
||||
transform_jobs_fetched = filter(
|
||||
lambda x: x.last_modified_time > last_modified_time_after,
|
||||
transform_jobs_fetched,
|
||||
)
|
||||
|
||||
if last_modified_time_before is not None:
|
||||
transform_jobs_fetched = filter(
|
||||
lambda x: x.last_modified_time < last_modified_time_before,
|
||||
transform_jobs_fetched,
|
||||
)
|
||||
if status_equals is not None:
|
||||
transform_jobs_fetched = filter(
|
||||
lambda x: x.transform_job_status == status_equals,
|
||||
transform_jobs_fetched,
|
||||
)
|
||||
|
||||
transform_job_summaries = [
|
||||
{
|
||||
"TransformJobName": transform_job_data.transform_job_name,
|
||||
"TransformJobArn": transform_job_data.transform_job_arn,
|
||||
"CreationTime": transform_job_data.creation_time,
|
||||
"TransformEndTime": transform_job_data.transform_end_time,
|
||||
"LastModifiedTime": transform_job_data.last_modified_time,
|
||||
"TransformJobStatus": transform_job_data.transform_job_status,
|
||||
}
|
||||
for transform_job_data in transform_jobs_fetched
|
||||
]
|
||||
|
||||
return {
|
||||
"TransformJobSummaries": transform_job_summaries,
|
||||
"NextToken": str(next_index) if next_index is not None else None,
|
||||
}
|
||||
|
||||
def describe_transform_job(self, transform_job_name: str) -> Dict[str, Any]:
|
||||
try:
|
||||
return self.transform_jobs[transform_job_name].response_object
|
||||
except KeyError:
|
||||
arn = FakeTransformJob.arn_formatter(
|
||||
transform_job_name, self.account_id, self.region_name
|
||||
)
|
||||
message = f"Could not find transform job '{arn}'."
|
||||
raise ValidationError(message=message)
|
||||
|
||||
def create_training_job(
|
||||
self,
|
||||
training_job_name: str,
|
||||
|
@ -216,6 +216,37 @@ class SageMakerResponse(BaseResponse):
|
||||
response = self.sagemaker_backend.describe_processing_job(processing_job_name)
|
||||
return json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def create_transform_job(self) -> TYPE_RESPONSE:
|
||||
transform_job = self.sagemaker_backend.create_transform_job(
|
||||
transform_job_name=self._get_param("TransformJobName"),
|
||||
model_name=self._get_param("ModelName"),
|
||||
max_concurrent_transforms=self._get_param("MaxConcurrentTransforms"),
|
||||
model_client_config=self._get_param("ModelClientConfig"),
|
||||
max_payload_in_mb=self._get_param("MaxPayloadInMB"),
|
||||
batch_strategy=self._get_param("BatchStrategy"),
|
||||
environment=self._get_param("Environment"),
|
||||
transform_input=self._get_param("TransformInput"),
|
||||
transform_output=self._get_param("TransformOutput"),
|
||||
data_capture_config=self._get_param("DataCaptureConfig"),
|
||||
transform_resources=self._get_param("TransformResources"),
|
||||
data_processing=self._get_param("DataProcessing"),
|
||||
tags=self._get_param("Tags"),
|
||||
experiment_config=self._get_param("ExperimentConfig"),
|
||||
)
|
||||
response = {
|
||||
"TransformJobArn": transform_job.transform_job_arn,
|
||||
}
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def describe_transform_job(self) -> str:
|
||||
transform_job_name = self._get_param("TransformJobName")
|
||||
response = self.sagemaker_backend.describe_transform_job(
|
||||
transform_job_name=transform_job_name
|
||||
)
|
||||
return json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def create_training_job(self) -> TYPE_RESPONSE:
|
||||
training_job = self.sagemaker_backend.create_training_job(
|
||||
@ -651,6 +682,59 @@ class SageMakerResponse(BaseResponse):
|
||||
)
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def list_transform_jobs(self) -> TYPE_RESPONSE:
|
||||
max_results_range = range(1, 101)
|
||||
allowed_sort_by = ["Name", "CreationTime", "Status"]
|
||||
allowed_sort_order = ["Ascending", "Descending"]
|
||||
allowed_status_equals = [
|
||||
"Completed",
|
||||
"Stopped",
|
||||
"InProgress",
|
||||
"Stopping",
|
||||
"Failed",
|
||||
]
|
||||
|
||||
max_results = self._get_int_param("MaxResults")
|
||||
sort_by = self._get_param("SortBy", "CreationTime")
|
||||
sort_order = self._get_param("SortOrder", "Ascending")
|
||||
status_equals = self._get_param("StatusEquals")
|
||||
next_token = self._get_param("NextToken")
|
||||
errors = []
|
||||
if max_results and max_results not in max_results_range:
|
||||
errors.append(
|
||||
f"Value '{max_results}' at 'maxResults' failed to satisfy constraint: Member must have value less than or equal to {max_results_range[-1]}"
|
||||
)
|
||||
|
||||
if sort_by not in allowed_sort_by:
|
||||
errors.append(format_enum_error(sort_by, "sortBy", allowed_sort_by))
|
||||
if sort_order not in allowed_sort_order:
|
||||
errors.append(
|
||||
format_enum_error(sort_order, "sortOrder", allowed_sort_order)
|
||||
)
|
||||
|
||||
if status_equals and status_equals not in allowed_status_equals:
|
||||
errors.append(
|
||||
format_enum_error(status_equals, "statusEquals", allowed_status_equals)
|
||||
)
|
||||
|
||||
if errors != []:
|
||||
raise AWSValidationException(
|
||||
f"{len(errors)} validation errors detected: {';'.join(errors)}"
|
||||
)
|
||||
|
||||
response = self.sagemaker_backend.list_transform_jobs(
|
||||
next_token=next_token,
|
||||
max_results=max_results,
|
||||
creation_time_after=self._get_param("CreationTimeAfter"),
|
||||
creation_time_before=self._get_param("CreationTimeBefore"),
|
||||
last_modified_time_after=self._get_param("LastModifiedTimeAfter"),
|
||||
last_modified_time_before=self._get_param("LastModifiedTimeBefore"),
|
||||
name_contains=self._get_param("NameContains"),
|
||||
status_equals=status_equals,
|
||||
)
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def list_training_jobs(self) -> TYPE_RESPONSE:
|
||||
max_results_range = range(1, 101)
|
||||
|
407
tests/test_sagemaker/test_sagemaker_transform.py
Normal file
407
tests/test_sagemaker/test_sagemaker_transform.py
Normal file
@ -0,0 +1,407 @@
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
import datetime
|
||||
import sure # noqa # pylint: disable=unused-import
|
||||
import pytest
|
||||
|
||||
from moto import mock_sagemaker
|
||||
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
|
||||
|
||||
FAKE_ROLE_ARN = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole"
|
||||
TEST_REGION_NAME = "us-east-1"
|
||||
|
||||
|
||||
class MyTransformJobModel(object):
|
||||
def __init__(
|
||||
self,
|
||||
transform_job_name,
|
||||
model_name,
|
||||
max_concurrent_transforms=None,
|
||||
model_client_config=None,
|
||||
max_payload_in_mb=None,
|
||||
batch_strategy=None,
|
||||
environment=None,
|
||||
transform_input=None,
|
||||
transform_output=None,
|
||||
data_capture_config=None,
|
||||
transform_resources=None,
|
||||
data_processing=None,
|
||||
tags=None,
|
||||
experiment_config=None,
|
||||
):
|
||||
self.transform_job_name = transform_job_name
|
||||
self.model_name = model_name
|
||||
self.max_concurrent_transforms = max_concurrent_transforms or 1
|
||||
self.model_client_config = model_client_config or {}
|
||||
self.max_payload_in_mb = max_payload_in_mb or 1
|
||||
self.batch_strategy = batch_strategy or "SingleRecord"
|
||||
self.environment = environment or {}
|
||||
self.transform_input = transform_input or {
|
||||
"DataSource": {
|
||||
"S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": "input"}
|
||||
},
|
||||
"ContentType": "application/json",
|
||||
"CompressionType": "None",
|
||||
"SplitType": "None",
|
||||
}
|
||||
self.transform_output = transform_output or {
|
||||
"S3OutputPath": "some-bucket",
|
||||
"Accept": "application/json",
|
||||
"AssembleWith": "None",
|
||||
"KmsKeyId": "None",
|
||||
}
|
||||
self.data_capture_config = data_capture_config or {
|
||||
"DestinationS3Uri": "data_capture",
|
||||
"KmsKeyId": "None",
|
||||
"GenerateInferenceId": False,
|
||||
}
|
||||
self.transform_resources = transform_resources or {
|
||||
"InstanceType": "ml.m5.2xlarge",
|
||||
"InstanceCount": 1,
|
||||
}
|
||||
self.data_processing = data_processing or {}
|
||||
self.tags = tags or []
|
||||
self.experiment_config = experiment_config or {}
|
||||
|
||||
def save(self):
|
||||
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
params = {
|
||||
"TransformJobName": self.transform_job_name,
|
||||
"ModelName": self.model_name,
|
||||
"MaxConcurrentTransforms": self.max_concurrent_transforms,
|
||||
"ModelClientConfig": self.model_client_config,
|
||||
"MaxPayloadInMB": self.max_payload_in_mb,
|
||||
"BatchStrategy": self.batch_strategy,
|
||||
"Environment": self.environment,
|
||||
"TransformInput": self.transform_input,
|
||||
"TransformOutput": self.transform_output,
|
||||
"DataCaptureConfig": self.data_capture_config,
|
||||
"TransformResources": self.transform_resources,
|
||||
"DataProcessing": self.data_processing,
|
||||
"Tags": self.tags,
|
||||
"ExperimentConfig": self.experiment_config,
|
||||
}
|
||||
return sagemaker.create_transform_job(**params)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_transform_job():
|
||||
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
transform_job_name = "MyTransformJob"
|
||||
model_name = "MyModelName"
|
||||
bucket = "my-bucket"
|
||||
transform_input = {
|
||||
"DataSource": {"S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": "input"}},
|
||||
"ContentType": "application/json",
|
||||
"CompressionType": "None",
|
||||
"SplitType": "None",
|
||||
}
|
||||
transform_output = {
|
||||
"S3OutputPath": bucket,
|
||||
"Accept": "application/json",
|
||||
"AssembleWith": "None",
|
||||
"KmsKeyId": "None",
|
||||
}
|
||||
|
||||
model_client_config = {
|
||||
"InvocationsTimeoutInSeconds": 60,
|
||||
"InvocationsMaxRetries": 1,
|
||||
}
|
||||
|
||||
max_payload_in_mb = 1
|
||||
|
||||
data_capture_config = {
|
||||
"DestinationS3Uri": "data_capture",
|
||||
"KmsKeyId": "None",
|
||||
"GenerateInferenceId": False,
|
||||
}
|
||||
|
||||
transform_resources = {
|
||||
"InstanceType": "ml.m5.2xlarge",
|
||||
"InstanceCount": 1,
|
||||
}
|
||||
|
||||
data_processing = {
|
||||
"InputFilter": "$.features",
|
||||
"OutputFilter": "$['id','SageMakerOutput']",
|
||||
"JoinSource": "None",
|
||||
}
|
||||
|
||||
experiment_config = {
|
||||
"ExperimentName": "MyExperiment",
|
||||
"TrialName": "MyTrial",
|
||||
"TrialComponentDisplayName": "MyTrialDisplay",
|
||||
"RunName": "MyRun",
|
||||
}
|
||||
|
||||
job = MyTransformJobModel(
|
||||
transform_job_name=transform_job_name,
|
||||
model_name=model_name,
|
||||
transform_output=transform_output,
|
||||
model_client_config=model_client_config,
|
||||
max_payload_in_mb=max_payload_in_mb,
|
||||
data_capture_config=data_capture_config,
|
||||
transform_resources=transform_resources,
|
||||
data_processing=data_processing,
|
||||
experiment_config=experiment_config,
|
||||
)
|
||||
resp = job.save()
|
||||
resp["TransformJobArn"].should.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:transform-job/{transform_job_name}$"
|
||||
)
|
||||
resp = sagemaker.describe_transform_job(TransformJobName=transform_job_name)
|
||||
resp["TransformJobName"].should.equal(transform_job_name)
|
||||
resp["TransformJobStatus"].should.equal("Completed")
|
||||
resp["ModelName"].should.equal(model_name)
|
||||
resp["MaxConcurrentTransforms"].should.equal(1)
|
||||
resp["ModelClientConfig"].should.equal(model_client_config)
|
||||
resp["MaxPayloadInMB"].should.equal(max_payload_in_mb)
|
||||
resp["BatchStrategy"].should.equal("SingleRecord")
|
||||
resp["TransformInput"].should.equal(transform_input)
|
||||
resp["TransformOutput"].should.equal(transform_output)
|
||||
resp["DataCaptureConfig"].should.equal(data_capture_config)
|
||||
resp["TransformResources"].should.equal(transform_resources)
|
||||
resp["DataProcessing"].should.equal(data_processing)
|
||||
resp["ExperimentConfig"].should.equal(experiment_config)
|
||||
assert isinstance(resp["CreationTime"], datetime.datetime)
|
||||
assert isinstance(resp["TransformStartTime"], datetime.datetime)
|
||||
assert isinstance(resp["TransformEndTime"], datetime.datetime)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_transform_jobs():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
name = "blah"
|
||||
model_name = "blah_model"
|
||||
test_transform_job = MyTransformJobModel(
|
||||
transform_job_name=name, model_name=model_name
|
||||
)
|
||||
test_transform_job.save()
|
||||
transform_jobs = client.list_transform_jobs()
|
||||
assert len(transform_jobs["TransformJobSummaries"]).should.equal(1)
|
||||
assert transform_jobs["TransformJobSummaries"][0]["TransformJobName"].should.equal(
|
||||
name
|
||||
)
|
||||
|
||||
assert transform_jobs["TransformJobSummaries"][0]["TransformJobArn"].should.match(
|
||||
rf"^arn:aws:sagemaker:.*:.*:transform-job/{name}$"
|
||||
)
|
||||
assert transform_jobs.get("NextToken") is None
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_transform_jobs_multiple():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
name_job_1 = "blah"
|
||||
model_name1 = "blah_model"
|
||||
test_transform_job_1 = MyTransformJobModel(
|
||||
transform_job_name=name_job_1, model_name=model_name1
|
||||
)
|
||||
test_transform_job_1.save()
|
||||
|
||||
name_job_2 = "blah2"
|
||||
model_name2 = "blah_model2"
|
||||
test_transform_job_2 = MyTransformJobModel(
|
||||
transform_job_name=name_job_2, model_name=model_name2
|
||||
)
|
||||
test_transform_job_2.save()
|
||||
transform_jobs_limit = client.list_transform_jobs(MaxResults=1)
|
||||
assert len(transform_jobs_limit["TransformJobSummaries"]).should.equal(1)
|
||||
|
||||
transform_jobs = client.list_transform_jobs()
|
||||
assert len(transform_jobs["TransformJobSummaries"]).should.equal(2)
|
||||
assert transform_jobs.get("NextToken").should.be.none
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_transform_jobs_none():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
transform_jobs = client.list_transform_jobs()
|
||||
assert len(transform_jobs["TransformJobSummaries"]).should.equal(0)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_transform_jobs_should_validate_input():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
junk_status_equals = "blah"
|
||||
with pytest.raises(ClientError) as ex:
|
||||
client.list_transform_jobs(StatusEquals=junk_status_equals)
|
||||
expected_error = f"1 validation errors detected: Value '{junk_status_equals}' at 'statusEquals' failed to satisfy constraint: Member must satisfy enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', 'Failed']"
|
||||
assert ex.value.response["Error"]["Code"] == "ValidationException"
|
||||
assert ex.value.response["Error"]["Message"] == expected_error
|
||||
|
||||
junk_next_token = "asdf"
|
||||
with pytest.raises(ClientError) as ex:
|
||||
client.list_transform_jobs(NextToken=junk_next_token)
|
||||
assert ex.value.response["Error"]["Code"] == "ValidationException"
|
||||
assert (
|
||||
ex.value.response["Error"]["Message"]
|
||||
== 'Invalid pagination token because "{0}".'
|
||||
)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_transform_jobs_with_name_filters():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
for i in range(5):
|
||||
name = f"xgboost-{i}"
|
||||
model_name = f"blah_model-{i}"
|
||||
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
|
||||
for i in range(5):
|
||||
name = f"vgg-{i}"
|
||||
model_name = f"blah_model-{i}"
|
||||
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
|
||||
xgboost_transform_jobs = client.list_transform_jobs(NameContains="xgboost")
|
||||
assert len(xgboost_transform_jobs["TransformJobSummaries"]).should.equal(5)
|
||||
|
||||
transform_jobs_with_2 = client.list_transform_jobs(NameContains="2")
|
||||
assert len(transform_jobs_with_2["TransformJobSummaries"]).should.equal(2)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_transform_jobs_paginated():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
for i in range(5):
|
||||
name = f"xgboost-{i}"
|
||||
model_name = f"my-model-{i}"
|
||||
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
|
||||
xgboost_transform_job_1 = client.list_transform_jobs(
|
||||
NameContains="xgboost", MaxResults=1
|
||||
)
|
||||
assert len(xgboost_transform_job_1["TransformJobSummaries"]).should.equal(1)
|
||||
assert xgboost_transform_job_1["TransformJobSummaries"][0][
|
||||
"TransformJobName"
|
||||
].should.equal("xgboost-0")
|
||||
assert xgboost_transform_job_1.get("NextToken").should_not.be.none
|
||||
|
||||
xgboost_transform_job_next = client.list_transform_jobs(
|
||||
NameContains="xgboost",
|
||||
MaxResults=1,
|
||||
NextToken=xgboost_transform_job_1.get("NextToken"),
|
||||
)
|
||||
assert len(xgboost_transform_job_next["TransformJobSummaries"]).should.equal(1)
|
||||
assert xgboost_transform_job_next["TransformJobSummaries"][0][
|
||||
"TransformJobName"
|
||||
].should.equal("xgboost-1")
|
||||
assert xgboost_transform_job_next.get("NextToken").should_not.be.none
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_transform_jobs_paginated_with_target_in_middle():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
for i in range(5):
|
||||
name = f"xgboost-{i}"
|
||||
model_name = f"my-model-{i}"
|
||||
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
|
||||
for i in range(5):
|
||||
name = f"vgg-{i}"
|
||||
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
|
||||
|
||||
vgg_transform_job_1 = client.list_transform_jobs(NameContains="vgg", MaxResults=1)
|
||||
assert len(vgg_transform_job_1["TransformJobSummaries"]).should.equal(0)
|
||||
assert vgg_transform_job_1.get("NextToken").should_not.be.none
|
||||
|
||||
vgg_transform_job_6 = client.list_transform_jobs(NameContains="vgg", MaxResults=6)
|
||||
|
||||
assert len(vgg_transform_job_6["TransformJobSummaries"]).should.equal(1)
|
||||
assert vgg_transform_job_6["TransformJobSummaries"][0][
|
||||
"TransformJobName"
|
||||
].should.equal("vgg-0")
|
||||
assert vgg_transform_job_6.get("NextToken").should_not.be.none
|
||||
|
||||
vgg_transform_job_10 = client.list_transform_jobs(NameContains="vgg", MaxResults=10)
|
||||
|
||||
assert len(vgg_transform_job_10["TransformJobSummaries"]).should.equal(5)
|
||||
assert vgg_transform_job_10["TransformJobSummaries"][-1][
|
||||
"TransformJobName"
|
||||
].should.equal("vgg-4")
|
||||
assert vgg_transform_job_10.get("NextToken").should.be.none
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_transform_jobs_paginated_with_fragmented_targets():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
for i in range(5):
|
||||
name = f"xgboost-{i}"
|
||||
model_name = f"my-model-{i}"
|
||||
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
|
||||
for i in range(5):
|
||||
name = f"vgg-{i}"
|
||||
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
|
||||
|
||||
transform_jobs_with_2 = client.list_transform_jobs(NameContains="2", MaxResults=8)
|
||||
assert len(transform_jobs_with_2["TransformJobSummaries"]).should.equal(2)
|
||||
assert transform_jobs_with_2.get("NextToken").should_not.be.none
|
||||
|
||||
transform_jobs_with_2_next = client.list_transform_jobs(
|
||||
NameContains="2", MaxResults=1, NextToken=transform_jobs_with_2.get("NextToken")
|
||||
)
|
||||
assert len(transform_jobs_with_2_next["TransformJobSummaries"]).should.equal(0)
|
||||
assert transform_jobs_with_2_next.get("NextToken").should_not.be.none
|
||||
|
||||
transform_jobs_with_2_next_next = client.list_transform_jobs(
|
||||
NameContains="2",
|
||||
MaxResults=1,
|
||||
NextToken=transform_jobs_with_2_next.get("NextToken"),
|
||||
)
|
||||
assert len(transform_jobs_with_2_next_next["TransformJobSummaries"]).should.equal(0)
|
||||
assert transform_jobs_with_2_next_next.get("NextToken").should.be.none
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_add_tags_to_transform_job():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
name = "blah"
|
||||
model_name = "my-model"
|
||||
resource_arn = "arn:aws:sagemaker:us-east-1:123456789012:transform-job/blah"
|
||||
|
||||
test_transform_job = MyTransformJobModel(
|
||||
transform_job_name=name, model_name=model_name
|
||||
)
|
||||
test_transform_job.save()
|
||||
tags = [
|
||||
{"Key": "myKey", "Value": "myValue"},
|
||||
]
|
||||
response = client.add_tags(ResourceArn=resource_arn, Tags=tags)
|
||||
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
response = client.list_tags(ResourceArn=resource_arn)
|
||||
assert response["Tags"] == tags
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_delete_tags_from_transform_job():
|
||||
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
name = "blah"
|
||||
model_name = "my-model"
|
||||
resource_arn = "arn:aws:sagemaker:us-east-1:123456789012:transform-job/blah"
|
||||
test_transform_job = MyTransformJobModel(
|
||||
transform_job_name=name, model_name=model_name
|
||||
)
|
||||
test_transform_job.save()
|
||||
|
||||
tags = [
|
||||
{"Key": "myKey", "Value": "myValue"},
|
||||
]
|
||||
response = client.add_tags(ResourceArn=resource_arn, Tags=tags)
|
||||
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
tag_keys = [tag["Key"] for tag in tags]
|
||||
response = client.delete_tags(ResourceArn=resource_arn, TagKeys=tag_keys)
|
||||
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
response = client.list_tags(ResourceArn=resource_arn)
|
||||
assert response["Tags"] == []
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_describe_unknown_transform_job():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
with pytest.raises(ClientError) as exc:
|
||||
client.describe_transform_job(TransformJobName="unknown")
|
||||
err = exc.value.response["Error"]
|
||||
err["Code"].should.equal("ValidationException")
|
||||
err["Message"].should.equal(
|
||||
f"Could not find transform job 'arn:aws:sagemaker:us-east-1:{ACCOUNT_ID}:transform-job/unknown'."
|
||||
)
|
Loading…
Reference in New Issue
Block a user