Sagemaker - Add Transform jobs (#6296)

This commit is contained in:
Brandon 2023-05-10 13:54:49 -04:00 committed by GitHub
parent 417561cec6
commit 6b2ee153e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 702 additions and 4 deletions

View File

@ -6009,7 +6009,7 @@
- [ ] create_space - [ ] create_space
- [ ] create_studio_lifecycle_config - [ ] create_studio_lifecycle_config
- [X] create_training_job - [X] create_training_job
- [ ] create_transform_job - [X] create_transform_job
- [X] create_trial - [X] create_trial
- [X] create_trial_component - [X] create_trial_component
- [ ] create_user_profile - [ ] create_user_profile
@ -6113,7 +6113,7 @@
- [ ] describe_studio_lifecycle_config - [ ] describe_studio_lifecycle_config
- [ ] describe_subscribed_workteam - [ ] describe_subscribed_workteam
- [X] describe_training_job - [X] describe_training_job
- [ ] describe_transform_job - [X] describe_transform_job
- [X] describe_trial - [X] describe_trial
- [X] describe_trial_component - [X] describe_trial_component
- [ ] describe_user_profile - [ ] describe_user_profile
@ -6193,7 +6193,7 @@
- [X] list_tags - [X] list_tags
- [X] list_training_jobs - [X] list_training_jobs
- [ ] list_training_jobs_for_hyper_parameter_tuning_job - [ ] list_training_jobs_for_hyper_parameter_tuning_job
- [ ] list_transform_jobs - [X] list_transform_jobs
- [X] list_trial_components - [X] list_trial_components
- [X] list_trials - [X] list_trials
- [ ] list_user_profiles - [ ] list_user_profiles

View File

@ -3,7 +3,7 @@ import os
import random import random
import string import string
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional, Iterable from typing import Any, Dict, List, Optional, Iterable, Union
from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel
from moto.sagemaker import validators from moto.sagemaker import validators
@ -717,6 +717,80 @@ class FakeEndpointConfig(BaseObject, CloudFormationModel):
) )
class FakeTransformJob(BaseObject):
def __init__(
self,
account_id: str,
region_name: str,
transform_job_name: str,
model_name: str,
max_concurrent_transforms: int,
model_client_config: Dict[str, int],
max_payload_in_mb: int,
batch_strategy: str,
environment: Dict[str, str],
transform_input: Dict[str, Union[Dict[str, str], str]],
transform_output: Dict[str, str],
data_capture_config: Dict[str, Union[str, bool]],
transform_resources: Dict[str, Union[str, int]],
data_processing: Dict[str, str],
tags: Dict[str, str],
experiment_config: Dict[str, str],
):
self.transform_job_name = transform_job_name
self.model_name = model_name
self.max_concurrent_transforms = max_concurrent_transforms
self.model_client_config = model_client_config
self.max_payload_in_mb = max_payload_in_mb
self.batch_strategy = batch_strategy
self.environment = environment
self.transform_input = transform_input
self.transform_output = transform_output
self.data_capture_config = data_capture_config
self.transform_resources = transform_resources
self.data_processing = data_processing
self.tags = tags
self.experiment_config = experiment_config
self.transform_job_arn = FakeTransformJob.arn_formatter(
transform_job_name, account_id, region_name
)
self.transform_job_status = "Completed"
self.failure_reason = ""
self.labeling_job_arn = ""
self.auto_ml_job_arn = ""
now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.creation_time = now_string
self.transform_start_time = now_string
self.transform_end_time = now_string
self.last_modified_time = now_string
# Override title case
def camelCase(self, key: str) -> str:
words = []
for word in key.split("_"):
if word == "mb":
words.append("MB")
else:
words.append(word.title())
return "".join(words)
@property
def response_object(self) -> Dict[str, Any]: # type: ignore[misc]
response_object = self.gen_response_object()
response = {
k: v for k, v in response_object.items() if v is not None and v != [None]
}
return response
@property
def response_create(self) -> Dict[str, str]:
return {"TransformJobArn": self.transform_job_arn}
@staticmethod
def arn_formatter(name: str, account_id: str, region_name: str) -> str:
return arn_formatter("transform-job", name, account_id, region_name)
class Model(BaseObject, CloudFormationModel): class Model(BaseObject, CloudFormationModel):
def __init__( def __init__(
self, self,
@ -1199,6 +1273,7 @@ class SageMakerModelBackend(BaseBackend):
self.trials: Dict[str, FakeTrial] = {} self.trials: Dict[str, FakeTrial] = {}
self.trial_components: Dict[str, FakeTrialComponent] = {} self.trial_components: Dict[str, FakeTrialComponent] = {}
self.training_jobs: Dict[str, FakeTrainingJob] = {} self.training_jobs: Dict[str, FakeTrainingJob] = {}
self.transform_jobs: Dict[str, FakeTransformJob] = {}
self.notebook_instance_lifecycle_configurations: Dict[ self.notebook_instance_lifecycle_configurations: Dict[
str, FakeSageMakerNotebookInstanceLifecycleConfig str, FakeSageMakerNotebookInstanceLifecycleConfig
] = {} ] = {}
@ -1324,6 +1399,7 @@ class SageMakerModelBackend(BaseBackend):
"endpoint": self.endpoints, "endpoint": self.endpoints,
"endpoint-config": self.endpoint_configs, "endpoint-config": self.endpoint_configs,
"training-job": self.training_jobs, "training-job": self.training_jobs,
"transform-job": self.transform_jobs,
"experiment": self.experiments, "experiment": self.experiments,
"experiment-trial": self.trials, "experiment-trial": self.trials,
"experiment-trial-component": self.trial_components, "experiment-trial-component": self.trial_components,
@ -2271,6 +2347,137 @@ class SageMakerModelBackend(BaseBackend):
"NextToken": str(next_index) if next_index is not None else None, "NextToken": str(next_index) if next_index is not None else None,
} }
def create_transform_job(
self,
transform_job_name: str,
model_name: str,
max_concurrent_transforms: int,
model_client_config: Dict[str, int],
max_payload_in_mb: int,
batch_strategy: str,
environment: Dict[str, str],
transform_input: Dict[str, Union[Dict[str, str], str]],
transform_output: Dict[str, str],
data_capture_config: Dict[str, Union[str, bool]],
transform_resources: Dict[str, Union[str, int]],
data_processing: Dict[str, str],
tags: Dict[str, str],
experiment_config: Dict[str, str],
) -> FakeTransformJob:
transform_job = FakeTransformJob(
account_id=self.account_id,
region_name=self.region_name,
transform_job_name=transform_job_name,
model_name=model_name,
max_concurrent_transforms=max_concurrent_transforms,
model_client_config=model_client_config,
max_payload_in_mb=max_payload_in_mb,
batch_strategy=batch_strategy,
environment=environment,
transform_input=transform_input,
transform_output=transform_output,
data_capture_config=data_capture_config,
transform_resources=transform_resources,
data_processing=data_processing,
tags=tags,
experiment_config=experiment_config,
)
self.transform_jobs[transform_job_name] = transform_job
return transform_job
def list_transform_jobs(
self,
next_token: str,
max_results: int,
creation_time_after: str,
creation_time_before: str,
last_modified_time_after: str,
last_modified_time_before: str,
name_contains: str,
status_equals: str,
) -> Dict[str, Any]:
if next_token:
try:
starting_index = int(next_token)
if starting_index > len(self.transform_jobs):
raise ValueError # invalid next_token
except ValueError:
raise AWSValidationException('Invalid pagination token because "{0}".')
else:
starting_index = 0
if max_results:
end_index = max_results + starting_index
transform_jobs_fetched: Iterable[FakeTransformJob] = list(
self.transform_jobs.values()
)[starting_index:end_index]
if end_index >= len(self.transform_jobs):
next_index = None
else:
next_index = end_index
else:
transform_jobs_fetched = list(self.transform_jobs.values())
next_index = None
if name_contains is not None:
transform_jobs_fetched = filter(
lambda x: name_contains in x.transform_job_name, transform_jobs_fetched
)
if creation_time_after is not None:
transform_jobs_fetched = filter(
lambda x: x.creation_time > creation_time_after, transform_jobs_fetched
)
if creation_time_before is not None:
transform_jobs_fetched = filter(
lambda x: x.creation_time < creation_time_before, transform_jobs_fetched
)
if last_modified_time_after is not None:
transform_jobs_fetched = filter(
lambda x: x.last_modified_time > last_modified_time_after,
transform_jobs_fetched,
)
if last_modified_time_before is not None:
transform_jobs_fetched = filter(
lambda x: x.last_modified_time < last_modified_time_before,
transform_jobs_fetched,
)
if status_equals is not None:
transform_jobs_fetched = filter(
lambda x: x.transform_job_status == status_equals,
transform_jobs_fetched,
)
transform_job_summaries = [
{
"TransformJobName": transform_job_data.transform_job_name,
"TransformJobArn": transform_job_data.transform_job_arn,
"CreationTime": transform_job_data.creation_time,
"TransformEndTime": transform_job_data.transform_end_time,
"LastModifiedTime": transform_job_data.last_modified_time,
"TransformJobStatus": transform_job_data.transform_job_status,
}
for transform_job_data in transform_jobs_fetched
]
return {
"TransformJobSummaries": transform_job_summaries,
"NextToken": str(next_index) if next_index is not None else None,
}
def describe_transform_job(self, transform_job_name: str) -> Dict[str, Any]:
try:
return self.transform_jobs[transform_job_name].response_object
except KeyError:
arn = FakeTransformJob.arn_formatter(
transform_job_name, self.account_id, self.region_name
)
message = f"Could not find transform job '{arn}'."
raise ValidationError(message=message)
def create_training_job( def create_training_job(
self, self,
training_job_name: str, training_job_name: str,

View File

@ -216,6 +216,37 @@ class SageMakerResponse(BaseResponse):
response = self.sagemaker_backend.describe_processing_job(processing_job_name) response = self.sagemaker_backend.describe_processing_job(processing_job_name)
return json.dumps(response) return json.dumps(response)
@amzn_request_id
def create_transform_job(self) -> TYPE_RESPONSE:
transform_job = self.sagemaker_backend.create_transform_job(
transform_job_name=self._get_param("TransformJobName"),
model_name=self._get_param("ModelName"),
max_concurrent_transforms=self._get_param("MaxConcurrentTransforms"),
model_client_config=self._get_param("ModelClientConfig"),
max_payload_in_mb=self._get_param("MaxPayloadInMB"),
batch_strategy=self._get_param("BatchStrategy"),
environment=self._get_param("Environment"),
transform_input=self._get_param("TransformInput"),
transform_output=self._get_param("TransformOutput"),
data_capture_config=self._get_param("DataCaptureConfig"),
transform_resources=self._get_param("TransformResources"),
data_processing=self._get_param("DataProcessing"),
tags=self._get_param("Tags"),
experiment_config=self._get_param("ExperimentConfig"),
)
response = {
"TransformJobArn": transform_job.transform_job_arn,
}
return 200, {}, json.dumps(response)
@amzn_request_id
def describe_transform_job(self) -> str:
transform_job_name = self._get_param("TransformJobName")
response = self.sagemaker_backend.describe_transform_job(
transform_job_name=transform_job_name
)
return json.dumps(response)
@amzn_request_id @amzn_request_id
def create_training_job(self) -> TYPE_RESPONSE: def create_training_job(self) -> TYPE_RESPONSE:
training_job = self.sagemaker_backend.create_training_job( training_job = self.sagemaker_backend.create_training_job(
@ -651,6 +682,59 @@ class SageMakerResponse(BaseResponse):
) )
return 200, {}, json.dumps(response) return 200, {}, json.dumps(response)
@amzn_request_id
def list_transform_jobs(self) -> TYPE_RESPONSE:
max_results_range = range(1, 101)
allowed_sort_by = ["Name", "CreationTime", "Status"]
allowed_sort_order = ["Ascending", "Descending"]
allowed_status_equals = [
"Completed",
"Stopped",
"InProgress",
"Stopping",
"Failed",
]
max_results = self._get_int_param("MaxResults")
sort_by = self._get_param("SortBy", "CreationTime")
sort_order = self._get_param("SortOrder", "Ascending")
status_equals = self._get_param("StatusEquals")
next_token = self._get_param("NextToken")
errors = []
if max_results and max_results not in max_results_range:
errors.append(
f"Value '{max_results}' at 'maxResults' failed to satisfy constraint: Member must have value less than or equal to {max_results_range[-1]}"
)
if sort_by not in allowed_sort_by:
errors.append(format_enum_error(sort_by, "sortBy", allowed_sort_by))
if sort_order not in allowed_sort_order:
errors.append(
format_enum_error(sort_order, "sortOrder", allowed_sort_order)
)
if status_equals and status_equals not in allowed_status_equals:
errors.append(
format_enum_error(status_equals, "statusEquals", allowed_status_equals)
)
if errors != []:
raise AWSValidationException(
f"{len(errors)} validation errors detected: {';'.join(errors)}"
)
response = self.sagemaker_backend.list_transform_jobs(
next_token=next_token,
max_results=max_results,
creation_time_after=self._get_param("CreationTimeAfter"),
creation_time_before=self._get_param("CreationTimeBefore"),
last_modified_time_after=self._get_param("LastModifiedTimeAfter"),
last_modified_time_before=self._get_param("LastModifiedTimeBefore"),
name_contains=self._get_param("NameContains"),
status_equals=status_equals,
)
return 200, {}, json.dumps(response)
@amzn_request_id @amzn_request_id
def list_training_jobs(self) -> TYPE_RESPONSE: def list_training_jobs(self) -> TYPE_RESPONSE:
max_results_range = range(1, 101) max_results_range = range(1, 101)

View File

@ -0,0 +1,407 @@
import boto3
from botocore.exceptions import ClientError
import datetime
import sure # noqa # pylint: disable=unused-import
import pytest
from moto import mock_sagemaker
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
FAKE_ROLE_ARN = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole"
TEST_REGION_NAME = "us-east-1"
class MyTransformJobModel(object):
def __init__(
self,
transform_job_name,
model_name,
max_concurrent_transforms=None,
model_client_config=None,
max_payload_in_mb=None,
batch_strategy=None,
environment=None,
transform_input=None,
transform_output=None,
data_capture_config=None,
transform_resources=None,
data_processing=None,
tags=None,
experiment_config=None,
):
self.transform_job_name = transform_job_name
self.model_name = model_name
self.max_concurrent_transforms = max_concurrent_transforms or 1
self.model_client_config = model_client_config or {}
self.max_payload_in_mb = max_payload_in_mb or 1
self.batch_strategy = batch_strategy or "SingleRecord"
self.environment = environment or {}
self.transform_input = transform_input or {
"DataSource": {
"S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": "input"}
},
"ContentType": "application/json",
"CompressionType": "None",
"SplitType": "None",
}
self.transform_output = transform_output or {
"S3OutputPath": "some-bucket",
"Accept": "application/json",
"AssembleWith": "None",
"KmsKeyId": "None",
}
self.data_capture_config = data_capture_config or {
"DestinationS3Uri": "data_capture",
"KmsKeyId": "None",
"GenerateInferenceId": False,
}
self.transform_resources = transform_resources or {
"InstanceType": "ml.m5.2xlarge",
"InstanceCount": 1,
}
self.data_processing = data_processing or {}
self.tags = tags or []
self.experiment_config = experiment_config or {}
def save(self):
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
params = {
"TransformJobName": self.transform_job_name,
"ModelName": self.model_name,
"MaxConcurrentTransforms": self.max_concurrent_transforms,
"ModelClientConfig": self.model_client_config,
"MaxPayloadInMB": self.max_payload_in_mb,
"BatchStrategy": self.batch_strategy,
"Environment": self.environment,
"TransformInput": self.transform_input,
"TransformOutput": self.transform_output,
"DataCaptureConfig": self.data_capture_config,
"TransformResources": self.transform_resources,
"DataProcessing": self.data_processing,
"Tags": self.tags,
"ExperimentConfig": self.experiment_config,
}
return sagemaker.create_transform_job(**params)
@mock_sagemaker
def test_create_transform_job():
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
transform_job_name = "MyTransformJob"
model_name = "MyModelName"
bucket = "my-bucket"
transform_input = {
"DataSource": {"S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": "input"}},
"ContentType": "application/json",
"CompressionType": "None",
"SplitType": "None",
}
transform_output = {
"S3OutputPath": bucket,
"Accept": "application/json",
"AssembleWith": "None",
"KmsKeyId": "None",
}
model_client_config = {
"InvocationsTimeoutInSeconds": 60,
"InvocationsMaxRetries": 1,
}
max_payload_in_mb = 1
data_capture_config = {
"DestinationS3Uri": "data_capture",
"KmsKeyId": "None",
"GenerateInferenceId": False,
}
transform_resources = {
"InstanceType": "ml.m5.2xlarge",
"InstanceCount": 1,
}
data_processing = {
"InputFilter": "$.features",
"OutputFilter": "$['id','SageMakerOutput']",
"JoinSource": "None",
}
experiment_config = {
"ExperimentName": "MyExperiment",
"TrialName": "MyTrial",
"TrialComponentDisplayName": "MyTrialDisplay",
"RunName": "MyRun",
}
job = MyTransformJobModel(
transform_job_name=transform_job_name,
model_name=model_name,
transform_output=transform_output,
model_client_config=model_client_config,
max_payload_in_mb=max_payload_in_mb,
data_capture_config=data_capture_config,
transform_resources=transform_resources,
data_processing=data_processing,
experiment_config=experiment_config,
)
resp = job.save()
resp["TransformJobArn"].should.match(
rf"^arn:aws:sagemaker:.*:.*:transform-job/{transform_job_name}$"
)
resp = sagemaker.describe_transform_job(TransformJobName=transform_job_name)
resp["TransformJobName"].should.equal(transform_job_name)
resp["TransformJobStatus"].should.equal("Completed")
resp["ModelName"].should.equal(model_name)
resp["MaxConcurrentTransforms"].should.equal(1)
resp["ModelClientConfig"].should.equal(model_client_config)
resp["MaxPayloadInMB"].should.equal(max_payload_in_mb)
resp["BatchStrategy"].should.equal("SingleRecord")
resp["TransformInput"].should.equal(transform_input)
resp["TransformOutput"].should.equal(transform_output)
resp["DataCaptureConfig"].should.equal(data_capture_config)
resp["TransformResources"].should.equal(transform_resources)
resp["DataProcessing"].should.equal(data_processing)
resp["ExperimentConfig"].should.equal(experiment_config)
assert isinstance(resp["CreationTime"], datetime.datetime)
assert isinstance(resp["TransformStartTime"], datetime.datetime)
assert isinstance(resp["TransformEndTime"], datetime.datetime)
@mock_sagemaker
def test_list_transform_jobs():
client = boto3.client("sagemaker", region_name="us-east-1")
name = "blah"
model_name = "blah_model"
test_transform_job = MyTransformJobModel(
transform_job_name=name, model_name=model_name
)
test_transform_job.save()
transform_jobs = client.list_transform_jobs()
assert len(transform_jobs["TransformJobSummaries"]).should.equal(1)
assert transform_jobs["TransformJobSummaries"][0]["TransformJobName"].should.equal(
name
)
assert transform_jobs["TransformJobSummaries"][0]["TransformJobArn"].should.match(
rf"^arn:aws:sagemaker:.*:.*:transform-job/{name}$"
)
assert transform_jobs.get("NextToken") is None
@mock_sagemaker
def test_list_transform_jobs_multiple():
client = boto3.client("sagemaker", region_name="us-east-1")
name_job_1 = "blah"
model_name1 = "blah_model"
test_transform_job_1 = MyTransformJobModel(
transform_job_name=name_job_1, model_name=model_name1
)
test_transform_job_1.save()
name_job_2 = "blah2"
model_name2 = "blah_model2"
test_transform_job_2 = MyTransformJobModel(
transform_job_name=name_job_2, model_name=model_name2
)
test_transform_job_2.save()
transform_jobs_limit = client.list_transform_jobs(MaxResults=1)
assert len(transform_jobs_limit["TransformJobSummaries"]).should.equal(1)
transform_jobs = client.list_transform_jobs()
assert len(transform_jobs["TransformJobSummaries"]).should.equal(2)
assert transform_jobs.get("NextToken").should.be.none
@mock_sagemaker
def test_list_transform_jobs_none():
client = boto3.client("sagemaker", region_name="us-east-1")
transform_jobs = client.list_transform_jobs()
assert len(transform_jobs["TransformJobSummaries"]).should.equal(0)
@mock_sagemaker
def test_list_transform_jobs_should_validate_input():
client = boto3.client("sagemaker", region_name="us-east-1")
junk_status_equals = "blah"
with pytest.raises(ClientError) as ex:
client.list_transform_jobs(StatusEquals=junk_status_equals)
expected_error = f"1 validation errors detected: Value '{junk_status_equals}' at 'statusEquals' failed to satisfy constraint: Member must satisfy enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', 'Failed']"
assert ex.value.response["Error"]["Code"] == "ValidationException"
assert ex.value.response["Error"]["Message"] == expected_error
junk_next_token = "asdf"
with pytest.raises(ClientError) as ex:
client.list_transform_jobs(NextToken=junk_next_token)
assert ex.value.response["Error"]["Code"] == "ValidationException"
assert (
ex.value.response["Error"]["Message"]
== 'Invalid pagination token because "{0}".'
)
@mock_sagemaker
def test_list_transform_jobs_with_name_filters():
client = boto3.client("sagemaker", region_name="us-east-1")
for i in range(5):
name = f"xgboost-{i}"
model_name = f"blah_model-{i}"
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
for i in range(5):
name = f"vgg-{i}"
model_name = f"blah_model-{i}"
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
xgboost_transform_jobs = client.list_transform_jobs(NameContains="xgboost")
assert len(xgboost_transform_jobs["TransformJobSummaries"]).should.equal(5)
transform_jobs_with_2 = client.list_transform_jobs(NameContains="2")
assert len(transform_jobs_with_2["TransformJobSummaries"]).should.equal(2)
@mock_sagemaker
def test_list_transform_jobs_paginated():
client = boto3.client("sagemaker", region_name="us-east-1")
for i in range(5):
name = f"xgboost-{i}"
model_name = f"my-model-{i}"
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
xgboost_transform_job_1 = client.list_transform_jobs(
NameContains="xgboost", MaxResults=1
)
assert len(xgboost_transform_job_1["TransformJobSummaries"]).should.equal(1)
assert xgboost_transform_job_1["TransformJobSummaries"][0][
"TransformJobName"
].should.equal("xgboost-0")
assert xgboost_transform_job_1.get("NextToken").should_not.be.none
xgboost_transform_job_next = client.list_transform_jobs(
NameContains="xgboost",
MaxResults=1,
NextToken=xgboost_transform_job_1.get("NextToken"),
)
assert len(xgboost_transform_job_next["TransformJobSummaries"]).should.equal(1)
assert xgboost_transform_job_next["TransformJobSummaries"][0][
"TransformJobName"
].should.equal("xgboost-1")
assert xgboost_transform_job_next.get("NextToken").should_not.be.none
@mock_sagemaker
def test_list_transform_jobs_paginated_with_target_in_middle():
client = boto3.client("sagemaker", region_name="us-east-1")
for i in range(5):
name = f"xgboost-{i}"
model_name = f"my-model-{i}"
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
for i in range(5):
name = f"vgg-{i}"
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
vgg_transform_job_1 = client.list_transform_jobs(NameContains="vgg", MaxResults=1)
assert len(vgg_transform_job_1["TransformJobSummaries"]).should.equal(0)
assert vgg_transform_job_1.get("NextToken").should_not.be.none
vgg_transform_job_6 = client.list_transform_jobs(NameContains="vgg", MaxResults=6)
assert len(vgg_transform_job_6["TransformJobSummaries"]).should.equal(1)
assert vgg_transform_job_6["TransformJobSummaries"][0][
"TransformJobName"
].should.equal("vgg-0")
assert vgg_transform_job_6.get("NextToken").should_not.be.none
vgg_transform_job_10 = client.list_transform_jobs(NameContains="vgg", MaxResults=10)
assert len(vgg_transform_job_10["TransformJobSummaries"]).should.equal(5)
assert vgg_transform_job_10["TransformJobSummaries"][-1][
"TransformJobName"
].should.equal("vgg-4")
assert vgg_transform_job_10.get("NextToken").should.be.none
@mock_sagemaker
def test_list_transform_jobs_paginated_with_fragmented_targets():
client = boto3.client("sagemaker", region_name="us-east-1")
for i in range(5):
name = f"xgboost-{i}"
model_name = f"my-model-{i}"
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
for i in range(5):
name = f"vgg-{i}"
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
transform_jobs_with_2 = client.list_transform_jobs(NameContains="2", MaxResults=8)
assert len(transform_jobs_with_2["TransformJobSummaries"]).should.equal(2)
assert transform_jobs_with_2.get("NextToken").should_not.be.none
transform_jobs_with_2_next = client.list_transform_jobs(
NameContains="2", MaxResults=1, NextToken=transform_jobs_with_2.get("NextToken")
)
assert len(transform_jobs_with_2_next["TransformJobSummaries"]).should.equal(0)
assert transform_jobs_with_2_next.get("NextToken").should_not.be.none
transform_jobs_with_2_next_next = client.list_transform_jobs(
NameContains="2",
MaxResults=1,
NextToken=transform_jobs_with_2_next.get("NextToken"),
)
assert len(transform_jobs_with_2_next_next["TransformJobSummaries"]).should.equal(0)
assert transform_jobs_with_2_next_next.get("NextToken").should.be.none
@mock_sagemaker
def test_add_tags_to_transform_job():
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
name = "blah"
model_name = "my-model"
resource_arn = "arn:aws:sagemaker:us-east-1:123456789012:transform-job/blah"
test_transform_job = MyTransformJobModel(
transform_job_name=name, model_name=model_name
)
test_transform_job.save()
tags = [
{"Key": "myKey", "Value": "myValue"},
]
response = client.add_tags(ResourceArn=resource_arn, Tags=tags)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
response = client.list_tags(ResourceArn=resource_arn)
assert response["Tags"] == tags
@mock_sagemaker
def test_delete_tags_from_transform_job():
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
name = "blah"
model_name = "my-model"
resource_arn = "arn:aws:sagemaker:us-east-1:123456789012:transform-job/blah"
test_transform_job = MyTransformJobModel(
transform_job_name=name, model_name=model_name
)
test_transform_job.save()
tags = [
{"Key": "myKey", "Value": "myValue"},
]
response = client.add_tags(ResourceArn=resource_arn, Tags=tags)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
tag_keys = [tag["Key"] for tag in tags]
response = client.delete_tags(ResourceArn=resource_arn, TagKeys=tag_keys)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
response = client.list_tags(ResourceArn=resource_arn)
assert response["Tags"] == []
@mock_sagemaker
def test_describe_unknown_transform_job():
client = boto3.client("sagemaker", region_name="us-east-1")
with pytest.raises(ClientError) as exc:
client.describe_transform_job(TransformJobName="unknown")
err = exc.value.response["Error"]
err["Code"].should.equal("ValidationException")
err["Message"].should.equal(
f"Could not find transform job 'arn:aws:sagemaker:us-east-1:{ACCOUNT_ID}:transform-job/unknown'."
)