# -*- coding: utf-8 -*- from __future__ import unicode_literals import boto3 import datetime import sure # noqa from moto import mock_sagemaker from moto.sts.models import ACCOUNT_ID FAKE_ROLE_ARN = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID) TEST_REGION_NAME = "us-east-1" @mock_sagemaker def test_create_training_job(): sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) training_job_name = "MyTrainingJob" container = "382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1" bucket = "my-bucket" prefix = "sagemaker/DEMO-breast-cancer-prediction/" params = { "RoleArn": FAKE_ROLE_ARN, "TrainingJobName": training_job_name, "AlgorithmSpecification": { "TrainingImage": container, "TrainingInputMode": "File", }, "ResourceConfig": { "InstanceCount": 1, "InstanceType": "ml.c4.2xlarge", "VolumeSizeInGB": 10, }, "InputDataConfig": [ { "ChannelName": "train", "DataSource": { "S3DataSource": { "S3DataType": "S3Prefix", "S3Uri": "s3://{}/{}/train/".format(bucket, prefix), "S3DataDistributionType": "ShardedByS3Key", } }, "CompressionType": "None", "RecordWrapperType": "None", }, { "ChannelName": "validation", "DataSource": { "S3DataSource": { "S3DataType": "S3Prefix", "S3Uri": "s3://{}/{}/validation/".format(bucket, prefix), "S3DataDistributionType": "FullyReplicated", } }, "CompressionType": "None", "RecordWrapperType": "None", }, ], "OutputDataConfig": {"S3OutputPath": "s3://{}/{}/".format(bucket, prefix)}, "HyperParameters": { "feature_dim": "30", "mini_batch_size": "100", "predictor_type": "regressor", "epochs": "10", "num_models": "32", "loss": "absolute_loss", }, "StoppingCondition": {"MaxRuntimeInSeconds": 60 * 60}, } resp = sagemaker.create_training_job(**params) resp["TrainingJobArn"].should.match( r"^arn:aws:sagemaker:.*:.*:training-job/{}$".format(training_job_name) ) resp = sagemaker.describe_training_job(TrainingJobName=training_job_name) resp["TrainingJobName"].should.equal(training_job_name) resp["TrainingJobArn"].should.match( r"^arn:aws:sagemaker:.*:.*:training-job/{}$".format(training_job_name) ) assert resp["ModelArtifacts"]["S3ModelArtifacts"].startswith( params["OutputDataConfig"]["S3OutputPath"] ) assert training_job_name in (resp["ModelArtifacts"]["S3ModelArtifacts"]) assert \ resp["ModelArtifacts"]["S3ModelArtifacts"].endswith("output/model.tar.gz") assert resp["TrainingJobStatus"] == "Completed" assert resp["SecondaryStatus"] == "Completed" assert resp["HyperParameters"] == params["HyperParameters"] assert \ resp["AlgorithmSpecification"]["TrainingImage"] == \ params["AlgorithmSpecification"]["TrainingImage"] assert \ resp["AlgorithmSpecification"]["TrainingInputMode"] == \ params["AlgorithmSpecification"]["TrainingInputMode"] assert "MetricDefinitions" in resp["AlgorithmSpecification"] assert "Name" in resp["AlgorithmSpecification"]["MetricDefinitions"][0] assert "Regex" in resp["AlgorithmSpecification"]["MetricDefinitions"][0] assert resp["RoleArn"] == FAKE_ROLE_ARN assert resp["InputDataConfig"] == params["InputDataConfig"] assert resp["OutputDataConfig"] == params["OutputDataConfig"] assert resp["ResourceConfig"] == params["ResourceConfig"] assert resp["StoppingCondition"] == params["StoppingCondition"] assert isinstance(resp["CreationTime"], datetime.datetime) assert isinstance(resp["TrainingStartTime"], datetime.datetime) assert isinstance(resp["TrainingEndTime"], datetime.datetime) assert isinstance(resp["LastModifiedTime"], datetime.datetime) assert "SecondaryStatusTransitions" in resp assert "Status" in resp["SecondaryStatusTransitions"][0] assert "StartTime" in resp["SecondaryStatusTransitions"][0] assert "EndTime" in resp["SecondaryStatusTransitions"][0] assert "StatusMessage" in resp["SecondaryStatusTransitions"][0] assert "FinalMetricDataList" in resp assert "MetricName" in resp["FinalMetricDataList"][0] assert "Value" in resp["FinalMetricDataList"][0] assert "Timestamp" in resp["FinalMetricDataList"][0] pass