moto/tests/test_sagemaker/test_sagemaker_processing.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

408 lines
14 KiB
Python
Raw Normal View History

import datetime
import re
import boto3
import pytest
from botocore.exceptions import ClientError
2024-01-07 12:03:33 +00:00
from moto import mock_aws
2022-08-13 09:49:43 +00:00
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
FAKE_ROLE_ARN = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole"
FAKE_PROCESSING_JOB_NAME = "MyProcessingJob"
FAKE_CONTAINER = "382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1"
TEST_REGION_NAME = "us-east-1"
@pytest.fixture(name="sagemaker_client")
def fixture_sagemaker_client():
2024-01-07 12:03:33 +00:00
with mock_aws():
yield boto3.client("sagemaker", region_name=TEST_REGION_NAME)
class MyProcessingJobModel:
def __init__(
self,
processing_job_name,
role_arn,
container=None,
bucket=None,
prefix=None,
app_specification=None,
network_config=None,
processing_inputs=None,
processing_output_config=None,
processing_resources=None,
stopping_condition=None,
):
self.processing_job_name = processing_job_name
self.role_arn = role_arn
self.container = (
container
or "683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn:0.23-1-cpu-py3"
)
self.bucket = bucket or "my-bucket"
self.prefix = prefix or "sagemaker"
self.app_specification = app_specification or {
"ImageUri": self.container,
"ContainerEntrypoint": ["python3"],
}
self.network_config = network_config or {
"EnableInterContainerTrafficEncryption": False,
"EnableNetworkIsolation": False,
}
self.processing_inputs = processing_inputs or [
{
"InputName": "input",
"AppManaged": False,
"S3Input": {
"S3Uri": f"s3://{self.bucket}/{self.prefix}/processing/",
"LocalPath": "/opt/ml/processing/input",
"S3DataType": "S3Prefix",
"S3InputMode": "File",
"S3DataDistributionType": "FullyReplicated",
"S3CompressionType": "None",
},
}
]
self.processing_output_config = processing_output_config or {
"Outputs": [
{
"OutputName": "output",
"S3Output": {
"S3Uri": f"s3://{self.bucket}/{self.prefix}/processing/",
"LocalPath": "/opt/ml/processing/output",
"S3UploadMode": "EndOfJob",
},
"AppManaged": False,
}
]
}
self.processing_resources = processing_resources or {
"ClusterConfig": {
"InstanceCount": 1,
"InstanceType": "ml.m5.large",
"VolumeSizeInGB": 10,
},
}
self.stopping_condition = stopping_condition or {
"MaxRuntimeInSeconds": 3600,
}
def save(self, sagemaker_client):
params = {
"AppSpecification": self.app_specification,
"NetworkConfig": self.network_config,
"ProcessingInputs": self.processing_inputs,
"ProcessingJobName": self.processing_job_name,
"ProcessingOutputConfig": self.processing_output_config,
"ProcessingResources": self.processing_resources,
"RoleArn": self.role_arn,
"StoppingCondition": self.stopping_condition,
}
return sagemaker_client.create_processing_job(**params)
def test_create_processing_job(sagemaker_client):
bucket = "my-bucket"
prefix = "my-prefix"
app_specification = {
"ImageUri": FAKE_CONTAINER,
"ContainerEntrypoint": ["python3", "app.py"],
}
processing_resources = {
"ClusterConfig": {
"InstanceCount": 2,
"InstanceType": "ml.m5.xlarge",
"VolumeSizeInGB": 20,
},
}
stopping_condition = {"MaxRuntimeInSeconds": 60 * 60}
job = MyProcessingJobModel(
processing_job_name=FAKE_PROCESSING_JOB_NAME,
role_arn=FAKE_ROLE_ARN,
container=FAKE_CONTAINER,
bucket=bucket,
prefix=prefix,
app_specification=app_specification,
processing_resources=processing_resources,
stopping_condition=stopping_condition,
)
resp = job.save(sagemaker_client)
assert re.match(
rf"^arn:aws:sagemaker:.*:.*:processing-job/{FAKE_PROCESSING_JOB_NAME}$",
resp["ProcessingJobArn"],
)
resp = sagemaker_client.describe_processing_job(
ProcessingJobName=FAKE_PROCESSING_JOB_NAME
)
assert resp["ProcessingJobName"] == FAKE_PROCESSING_JOB_NAME
assert re.match(
rf"^arn:aws:sagemaker:.*:.*:processing-job/{FAKE_PROCESSING_JOB_NAME}$",
resp["ProcessingJobArn"],
)
assert "python3" in resp["AppSpecification"]["ContainerEntrypoint"]
assert "app.py" in resp["AppSpecification"]["ContainerEntrypoint"]
assert resp["RoleArn"] == FAKE_ROLE_ARN
assert resp["ProcessingJobStatus"] == "Completed"
assert isinstance(resp["CreationTime"], datetime.datetime)
assert isinstance(resp["LastModifiedTime"], datetime.datetime)
def test_list_processing_jobs(sagemaker_client):
test_processing_job = MyProcessingJobModel(
processing_job_name=FAKE_PROCESSING_JOB_NAME, role_arn=FAKE_ROLE_ARN
)
test_processing_job.save(sagemaker_client)
processing_jobs = sagemaker_client.list_processing_jobs()
assert len(processing_jobs["ProcessingJobSummaries"]) == 1
assert (
processing_jobs["ProcessingJobSummaries"][0]["ProcessingJobName"]
== FAKE_PROCESSING_JOB_NAME
)
assert re.match(
rf"^arn:aws:sagemaker:.*:.*:processing-job/{FAKE_PROCESSING_JOB_NAME}$",
processing_jobs["ProcessingJobSummaries"][0]["ProcessingJobArn"],
)
assert processing_jobs.get("NextToken") is None
def test_list_processing_jobs_multiple(sagemaker_client):
name_job_1 = "blah"
arn_job_1 = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar"
test_processing_job_1 = MyProcessingJobModel(
processing_job_name=name_job_1, role_arn=arn_job_1
)
test_processing_job_1.save(sagemaker_client)
name_job_2 = "blah2"
arn_job_2 = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar2"
test_processing_job_2 = MyProcessingJobModel(
processing_job_name=name_job_2, role_arn=arn_job_2
)
test_processing_job_2.save(sagemaker_client)
processing_jobs_limit = sagemaker_client.list_processing_jobs(MaxResults=1)
assert len(processing_jobs_limit["ProcessingJobSummaries"]) == 1
processing_jobs = sagemaker_client.list_processing_jobs()
assert len(processing_jobs["ProcessingJobSummaries"]) == 2
assert processing_jobs.get("NextToken") is None
def test_list_processing_jobs_none(sagemaker_client):
processing_jobs = sagemaker_client.list_processing_jobs()
assert len(processing_jobs["ProcessingJobSummaries"]) == 0
def test_list_processing_jobs_should_validate_input(sagemaker_client):
junk_status_equals = "blah"
with pytest.raises(ClientError) as ex:
sagemaker_client.list_processing_jobs(StatusEquals=junk_status_equals)
expected_error = (
f"1 validation errors detected: Value '{junk_status_equals}' at "
"'statusEquals' failed to satisfy constraint: Member must satisfy "
"enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', "
"'Failed']"
)
assert ex.value.response["Error"]["Code"] == "ValidationException"
assert ex.value.response["Error"]["Message"] == expected_error
junk_next_token = "asdf"
with pytest.raises(ClientError) as ex:
sagemaker_client.list_processing_jobs(NextToken=junk_next_token)
assert ex.value.response["Error"]["Code"] == "ValidationException"
assert (
ex.value.response["Error"]["Message"]
== 'Invalid pagination token because "{0}".'
)
def test_list_processing_jobs_with_name_filters(sagemaker_client):
for i in range(5):
name = f"xgboost-{i}"
arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{i}"
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save(
sagemaker_client
)
for i in range(5):
name = f"vgg-{i}"
arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{i}"
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save(
sagemaker_client
)
xgboost_processing_jobs = sagemaker_client.list_processing_jobs(
NameContains="xgboost"
)
assert len(xgboost_processing_jobs["ProcessingJobSummaries"]) == 5
processing_jobs_with_2 = sagemaker_client.list_processing_jobs(NameContains="2")
assert len(processing_jobs_with_2["ProcessingJobSummaries"]) == 2
def test_list_processing_jobs_paginated(sagemaker_client):
for i in range(5):
name = f"xgboost-{i}"
arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{i}"
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save(
sagemaker_client
)
xgboost_processing_job_1 = sagemaker_client.list_processing_jobs(
NameContains="xgboost", MaxResults=1
)
assert len(xgboost_processing_job_1["ProcessingJobSummaries"]) == 1
assert (
xgboost_processing_job_1["ProcessingJobSummaries"][0]["ProcessingJobName"]
== "xgboost-0"
)
assert xgboost_processing_job_1.get("NextToken") is not None
xgboost_processing_job_next = sagemaker_client.list_processing_jobs(
NameContains="xgboost",
MaxResults=1,
NextToken=xgboost_processing_job_1.get("NextToken"),
)
assert len(xgboost_processing_job_next["ProcessingJobSummaries"]) == 1
assert (
xgboost_processing_job_next["ProcessingJobSummaries"][0]["ProcessingJobName"]
== "xgboost-1"
)
assert xgboost_processing_job_next.get("NextToken") is not None
def test_list_processing_jobs_paginated_with_target_in_middle(sagemaker_client):
for i in range(5):
name = f"xgboost-{i}"
arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{i}"
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save(
sagemaker_client
)
for i in range(5):
name = f"vgg-{i}"
arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{i}"
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save(
sagemaker_client
)
vgg_processing_job_1 = sagemaker_client.list_processing_jobs(
NameContains="vgg", MaxResults=1
)
assert len(vgg_processing_job_1["ProcessingJobSummaries"]) == 0
assert vgg_processing_job_1.get("NextToken") is not None
vgg_processing_job_6 = sagemaker_client.list_processing_jobs(
NameContains="vgg", MaxResults=6
)
assert len(vgg_processing_job_6["ProcessingJobSummaries"]) == 1
assert (
vgg_processing_job_6["ProcessingJobSummaries"][0]["ProcessingJobName"]
== "vgg-0"
)
assert vgg_processing_job_6.get("NextToken") is not None
vgg_processing_job_10 = sagemaker_client.list_processing_jobs(
NameContains="vgg", MaxResults=10
)
assert len(vgg_processing_job_10["ProcessingJobSummaries"]) == 5
assert (
vgg_processing_job_10["ProcessingJobSummaries"][-1]["ProcessingJobName"]
== "vgg-4"
)
assert vgg_processing_job_10.get("NextToken") is None
def test_list_processing_jobs_paginated_with_fragmented_targets(sagemaker_client):
for i in range(5):
name = f"xgboost-{i}"
arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{i}"
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save(
sagemaker_client
)
for i in range(5):
name = f"vgg-{i}"
arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{i}"
MyProcessingJobModel(processing_job_name=name, role_arn=arn).save(
sagemaker_client
)
processing_jobs_with_2 = sagemaker_client.list_processing_jobs(
NameContains="2", MaxResults=8
)
assert len(processing_jobs_with_2["ProcessingJobSummaries"]) == 2
assert processing_jobs_with_2.get("NextToken") is not None
processing_jobs_with_2_next = sagemaker_client.list_processing_jobs(
NameContains="2",
MaxResults=1,
NextToken=processing_jobs_with_2.get("NextToken"),
)
assert len(processing_jobs_with_2_next["ProcessingJobSummaries"]) == 0
assert processing_jobs_with_2_next.get("NextToken") is not None
processing_jobs_with_2_next_next = sagemaker_client.list_processing_jobs(
NameContains="2",
MaxResults=1,
NextToken=processing_jobs_with_2_next.get("NextToken"),
)
assert len(processing_jobs_with_2_next_next["ProcessingJobSummaries"]) == 0
assert processing_jobs_with_2_next_next.get("NextToken") is None
def test_add_and_delete_tags_in_training_job(sagemaker_client):
processing_job_name = "MyProcessingJob"
role_arn = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole"
container = "382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1"
bucket = "my-bucket"
prefix = "my-prefix"
app_specification = {
"ImageUri": container,
"ContainerEntrypoint": ["python3", "app.py"],
}
processing_resources = {
"ClusterConfig": {
"InstanceCount": 2,
"InstanceType": "ml.m5.xlarge",
"VolumeSizeInGB": 20,
},
}
stopping_condition = {"MaxRuntimeInSeconds": 60 * 60}
job = MyProcessingJobModel(
processing_job_name,
role_arn,
container=container,
bucket=bucket,
prefix=prefix,
app_specification=app_specification,
processing_resources=processing_resources,
stopping_condition=stopping_condition,
)
resp = job.save(sagemaker_client)
resource_arn = resp["ProcessingJobArn"]
tags = [
{"Key": "myKey", "Value": "myValue"},
]
response = sagemaker_client.add_tags(ResourceArn=resource_arn, Tags=tags)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
response = sagemaker_client.list_tags(ResourceArn=resource_arn)
assert response["Tags"] == tags
tag_keys = [tag["Key"] for tag in tags]
response = sagemaker_client.delete_tags(ResourceArn=resource_arn, TagKeys=tag_keys)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
response = sagemaker_client.list_tags(ResourceArn=resource_arn)
assert response["Tags"] == []