| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  | import datetime | 
					
						
							|  |  |  | import re | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-10 13:54:49 -04:00
										 |  |  | import boto3 | 
					
						
							|  |  |  | import pytest | 
					
						
							| 
									
										
										
										
											2023-11-30 07:55:51 -08:00
										 |  |  | from botocore.exceptions import ClientError | 
					
						
							| 
									
										
										
										
											2023-05-10 13:54:49 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | from moto import mock_sagemaker | 
					
						
							|  |  |  | from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | FAKE_ROLE_ARN = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole" | 
					
						
							|  |  |  | TEST_REGION_NAME = "us-east-1" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  | class MyTransformJobModel: | 
					
						
							| 
									
										
										
										
											2023-05-10 13:54:49 -04:00
										 |  |  |     def __init__( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         transform_job_name, | 
					
						
							|  |  |  |         model_name, | 
					
						
							|  |  |  |         max_concurrent_transforms=None, | 
					
						
							|  |  |  |         model_client_config=None, | 
					
						
							|  |  |  |         max_payload_in_mb=None, | 
					
						
							|  |  |  |         batch_strategy=None, | 
					
						
							|  |  |  |         environment=None, | 
					
						
							|  |  |  |         transform_input=None, | 
					
						
							|  |  |  |         transform_output=None, | 
					
						
							|  |  |  |         data_capture_config=None, | 
					
						
							|  |  |  |         transform_resources=None, | 
					
						
							|  |  |  |         data_processing=None, | 
					
						
							|  |  |  |         tags=None, | 
					
						
							|  |  |  |         experiment_config=None, | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |         self.transform_job_name = transform_job_name | 
					
						
							|  |  |  |         self.model_name = model_name | 
					
						
							|  |  |  |         self.max_concurrent_transforms = max_concurrent_transforms or 1 | 
					
						
							|  |  |  |         self.model_client_config = model_client_config or {} | 
					
						
							|  |  |  |         self.max_payload_in_mb = max_payload_in_mb or 1 | 
					
						
							|  |  |  |         self.batch_strategy = batch_strategy or "SingleRecord" | 
					
						
							|  |  |  |         self.environment = environment or {} | 
					
						
							|  |  |  |         self.transform_input = transform_input or { | 
					
						
							|  |  |  |             "DataSource": { | 
					
						
							|  |  |  |                 "S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": "input"} | 
					
						
							|  |  |  |             }, | 
					
						
							|  |  |  |             "ContentType": "application/json", | 
					
						
							|  |  |  |             "CompressionType": "None", | 
					
						
							|  |  |  |             "SplitType": "None", | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         self.transform_output = transform_output or { | 
					
						
							|  |  |  |             "S3OutputPath": "some-bucket", | 
					
						
							|  |  |  |             "Accept": "application/json", | 
					
						
							|  |  |  |             "AssembleWith": "None", | 
					
						
							|  |  |  |             "KmsKeyId": "None", | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         self.data_capture_config = data_capture_config or { | 
					
						
							|  |  |  |             "DestinationS3Uri": "data_capture", | 
					
						
							|  |  |  |             "KmsKeyId": "None", | 
					
						
							|  |  |  |             "GenerateInferenceId": False, | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         self.transform_resources = transform_resources or { | 
					
						
							|  |  |  |             "InstanceType": "ml.m5.2xlarge", | 
					
						
							|  |  |  |             "InstanceCount": 1, | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         self.data_processing = data_processing or {} | 
					
						
							|  |  |  |         self.tags = tags or [] | 
					
						
							|  |  |  |         self.experiment_config = experiment_config or {} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def save(self): | 
					
						
							|  |  |  |         sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         params = { | 
					
						
							|  |  |  |             "TransformJobName": self.transform_job_name, | 
					
						
							|  |  |  |             "ModelName": self.model_name, | 
					
						
							|  |  |  |             "MaxConcurrentTransforms": self.max_concurrent_transforms, | 
					
						
							|  |  |  |             "ModelClientConfig": self.model_client_config, | 
					
						
							|  |  |  |             "MaxPayloadInMB": self.max_payload_in_mb, | 
					
						
							|  |  |  |             "BatchStrategy": self.batch_strategy, | 
					
						
							|  |  |  |             "Environment": self.environment, | 
					
						
							|  |  |  |             "TransformInput": self.transform_input, | 
					
						
							|  |  |  |             "TransformOutput": self.transform_output, | 
					
						
							|  |  |  |             "DataCaptureConfig": self.data_capture_config, | 
					
						
							|  |  |  |             "TransformResources": self.transform_resources, | 
					
						
							|  |  |  |             "DataProcessing": self.data_processing, | 
					
						
							|  |  |  |             "Tags": self.tags, | 
					
						
							|  |  |  |             "ExperimentConfig": self.experiment_config, | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         return sagemaker.create_transform_job(**params) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @mock_sagemaker | 
					
						
							|  |  |  | def test_create_transform_job(): | 
					
						
							|  |  |  |     sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) | 
					
						
							|  |  |  |     transform_job_name = "MyTransformJob" | 
					
						
							|  |  |  |     model_name = "MyModelName" | 
					
						
							|  |  |  |     bucket = "my-bucket" | 
					
						
							|  |  |  |     transform_input = { | 
					
						
							|  |  |  |         "DataSource": {"S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": "input"}}, | 
					
						
							|  |  |  |         "ContentType": "application/json", | 
					
						
							|  |  |  |         "CompressionType": "None", | 
					
						
							|  |  |  |         "SplitType": "None", | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     transform_output = { | 
					
						
							|  |  |  |         "S3OutputPath": bucket, | 
					
						
							|  |  |  |         "Accept": "application/json", | 
					
						
							|  |  |  |         "AssembleWith": "None", | 
					
						
							|  |  |  |         "KmsKeyId": "None", | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     model_client_config = { | 
					
						
							|  |  |  |         "InvocationsTimeoutInSeconds": 60, | 
					
						
							|  |  |  |         "InvocationsMaxRetries": 1, | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     max_payload_in_mb = 1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     data_capture_config = { | 
					
						
							|  |  |  |         "DestinationS3Uri": "data_capture", | 
					
						
							|  |  |  |         "KmsKeyId": "None", | 
					
						
							|  |  |  |         "GenerateInferenceId": False, | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     transform_resources = { | 
					
						
							|  |  |  |         "InstanceType": "ml.m5.2xlarge", | 
					
						
							|  |  |  |         "InstanceCount": 1, | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     data_processing = { | 
					
						
							|  |  |  |         "InputFilter": "$.features", | 
					
						
							|  |  |  |         "OutputFilter": "$['id','SageMakerOutput']", | 
					
						
							|  |  |  |         "JoinSource": "None", | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     experiment_config = { | 
					
						
							|  |  |  |         "ExperimentName": "MyExperiment", | 
					
						
							|  |  |  |         "TrialName": "MyTrial", | 
					
						
							|  |  |  |         "TrialComponentDisplayName": "MyTrialDisplay", | 
					
						
							|  |  |  |         "RunName": "MyRun", | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     job = MyTransformJobModel( | 
					
						
							|  |  |  |         transform_job_name=transform_job_name, | 
					
						
							|  |  |  |         model_name=model_name, | 
					
						
							|  |  |  |         transform_output=transform_output, | 
					
						
							|  |  |  |         model_client_config=model_client_config, | 
					
						
							|  |  |  |         max_payload_in_mb=max_payload_in_mb, | 
					
						
							|  |  |  |         data_capture_config=data_capture_config, | 
					
						
							|  |  |  |         transform_resources=transform_resources, | 
					
						
							|  |  |  |         data_processing=data_processing, | 
					
						
							|  |  |  |         experiment_config=experiment_config, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     resp = job.save() | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert re.match( | 
					
						
							|  |  |  |         rf"^arn:aws:sagemaker:.*:.*:transform-job/{transform_job_name}$", | 
					
						
							|  |  |  |         resp["TransformJobArn"], | 
					
						
							| 
									
										
										
										
											2023-05-10 13:54:49 -04:00
										 |  |  |     ) | 
					
						
							|  |  |  |     resp = sagemaker.describe_transform_job(TransformJobName=transform_job_name) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert resp["TransformJobName"] == transform_job_name | 
					
						
							|  |  |  |     assert resp["TransformJobStatus"] == "Completed" | 
					
						
							|  |  |  |     assert resp["ModelName"] == model_name | 
					
						
							|  |  |  |     assert resp["MaxConcurrentTransforms"] == 1 | 
					
						
							|  |  |  |     assert resp["ModelClientConfig"] == model_client_config | 
					
						
							|  |  |  |     assert resp["MaxPayloadInMB"] == max_payload_in_mb | 
					
						
							|  |  |  |     assert resp["BatchStrategy"] == "SingleRecord" | 
					
						
							|  |  |  |     assert resp["TransformInput"] == transform_input | 
					
						
							|  |  |  |     assert resp["TransformOutput"] == transform_output | 
					
						
							|  |  |  |     assert resp["DataCaptureConfig"] == data_capture_config | 
					
						
							|  |  |  |     assert resp["TransformResources"] == transform_resources | 
					
						
							|  |  |  |     assert resp["DataProcessing"] == data_processing | 
					
						
							|  |  |  |     assert resp["ExperimentConfig"] == experiment_config | 
					
						
							| 
									
										
										
										
											2023-05-10 13:54:49 -04:00
										 |  |  |     assert isinstance(resp["CreationTime"], datetime.datetime) | 
					
						
							|  |  |  |     assert isinstance(resp["TransformStartTime"], datetime.datetime) | 
					
						
							|  |  |  |     assert isinstance(resp["TransformEndTime"], datetime.datetime) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @mock_sagemaker | 
					
						
							|  |  |  | def test_list_transform_jobs(): | 
					
						
							|  |  |  |     client = boto3.client("sagemaker", region_name="us-east-1") | 
					
						
							|  |  |  |     name = "blah" | 
					
						
							|  |  |  |     model_name = "blah_model" | 
					
						
							|  |  |  |     test_transform_job = MyTransformJobModel( | 
					
						
							|  |  |  |         transform_job_name=name, model_name=model_name | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     test_transform_job.save() | 
					
						
							|  |  |  |     transform_jobs = client.list_transform_jobs() | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(transform_jobs["TransformJobSummaries"]) == 1 | 
					
						
							|  |  |  |     assert transform_jobs["TransformJobSummaries"][0]["TransformJobName"] == name | 
					
						
							| 
									
										
										
										
											2023-05-10 13:54:49 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert re.match( | 
					
						
							|  |  |  |         rf"^arn:aws:sagemaker:.*:.*:transform-job/{name}$", | 
					
						
							|  |  |  |         transform_jobs["TransformJobSummaries"][0]["TransformJobArn"], | 
					
						
							| 
									
										
										
										
											2023-05-10 13:54:49 -04:00
										 |  |  |     ) | 
					
						
							|  |  |  |     assert transform_jobs.get("NextToken") is None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @mock_sagemaker | 
					
						
							|  |  |  | def test_list_transform_jobs_multiple(): | 
					
						
							|  |  |  |     client = boto3.client("sagemaker", region_name="us-east-1") | 
					
						
							|  |  |  |     name_job_1 = "blah" | 
					
						
							|  |  |  |     model_name1 = "blah_model" | 
					
						
							|  |  |  |     test_transform_job_1 = MyTransformJobModel( | 
					
						
							|  |  |  |         transform_job_name=name_job_1, model_name=model_name1 | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     test_transform_job_1.save() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     name_job_2 = "blah2" | 
					
						
							|  |  |  |     model_name2 = "blah_model2" | 
					
						
							|  |  |  |     test_transform_job_2 = MyTransformJobModel( | 
					
						
							|  |  |  |         transform_job_name=name_job_2, model_name=model_name2 | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     test_transform_job_2.save() | 
					
						
							|  |  |  |     transform_jobs_limit = client.list_transform_jobs(MaxResults=1) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(transform_jobs_limit["TransformJobSummaries"]) == 1 | 
					
						
							| 
									
										
										
										
											2023-05-10 13:54:49 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  |     transform_jobs = client.list_transform_jobs() | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(transform_jobs["TransformJobSummaries"]) == 2 | 
					
						
							|  |  |  |     assert transform_jobs.get("NextToken") is None | 
					
						
							| 
									
										
										
										
											2023-05-10 13:54:49 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @mock_sagemaker | 
					
						
							|  |  |  | def test_list_transform_jobs_none(): | 
					
						
							|  |  |  |     client = boto3.client("sagemaker", region_name="us-east-1") | 
					
						
							|  |  |  |     transform_jobs = client.list_transform_jobs() | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(transform_jobs["TransformJobSummaries"]) == 0 | 
					
						
							| 
									
										
										
										
											2023-05-10 13:54:49 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @mock_sagemaker | 
					
						
							|  |  |  | def test_list_transform_jobs_should_validate_input(): | 
					
						
							|  |  |  |     client = boto3.client("sagemaker", region_name="us-east-1") | 
					
						
							|  |  |  |     junk_status_equals = "blah" | 
					
						
							|  |  |  |     with pytest.raises(ClientError) as ex: | 
					
						
							|  |  |  |         client.list_transform_jobs(StatusEquals=junk_status_equals) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     expected_error = ( | 
					
						
							|  |  |  |         f"1 validation errors detected: Value '{junk_status_equals}' at " | 
					
						
							|  |  |  |         "'statusEquals' failed to satisfy constraint: Member must satisfy " | 
					
						
							|  |  |  |         "enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', " | 
					
						
							|  |  |  |         "'Failed']" | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-05-10 13:54:49 -04:00
										 |  |  |     assert ex.value.response["Error"]["Code"] == "ValidationException" | 
					
						
							|  |  |  |     assert ex.value.response["Error"]["Message"] == expected_error | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     junk_next_token = "asdf" | 
					
						
							|  |  |  |     with pytest.raises(ClientError) as ex: | 
					
						
							|  |  |  |         client.list_transform_jobs(NextToken=junk_next_token) | 
					
						
							|  |  |  |     assert ex.value.response["Error"]["Code"] == "ValidationException" | 
					
						
							|  |  |  |     assert ( | 
					
						
							|  |  |  |         ex.value.response["Error"]["Message"] | 
					
						
							|  |  |  |         == 'Invalid pagination token because "{0}".' | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @mock_sagemaker | 
					
						
							|  |  |  | def test_list_transform_jobs_with_name_filters(): | 
					
						
							|  |  |  |     client = boto3.client("sagemaker", region_name="us-east-1") | 
					
						
							|  |  |  |     for i in range(5): | 
					
						
							|  |  |  |         name = f"xgboost-{i}" | 
					
						
							|  |  |  |         model_name = f"blah_model-{i}" | 
					
						
							|  |  |  |         MyTransformJobModel(transform_job_name=name, model_name=model_name).save() | 
					
						
							|  |  |  |     for i in range(5): | 
					
						
							|  |  |  |         name = f"vgg-{i}" | 
					
						
							|  |  |  |         model_name = f"blah_model-{i}" | 
					
						
							|  |  |  |         MyTransformJobModel(transform_job_name=name, model_name=model_name).save() | 
					
						
							|  |  |  |     xgboost_transform_jobs = client.list_transform_jobs(NameContains="xgboost") | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(xgboost_transform_jobs["TransformJobSummaries"]) == 5 | 
					
						
							| 
									
										
										
										
											2023-05-10 13:54:49 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  |     transform_jobs_with_2 = client.list_transform_jobs(NameContains="2") | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(transform_jobs_with_2["TransformJobSummaries"]) == 2 | 
					
						
							| 
									
										
										
										
											2023-05-10 13:54:49 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @mock_sagemaker | 
					
						
							|  |  |  | def test_list_transform_jobs_paginated(): | 
					
						
							|  |  |  |     client = boto3.client("sagemaker", region_name="us-east-1") | 
					
						
							|  |  |  |     for i in range(5): | 
					
						
							|  |  |  |         name = f"xgboost-{i}" | 
					
						
							|  |  |  |         model_name = f"my-model-{i}" | 
					
						
							|  |  |  |         MyTransformJobModel(transform_job_name=name, model_name=model_name).save() | 
					
						
							|  |  |  |     xgboost_transform_job_1 = client.list_transform_jobs( | 
					
						
							|  |  |  |         NameContains="xgboost", MaxResults=1 | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(xgboost_transform_job_1["TransformJobSummaries"]) == 1 | 
					
						
							|  |  |  |     assert ( | 
					
						
							|  |  |  |         xgboost_transform_job_1["TransformJobSummaries"][0]["TransformJobName"] | 
					
						
							|  |  |  |         == "xgboost-0" | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     assert xgboost_transform_job_1.get("NextToken") is not None | 
					
						
							| 
									
										
										
										
											2023-05-10 13:54:49 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  |     xgboost_transform_job_next = client.list_transform_jobs( | 
					
						
							|  |  |  |         NameContains="xgboost", | 
					
						
							|  |  |  |         MaxResults=1, | 
					
						
							|  |  |  |         NextToken=xgboost_transform_job_1.get("NextToken"), | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(xgboost_transform_job_next["TransformJobSummaries"]) == 1 | 
					
						
							|  |  |  |     assert ( | 
					
						
							|  |  |  |         xgboost_transform_job_next["TransformJobSummaries"][0]["TransformJobName"] | 
					
						
							|  |  |  |         == "xgboost-1" | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     assert xgboost_transform_job_next.get("NextToken") is not None | 
					
						
							| 
									
										
										
										
											2023-05-10 13:54:49 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @mock_sagemaker | 
					
						
							|  |  |  | def test_list_transform_jobs_paginated_with_target_in_middle(): | 
					
						
							|  |  |  |     client = boto3.client("sagemaker", region_name="us-east-1") | 
					
						
							|  |  |  |     for i in range(5): | 
					
						
							|  |  |  |         name = f"xgboost-{i}" | 
					
						
							|  |  |  |         model_name = f"my-model-{i}" | 
					
						
							|  |  |  |         MyTransformJobModel(transform_job_name=name, model_name=model_name).save() | 
					
						
							|  |  |  |     for i in range(5): | 
					
						
							|  |  |  |         name = f"vgg-{i}" | 
					
						
							|  |  |  |         MyTransformJobModel(transform_job_name=name, model_name=model_name).save() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     vgg_transform_job_1 = client.list_transform_jobs(NameContains="vgg", MaxResults=1) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(vgg_transform_job_1["TransformJobSummaries"]) == 0 | 
					
						
							|  |  |  |     assert vgg_transform_job_1.get("NextToken") is not None | 
					
						
							| 
									
										
										
										
											2023-05-10 13:54:49 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  |     vgg_transform_job_6 = client.list_transform_jobs(NameContains="vgg", MaxResults=6) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(vgg_transform_job_6["TransformJobSummaries"]) == 1 | 
					
						
							|  |  |  |     assert ( | 
					
						
							|  |  |  |         vgg_transform_job_6["TransformJobSummaries"][0]["TransformJobName"] == "vgg-0" | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     assert vgg_transform_job_6.get("NextToken") is not None | 
					
						
							| 
									
										
										
										
											2023-05-10 13:54:49 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  |     vgg_transform_job_10 = client.list_transform_jobs(NameContains="vgg", MaxResults=10) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(vgg_transform_job_10["TransformJobSummaries"]) == 5 | 
					
						
							|  |  |  |     assert ( | 
					
						
							|  |  |  |         vgg_transform_job_10["TransformJobSummaries"][-1]["TransformJobName"] == "vgg-4" | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     assert vgg_transform_job_10.get("NextToken") is None | 
					
						
							| 
									
										
										
										
											2023-05-10 13:54:49 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @mock_sagemaker | 
					
						
							|  |  |  | def test_list_transform_jobs_paginated_with_fragmented_targets(): | 
					
						
							|  |  |  |     client = boto3.client("sagemaker", region_name="us-east-1") | 
					
						
							|  |  |  |     for i in range(5): | 
					
						
							|  |  |  |         name = f"xgboost-{i}" | 
					
						
							|  |  |  |         model_name = f"my-model-{i}" | 
					
						
							|  |  |  |         MyTransformJobModel(transform_job_name=name, model_name=model_name).save() | 
					
						
							|  |  |  |     for i in range(5): | 
					
						
							|  |  |  |         name = f"vgg-{i}" | 
					
						
							|  |  |  |         MyTransformJobModel(transform_job_name=name, model_name=model_name).save() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     transform_jobs_with_2 = client.list_transform_jobs(NameContains="2", MaxResults=8) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(transform_jobs_with_2["TransformJobSummaries"]) == 2 | 
					
						
							|  |  |  |     assert transform_jobs_with_2.get("NextToken") is not None | 
					
						
							| 
									
										
										
										
											2023-05-10 13:54:49 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  |     transform_jobs_with_2_next = client.list_transform_jobs( | 
					
						
							|  |  |  |         NameContains="2", MaxResults=1, NextToken=transform_jobs_with_2.get("NextToken") | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(transform_jobs_with_2_next["TransformJobSummaries"]) == 0 | 
					
						
							|  |  |  |     assert transform_jobs_with_2_next.get("NextToken") is not None | 
					
						
							| 
									
										
										
										
											2023-05-10 13:54:49 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  |     transform_jobs_with_2_next_next = client.list_transform_jobs( | 
					
						
							|  |  |  |         NameContains="2", | 
					
						
							|  |  |  |         MaxResults=1, | 
					
						
							|  |  |  |         NextToken=transform_jobs_with_2_next.get("NextToken"), | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(transform_jobs_with_2_next_next["TransformJobSummaries"]) == 0 | 
					
						
							|  |  |  |     assert transform_jobs_with_2_next_next.get("NextToken") is None | 
					
						
							| 
									
										
										
										
											2023-05-10 13:54:49 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @mock_sagemaker | 
					
						
							|  |  |  | def test_add_tags_to_transform_job(): | 
					
						
							|  |  |  |     client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) | 
					
						
							|  |  |  |     name = "blah" | 
					
						
							|  |  |  |     model_name = "my-model" | 
					
						
							|  |  |  |     resource_arn = "arn:aws:sagemaker:us-east-1:123456789012:transform-job/blah" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     test_transform_job = MyTransformJobModel( | 
					
						
							|  |  |  |         transform_job_name=name, model_name=model_name | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     test_transform_job.save() | 
					
						
							|  |  |  |     tags = [ | 
					
						
							|  |  |  |         {"Key": "myKey", "Value": "myValue"}, | 
					
						
							|  |  |  |     ] | 
					
						
							|  |  |  |     response = client.add_tags(ResourceArn=resource_arn, Tags=tags) | 
					
						
							|  |  |  |     assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     response = client.list_tags(ResourceArn=resource_arn) | 
					
						
							|  |  |  |     assert response["Tags"] == tags | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @mock_sagemaker | 
					
						
							|  |  |  | def test_delete_tags_from_transform_job(): | 
					
						
							|  |  |  |     client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) | 
					
						
							|  |  |  |     name = "blah" | 
					
						
							|  |  |  |     model_name = "my-model" | 
					
						
							|  |  |  |     resource_arn = "arn:aws:sagemaker:us-east-1:123456789012:transform-job/blah" | 
					
						
							|  |  |  |     test_transform_job = MyTransformJobModel( | 
					
						
							|  |  |  |         transform_job_name=name, model_name=model_name | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     test_transform_job.save() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     tags = [ | 
					
						
							|  |  |  |         {"Key": "myKey", "Value": "myValue"}, | 
					
						
							|  |  |  |     ] | 
					
						
							|  |  |  |     response = client.add_tags(ResourceArn=resource_arn, Tags=tags) | 
					
						
							|  |  |  |     assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     tag_keys = [tag["Key"] for tag in tags] | 
					
						
							|  |  |  |     response = client.delete_tags(ResourceArn=resource_arn, TagKeys=tag_keys) | 
					
						
							|  |  |  |     assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     response = client.list_tags(ResourceArn=resource_arn) | 
					
						
							|  |  |  |     assert response["Tags"] == [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @mock_sagemaker | 
					
						
							|  |  |  | def test_describe_unknown_transform_job(): | 
					
						
							|  |  |  |     client = boto3.client("sagemaker", region_name="us-east-1") | 
					
						
							|  |  |  |     with pytest.raises(ClientError) as exc: | 
					
						
							|  |  |  |         client.describe_transform_job(TransformJobName="unknown") | 
					
						
							|  |  |  |     err = exc.value.response["Error"] | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert err["Code"] == "ValidationException" | 
					
						
							|  |  |  |     assert err["Message"] == ( | 
					
						
							|  |  |  |         "Could not find transform job 'arn:aws:sagemaker:us-east-1:" | 
					
						
							|  |  |  |         f"{ACCOUNT_ID}:transform-job/unknown'." | 
					
						
							| 
									
										
										
										
											2023-05-10 13:54:49 -04:00
										 |  |  |     ) |