123 lines
4.7 KiB
Python
123 lines
4.7 KiB
Python
# -*- coding: utf-8 -*-
|
|
from __future__ import unicode_literals
|
|
|
|
import boto3
|
|
import datetime
|
|
import sure # noqa
|
|
|
|
from moto import mock_sagemaker
|
|
from moto.sts.models import ACCOUNT_ID
|
|
|
|
FAKE_ROLE_ARN = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)
|
|
TEST_REGION_NAME = "us-east-1"
|
|
|
|
|
|
@mock_sagemaker
|
|
def test_create_training_job():
|
|
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
|
|
|
training_job_name = "MyTrainingJob"
|
|
container = "382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1"
|
|
bucket = "my-bucket"
|
|
prefix = "sagemaker/DEMO-breast-cancer-prediction/"
|
|
|
|
params = {
|
|
"RoleArn": FAKE_ROLE_ARN,
|
|
"TrainingJobName": training_job_name,
|
|
"AlgorithmSpecification": {
|
|
"TrainingImage": container,
|
|
"TrainingInputMode": "File",
|
|
},
|
|
"ResourceConfig": {
|
|
"InstanceCount": 1,
|
|
"InstanceType": "ml.c4.2xlarge",
|
|
"VolumeSizeInGB": 10,
|
|
},
|
|
"InputDataConfig": [
|
|
{
|
|
"ChannelName": "train",
|
|
"DataSource": {
|
|
"S3DataSource": {
|
|
"S3DataType": "S3Prefix",
|
|
"S3Uri": "s3://{}/{}/train/".format(bucket, prefix),
|
|
"S3DataDistributionType": "ShardedByS3Key",
|
|
}
|
|
},
|
|
"CompressionType": "None",
|
|
"RecordWrapperType": "None",
|
|
},
|
|
{
|
|
"ChannelName": "validation",
|
|
"DataSource": {
|
|
"S3DataSource": {
|
|
"S3DataType": "S3Prefix",
|
|
"S3Uri": "s3://{}/{}/validation/".format(bucket, prefix),
|
|
"S3DataDistributionType": "FullyReplicated",
|
|
}
|
|
},
|
|
"CompressionType": "None",
|
|
"RecordWrapperType": "None",
|
|
},
|
|
],
|
|
"OutputDataConfig": {"S3OutputPath": "s3://{}/{}/".format(bucket, prefix)},
|
|
"HyperParameters": {
|
|
"feature_dim": "30",
|
|
"mini_batch_size": "100",
|
|
"predictor_type": "regressor",
|
|
"epochs": "10",
|
|
"num_models": "32",
|
|
"loss": "absolute_loss",
|
|
},
|
|
"StoppingCondition": {"MaxRuntimeInSeconds": 60 * 60},
|
|
}
|
|
|
|
resp = sagemaker.create_training_job(**params)
|
|
resp["TrainingJobArn"].should.match(
|
|
r"^arn:aws:sagemaker:.*:.*:training-job/{}$".format(training_job_name)
|
|
)
|
|
|
|
resp = sagemaker.describe_training_job(TrainingJobName=training_job_name)
|
|
resp["TrainingJobName"].should.equal(training_job_name)
|
|
resp["TrainingJobArn"].should.match(
|
|
r"^arn:aws:sagemaker:.*:.*:training-job/{}$".format(training_job_name)
|
|
)
|
|
assert resp["ModelArtifacts"]["S3ModelArtifacts"].startswith(
|
|
params["OutputDataConfig"]["S3OutputPath"]
|
|
)
|
|
assert training_job_name in (resp["ModelArtifacts"]["S3ModelArtifacts"])
|
|
assert resp["ModelArtifacts"]["S3ModelArtifacts"].endswith("output/model.tar.gz")
|
|
assert resp["TrainingJobStatus"] == "Completed"
|
|
assert resp["SecondaryStatus"] == "Completed"
|
|
assert resp["HyperParameters"] == params["HyperParameters"]
|
|
assert (
|
|
resp["AlgorithmSpecification"]["TrainingImage"]
|
|
== params["AlgorithmSpecification"]["TrainingImage"]
|
|
)
|
|
assert (
|
|
resp["AlgorithmSpecification"]["TrainingInputMode"]
|
|
== params["AlgorithmSpecification"]["TrainingInputMode"]
|
|
)
|
|
assert "MetricDefinitions" in resp["AlgorithmSpecification"]
|
|
assert "Name" in resp["AlgorithmSpecification"]["MetricDefinitions"][0]
|
|
assert "Regex" in resp["AlgorithmSpecification"]["MetricDefinitions"][0]
|
|
assert resp["RoleArn"] == FAKE_ROLE_ARN
|
|
assert resp["InputDataConfig"] == params["InputDataConfig"]
|
|
assert resp["OutputDataConfig"] == params["OutputDataConfig"]
|
|
assert resp["ResourceConfig"] == params["ResourceConfig"]
|
|
assert resp["StoppingCondition"] == params["StoppingCondition"]
|
|
assert isinstance(resp["CreationTime"], datetime.datetime)
|
|
assert isinstance(resp["TrainingStartTime"], datetime.datetime)
|
|
assert isinstance(resp["TrainingEndTime"], datetime.datetime)
|
|
assert isinstance(resp["LastModifiedTime"], datetime.datetime)
|
|
assert "SecondaryStatusTransitions" in resp
|
|
assert "Status" in resp["SecondaryStatusTransitions"][0]
|
|
assert "StartTime" in resp["SecondaryStatusTransitions"][0]
|
|
assert "EndTime" in resp["SecondaryStatusTransitions"][0]
|
|
assert "StatusMessage" in resp["SecondaryStatusTransitions"][0]
|
|
assert "FinalMetricDataList" in resp
|
|
assert "MetricName" in resp["FinalMetricDataList"][0]
|
|
assert "Value" in resp["FinalMetricDataList"][0]
|
|
assert "Timestamp" in resp["FinalMetricDataList"][0]
|
|
|
|
pass
|