| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  | import datetime | 
					
						
							|  |  |  | import re | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  | import boto3 | 
					
						
							| 
									
										
										
										
											2021-10-18 19:44:29 +00:00
										 |  |  | import pytest | 
					
						
							| 
									
										
										
										
											2023-11-30 07:55:51 -08:00
										 |  |  | from botocore.exceptions import ClientError | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | from moto import mock_sagemaker | 
					
						
							| 
									
										
										
										
											2022-08-13 09:49:43 +00:00
										 |  |  | from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-17 21:41:08 -01:00
										 |  |  | FAKE_ROLE_ARN = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole" | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  | TEST_REGION_NAME = "us-east-1" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  | class MyTrainingJobModel: | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  |     def __init__( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         training_job_name, | 
					
						
							|  |  |  |         role_arn, | 
					
						
							|  |  |  |         container=None, | 
					
						
							|  |  |  |         bucket=None, | 
					
						
							|  |  |  |         prefix=None, | 
					
						
							|  |  |  |         algorithm_specification=None, | 
					
						
							|  |  |  |         resource_config=None, | 
					
						
							|  |  |  |         input_data_config=None, | 
					
						
							|  |  |  |         output_data_config=None, | 
					
						
							|  |  |  |         hyper_parameters=None, | 
					
						
							|  |  |  |         stopping_condition=None, | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |         self.training_job_name = training_job_name | 
					
						
							|  |  |  |         self.role_arn = role_arn | 
					
						
							|  |  |  |         self.container = ( | 
					
						
							|  |  |  |             container or "382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1" | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self.bucket = bucket or "my-bucket" | 
					
						
							|  |  |  |         self.prefix = prefix or "sagemaker/DEMO-breast-cancer-prediction/" | 
					
						
							|  |  |  |         self.algorithm_specification = algorithm_specification or { | 
					
						
							|  |  |  |             "TrainingImage": self.container, | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  |             "TrainingInputMode": "File", | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  |         } | 
					
						
							|  |  |  |         self.resource_config = resource_config or { | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  |             "InstanceCount": 1, | 
					
						
							|  |  |  |             "InstanceType": "ml.c4.2xlarge", | 
					
						
							|  |  |  |             "VolumeSizeInGB": 10, | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  |         } | 
					
						
							|  |  |  |         self.input_data_config = input_data_config or [ | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  |             { | 
					
						
							|  |  |  |                 "ChannelName": "train", | 
					
						
							|  |  |  |                 "DataSource": { | 
					
						
							|  |  |  |                     "S3DataSource": { | 
					
						
							|  |  |  |                         "S3DataType": "S3Prefix", | 
					
						
							| 
									
										
										
										
											2022-11-17 21:41:08 -01:00
										 |  |  |                         "S3Uri": f"s3://{self.bucket}/{self.prefix}/train/", | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  |                         "S3DataDistributionType": "ShardedByS3Key", | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                 }, | 
					
						
							|  |  |  |                 "CompressionType": "None", | 
					
						
							|  |  |  |                 "RecordWrapperType": "None", | 
					
						
							|  |  |  |             }, | 
					
						
							|  |  |  |             { | 
					
						
							|  |  |  |                 "ChannelName": "validation", | 
					
						
							|  |  |  |                 "DataSource": { | 
					
						
							|  |  |  |                     "S3DataSource": { | 
					
						
							|  |  |  |                         "S3DataType": "S3Prefix", | 
					
						
							| 
									
										
										
										
											2022-11-17 21:41:08 -01:00
										 |  |  |                         "S3Uri": f"s3://{self.bucket}/{self.prefix}/validation/", | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  |                         "S3DataDistributionType": "FullyReplicated", | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                 }, | 
					
						
							|  |  |  |                 "CompressionType": "None", | 
					
						
							|  |  |  |                 "RecordWrapperType": "None", | 
					
						
							|  |  |  |             }, | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  |         ] | 
					
						
							|  |  |  |         self.output_data_config = output_data_config or { | 
					
						
							| 
									
										
										
										
											2022-11-17 21:41:08 -01:00
										 |  |  |             "S3OutputPath": f"s3://{self.bucket}/{self.prefix}/" | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  |         } | 
					
						
							|  |  |  |         self.hyper_parameters = hyper_parameters or { | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  |             "feature_dim": "30", | 
					
						
							|  |  |  |             "mini_batch_size": "100", | 
					
						
							|  |  |  |             "predictor_type": "regressor", | 
					
						
							|  |  |  |             "epochs": "10", | 
					
						
							|  |  |  |             "num_models": "32", | 
					
						
							|  |  |  |             "loss": "absolute_loss", | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.stopping_condition = stopping_condition or {"MaxRuntimeInSeconds": 60 * 60} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def save(self): | 
					
						
							|  |  |  |         sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         params = { | 
					
						
							|  |  |  |             "RoleArn": self.role_arn, | 
					
						
							|  |  |  |             "TrainingJobName": self.training_job_name, | 
					
						
							|  |  |  |             "AlgorithmSpecification": self.algorithm_specification, | 
					
						
							|  |  |  |             "ResourceConfig": self.resource_config, | 
					
						
							|  |  |  |             "InputDataConfig": self.input_data_config, | 
					
						
							|  |  |  |             "OutputDataConfig": self.output_data_config, | 
					
						
							|  |  |  |             "HyperParameters": self.hyper_parameters, | 
					
						
							|  |  |  |             "StoppingCondition": self.stopping_condition, | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         return sagemaker.create_training_job(**params) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @mock_sagemaker | 
					
						
							|  |  |  | def test_create_training_job(): | 
					
						
							|  |  |  |     sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     training_job_name = "MyTrainingJob" | 
					
						
							| 
									
										
										
										
											2022-11-17 21:41:08 -01:00
										 |  |  |     role_arn = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole" | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  |     container = "382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1" | 
					
						
							|  |  |  |     bucket = "my-bucket" | 
					
						
							|  |  |  |     prefix = "sagemaker/DEMO-breast-cancer-prediction/" | 
					
						
							|  |  |  |     algorithm_specification = { | 
					
						
							|  |  |  |         "TrainingImage": container, | 
					
						
							|  |  |  |         "TrainingInputMode": "File", | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     resource_config = { | 
					
						
							|  |  |  |         "InstanceCount": 1, | 
					
						
							|  |  |  |         "InstanceType": "ml.c4.2xlarge", | 
					
						
							|  |  |  |         "VolumeSizeInGB": 10, | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     input_data_config = [ | 
					
						
							|  |  |  |         { | 
					
						
							|  |  |  |             "ChannelName": "train", | 
					
						
							|  |  |  |             "DataSource": { | 
					
						
							|  |  |  |                 "S3DataSource": { | 
					
						
							|  |  |  |                     "S3DataType": "S3Prefix", | 
					
						
							| 
									
										
										
										
											2022-11-17 21:41:08 -01:00
										 |  |  |                     "S3Uri": f"s3://{bucket}/{prefix}/train/", | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  |                     "S3DataDistributionType": "ShardedByS3Key", | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             }, | 
					
						
							|  |  |  |             "CompressionType": "None", | 
					
						
							|  |  |  |             "RecordWrapperType": "None", | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  |         }, | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  |         { | 
					
						
							|  |  |  |             "ChannelName": "validation", | 
					
						
							|  |  |  |             "DataSource": { | 
					
						
							|  |  |  |                 "S3DataSource": { | 
					
						
							|  |  |  |                     "S3DataType": "S3Prefix", | 
					
						
							| 
									
										
										
										
											2022-11-17 21:41:08 -01:00
										 |  |  |                     "S3Uri": f"s3://{bucket}/{prefix}/validation/", | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  |                     "S3DataDistributionType": "FullyReplicated", | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             }, | 
					
						
							|  |  |  |             "CompressionType": "None", | 
					
						
							|  |  |  |             "RecordWrapperType": "None", | 
					
						
							|  |  |  |         }, | 
					
						
							|  |  |  |     ] | 
					
						
							| 
									
										
										
										
											2022-11-17 21:41:08 -01:00
										 |  |  |     output_data_config = {"S3OutputPath": f"s3://{bucket}/{prefix}/"} | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  |     hyper_parameters = { | 
					
						
							|  |  |  |         "feature_dim": "30", | 
					
						
							|  |  |  |         "mini_batch_size": "100", | 
					
						
							|  |  |  |         "predictor_type": "regressor", | 
					
						
							|  |  |  |         "epochs": "10", | 
					
						
							|  |  |  |         "num_models": "32", | 
					
						
							|  |  |  |         "loss": "absolute_loss", | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  |     } | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  |     stopping_condition = {"MaxRuntimeInSeconds": 60 * 60} | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  |     job = MyTrainingJobModel( | 
					
						
							|  |  |  |         training_job_name, | 
					
						
							|  |  |  |         role_arn, | 
					
						
							|  |  |  |         container=container, | 
					
						
							|  |  |  |         bucket=bucket, | 
					
						
							|  |  |  |         prefix=prefix, | 
					
						
							|  |  |  |         algorithm_specification=algorithm_specification, | 
					
						
							|  |  |  |         resource_config=resource_config, | 
					
						
							|  |  |  |         input_data_config=input_data_config, | 
					
						
							|  |  |  |         output_data_config=output_data_config, | 
					
						
							|  |  |  |         hyper_parameters=hyper_parameters, | 
					
						
							|  |  |  |         stopping_condition=stopping_condition, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     resp = job.save() | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert re.match( | 
					
						
							|  |  |  |         rf"^arn:aws:sagemaker:.*:.*:training-job/{training_job_name}$", | 
					
						
							|  |  |  |         resp["TrainingJobArn"], | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     resp = sagemaker.describe_training_job(TrainingJobName=training_job_name) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert resp["TrainingJobName"] == training_job_name | 
					
						
							|  |  |  |     assert re.match( | 
					
						
							|  |  |  |         rf"^arn:aws:sagemaker:.*:.*:training-job/{training_job_name}$", | 
					
						
							|  |  |  |         resp["TrainingJobArn"], | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2020-10-06 07:54:49 +02:00
										 |  |  |     assert resp["ModelArtifacts"]["S3ModelArtifacts"].startswith( | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  |         output_data_config["S3OutputPath"] | 
					
						
							| 
									
										
										
										
											2020-10-06 08:46:05 +02:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2020-10-06 07:54:49 +02:00
										 |  |  |     assert training_job_name in (resp["ModelArtifacts"]["S3ModelArtifacts"]) | 
					
						
							| 
									
										
										
										
											2020-10-06 08:46:05 +02:00
										 |  |  |     assert resp["ModelArtifacts"]["S3ModelArtifacts"].endswith("output/model.tar.gz") | 
					
						
							| 
									
										
										
										
											2020-10-06 07:54:49 +02:00
										 |  |  |     assert resp["TrainingJobStatus"] == "Completed" | 
					
						
							|  |  |  |     assert resp["SecondaryStatus"] == "Completed" | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  |     assert resp["HyperParameters"] == hyper_parameters | 
					
						
							| 
									
										
										
										
											2020-10-06 08:46:05 +02:00
										 |  |  |     assert ( | 
					
						
							|  |  |  |         resp["AlgorithmSpecification"]["TrainingImage"] | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  |         == algorithm_specification["TrainingImage"] | 
					
						
							| 
									
										
										
										
											2020-10-06 08:46:05 +02:00
										 |  |  |     ) | 
					
						
							|  |  |  |     assert ( | 
					
						
							|  |  |  |         resp["AlgorithmSpecification"]["TrainingInputMode"] | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  |         == algorithm_specification["TrainingInputMode"] | 
					
						
							| 
									
										
										
										
											2020-10-06 08:46:05 +02:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2020-10-06 07:54:49 +02:00
										 |  |  |     assert "MetricDefinitions" in resp["AlgorithmSpecification"] | 
					
						
							|  |  |  |     assert "Name" in resp["AlgorithmSpecification"]["MetricDefinitions"][0] | 
					
						
							|  |  |  |     assert "Regex" in resp["AlgorithmSpecification"]["MetricDefinitions"][0] | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  |     assert resp["RoleArn"] == role_arn | 
					
						
							|  |  |  |     assert resp["InputDataConfig"] == input_data_config | 
					
						
							|  |  |  |     assert resp["OutputDataConfig"] == output_data_config | 
					
						
							|  |  |  |     assert resp["ResourceConfig"] == resource_config | 
					
						
							|  |  |  |     assert resp["StoppingCondition"] == stopping_condition | 
					
						
							| 
									
										
										
										
											2020-10-06 07:54:49 +02:00
										 |  |  |     assert isinstance(resp["CreationTime"], datetime.datetime) | 
					
						
							|  |  |  |     assert isinstance(resp["TrainingStartTime"], datetime.datetime) | 
					
						
							|  |  |  |     assert isinstance(resp["TrainingEndTime"], datetime.datetime) | 
					
						
							|  |  |  |     assert isinstance(resp["LastModifiedTime"], datetime.datetime) | 
					
						
							|  |  |  |     assert "SecondaryStatusTransitions" in resp | 
					
						
							|  |  |  |     assert "Status" in resp["SecondaryStatusTransitions"][0] | 
					
						
							|  |  |  |     assert "StartTime" in resp["SecondaryStatusTransitions"][0] | 
					
						
							|  |  |  |     assert "EndTime" in resp["SecondaryStatusTransitions"][0] | 
					
						
							|  |  |  |     assert "StatusMessage" in resp["SecondaryStatusTransitions"][0] | 
					
						
							|  |  |  |     assert "FinalMetricDataList" in resp | 
					
						
							|  |  |  |     assert "MetricName" in resp["FinalMetricDataList"][0] | 
					
						
							|  |  |  |     assert "Value" in resp["FinalMetricDataList"][0] | 
					
						
							|  |  |  |     assert "Timestamp" in resp["FinalMetricDataList"][0] | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | @mock_sagemaker | 
					
						
							|  |  |  | def test_list_training_jobs(): | 
					
						
							|  |  |  |     client = boto3.client("sagemaker", region_name="us-east-1") | 
					
						
							|  |  |  |     name = "blah" | 
					
						
							|  |  |  |     arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar" | 
					
						
							|  |  |  |     test_training_job = MyTrainingJobModel(training_job_name=name, role_arn=arn) | 
					
						
							|  |  |  |     test_training_job.save() | 
					
						
							|  |  |  |     training_jobs = client.list_training_jobs() | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(training_jobs["TrainingJobSummaries"]) == 1 | 
					
						
							|  |  |  |     assert training_jobs["TrainingJobSummaries"][0]["TrainingJobName"] == name | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert re.match( | 
					
						
							|  |  |  |         rf"^arn:aws:sagemaker:.*:.*:training-job/{name}$", | 
					
						
							|  |  |  |         training_jobs["TrainingJobSummaries"][0]["TrainingJobArn"], | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  |     ) | 
					
						
							|  |  |  |     assert training_jobs.get("NextToken") is None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @mock_sagemaker | 
					
						
							|  |  |  | def test_list_training_jobs_multiple(): | 
					
						
							|  |  |  |     client = boto3.client("sagemaker", region_name="us-east-1") | 
					
						
							|  |  |  |     name_job_1 = "blah" | 
					
						
							|  |  |  |     arn_job_1 = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar" | 
					
						
							|  |  |  |     test_training_job_1 = MyTrainingJobModel( | 
					
						
							|  |  |  |         training_job_name=name_job_1, role_arn=arn_job_1 | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     test_training_job_1.save() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     name_job_2 = "blah2" | 
					
						
							|  |  |  |     arn_job_2 = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar2" | 
					
						
							|  |  |  |     test_training_job_2 = MyTrainingJobModel( | 
					
						
							|  |  |  |         training_job_name=name_job_2, role_arn=arn_job_2 | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     test_training_job_2.save() | 
					
						
							|  |  |  |     training_jobs_limit = client.list_training_jobs(MaxResults=1) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(training_jobs_limit["TrainingJobSummaries"]) == 1 | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     training_jobs = client.list_training_jobs() | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(training_jobs["TrainingJobSummaries"]) == 2 | 
					
						
							|  |  |  |     assert training_jobs.get("NextToken") is None | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @mock_sagemaker | 
					
						
							|  |  |  | def test_list_training_jobs_none(): | 
					
						
							|  |  |  |     client = boto3.client("sagemaker", region_name="us-east-1") | 
					
						
							|  |  |  |     training_jobs = client.list_training_jobs() | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(training_jobs["TrainingJobSummaries"]) == 0 | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @mock_sagemaker | 
					
						
							|  |  |  | def test_list_training_jobs_should_validate_input(): | 
					
						
							|  |  |  |     client = boto3.client("sagemaker", region_name="us-east-1") | 
					
						
							|  |  |  |     junk_status_equals = "blah" | 
					
						
							|  |  |  |     with pytest.raises(ClientError) as ex: | 
					
						
							|  |  |  |         client.list_training_jobs(StatusEquals=junk_status_equals) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     expected_error = ( | 
					
						
							|  |  |  |         f"1 validation errors detected: Value '{junk_status_equals}' at " | 
					
						
							|  |  |  |         "'statusEquals' failed to satisfy constraint: Member must satisfy " | 
					
						
							|  |  |  |         "enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', " | 
					
						
							|  |  |  |         "'Failed']" | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  |     assert ex.value.response["Error"]["Code"] == "ValidationException" | 
					
						
							|  |  |  |     assert ex.value.response["Error"]["Message"] == expected_error | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     junk_next_token = "asdf" | 
					
						
							|  |  |  |     with pytest.raises(ClientError) as ex: | 
					
						
							|  |  |  |         client.list_training_jobs(NextToken=junk_next_token) | 
					
						
							|  |  |  |     assert ex.value.response["Error"]["Code"] == "ValidationException" | 
					
						
							|  |  |  |     assert ( | 
					
						
							|  |  |  |         ex.value.response["Error"]["Message"] | 
					
						
							|  |  |  |         == 'Invalid pagination token because "{0}".' | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @mock_sagemaker | 
					
						
							|  |  |  | def test_list_training_jobs_with_name_filters(): | 
					
						
							|  |  |  |     client = boto3.client("sagemaker", region_name="us-east-1") | 
					
						
							|  |  |  |     for i in range(5): | 
					
						
							| 
									
										
										
										
											2022-11-17 21:41:08 -01:00
										 |  |  |         name = f"xgboost-{i}" | 
					
						
							|  |  |  |         arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{i}" | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  |         MyTrainingJobModel(training_job_name=name, role_arn=arn).save() | 
					
						
							|  |  |  |     for i in range(5): | 
					
						
							| 
									
										
										
										
											2022-11-17 21:41:08 -01:00
										 |  |  |         name = f"vgg-{i}" | 
					
						
							|  |  |  |         arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{i}" | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  |         MyTrainingJobModel(training_job_name=name, role_arn=arn).save() | 
					
						
							|  |  |  |     xgboost_training_jobs = client.list_training_jobs(NameContains="xgboost") | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(xgboost_training_jobs["TrainingJobSummaries"]) == 5 | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     training_jobs_with_2 = client.list_training_jobs(NameContains="2") | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(training_jobs_with_2["TrainingJobSummaries"]) == 2 | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @mock_sagemaker | 
					
						
							|  |  |  | def test_list_training_jobs_paginated(): | 
					
						
							|  |  |  |     client = boto3.client("sagemaker", region_name="us-east-1") | 
					
						
							|  |  |  |     for i in range(5): | 
					
						
							| 
									
										
										
										
											2022-11-17 21:41:08 -01:00
										 |  |  |         name = f"xgboost-{i}" | 
					
						
							|  |  |  |         arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{i}" | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  |         MyTrainingJobModel(training_job_name=name, role_arn=arn).save() | 
					
						
							|  |  |  |     xgboost_training_job_1 = client.list_training_jobs( | 
					
						
							|  |  |  |         NameContains="xgboost", MaxResults=1 | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(xgboost_training_job_1["TrainingJobSummaries"]) == 1 | 
					
						
							|  |  |  |     assert ( | 
					
						
							|  |  |  |         xgboost_training_job_1["TrainingJobSummaries"][0]["TrainingJobName"] | 
					
						
							|  |  |  |         == "xgboost-0" | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     assert xgboost_training_job_1.get("NextToken") is not None | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     xgboost_training_job_next = client.list_training_jobs( | 
					
						
							|  |  |  |         NameContains="xgboost", | 
					
						
							|  |  |  |         MaxResults=1, | 
					
						
							|  |  |  |         NextToken=xgboost_training_job_1.get("NextToken"), | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(xgboost_training_job_next["TrainingJobSummaries"]) == 1 | 
					
						
							|  |  |  |     assert ( | 
					
						
							|  |  |  |         xgboost_training_job_next["TrainingJobSummaries"][0]["TrainingJobName"] | 
					
						
							|  |  |  |         == "xgboost-1" | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     assert xgboost_training_job_next.get("NextToken") is not None | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @mock_sagemaker | 
					
						
							|  |  |  | def test_list_training_jobs_paginated_with_target_in_middle(): | 
					
						
							|  |  |  |     client = boto3.client("sagemaker", region_name="us-east-1") | 
					
						
							|  |  |  |     for i in range(5): | 
					
						
							| 
									
										
										
										
											2022-11-17 21:41:08 -01:00
										 |  |  |         name = f"xgboost-{i}" | 
					
						
							|  |  |  |         arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{i}" | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  |         MyTrainingJobModel(training_job_name=name, role_arn=arn).save() | 
					
						
							|  |  |  |     for i in range(5): | 
					
						
							| 
									
										
										
										
											2022-11-17 21:41:08 -01:00
										 |  |  |         name = f"vgg-{i}" | 
					
						
							|  |  |  |         arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{i}" | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  |         MyTrainingJobModel(training_job_name=name, role_arn=arn).save() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     vgg_training_job_1 = client.list_training_jobs(NameContains="vgg", MaxResults=1) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(vgg_training_job_1["TrainingJobSummaries"]) == 0 | 
					
						
							|  |  |  |     assert vgg_training_job_1.get("NextToken") is not None | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     vgg_training_job_6 = client.list_training_jobs(NameContains="vgg", MaxResults=6) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(vgg_training_job_6["TrainingJobSummaries"]) == 1 | 
					
						
							|  |  |  |     assert vgg_training_job_6["TrainingJobSummaries"][0]["TrainingJobName"] == "vgg-0" | 
					
						
							|  |  |  |     assert vgg_training_job_6.get("NextToken") is not None | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     vgg_training_job_10 = client.list_training_jobs(NameContains="vgg", MaxResults=10) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(vgg_training_job_10["TrainingJobSummaries"]) == 5 | 
					
						
							|  |  |  |     assert vgg_training_job_10["TrainingJobSummaries"][-1]["TrainingJobName"] == "vgg-4" | 
					
						
							|  |  |  |     assert vgg_training_job_10.get("NextToken") is None | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @mock_sagemaker | 
					
						
							|  |  |  | def test_list_training_jobs_paginated_with_fragmented_targets(): | 
					
						
							|  |  |  |     client = boto3.client("sagemaker", region_name="us-east-1") | 
					
						
							|  |  |  |     for i in range(5): | 
					
						
							| 
									
										
										
										
											2022-11-17 21:41:08 -01:00
										 |  |  |         name = f"xgboost-{i}" | 
					
						
							|  |  |  |         arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{i}" | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  |         MyTrainingJobModel(training_job_name=name, role_arn=arn).save() | 
					
						
							|  |  |  |     for i in range(5): | 
					
						
							| 
									
										
										
										
											2022-11-17 21:41:08 -01:00
										 |  |  |         name = f"vgg-{i}" | 
					
						
							|  |  |  |         arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{i}" | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  |         MyTrainingJobModel(training_job_name=name, role_arn=arn).save() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     training_jobs_with_2 = client.list_training_jobs(NameContains="2", MaxResults=8) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(training_jobs_with_2["TrainingJobSummaries"]) == 2 | 
					
						
							|  |  |  |     assert training_jobs_with_2.get("NextToken") is not None | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     training_jobs_with_2_next = client.list_training_jobs( | 
					
						
							| 
									
										
										
										
											2022-03-10 13:39:59 -01:00
										 |  |  |         NameContains="2", MaxResults=1, NextToken=training_jobs_with_2.get("NextToken") | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(training_jobs_with_2_next["TrainingJobSummaries"]) == 0 | 
					
						
							|  |  |  |     assert training_jobs_with_2_next.get("NextToken") is not None | 
					
						
							| 
									
										
										
										
											2021-09-02 20:45:47 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     training_jobs_with_2_next_next = client.list_training_jobs( | 
					
						
							|  |  |  |         NameContains="2", | 
					
						
							|  |  |  |         MaxResults=1, | 
					
						
							|  |  |  |         NextToken=training_jobs_with_2_next.get("NextToken"), | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(training_jobs_with_2_next_next["TrainingJobSummaries"]) == 0 | 
					
						
							|  |  |  |     assert training_jobs_with_2_next_next.get("NextToken") is None | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @mock_sagemaker | 
					
						
							|  |  |  | def test_add_tags_to_training_job(): | 
					
						
							|  |  |  |     client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) | 
					
						
							|  |  |  |     name = "blah" | 
					
						
							|  |  |  |     resource_arn = f"arn:aws:sagemaker:us-east-1:000000000000:training-job/{name}" | 
					
						
							|  |  |  |     test_training_job = MyTrainingJobModel( | 
					
						
							|  |  |  |         training_job_name=name, role_arn=resource_arn | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     test_training_job.save() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     tags = [ | 
					
						
							|  |  |  |         {"Key": "myKey", "Value": "myValue"}, | 
					
						
							|  |  |  |     ] | 
					
						
							|  |  |  |     response = client.add_tags(ResourceArn=resource_arn, Tags=tags) | 
					
						
							|  |  |  |     assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     response = client.list_tags(ResourceArn=resource_arn) | 
					
						
							|  |  |  |     assert response["Tags"] == tags | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @mock_sagemaker | 
					
						
							|  |  |  | def test_delete_tags_from_training_job(): | 
					
						
							|  |  |  |     client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) | 
					
						
							|  |  |  |     name = "blah" | 
					
						
							|  |  |  |     resource_arn = f"arn:aws:sagemaker:us-east-1:000000000000:training-job/{name}" | 
					
						
							|  |  |  |     test_training_job = MyTrainingJobModel( | 
					
						
							|  |  |  |         training_job_name=name, role_arn=resource_arn | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     test_training_job.save() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     tags = [ | 
					
						
							|  |  |  |         {"Key": "myKey", "Value": "myValue"}, | 
					
						
							|  |  |  |     ] | 
					
						
							|  |  |  |     response = client.add_tags(ResourceArn=resource_arn, Tags=tags) | 
					
						
							|  |  |  |     assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     tag_keys = [tag["Key"] for tag in tags] | 
					
						
							|  |  |  |     response = client.delete_tags(ResourceArn=resource_arn, TagKeys=tag_keys) | 
					
						
							|  |  |  |     assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     response = client.list_tags(ResourceArn=resource_arn) | 
					
						
							|  |  |  |     assert response["Tags"] == [] | 
					
						
							| 
									
										
										
										
											2022-10-19 21:53:02 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @mock_sagemaker | 
					
						
							|  |  |  | def test_describe_unknown_training_job(): | 
					
						
							|  |  |  |     client = boto3.client("sagemaker", region_name="us-east-1") | 
					
						
							|  |  |  |     with pytest.raises(ClientError) as exc: | 
					
						
							|  |  |  |         client.describe_training_job(TrainingJobName="unknown") | 
					
						
							|  |  |  |     err = exc.value.response["Error"] | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert err["Code"] == "ValidationException" | 
					
						
							|  |  |  |     assert err["Message"] == ( | 
					
						
							|  |  |  |         "Could not find training job 'arn:aws:sagemaker:us-east-1:" | 
					
						
							|  |  |  |         f"{ACCOUNT_ID}:training-job/unknown'." | 
					
						
							| 
									
										
										
										
											2022-10-19 21:53:02 +00:00
										 |  |  |     ) |