2023-08-08 10:06:51 +00:00
|
|
|
import datetime
|
|
|
|
import re
|
|
|
|
|
2023-05-10 17:54:49 +00:00
|
|
|
import boto3
|
|
|
|
from botocore.exceptions import ClientError
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
from moto import mock_sagemaker
|
|
|
|
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
|
|
|
|
|
|
|
|
FAKE_ROLE_ARN = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole"
|
|
|
|
TEST_REGION_NAME = "us-east-1"
|
|
|
|
|
|
|
|
|
2023-08-08 10:06:51 +00:00
|
|
|
class MyTransformJobModel:
|
2023-05-10 17:54:49 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
transform_job_name,
|
|
|
|
model_name,
|
|
|
|
max_concurrent_transforms=None,
|
|
|
|
model_client_config=None,
|
|
|
|
max_payload_in_mb=None,
|
|
|
|
batch_strategy=None,
|
|
|
|
environment=None,
|
|
|
|
transform_input=None,
|
|
|
|
transform_output=None,
|
|
|
|
data_capture_config=None,
|
|
|
|
transform_resources=None,
|
|
|
|
data_processing=None,
|
|
|
|
tags=None,
|
|
|
|
experiment_config=None,
|
|
|
|
):
|
|
|
|
self.transform_job_name = transform_job_name
|
|
|
|
self.model_name = model_name
|
|
|
|
self.max_concurrent_transforms = max_concurrent_transforms or 1
|
|
|
|
self.model_client_config = model_client_config or {}
|
|
|
|
self.max_payload_in_mb = max_payload_in_mb or 1
|
|
|
|
self.batch_strategy = batch_strategy or "SingleRecord"
|
|
|
|
self.environment = environment or {}
|
|
|
|
self.transform_input = transform_input or {
|
|
|
|
"DataSource": {
|
|
|
|
"S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": "input"}
|
|
|
|
},
|
|
|
|
"ContentType": "application/json",
|
|
|
|
"CompressionType": "None",
|
|
|
|
"SplitType": "None",
|
|
|
|
}
|
|
|
|
self.transform_output = transform_output or {
|
|
|
|
"S3OutputPath": "some-bucket",
|
|
|
|
"Accept": "application/json",
|
|
|
|
"AssembleWith": "None",
|
|
|
|
"KmsKeyId": "None",
|
|
|
|
}
|
|
|
|
self.data_capture_config = data_capture_config or {
|
|
|
|
"DestinationS3Uri": "data_capture",
|
|
|
|
"KmsKeyId": "None",
|
|
|
|
"GenerateInferenceId": False,
|
|
|
|
}
|
|
|
|
self.transform_resources = transform_resources or {
|
|
|
|
"InstanceType": "ml.m5.2xlarge",
|
|
|
|
"InstanceCount": 1,
|
|
|
|
}
|
|
|
|
self.data_processing = data_processing or {}
|
|
|
|
self.tags = tags or []
|
|
|
|
self.experiment_config = experiment_config or {}
|
|
|
|
|
|
|
|
def save(self):
|
|
|
|
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
|
|
|
|
|
|
|
params = {
|
|
|
|
"TransformJobName": self.transform_job_name,
|
|
|
|
"ModelName": self.model_name,
|
|
|
|
"MaxConcurrentTransforms": self.max_concurrent_transforms,
|
|
|
|
"ModelClientConfig": self.model_client_config,
|
|
|
|
"MaxPayloadInMB": self.max_payload_in_mb,
|
|
|
|
"BatchStrategy": self.batch_strategy,
|
|
|
|
"Environment": self.environment,
|
|
|
|
"TransformInput": self.transform_input,
|
|
|
|
"TransformOutput": self.transform_output,
|
|
|
|
"DataCaptureConfig": self.data_capture_config,
|
|
|
|
"TransformResources": self.transform_resources,
|
|
|
|
"DataProcessing": self.data_processing,
|
|
|
|
"Tags": self.tags,
|
|
|
|
"ExperimentConfig": self.experiment_config,
|
|
|
|
}
|
|
|
|
return sagemaker.create_transform_job(**params)
|
|
|
|
|
|
|
|
|
|
|
|
@mock_sagemaker
|
|
|
|
def test_create_transform_job():
|
|
|
|
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
|
|
|
transform_job_name = "MyTransformJob"
|
|
|
|
model_name = "MyModelName"
|
|
|
|
bucket = "my-bucket"
|
|
|
|
transform_input = {
|
|
|
|
"DataSource": {"S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": "input"}},
|
|
|
|
"ContentType": "application/json",
|
|
|
|
"CompressionType": "None",
|
|
|
|
"SplitType": "None",
|
|
|
|
}
|
|
|
|
transform_output = {
|
|
|
|
"S3OutputPath": bucket,
|
|
|
|
"Accept": "application/json",
|
|
|
|
"AssembleWith": "None",
|
|
|
|
"KmsKeyId": "None",
|
|
|
|
}
|
|
|
|
|
|
|
|
model_client_config = {
|
|
|
|
"InvocationsTimeoutInSeconds": 60,
|
|
|
|
"InvocationsMaxRetries": 1,
|
|
|
|
}
|
|
|
|
|
|
|
|
max_payload_in_mb = 1
|
|
|
|
|
|
|
|
data_capture_config = {
|
|
|
|
"DestinationS3Uri": "data_capture",
|
|
|
|
"KmsKeyId": "None",
|
|
|
|
"GenerateInferenceId": False,
|
|
|
|
}
|
|
|
|
|
|
|
|
transform_resources = {
|
|
|
|
"InstanceType": "ml.m5.2xlarge",
|
|
|
|
"InstanceCount": 1,
|
|
|
|
}
|
|
|
|
|
|
|
|
data_processing = {
|
|
|
|
"InputFilter": "$.features",
|
|
|
|
"OutputFilter": "$['id','SageMakerOutput']",
|
|
|
|
"JoinSource": "None",
|
|
|
|
}
|
|
|
|
|
|
|
|
experiment_config = {
|
|
|
|
"ExperimentName": "MyExperiment",
|
|
|
|
"TrialName": "MyTrial",
|
|
|
|
"TrialComponentDisplayName": "MyTrialDisplay",
|
|
|
|
"RunName": "MyRun",
|
|
|
|
}
|
|
|
|
|
|
|
|
job = MyTransformJobModel(
|
|
|
|
transform_job_name=transform_job_name,
|
|
|
|
model_name=model_name,
|
|
|
|
transform_output=transform_output,
|
|
|
|
model_client_config=model_client_config,
|
|
|
|
max_payload_in_mb=max_payload_in_mb,
|
|
|
|
data_capture_config=data_capture_config,
|
|
|
|
transform_resources=transform_resources,
|
|
|
|
data_processing=data_processing,
|
|
|
|
experiment_config=experiment_config,
|
|
|
|
)
|
|
|
|
resp = job.save()
|
2023-08-08 10:06:51 +00:00
|
|
|
assert re.match(
|
|
|
|
rf"^arn:aws:sagemaker:.*:.*:transform-job/{transform_job_name}$",
|
|
|
|
resp["TransformJobArn"],
|
2023-05-10 17:54:49 +00:00
|
|
|
)
|
|
|
|
resp = sagemaker.describe_transform_job(TransformJobName=transform_job_name)
|
2023-08-08 10:06:51 +00:00
|
|
|
assert resp["TransformJobName"] == transform_job_name
|
|
|
|
assert resp["TransformJobStatus"] == "Completed"
|
|
|
|
assert resp["ModelName"] == model_name
|
|
|
|
assert resp["MaxConcurrentTransforms"] == 1
|
|
|
|
assert resp["ModelClientConfig"] == model_client_config
|
|
|
|
assert resp["MaxPayloadInMB"] == max_payload_in_mb
|
|
|
|
assert resp["BatchStrategy"] == "SingleRecord"
|
|
|
|
assert resp["TransformInput"] == transform_input
|
|
|
|
assert resp["TransformOutput"] == transform_output
|
|
|
|
assert resp["DataCaptureConfig"] == data_capture_config
|
|
|
|
assert resp["TransformResources"] == transform_resources
|
|
|
|
assert resp["DataProcessing"] == data_processing
|
|
|
|
assert resp["ExperimentConfig"] == experiment_config
|
2023-05-10 17:54:49 +00:00
|
|
|
assert isinstance(resp["CreationTime"], datetime.datetime)
|
|
|
|
assert isinstance(resp["TransformStartTime"], datetime.datetime)
|
|
|
|
assert isinstance(resp["TransformEndTime"], datetime.datetime)
|
|
|
|
|
|
|
|
|
|
|
|
@mock_sagemaker
|
|
|
|
def test_list_transform_jobs():
|
|
|
|
client = boto3.client("sagemaker", region_name="us-east-1")
|
|
|
|
name = "blah"
|
|
|
|
model_name = "blah_model"
|
|
|
|
test_transform_job = MyTransformJobModel(
|
|
|
|
transform_job_name=name, model_name=model_name
|
|
|
|
)
|
|
|
|
test_transform_job.save()
|
|
|
|
transform_jobs = client.list_transform_jobs()
|
2023-08-08 10:06:51 +00:00
|
|
|
assert len(transform_jobs["TransformJobSummaries"]) == 1
|
|
|
|
assert transform_jobs["TransformJobSummaries"][0]["TransformJobName"] == name
|
2023-05-10 17:54:49 +00:00
|
|
|
|
2023-08-08 10:06:51 +00:00
|
|
|
assert re.match(
|
|
|
|
rf"^arn:aws:sagemaker:.*:.*:transform-job/{name}$",
|
|
|
|
transform_jobs["TransformJobSummaries"][0]["TransformJobArn"],
|
2023-05-10 17:54:49 +00:00
|
|
|
)
|
|
|
|
assert transform_jobs.get("NextToken") is None
|
|
|
|
|
|
|
|
|
|
|
|
@mock_sagemaker
|
|
|
|
def test_list_transform_jobs_multiple():
|
|
|
|
client = boto3.client("sagemaker", region_name="us-east-1")
|
|
|
|
name_job_1 = "blah"
|
|
|
|
model_name1 = "blah_model"
|
|
|
|
test_transform_job_1 = MyTransformJobModel(
|
|
|
|
transform_job_name=name_job_1, model_name=model_name1
|
|
|
|
)
|
|
|
|
test_transform_job_1.save()
|
|
|
|
|
|
|
|
name_job_2 = "blah2"
|
|
|
|
model_name2 = "blah_model2"
|
|
|
|
test_transform_job_2 = MyTransformJobModel(
|
|
|
|
transform_job_name=name_job_2, model_name=model_name2
|
|
|
|
)
|
|
|
|
test_transform_job_2.save()
|
|
|
|
transform_jobs_limit = client.list_transform_jobs(MaxResults=1)
|
2023-08-08 10:06:51 +00:00
|
|
|
assert len(transform_jobs_limit["TransformJobSummaries"]) == 1
|
2023-05-10 17:54:49 +00:00
|
|
|
|
|
|
|
transform_jobs = client.list_transform_jobs()
|
2023-08-08 10:06:51 +00:00
|
|
|
assert len(transform_jobs["TransformJobSummaries"]) == 2
|
|
|
|
assert transform_jobs.get("NextToken") is None
|
2023-05-10 17:54:49 +00:00
|
|
|
|
|
|
|
|
|
|
|
@mock_sagemaker
|
|
|
|
def test_list_transform_jobs_none():
|
|
|
|
client = boto3.client("sagemaker", region_name="us-east-1")
|
|
|
|
transform_jobs = client.list_transform_jobs()
|
2023-08-08 10:06:51 +00:00
|
|
|
assert len(transform_jobs["TransformJobSummaries"]) == 0
|
2023-05-10 17:54:49 +00:00
|
|
|
|
|
|
|
|
|
|
|
@mock_sagemaker
|
|
|
|
def test_list_transform_jobs_should_validate_input():
|
|
|
|
client = boto3.client("sagemaker", region_name="us-east-1")
|
|
|
|
junk_status_equals = "blah"
|
|
|
|
with pytest.raises(ClientError) as ex:
|
|
|
|
client.list_transform_jobs(StatusEquals=junk_status_equals)
|
2023-08-08 10:06:51 +00:00
|
|
|
expected_error = (
|
|
|
|
f"1 validation errors detected: Value '{junk_status_equals}' at "
|
|
|
|
"'statusEquals' failed to satisfy constraint: Member must satisfy "
|
|
|
|
"enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', "
|
|
|
|
"'Failed']"
|
|
|
|
)
|
2023-05-10 17:54:49 +00:00
|
|
|
assert ex.value.response["Error"]["Code"] == "ValidationException"
|
|
|
|
assert ex.value.response["Error"]["Message"] == expected_error
|
|
|
|
|
|
|
|
junk_next_token = "asdf"
|
|
|
|
with pytest.raises(ClientError) as ex:
|
|
|
|
client.list_transform_jobs(NextToken=junk_next_token)
|
|
|
|
assert ex.value.response["Error"]["Code"] == "ValidationException"
|
|
|
|
assert (
|
|
|
|
ex.value.response["Error"]["Message"]
|
|
|
|
== 'Invalid pagination token because "{0}".'
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@mock_sagemaker
|
|
|
|
def test_list_transform_jobs_with_name_filters():
|
|
|
|
client = boto3.client("sagemaker", region_name="us-east-1")
|
|
|
|
for i in range(5):
|
|
|
|
name = f"xgboost-{i}"
|
|
|
|
model_name = f"blah_model-{i}"
|
|
|
|
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
|
|
|
|
for i in range(5):
|
|
|
|
name = f"vgg-{i}"
|
|
|
|
model_name = f"blah_model-{i}"
|
|
|
|
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
|
|
|
|
xgboost_transform_jobs = client.list_transform_jobs(NameContains="xgboost")
|
2023-08-08 10:06:51 +00:00
|
|
|
assert len(xgboost_transform_jobs["TransformJobSummaries"]) == 5
|
2023-05-10 17:54:49 +00:00
|
|
|
|
|
|
|
transform_jobs_with_2 = client.list_transform_jobs(NameContains="2")
|
2023-08-08 10:06:51 +00:00
|
|
|
assert len(transform_jobs_with_2["TransformJobSummaries"]) == 2
|
2023-05-10 17:54:49 +00:00
|
|
|
|
|
|
|
|
|
|
|
@mock_sagemaker
|
|
|
|
def test_list_transform_jobs_paginated():
|
|
|
|
client = boto3.client("sagemaker", region_name="us-east-1")
|
|
|
|
for i in range(5):
|
|
|
|
name = f"xgboost-{i}"
|
|
|
|
model_name = f"my-model-{i}"
|
|
|
|
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
|
|
|
|
xgboost_transform_job_1 = client.list_transform_jobs(
|
|
|
|
NameContains="xgboost", MaxResults=1
|
|
|
|
)
|
2023-08-08 10:06:51 +00:00
|
|
|
assert len(xgboost_transform_job_1["TransformJobSummaries"]) == 1
|
|
|
|
assert (
|
|
|
|
xgboost_transform_job_1["TransformJobSummaries"][0]["TransformJobName"]
|
|
|
|
== "xgboost-0"
|
|
|
|
)
|
|
|
|
assert xgboost_transform_job_1.get("NextToken") is not None
|
2023-05-10 17:54:49 +00:00
|
|
|
|
|
|
|
xgboost_transform_job_next = client.list_transform_jobs(
|
|
|
|
NameContains="xgboost",
|
|
|
|
MaxResults=1,
|
|
|
|
NextToken=xgboost_transform_job_1.get("NextToken"),
|
|
|
|
)
|
2023-08-08 10:06:51 +00:00
|
|
|
assert len(xgboost_transform_job_next["TransformJobSummaries"]) == 1
|
|
|
|
assert (
|
|
|
|
xgboost_transform_job_next["TransformJobSummaries"][0]["TransformJobName"]
|
|
|
|
== "xgboost-1"
|
|
|
|
)
|
|
|
|
assert xgboost_transform_job_next.get("NextToken") is not None
|
2023-05-10 17:54:49 +00:00
|
|
|
|
|
|
|
|
|
|
|
@mock_sagemaker
|
|
|
|
def test_list_transform_jobs_paginated_with_target_in_middle():
|
|
|
|
client = boto3.client("sagemaker", region_name="us-east-1")
|
|
|
|
for i in range(5):
|
|
|
|
name = f"xgboost-{i}"
|
|
|
|
model_name = f"my-model-{i}"
|
|
|
|
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
|
|
|
|
for i in range(5):
|
|
|
|
name = f"vgg-{i}"
|
|
|
|
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
|
|
|
|
|
|
|
|
vgg_transform_job_1 = client.list_transform_jobs(NameContains="vgg", MaxResults=1)
|
2023-08-08 10:06:51 +00:00
|
|
|
assert len(vgg_transform_job_1["TransformJobSummaries"]) == 0
|
|
|
|
assert vgg_transform_job_1.get("NextToken") is not None
|
2023-05-10 17:54:49 +00:00
|
|
|
|
|
|
|
vgg_transform_job_6 = client.list_transform_jobs(NameContains="vgg", MaxResults=6)
|
|
|
|
|
2023-08-08 10:06:51 +00:00
|
|
|
assert len(vgg_transform_job_6["TransformJobSummaries"]) == 1
|
|
|
|
assert (
|
|
|
|
vgg_transform_job_6["TransformJobSummaries"][0]["TransformJobName"] == "vgg-0"
|
|
|
|
)
|
|
|
|
assert vgg_transform_job_6.get("NextToken") is not None
|
2023-05-10 17:54:49 +00:00
|
|
|
|
|
|
|
vgg_transform_job_10 = client.list_transform_jobs(NameContains="vgg", MaxResults=10)
|
|
|
|
|
2023-08-08 10:06:51 +00:00
|
|
|
assert len(vgg_transform_job_10["TransformJobSummaries"]) == 5
|
|
|
|
assert (
|
|
|
|
vgg_transform_job_10["TransformJobSummaries"][-1]["TransformJobName"] == "vgg-4"
|
|
|
|
)
|
|
|
|
assert vgg_transform_job_10.get("NextToken") is None
|
2023-05-10 17:54:49 +00:00
|
|
|
|
|
|
|
|
|
|
|
@mock_sagemaker
|
|
|
|
def test_list_transform_jobs_paginated_with_fragmented_targets():
|
|
|
|
client = boto3.client("sagemaker", region_name="us-east-1")
|
|
|
|
for i in range(5):
|
|
|
|
name = f"xgboost-{i}"
|
|
|
|
model_name = f"my-model-{i}"
|
|
|
|
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
|
|
|
|
for i in range(5):
|
|
|
|
name = f"vgg-{i}"
|
|
|
|
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
|
|
|
|
|
|
|
|
transform_jobs_with_2 = client.list_transform_jobs(NameContains="2", MaxResults=8)
|
2023-08-08 10:06:51 +00:00
|
|
|
assert len(transform_jobs_with_2["TransformJobSummaries"]) == 2
|
|
|
|
assert transform_jobs_with_2.get("NextToken") is not None
|
2023-05-10 17:54:49 +00:00
|
|
|
|
|
|
|
transform_jobs_with_2_next = client.list_transform_jobs(
|
|
|
|
NameContains="2", MaxResults=1, NextToken=transform_jobs_with_2.get("NextToken")
|
|
|
|
)
|
2023-08-08 10:06:51 +00:00
|
|
|
assert len(transform_jobs_with_2_next["TransformJobSummaries"]) == 0
|
|
|
|
assert transform_jobs_with_2_next.get("NextToken") is not None
|
2023-05-10 17:54:49 +00:00
|
|
|
|
|
|
|
transform_jobs_with_2_next_next = client.list_transform_jobs(
|
|
|
|
NameContains="2",
|
|
|
|
MaxResults=1,
|
|
|
|
NextToken=transform_jobs_with_2_next.get("NextToken"),
|
|
|
|
)
|
2023-08-08 10:06:51 +00:00
|
|
|
assert len(transform_jobs_with_2_next_next["TransformJobSummaries"]) == 0
|
|
|
|
assert transform_jobs_with_2_next_next.get("NextToken") is None
|
2023-05-10 17:54:49 +00:00
|
|
|
|
|
|
|
|
|
|
|
@mock_sagemaker
|
|
|
|
def test_add_tags_to_transform_job():
|
|
|
|
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
|
|
|
name = "blah"
|
|
|
|
model_name = "my-model"
|
|
|
|
resource_arn = "arn:aws:sagemaker:us-east-1:123456789012:transform-job/blah"
|
|
|
|
|
|
|
|
test_transform_job = MyTransformJobModel(
|
|
|
|
transform_job_name=name, model_name=model_name
|
|
|
|
)
|
|
|
|
test_transform_job.save()
|
|
|
|
tags = [
|
|
|
|
{"Key": "myKey", "Value": "myValue"},
|
|
|
|
]
|
|
|
|
response = client.add_tags(ResourceArn=resource_arn, Tags=tags)
|
|
|
|
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
|
|
|
|
|
|
|
|
response = client.list_tags(ResourceArn=resource_arn)
|
|
|
|
assert response["Tags"] == tags
|
|
|
|
|
|
|
|
|
|
|
|
@mock_sagemaker
|
|
|
|
def test_delete_tags_from_transform_job():
|
|
|
|
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
|
|
|
name = "blah"
|
|
|
|
model_name = "my-model"
|
|
|
|
resource_arn = "arn:aws:sagemaker:us-east-1:123456789012:transform-job/blah"
|
|
|
|
test_transform_job = MyTransformJobModel(
|
|
|
|
transform_job_name=name, model_name=model_name
|
|
|
|
)
|
|
|
|
test_transform_job.save()
|
|
|
|
|
|
|
|
tags = [
|
|
|
|
{"Key": "myKey", "Value": "myValue"},
|
|
|
|
]
|
|
|
|
response = client.add_tags(ResourceArn=resource_arn, Tags=tags)
|
|
|
|
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
|
|
|
|
|
|
|
|
tag_keys = [tag["Key"] for tag in tags]
|
|
|
|
response = client.delete_tags(ResourceArn=resource_arn, TagKeys=tag_keys)
|
|
|
|
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
|
|
|
|
|
|
|
|
response = client.list_tags(ResourceArn=resource_arn)
|
|
|
|
assert response["Tags"] == []
|
|
|
|
|
|
|
|
|
|
|
|
@mock_sagemaker
|
|
|
|
def test_describe_unknown_transform_job():
|
|
|
|
client = boto3.client("sagemaker", region_name="us-east-1")
|
|
|
|
with pytest.raises(ClientError) as exc:
|
|
|
|
client.describe_transform_job(TransformJobName="unknown")
|
|
|
|
err = exc.value.response["Error"]
|
2023-08-08 10:06:51 +00:00
|
|
|
assert err["Code"] == "ValidationException"
|
|
|
|
assert err["Message"] == (
|
|
|
|
"Could not find transform job 'arn:aws:sagemaker:us-east-1:"
|
|
|
|
f"{ACCOUNT_ID}:transform-job/unknown'."
|
2023-05-10 17:54:49 +00:00
|
|
|
)
|