408 lines
15 KiB
Python
408 lines
15 KiB
Python
|
import boto3
|
||
|
from botocore.exceptions import ClientError
|
||
|
import datetime
|
||
|
import sure # noqa # pylint: disable=unused-import
|
||
|
import pytest
|
||
|
|
||
|
from moto import mock_sagemaker
|
||
|
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
|
||
|
|
||
|
FAKE_ROLE_ARN = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole"
|
||
|
TEST_REGION_NAME = "us-east-1"
|
||
|
|
||
|
|
||
|
class MyTransformJobModel(object):
|
||
|
def __init__(
|
||
|
self,
|
||
|
transform_job_name,
|
||
|
model_name,
|
||
|
max_concurrent_transforms=None,
|
||
|
model_client_config=None,
|
||
|
max_payload_in_mb=None,
|
||
|
batch_strategy=None,
|
||
|
environment=None,
|
||
|
transform_input=None,
|
||
|
transform_output=None,
|
||
|
data_capture_config=None,
|
||
|
transform_resources=None,
|
||
|
data_processing=None,
|
||
|
tags=None,
|
||
|
experiment_config=None,
|
||
|
):
|
||
|
self.transform_job_name = transform_job_name
|
||
|
self.model_name = model_name
|
||
|
self.max_concurrent_transforms = max_concurrent_transforms or 1
|
||
|
self.model_client_config = model_client_config or {}
|
||
|
self.max_payload_in_mb = max_payload_in_mb or 1
|
||
|
self.batch_strategy = batch_strategy or "SingleRecord"
|
||
|
self.environment = environment or {}
|
||
|
self.transform_input = transform_input or {
|
||
|
"DataSource": {
|
||
|
"S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": "input"}
|
||
|
},
|
||
|
"ContentType": "application/json",
|
||
|
"CompressionType": "None",
|
||
|
"SplitType": "None",
|
||
|
}
|
||
|
self.transform_output = transform_output or {
|
||
|
"S3OutputPath": "some-bucket",
|
||
|
"Accept": "application/json",
|
||
|
"AssembleWith": "None",
|
||
|
"KmsKeyId": "None",
|
||
|
}
|
||
|
self.data_capture_config = data_capture_config or {
|
||
|
"DestinationS3Uri": "data_capture",
|
||
|
"KmsKeyId": "None",
|
||
|
"GenerateInferenceId": False,
|
||
|
}
|
||
|
self.transform_resources = transform_resources or {
|
||
|
"InstanceType": "ml.m5.2xlarge",
|
||
|
"InstanceCount": 1,
|
||
|
}
|
||
|
self.data_processing = data_processing or {}
|
||
|
self.tags = tags or []
|
||
|
self.experiment_config = experiment_config or {}
|
||
|
|
||
|
def save(self):
|
||
|
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||
|
|
||
|
params = {
|
||
|
"TransformJobName": self.transform_job_name,
|
||
|
"ModelName": self.model_name,
|
||
|
"MaxConcurrentTransforms": self.max_concurrent_transforms,
|
||
|
"ModelClientConfig": self.model_client_config,
|
||
|
"MaxPayloadInMB": self.max_payload_in_mb,
|
||
|
"BatchStrategy": self.batch_strategy,
|
||
|
"Environment": self.environment,
|
||
|
"TransformInput": self.transform_input,
|
||
|
"TransformOutput": self.transform_output,
|
||
|
"DataCaptureConfig": self.data_capture_config,
|
||
|
"TransformResources": self.transform_resources,
|
||
|
"DataProcessing": self.data_processing,
|
||
|
"Tags": self.tags,
|
||
|
"ExperimentConfig": self.experiment_config,
|
||
|
}
|
||
|
return sagemaker.create_transform_job(**params)
|
||
|
|
||
|
|
||
|
@mock_sagemaker
|
||
|
def test_create_transform_job():
|
||
|
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||
|
transform_job_name = "MyTransformJob"
|
||
|
model_name = "MyModelName"
|
||
|
bucket = "my-bucket"
|
||
|
transform_input = {
|
||
|
"DataSource": {"S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": "input"}},
|
||
|
"ContentType": "application/json",
|
||
|
"CompressionType": "None",
|
||
|
"SplitType": "None",
|
||
|
}
|
||
|
transform_output = {
|
||
|
"S3OutputPath": bucket,
|
||
|
"Accept": "application/json",
|
||
|
"AssembleWith": "None",
|
||
|
"KmsKeyId": "None",
|
||
|
}
|
||
|
|
||
|
model_client_config = {
|
||
|
"InvocationsTimeoutInSeconds": 60,
|
||
|
"InvocationsMaxRetries": 1,
|
||
|
}
|
||
|
|
||
|
max_payload_in_mb = 1
|
||
|
|
||
|
data_capture_config = {
|
||
|
"DestinationS3Uri": "data_capture",
|
||
|
"KmsKeyId": "None",
|
||
|
"GenerateInferenceId": False,
|
||
|
}
|
||
|
|
||
|
transform_resources = {
|
||
|
"InstanceType": "ml.m5.2xlarge",
|
||
|
"InstanceCount": 1,
|
||
|
}
|
||
|
|
||
|
data_processing = {
|
||
|
"InputFilter": "$.features",
|
||
|
"OutputFilter": "$['id','SageMakerOutput']",
|
||
|
"JoinSource": "None",
|
||
|
}
|
||
|
|
||
|
experiment_config = {
|
||
|
"ExperimentName": "MyExperiment",
|
||
|
"TrialName": "MyTrial",
|
||
|
"TrialComponentDisplayName": "MyTrialDisplay",
|
||
|
"RunName": "MyRun",
|
||
|
}
|
||
|
|
||
|
job = MyTransformJobModel(
|
||
|
transform_job_name=transform_job_name,
|
||
|
model_name=model_name,
|
||
|
transform_output=transform_output,
|
||
|
model_client_config=model_client_config,
|
||
|
max_payload_in_mb=max_payload_in_mb,
|
||
|
data_capture_config=data_capture_config,
|
||
|
transform_resources=transform_resources,
|
||
|
data_processing=data_processing,
|
||
|
experiment_config=experiment_config,
|
||
|
)
|
||
|
resp = job.save()
|
||
|
resp["TransformJobArn"].should.match(
|
||
|
rf"^arn:aws:sagemaker:.*:.*:transform-job/{transform_job_name}$"
|
||
|
)
|
||
|
resp = sagemaker.describe_transform_job(TransformJobName=transform_job_name)
|
||
|
resp["TransformJobName"].should.equal(transform_job_name)
|
||
|
resp["TransformJobStatus"].should.equal("Completed")
|
||
|
resp["ModelName"].should.equal(model_name)
|
||
|
resp["MaxConcurrentTransforms"].should.equal(1)
|
||
|
resp["ModelClientConfig"].should.equal(model_client_config)
|
||
|
resp["MaxPayloadInMB"].should.equal(max_payload_in_mb)
|
||
|
resp["BatchStrategy"].should.equal("SingleRecord")
|
||
|
resp["TransformInput"].should.equal(transform_input)
|
||
|
resp["TransformOutput"].should.equal(transform_output)
|
||
|
resp["DataCaptureConfig"].should.equal(data_capture_config)
|
||
|
resp["TransformResources"].should.equal(transform_resources)
|
||
|
resp["DataProcessing"].should.equal(data_processing)
|
||
|
resp["ExperimentConfig"].should.equal(experiment_config)
|
||
|
assert isinstance(resp["CreationTime"], datetime.datetime)
|
||
|
assert isinstance(resp["TransformStartTime"], datetime.datetime)
|
||
|
assert isinstance(resp["TransformEndTime"], datetime.datetime)
|
||
|
|
||
|
|
||
|
@mock_sagemaker
|
||
|
def test_list_transform_jobs():
|
||
|
client = boto3.client("sagemaker", region_name="us-east-1")
|
||
|
name = "blah"
|
||
|
model_name = "blah_model"
|
||
|
test_transform_job = MyTransformJobModel(
|
||
|
transform_job_name=name, model_name=model_name
|
||
|
)
|
||
|
test_transform_job.save()
|
||
|
transform_jobs = client.list_transform_jobs()
|
||
|
assert len(transform_jobs["TransformJobSummaries"]).should.equal(1)
|
||
|
assert transform_jobs["TransformJobSummaries"][0]["TransformJobName"].should.equal(
|
||
|
name
|
||
|
)
|
||
|
|
||
|
assert transform_jobs["TransformJobSummaries"][0]["TransformJobArn"].should.match(
|
||
|
rf"^arn:aws:sagemaker:.*:.*:transform-job/{name}$"
|
||
|
)
|
||
|
assert transform_jobs.get("NextToken") is None
|
||
|
|
||
|
|
||
|
@mock_sagemaker
|
||
|
def test_list_transform_jobs_multiple():
|
||
|
client = boto3.client("sagemaker", region_name="us-east-1")
|
||
|
name_job_1 = "blah"
|
||
|
model_name1 = "blah_model"
|
||
|
test_transform_job_1 = MyTransformJobModel(
|
||
|
transform_job_name=name_job_1, model_name=model_name1
|
||
|
)
|
||
|
test_transform_job_1.save()
|
||
|
|
||
|
name_job_2 = "blah2"
|
||
|
model_name2 = "blah_model2"
|
||
|
test_transform_job_2 = MyTransformJobModel(
|
||
|
transform_job_name=name_job_2, model_name=model_name2
|
||
|
)
|
||
|
test_transform_job_2.save()
|
||
|
transform_jobs_limit = client.list_transform_jobs(MaxResults=1)
|
||
|
assert len(transform_jobs_limit["TransformJobSummaries"]).should.equal(1)
|
||
|
|
||
|
transform_jobs = client.list_transform_jobs()
|
||
|
assert len(transform_jobs["TransformJobSummaries"]).should.equal(2)
|
||
|
assert transform_jobs.get("NextToken").should.be.none
|
||
|
|
||
|
|
||
|
@mock_sagemaker
|
||
|
def test_list_transform_jobs_none():
|
||
|
client = boto3.client("sagemaker", region_name="us-east-1")
|
||
|
transform_jobs = client.list_transform_jobs()
|
||
|
assert len(transform_jobs["TransformJobSummaries"]).should.equal(0)
|
||
|
|
||
|
|
||
|
@mock_sagemaker
|
||
|
def test_list_transform_jobs_should_validate_input():
|
||
|
client = boto3.client("sagemaker", region_name="us-east-1")
|
||
|
junk_status_equals = "blah"
|
||
|
with pytest.raises(ClientError) as ex:
|
||
|
client.list_transform_jobs(StatusEquals=junk_status_equals)
|
||
|
expected_error = f"1 validation errors detected: Value '{junk_status_equals}' at 'statusEquals' failed to satisfy constraint: Member must satisfy enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', 'Failed']"
|
||
|
assert ex.value.response["Error"]["Code"] == "ValidationException"
|
||
|
assert ex.value.response["Error"]["Message"] == expected_error
|
||
|
|
||
|
junk_next_token = "asdf"
|
||
|
with pytest.raises(ClientError) as ex:
|
||
|
client.list_transform_jobs(NextToken=junk_next_token)
|
||
|
assert ex.value.response["Error"]["Code"] == "ValidationException"
|
||
|
assert (
|
||
|
ex.value.response["Error"]["Message"]
|
||
|
== 'Invalid pagination token because "{0}".'
|
||
|
)
|
||
|
|
||
|
|
||
|
@mock_sagemaker
|
||
|
def test_list_transform_jobs_with_name_filters():
|
||
|
client = boto3.client("sagemaker", region_name="us-east-1")
|
||
|
for i in range(5):
|
||
|
name = f"xgboost-{i}"
|
||
|
model_name = f"blah_model-{i}"
|
||
|
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
|
||
|
for i in range(5):
|
||
|
name = f"vgg-{i}"
|
||
|
model_name = f"blah_model-{i}"
|
||
|
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
|
||
|
xgboost_transform_jobs = client.list_transform_jobs(NameContains="xgboost")
|
||
|
assert len(xgboost_transform_jobs["TransformJobSummaries"]).should.equal(5)
|
||
|
|
||
|
transform_jobs_with_2 = client.list_transform_jobs(NameContains="2")
|
||
|
assert len(transform_jobs_with_2["TransformJobSummaries"]).should.equal(2)
|
||
|
|
||
|
|
||
|
@mock_sagemaker
|
||
|
def test_list_transform_jobs_paginated():
|
||
|
client = boto3.client("sagemaker", region_name="us-east-1")
|
||
|
for i in range(5):
|
||
|
name = f"xgboost-{i}"
|
||
|
model_name = f"my-model-{i}"
|
||
|
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
|
||
|
xgboost_transform_job_1 = client.list_transform_jobs(
|
||
|
NameContains="xgboost", MaxResults=1
|
||
|
)
|
||
|
assert len(xgboost_transform_job_1["TransformJobSummaries"]).should.equal(1)
|
||
|
assert xgboost_transform_job_1["TransformJobSummaries"][0][
|
||
|
"TransformJobName"
|
||
|
].should.equal("xgboost-0")
|
||
|
assert xgboost_transform_job_1.get("NextToken").should_not.be.none
|
||
|
|
||
|
xgboost_transform_job_next = client.list_transform_jobs(
|
||
|
NameContains="xgboost",
|
||
|
MaxResults=1,
|
||
|
NextToken=xgboost_transform_job_1.get("NextToken"),
|
||
|
)
|
||
|
assert len(xgboost_transform_job_next["TransformJobSummaries"]).should.equal(1)
|
||
|
assert xgboost_transform_job_next["TransformJobSummaries"][0][
|
||
|
"TransformJobName"
|
||
|
].should.equal("xgboost-1")
|
||
|
assert xgboost_transform_job_next.get("NextToken").should_not.be.none
|
||
|
|
||
|
|
||
|
@mock_sagemaker
|
||
|
def test_list_transform_jobs_paginated_with_target_in_middle():
|
||
|
client = boto3.client("sagemaker", region_name="us-east-1")
|
||
|
for i in range(5):
|
||
|
name = f"xgboost-{i}"
|
||
|
model_name = f"my-model-{i}"
|
||
|
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
|
||
|
for i in range(5):
|
||
|
name = f"vgg-{i}"
|
||
|
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
|
||
|
|
||
|
vgg_transform_job_1 = client.list_transform_jobs(NameContains="vgg", MaxResults=1)
|
||
|
assert len(vgg_transform_job_1["TransformJobSummaries"]).should.equal(0)
|
||
|
assert vgg_transform_job_1.get("NextToken").should_not.be.none
|
||
|
|
||
|
vgg_transform_job_6 = client.list_transform_jobs(NameContains="vgg", MaxResults=6)
|
||
|
|
||
|
assert len(vgg_transform_job_6["TransformJobSummaries"]).should.equal(1)
|
||
|
assert vgg_transform_job_6["TransformJobSummaries"][0][
|
||
|
"TransformJobName"
|
||
|
].should.equal("vgg-0")
|
||
|
assert vgg_transform_job_6.get("NextToken").should_not.be.none
|
||
|
|
||
|
vgg_transform_job_10 = client.list_transform_jobs(NameContains="vgg", MaxResults=10)
|
||
|
|
||
|
assert len(vgg_transform_job_10["TransformJobSummaries"]).should.equal(5)
|
||
|
assert vgg_transform_job_10["TransformJobSummaries"][-1][
|
||
|
"TransformJobName"
|
||
|
].should.equal("vgg-4")
|
||
|
assert vgg_transform_job_10.get("NextToken").should.be.none
|
||
|
|
||
|
|
||
|
@mock_sagemaker
|
||
|
def test_list_transform_jobs_paginated_with_fragmented_targets():
|
||
|
client = boto3.client("sagemaker", region_name="us-east-1")
|
||
|
for i in range(5):
|
||
|
name = f"xgboost-{i}"
|
||
|
model_name = f"my-model-{i}"
|
||
|
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
|
||
|
for i in range(5):
|
||
|
name = f"vgg-{i}"
|
||
|
MyTransformJobModel(transform_job_name=name, model_name=model_name).save()
|
||
|
|
||
|
transform_jobs_with_2 = client.list_transform_jobs(NameContains="2", MaxResults=8)
|
||
|
assert len(transform_jobs_with_2["TransformJobSummaries"]).should.equal(2)
|
||
|
assert transform_jobs_with_2.get("NextToken").should_not.be.none
|
||
|
|
||
|
transform_jobs_with_2_next = client.list_transform_jobs(
|
||
|
NameContains="2", MaxResults=1, NextToken=transform_jobs_with_2.get("NextToken")
|
||
|
)
|
||
|
assert len(transform_jobs_with_2_next["TransformJobSummaries"]).should.equal(0)
|
||
|
assert transform_jobs_with_2_next.get("NextToken").should_not.be.none
|
||
|
|
||
|
transform_jobs_with_2_next_next = client.list_transform_jobs(
|
||
|
NameContains="2",
|
||
|
MaxResults=1,
|
||
|
NextToken=transform_jobs_with_2_next.get("NextToken"),
|
||
|
)
|
||
|
assert len(transform_jobs_with_2_next_next["TransformJobSummaries"]).should.equal(0)
|
||
|
assert transform_jobs_with_2_next_next.get("NextToken").should.be.none
|
||
|
|
||
|
|
||
|
@mock_sagemaker
|
||
|
def test_add_tags_to_transform_job():
|
||
|
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||
|
name = "blah"
|
||
|
model_name = "my-model"
|
||
|
resource_arn = "arn:aws:sagemaker:us-east-1:123456789012:transform-job/blah"
|
||
|
|
||
|
test_transform_job = MyTransformJobModel(
|
||
|
transform_job_name=name, model_name=model_name
|
||
|
)
|
||
|
test_transform_job.save()
|
||
|
tags = [
|
||
|
{"Key": "myKey", "Value": "myValue"},
|
||
|
]
|
||
|
response = client.add_tags(ResourceArn=resource_arn, Tags=tags)
|
||
|
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||
|
|
||
|
response = client.list_tags(ResourceArn=resource_arn)
|
||
|
assert response["Tags"] == tags
|
||
|
|
||
|
|
||
|
@mock_sagemaker
|
||
|
def test_delete_tags_from_transform_job():
|
||
|
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||
|
name = "blah"
|
||
|
model_name = "my-model"
|
||
|
resource_arn = "arn:aws:sagemaker:us-east-1:123456789012:transform-job/blah"
|
||
|
test_transform_job = MyTransformJobModel(
|
||
|
transform_job_name=name, model_name=model_name
|
||
|
)
|
||
|
test_transform_job.save()
|
||
|
|
||
|
tags = [
|
||
|
{"Key": "myKey", "Value": "myValue"},
|
||
|
]
|
||
|
response = client.add_tags(ResourceArn=resource_arn, Tags=tags)
|
||
|
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||
|
|
||
|
tag_keys = [tag["Key"] for tag in tags]
|
||
|
response = client.delete_tags(ResourceArn=resource_arn, TagKeys=tag_keys)
|
||
|
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||
|
|
||
|
response = client.list_tags(ResourceArn=resource_arn)
|
||
|
assert response["Tags"] == []
|
||
|
|
||
|
|
||
|
@mock_sagemaker
|
||
|
def test_describe_unknown_transform_job():
|
||
|
client = boto3.client("sagemaker", region_name="us-east-1")
|
||
|
with pytest.raises(ClientError) as exc:
|
||
|
client.describe_transform_job(TransformJobName="unknown")
|
||
|
err = exc.value.response["Error"]
|
||
|
err["Code"].should.equal("ValidationException")
|
||
|
err["Message"].should.equal(
|
||
|
f"Could not find transform job 'arn:aws:sagemaker:us-east-1:{ACCOUNT_ID}:transform-job/unknown'."
|
||
|
)
|