moto/tests/test_sagemaker/test_sagemaker_training.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

458 lines
17 KiB
Python
Raw Normal View History

import datetime
import re
import boto3
2021-10-18 19:44:29 +00:00
import pytest
from botocore.exceptions import ClientError
2024-01-07 12:03:33 +00:00
from moto import mock_aws
2022-08-13 09:49:43 +00:00
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
FAKE_ROLE_ARN = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole"
TEST_REGION_NAME = "us-east-1"
class MyTrainingJobModel:
def __init__(
self,
training_job_name,
role_arn,
container=None,
bucket=None,
prefix=None,
algorithm_specification=None,
resource_config=None,
input_data_config=None,
output_data_config=None,
hyper_parameters=None,
stopping_condition=None,
):
self.training_job_name = training_job_name
self.role_arn = role_arn
self.container = (
container or "382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1"
)
self.bucket = bucket or "my-bucket"
self.prefix = prefix or "sagemaker/DEMO-breast-cancer-prediction/"
self.algorithm_specification = algorithm_specification or {
"TrainingImage": self.container,
"TrainingInputMode": "File",
}
self.resource_config = resource_config or {
"InstanceCount": 1,
"InstanceType": "ml.c4.2xlarge",
"VolumeSizeInGB": 10,
}
self.input_data_config = input_data_config or [
{
"ChannelName": "train",
"DataSource": {
"S3DataSource": {
"S3DataType": "S3Prefix",
"S3Uri": f"s3://{self.bucket}/{self.prefix}/train/",
"S3DataDistributionType": "ShardedByS3Key",
}
},
"CompressionType": "None",
"RecordWrapperType": "None",
},
{
"ChannelName": "validation",
"DataSource": {
"S3DataSource": {
"S3DataType": "S3Prefix",
"S3Uri": f"s3://{self.bucket}/{self.prefix}/validation/",
"S3DataDistributionType": "FullyReplicated",
}
},
"CompressionType": "None",
"RecordWrapperType": "None",
},
]
self.output_data_config = output_data_config or {
"S3OutputPath": f"s3://{self.bucket}/{self.prefix}/"
}
self.hyper_parameters = hyper_parameters or {
"feature_dim": "30",
"mini_batch_size": "100",
"predictor_type": "regressor",
"epochs": "10",
"num_models": "32",
"loss": "absolute_loss",
}
self.stopping_condition = stopping_condition or {"MaxRuntimeInSeconds": 60 * 60}
def save(self):
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
params = {
"RoleArn": self.role_arn,
"TrainingJobName": self.training_job_name,
"AlgorithmSpecification": self.algorithm_specification,
"ResourceConfig": self.resource_config,
"InputDataConfig": self.input_data_config,
"OutputDataConfig": self.output_data_config,
"HyperParameters": self.hyper_parameters,
"StoppingCondition": self.stopping_condition,
}
return sagemaker.create_training_job(**params)
2024-01-07 12:03:33 +00:00
@mock_aws
def test_create_training_job():
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
training_job_name = "MyTrainingJob"
role_arn = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole"
container = "382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1"
bucket = "my-bucket"
prefix = "sagemaker/DEMO-breast-cancer-prediction/"
algorithm_specification = {
"TrainingImage": container,
"TrainingInputMode": "File",
}
resource_config = {
"InstanceCount": 1,
"InstanceType": "ml.c4.2xlarge",
"VolumeSizeInGB": 10,
}
input_data_config = [
{
"ChannelName": "train",
"DataSource": {
"S3DataSource": {
"S3DataType": "S3Prefix",
"S3Uri": f"s3://{bucket}/{prefix}/train/",
"S3DataDistributionType": "ShardedByS3Key",
}
},
"CompressionType": "None",
"RecordWrapperType": "None",
},
{
"ChannelName": "validation",
"DataSource": {
"S3DataSource": {
"S3DataType": "S3Prefix",
"S3Uri": f"s3://{bucket}/{prefix}/validation/",
"S3DataDistributionType": "FullyReplicated",
}
},
"CompressionType": "None",
"RecordWrapperType": "None",
},
]
output_data_config = {"S3OutputPath": f"s3://{bucket}/{prefix}/"}
hyper_parameters = {
"feature_dim": "30",
"mini_batch_size": "100",
"predictor_type": "regressor",
"epochs": "10",
"num_models": "32",
"loss": "absolute_loss",
}
stopping_condition = {"MaxRuntimeInSeconds": 60 * 60}
job = MyTrainingJobModel(
training_job_name,
role_arn,
container=container,
bucket=bucket,
prefix=prefix,
algorithm_specification=algorithm_specification,
resource_config=resource_config,
input_data_config=input_data_config,
output_data_config=output_data_config,
hyper_parameters=hyper_parameters,
stopping_condition=stopping_condition,
)
resp = job.save()
assert re.match(
rf"^arn:aws:sagemaker:.*:.*:training-job/{training_job_name}$",
resp["TrainingJobArn"],
)
resp = sagemaker.describe_training_job(TrainingJobName=training_job_name)
assert resp["TrainingJobName"] == training_job_name
assert re.match(
rf"^arn:aws:sagemaker:.*:.*:training-job/{training_job_name}$",
resp["TrainingJobArn"],
)
assert resp["ModelArtifacts"]["S3ModelArtifacts"].startswith(
output_data_config["S3OutputPath"]
2020-10-06 06:46:05 +00:00
)
assert training_job_name in (resp["ModelArtifacts"]["S3ModelArtifacts"])
2020-10-06 06:46:05 +00:00
assert resp["ModelArtifacts"]["S3ModelArtifacts"].endswith("output/model.tar.gz")
assert resp["TrainingJobStatus"] == "Completed"
assert resp["SecondaryStatus"] == "Completed"
assert resp["HyperParameters"] == hyper_parameters
2020-10-06 06:46:05 +00:00
assert (
resp["AlgorithmSpecification"]["TrainingImage"]
== algorithm_specification["TrainingImage"]
2020-10-06 06:46:05 +00:00
)
assert (
resp["AlgorithmSpecification"]["TrainingInputMode"]
== algorithm_specification["TrainingInputMode"]
2020-10-06 06:46:05 +00:00
)
assert "MetricDefinitions" in resp["AlgorithmSpecification"]
assert "Name" in resp["AlgorithmSpecification"]["MetricDefinitions"][0]
assert "Regex" in resp["AlgorithmSpecification"]["MetricDefinitions"][0]
assert resp["RoleArn"] == role_arn
assert resp["InputDataConfig"] == input_data_config
assert resp["OutputDataConfig"] == output_data_config
assert resp["ResourceConfig"] == resource_config
assert resp["StoppingCondition"] == stopping_condition
assert isinstance(resp["CreationTime"], datetime.datetime)
assert isinstance(resp["TrainingStartTime"], datetime.datetime)
assert isinstance(resp["TrainingEndTime"], datetime.datetime)
assert isinstance(resp["LastModifiedTime"], datetime.datetime)
assert "SecondaryStatusTransitions" in resp
assert "Status" in resp["SecondaryStatusTransitions"][0]
assert "StartTime" in resp["SecondaryStatusTransitions"][0]
assert "EndTime" in resp["SecondaryStatusTransitions"][0]
assert "StatusMessage" in resp["SecondaryStatusTransitions"][0]
assert "FinalMetricDataList" in resp
assert "MetricName" in resp["FinalMetricDataList"][0]
assert "Value" in resp["FinalMetricDataList"][0]
assert "Timestamp" in resp["FinalMetricDataList"][0]
2024-01-07 12:03:33 +00:00
@mock_aws
def test_list_training_jobs():
client = boto3.client("sagemaker", region_name="us-east-1")
name = "blah"
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar"
test_training_job = MyTrainingJobModel(training_job_name=name, role_arn=arn)
test_training_job.save()
training_jobs = client.list_training_jobs()
assert len(training_jobs["TrainingJobSummaries"]) == 1
assert training_jobs["TrainingJobSummaries"][0]["TrainingJobName"] == name
assert re.match(
rf"^arn:aws:sagemaker:.*:.*:training-job/{name}$",
training_jobs["TrainingJobSummaries"][0]["TrainingJobArn"],
)
assert training_jobs.get("NextToken") is None
2024-01-07 12:03:33 +00:00
@mock_aws
def test_list_training_jobs_multiple():
client = boto3.client("sagemaker", region_name="us-east-1")
name_job_1 = "blah"
arn_job_1 = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar"
test_training_job_1 = MyTrainingJobModel(
training_job_name=name_job_1, role_arn=arn_job_1
)
test_training_job_1.save()
name_job_2 = "blah2"
arn_job_2 = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar2"
test_training_job_2 = MyTrainingJobModel(
training_job_name=name_job_2, role_arn=arn_job_2
)
test_training_job_2.save()
training_jobs_limit = client.list_training_jobs(MaxResults=1)
assert len(training_jobs_limit["TrainingJobSummaries"]) == 1
training_jobs = client.list_training_jobs()
assert len(training_jobs["TrainingJobSummaries"]) == 2
assert training_jobs.get("NextToken") is None
2024-01-07 12:03:33 +00:00
@mock_aws
def test_list_training_jobs_none():
client = boto3.client("sagemaker", region_name="us-east-1")
training_jobs = client.list_training_jobs()
assert len(training_jobs["TrainingJobSummaries"]) == 0
2024-01-07 12:03:33 +00:00
@mock_aws
def test_list_training_jobs_should_validate_input():
client = boto3.client("sagemaker", region_name="us-east-1")
junk_status_equals = "blah"
with pytest.raises(ClientError) as ex:
client.list_training_jobs(StatusEquals=junk_status_equals)
expected_error = (
f"1 validation errors detected: Value '{junk_status_equals}' at "
"'statusEquals' failed to satisfy constraint: Member must satisfy "
"enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', "
"'Failed']"
)
assert ex.value.response["Error"]["Code"] == "ValidationException"
assert ex.value.response["Error"]["Message"] == expected_error
junk_next_token = "asdf"
with pytest.raises(ClientError) as ex:
client.list_training_jobs(NextToken=junk_next_token)
assert ex.value.response["Error"]["Code"] == "ValidationException"
assert (
ex.value.response["Error"]["Message"]
== 'Invalid pagination token because "{0}".'
)
2024-01-07 12:03:33 +00:00
@mock_aws
def test_list_training_jobs_with_name_filters():
client = boto3.client("sagemaker", region_name="us-east-1")
for i in range(5):
name = f"xgboost-{i}"
arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{i}"
MyTrainingJobModel(training_job_name=name, role_arn=arn).save()
for i in range(5):
name = f"vgg-{i}"
arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{i}"
MyTrainingJobModel(training_job_name=name, role_arn=arn).save()
xgboost_training_jobs = client.list_training_jobs(NameContains="xgboost")
assert len(xgboost_training_jobs["TrainingJobSummaries"]) == 5
training_jobs_with_2 = client.list_training_jobs(NameContains="2")
assert len(training_jobs_with_2["TrainingJobSummaries"]) == 2
2024-01-07 12:03:33 +00:00
@mock_aws
def test_list_training_jobs_paginated():
client = boto3.client("sagemaker", region_name="us-east-1")
for i in range(5):
name = f"xgboost-{i}"
arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{i}"
MyTrainingJobModel(training_job_name=name, role_arn=arn).save()
xgboost_training_job_1 = client.list_training_jobs(
NameContains="xgboost", MaxResults=1
)
assert len(xgboost_training_job_1["TrainingJobSummaries"]) == 1
assert (
xgboost_training_job_1["TrainingJobSummaries"][0]["TrainingJobName"]
== "xgboost-0"
)
assert xgboost_training_job_1.get("NextToken") is not None
xgboost_training_job_next = client.list_training_jobs(
NameContains="xgboost",
MaxResults=1,
NextToken=xgboost_training_job_1.get("NextToken"),
)
assert len(xgboost_training_job_next["TrainingJobSummaries"]) == 1
assert (
xgboost_training_job_next["TrainingJobSummaries"][0]["TrainingJobName"]
== "xgboost-1"
)
assert xgboost_training_job_next.get("NextToken") is not None
2024-01-07 12:03:33 +00:00
@mock_aws
def test_list_training_jobs_paginated_with_target_in_middle():
client = boto3.client("sagemaker", region_name="us-east-1")
for i in range(5):
name = f"xgboost-{i}"
arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{i}"
MyTrainingJobModel(training_job_name=name, role_arn=arn).save()
for i in range(5):
name = f"vgg-{i}"
arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{i}"
MyTrainingJobModel(training_job_name=name, role_arn=arn).save()
vgg_training_job_1 = client.list_training_jobs(NameContains="vgg", MaxResults=1)
assert len(vgg_training_job_1["TrainingJobSummaries"]) == 0
assert vgg_training_job_1.get("NextToken") is not None
vgg_training_job_6 = client.list_training_jobs(NameContains="vgg", MaxResults=6)
assert len(vgg_training_job_6["TrainingJobSummaries"]) == 1
assert vgg_training_job_6["TrainingJobSummaries"][0]["TrainingJobName"] == "vgg-0"
assert vgg_training_job_6.get("NextToken") is not None
vgg_training_job_10 = client.list_training_jobs(NameContains="vgg", MaxResults=10)
assert len(vgg_training_job_10["TrainingJobSummaries"]) == 5
assert vgg_training_job_10["TrainingJobSummaries"][-1]["TrainingJobName"] == "vgg-4"
assert vgg_training_job_10.get("NextToken") is None
2024-01-07 12:03:33 +00:00
@mock_aws
def test_list_training_jobs_paginated_with_fragmented_targets():
client = boto3.client("sagemaker", region_name="us-east-1")
for i in range(5):
name = f"xgboost-{i}"
arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{i}"
MyTrainingJobModel(training_job_name=name, role_arn=arn).save()
for i in range(5):
name = f"vgg-{i}"
arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{i}"
MyTrainingJobModel(training_job_name=name, role_arn=arn).save()
training_jobs_with_2 = client.list_training_jobs(NameContains="2", MaxResults=8)
assert len(training_jobs_with_2["TrainingJobSummaries"]) == 2
assert training_jobs_with_2.get("NextToken") is not None
training_jobs_with_2_next = client.list_training_jobs(
NameContains="2", MaxResults=1, NextToken=training_jobs_with_2.get("NextToken")
)
assert len(training_jobs_with_2_next["TrainingJobSummaries"]) == 0
assert training_jobs_with_2_next.get("NextToken") is not None
training_jobs_with_2_next_next = client.list_training_jobs(
NameContains="2",
MaxResults=1,
NextToken=training_jobs_with_2_next.get("NextToken"),
)
assert len(training_jobs_with_2_next_next["TrainingJobSummaries"]) == 0
assert training_jobs_with_2_next_next.get("NextToken") is None
2024-01-07 12:03:33 +00:00
@mock_aws
def test_add_tags_to_training_job():
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
name = "blah"
resource_arn = f"arn:aws:sagemaker:us-east-1:000000000000:training-job/{name}"
test_training_job = MyTrainingJobModel(
training_job_name=name, role_arn=resource_arn
)
test_training_job.save()
tags = [
{"Key": "myKey", "Value": "myValue"},
]
response = client.add_tags(ResourceArn=resource_arn, Tags=tags)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
response = client.list_tags(ResourceArn=resource_arn)
assert response["Tags"] == tags
2024-01-07 12:03:33 +00:00
@mock_aws
def test_delete_tags_from_training_job():
client = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
name = "blah"
resource_arn = f"arn:aws:sagemaker:us-east-1:000000000000:training-job/{name}"
test_training_job = MyTrainingJobModel(
training_job_name=name, role_arn=resource_arn
)
test_training_job.save()
tags = [
{"Key": "myKey", "Value": "myValue"},
]
response = client.add_tags(ResourceArn=resource_arn, Tags=tags)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
tag_keys = [tag["Key"] for tag in tags]
response = client.delete_tags(ResourceArn=resource_arn, TagKeys=tag_keys)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
response = client.list_tags(ResourceArn=resource_arn)
assert response["Tags"] == []
2024-01-07 12:03:33 +00:00
@mock_aws
def test_describe_unknown_training_job():
client = boto3.client("sagemaker", region_name="us-east-1")
with pytest.raises(ClientError) as exc:
client.describe_training_job(TrainingJobName="unknown")
err = exc.value.response["Error"]
assert err["Code"] == "ValidationException"
assert err["Message"] == (
"Could not find training job 'arn:aws:sagemaker:us-east-1:"
f"{ACCOUNT_ID}:training-job/unknown'."
)