| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  | import datetime | 
					
						
							|  |  |  | import re | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  | import boto3 | 
					
						
							|  |  |  | import pytest | 
					
						
							| 
									
										
										
										
											2023-11-30 07:55:51 -08:00
										 |  |  | from botocore.exceptions import ClientError | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | from moto import mock_sagemaker | 
					
						
							| 
									
										
										
										
											2022-08-13 09:49:43 +00:00
										 |  |  | from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  | FAKE_ROLE_ARN = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole" | 
					
						
							|  |  |  | FAKE_PROCESSING_JOB_NAME = "MyProcessingJob" | 
					
						
							|  |  |  | FAKE_CONTAINER = "382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1" | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  | TEST_REGION_NAME = "us-east-1" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-04 16:28:30 +00:00
										 |  |  | @pytest.fixture(name="sagemaker_client") | 
					
						
							|  |  |  | def fixture_sagemaker_client(): | 
					
						
							|  |  |  |     with mock_sagemaker(): | 
					
						
							|  |  |  |         yield boto3.client("sagemaker", region_name=TEST_REGION_NAME) | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  | class MyProcessingJobModel: | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  |     def __init__( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         processing_job_name, | 
					
						
							|  |  |  |         role_arn, | 
					
						
							|  |  |  |         container=None, | 
					
						
							|  |  |  |         bucket=None, | 
					
						
							|  |  |  |         prefix=None, | 
					
						
							|  |  |  |         app_specification=None, | 
					
						
							|  |  |  |         network_config=None, | 
					
						
							|  |  |  |         processing_inputs=None, | 
					
						
							|  |  |  |         processing_output_config=None, | 
					
						
							|  |  |  |         processing_resources=None, | 
					
						
							|  |  |  |         stopping_condition=None, | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |         self.processing_job_name = processing_job_name | 
					
						
							|  |  |  |         self.role_arn = role_arn | 
					
						
							|  |  |  |         self.container = ( | 
					
						
							|  |  |  |             container | 
					
						
							|  |  |  |             or "683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn:0.23-1-cpu-py3" | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self.bucket = bucket or "my-bucket" | 
					
						
							|  |  |  |         self.prefix = prefix or "sagemaker" | 
					
						
							|  |  |  |         self.app_specification = app_specification or { | 
					
						
							|  |  |  |             "ImageUri": self.container, | 
					
						
							|  |  |  |             "ContainerEntrypoint": ["python3"], | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         self.network_config = network_config or { | 
					
						
							|  |  |  |             "EnableInterContainerTrafficEncryption": False, | 
					
						
							|  |  |  |             "EnableNetworkIsolation": False, | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         self.processing_inputs = processing_inputs or [ | 
					
						
							|  |  |  |             { | 
					
						
							|  |  |  |                 "InputName": "input", | 
					
						
							|  |  |  |                 "AppManaged": False, | 
					
						
							|  |  |  |                 "S3Input": { | 
					
						
							| 
									
										
										
										
											2022-11-17 21:41:08 -01:00
										 |  |  |                     "S3Uri": f"s3://{self.bucket}/{self.prefix}/processing/", | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  |                     "LocalPath": "/opt/ml/processing/input", | 
					
						
							|  |  |  |                     "S3DataType": "S3Prefix", | 
					
						
							|  |  |  |                     "S3InputMode": "File", | 
					
						
							|  |  |  |                     "S3DataDistributionType": "FullyReplicated", | 
					
						
							|  |  |  |                     "S3CompressionType": "None", | 
					
						
							|  |  |  |                 }, | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         ] | 
					
						
							|  |  |  |         self.processing_output_config = processing_output_config or { | 
					
						
							|  |  |  |             "Outputs": [ | 
					
						
							|  |  |  |                 { | 
					
						
							|  |  |  |                     "OutputName": "output", | 
					
						
							|  |  |  |                     "S3Output": { | 
					
						
							| 
									
										
										
										
											2022-11-17 21:41:08 -01:00
										 |  |  |                         "S3Uri": f"s3://{self.bucket}/{self.prefix}/processing/", | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  |                         "LocalPath": "/opt/ml/processing/output", | 
					
						
							|  |  |  |                         "S3UploadMode": "EndOfJob", | 
					
						
							|  |  |  |                     }, | 
					
						
							|  |  |  |                     "AppManaged": False, | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             ] | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         self.processing_resources = processing_resources or { | 
					
						
							|  |  |  |             "ClusterConfig": { | 
					
						
							|  |  |  |                 "InstanceCount": 1, | 
					
						
							|  |  |  |                 "InstanceType": "ml.m5.large", | 
					
						
							|  |  |  |                 "VolumeSizeInGB": 10, | 
					
						
							|  |  |  |             }, | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         self.stopping_condition = stopping_condition or { | 
					
						
							|  |  |  |             "MaxRuntimeInSeconds": 3600, | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     def save(self, sagemaker_client): | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  |         params = { | 
					
						
							|  |  |  |             "AppSpecification": self.app_specification, | 
					
						
							|  |  |  |             "NetworkConfig": self.network_config, | 
					
						
							|  |  |  |             "ProcessingInputs": self.processing_inputs, | 
					
						
							|  |  |  |             "ProcessingJobName": self.processing_job_name, | 
					
						
							|  |  |  |             "ProcessingOutputConfig": self.processing_output_config, | 
					
						
							|  |  |  |             "ProcessingResources": self.processing_resources, | 
					
						
							|  |  |  |             "RoleArn": self.role_arn, | 
					
						
							|  |  |  |             "StoppingCondition": self.stopping_condition, | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |         return sagemaker_client.create_processing_job(**params) | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  | def test_create_processing_job(sagemaker_client): | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  |     bucket = "my-bucket" | 
					
						
							|  |  |  |     prefix = "my-prefix" | 
					
						
							|  |  |  |     app_specification = { | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |         "ImageUri": FAKE_CONTAINER, | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  |         "ContainerEntrypoint": ["python3", "app.py"], | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     processing_resources = { | 
					
						
							|  |  |  |         "ClusterConfig": { | 
					
						
							|  |  |  |             "InstanceCount": 2, | 
					
						
							|  |  |  |             "InstanceType": "ml.m5.xlarge", | 
					
						
							|  |  |  |             "VolumeSizeInGB": 20, | 
					
						
							|  |  |  |         }, | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     stopping_condition = {"MaxRuntimeInSeconds": 60 * 60} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     job = MyProcessingJobModel( | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |         processing_job_name=FAKE_PROCESSING_JOB_NAME, | 
					
						
							|  |  |  |         role_arn=FAKE_ROLE_ARN, | 
					
						
							|  |  |  |         container=FAKE_CONTAINER, | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  |         bucket=bucket, | 
					
						
							|  |  |  |         prefix=prefix, | 
					
						
							|  |  |  |         app_specification=app_specification, | 
					
						
							|  |  |  |         processing_resources=processing_resources, | 
					
						
							|  |  |  |         stopping_condition=stopping_condition, | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     resp = job.save(sagemaker_client) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert re.match( | 
					
						
							|  |  |  |         rf"^arn:aws:sagemaker:.*:.*:processing-job/{FAKE_PROCESSING_JOB_NAME}$", | 
					
						
							|  |  |  |         resp["ProcessingJobArn"], | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     resp = sagemaker_client.describe_processing_job( | 
					
						
							|  |  |  |         ProcessingJobName=FAKE_PROCESSING_JOB_NAME | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert resp["ProcessingJobName"] == FAKE_PROCESSING_JOB_NAME | 
					
						
							|  |  |  |     assert re.match( | 
					
						
							|  |  |  |         rf"^arn:aws:sagemaker:.*:.*:processing-job/{FAKE_PROCESSING_JOB_NAME}$", | 
					
						
							|  |  |  |         resp["ProcessingJobArn"], | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  |     ) | 
					
						
							|  |  |  |     assert "python3" in resp["AppSpecification"]["ContainerEntrypoint"] | 
					
						
							|  |  |  |     assert "app.py" in resp["AppSpecification"]["ContainerEntrypoint"] | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     assert resp["RoleArn"] == FAKE_ROLE_ARN | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  |     assert resp["ProcessingJobStatus"] == "Completed" | 
					
						
							|  |  |  |     assert isinstance(resp["CreationTime"], datetime.datetime) | 
					
						
							|  |  |  |     assert isinstance(resp["LastModifiedTime"], datetime.datetime) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  | def test_list_processing_jobs(sagemaker_client): | 
					
						
							|  |  |  |     test_processing_job = MyProcessingJobModel( | 
					
						
							|  |  |  |         processing_job_name=FAKE_PROCESSING_JOB_NAME, role_arn=FAKE_ROLE_ARN | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     test_processing_job.save(sagemaker_client) | 
					
						
							|  |  |  |     processing_jobs = sagemaker_client.list_processing_jobs() | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(processing_jobs["ProcessingJobSummaries"]) == 1 | 
					
						
							|  |  |  |     assert ( | 
					
						
							|  |  |  |         processing_jobs["ProcessingJobSummaries"][0]["ProcessingJobName"] | 
					
						
							|  |  |  |         == FAKE_PROCESSING_JOB_NAME | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     assert re.match( | 
					
						
							|  |  |  |         rf"^arn:aws:sagemaker:.*:.*:processing-job/{FAKE_PROCESSING_JOB_NAME}$", | 
					
						
							|  |  |  |         processing_jobs["ProcessingJobSummaries"][0]["ProcessingJobArn"], | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  |     assert processing_jobs.get("NextToken") is None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  | def test_list_processing_jobs_multiple(sagemaker_client): | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  |     name_job_1 = "blah" | 
					
						
							|  |  |  |     arn_job_1 = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar" | 
					
						
							|  |  |  |     test_processing_job_1 = MyProcessingJobModel( | 
					
						
							|  |  |  |         processing_job_name=name_job_1, role_arn=arn_job_1 | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     test_processing_job_1.save(sagemaker_client) | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  |     name_job_2 = "blah2" | 
					
						
							|  |  |  |     arn_job_2 = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar2" | 
					
						
							|  |  |  |     test_processing_job_2 = MyProcessingJobModel( | 
					
						
							|  |  |  |         processing_job_name=name_job_2, role_arn=arn_job_2 | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     test_processing_job_2.save(sagemaker_client) | 
					
						
							|  |  |  |     processing_jobs_limit = sagemaker_client.list_processing_jobs(MaxResults=1) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(processing_jobs_limit["ProcessingJobSummaries"]) == 1 | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     processing_jobs = sagemaker_client.list_processing_jobs() | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(processing_jobs["ProcessingJobSummaries"]) == 2 | 
					
						
							|  |  |  |     assert processing_jobs.get("NextToken") is None | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  | def test_list_processing_jobs_none(sagemaker_client): | 
					
						
							|  |  |  |     processing_jobs = sagemaker_client.list_processing_jobs() | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(processing_jobs["ProcessingJobSummaries"]) == 0 | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  | def test_list_processing_jobs_should_validate_input(sagemaker_client): | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  |     junk_status_equals = "blah" | 
					
						
							|  |  |  |     with pytest.raises(ClientError) as ex: | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |         sagemaker_client.list_processing_jobs(StatusEquals=junk_status_equals) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     expected_error = ( | 
					
						
							|  |  |  |         f"1 validation errors detected: Value '{junk_status_equals}' at " | 
					
						
							|  |  |  |         "'statusEquals' failed to satisfy constraint: Member must satisfy " | 
					
						
							|  |  |  |         "enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', " | 
					
						
							|  |  |  |         "'Failed']" | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  |     assert ex.value.response["Error"]["Code"] == "ValidationException" | 
					
						
							|  |  |  |     assert ex.value.response["Error"]["Message"] == expected_error | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     junk_next_token = "asdf" | 
					
						
							|  |  |  |     with pytest.raises(ClientError) as ex: | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |         sagemaker_client.list_processing_jobs(NextToken=junk_next_token) | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  |     assert ex.value.response["Error"]["Code"] == "ValidationException" | 
					
						
							|  |  |  |     assert ( | 
					
						
							|  |  |  |         ex.value.response["Error"]["Message"] | 
					
						
							|  |  |  |         == 'Invalid pagination token because "{0}".' | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  | def test_list_processing_jobs_with_name_filters(sagemaker_client): | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  |     for i in range(5): | 
					
						
							| 
									
										
										
										
											2022-11-17 21:41:08 -01:00
										 |  |  |         name = f"xgboost-{i}" | 
					
						
							|  |  |  |         arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{i}" | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |         MyProcessingJobModel(processing_job_name=name, role_arn=arn).save( | 
					
						
							|  |  |  |             sagemaker_client | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  |     for i in range(5): | 
					
						
							| 
									
										
										
										
											2022-11-17 21:41:08 -01:00
										 |  |  |         name = f"vgg-{i}" | 
					
						
							|  |  |  |         arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{i}" | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |         MyProcessingJobModel(processing_job_name=name, role_arn=arn).save( | 
					
						
							|  |  |  |             sagemaker_client | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     xgboost_processing_jobs = sagemaker_client.list_processing_jobs( | 
					
						
							|  |  |  |         NameContains="xgboost" | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(xgboost_processing_jobs["ProcessingJobSummaries"]) == 5 | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     processing_jobs_with_2 = sagemaker_client.list_processing_jobs(NameContains="2") | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(processing_jobs_with_2["ProcessingJobSummaries"]) == 2 | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  | def test_list_processing_jobs_paginated(sagemaker_client): | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  |     for i in range(5): | 
					
						
							| 
									
										
										
										
											2022-11-17 21:41:08 -01:00
										 |  |  |         name = f"xgboost-{i}" | 
					
						
							|  |  |  |         arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{i}" | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |         MyProcessingJobModel(processing_job_name=name, role_arn=arn).save( | 
					
						
							|  |  |  |             sagemaker_client | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     xgboost_processing_job_1 = sagemaker_client.list_processing_jobs( | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  |         NameContains="xgboost", MaxResults=1 | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(xgboost_processing_job_1["ProcessingJobSummaries"]) == 1 | 
					
						
							|  |  |  |     assert ( | 
					
						
							|  |  |  |         xgboost_processing_job_1["ProcessingJobSummaries"][0]["ProcessingJobName"] | 
					
						
							|  |  |  |         == "xgboost-0" | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     assert xgboost_processing_job_1.get("NextToken") is not None | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     xgboost_processing_job_next = sagemaker_client.list_processing_jobs( | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  |         NameContains="xgboost", | 
					
						
							|  |  |  |         MaxResults=1, | 
					
						
							|  |  |  |         NextToken=xgboost_processing_job_1.get("NextToken"), | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(xgboost_processing_job_next["ProcessingJobSummaries"]) == 1 | 
					
						
							|  |  |  |     assert ( | 
					
						
							|  |  |  |         xgboost_processing_job_next["ProcessingJobSummaries"][0]["ProcessingJobName"] | 
					
						
							|  |  |  |         == "xgboost-1" | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     assert xgboost_processing_job_next.get("NextToken") is not None | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  | def test_list_processing_jobs_paginated_with_target_in_middle(sagemaker_client): | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  |     for i in range(5): | 
					
						
							| 
									
										
										
										
											2022-11-17 21:41:08 -01:00
										 |  |  |         name = f"xgboost-{i}" | 
					
						
							|  |  |  |         arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{i}" | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |         MyProcessingJobModel(processing_job_name=name, role_arn=arn).save( | 
					
						
							|  |  |  |             sagemaker_client | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  |     for i in range(5): | 
					
						
							| 
									
										
										
										
											2022-11-17 21:41:08 -01:00
										 |  |  |         name = f"vgg-{i}" | 
					
						
							|  |  |  |         arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{i}" | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |         MyProcessingJobModel(processing_job_name=name, role_arn=arn).save( | 
					
						
							|  |  |  |             sagemaker_client | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     vgg_processing_job_1 = sagemaker_client.list_processing_jobs( | 
					
						
							|  |  |  |         NameContains="vgg", MaxResults=1 | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(vgg_processing_job_1["ProcessingJobSummaries"]) == 0 | 
					
						
							|  |  |  |     assert vgg_processing_job_1.get("NextToken") is not None | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     vgg_processing_job_6 = sagemaker_client.list_processing_jobs( | 
					
						
							|  |  |  |         NameContains="vgg", MaxResults=6 | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(vgg_processing_job_6["ProcessingJobSummaries"]) == 1 | 
					
						
							|  |  |  |     assert ( | 
					
						
							|  |  |  |         vgg_processing_job_6["ProcessingJobSummaries"][0]["ProcessingJobName"] | 
					
						
							|  |  |  |         == "vgg-0" | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     assert vgg_processing_job_6.get("NextToken") is not None | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     vgg_processing_job_10 = sagemaker_client.list_processing_jobs( | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  |         NameContains="vgg", MaxResults=10 | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(vgg_processing_job_10["ProcessingJobSummaries"]) == 5 | 
					
						
							|  |  |  |     assert ( | 
					
						
							|  |  |  |         vgg_processing_job_10["ProcessingJobSummaries"][-1]["ProcessingJobName"] | 
					
						
							|  |  |  |         == "vgg-4" | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     assert vgg_processing_job_10.get("NextToken") is None | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  | def test_list_processing_jobs_paginated_with_fragmented_targets(sagemaker_client): | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  |     for i in range(5): | 
					
						
							| 
									
										
										
										
											2022-11-17 21:41:08 -01:00
										 |  |  |         name = f"xgboost-{i}" | 
					
						
							|  |  |  |         arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{i}" | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |         MyProcessingJobModel(processing_job_name=name, role_arn=arn).save( | 
					
						
							|  |  |  |             sagemaker_client | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  |     for i in range(5): | 
					
						
							| 
									
										
										
										
											2022-11-17 21:41:08 -01:00
										 |  |  |         name = f"vgg-{i}" | 
					
						
							|  |  |  |         arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{i}" | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |         MyProcessingJobModel(processing_job_name=name, role_arn=arn).save( | 
					
						
							|  |  |  |             sagemaker_client | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     processing_jobs_with_2 = sagemaker_client.list_processing_jobs( | 
					
						
							|  |  |  |         NameContains="2", MaxResults=8 | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(processing_jobs_with_2["ProcessingJobSummaries"]) == 2 | 
					
						
							|  |  |  |     assert processing_jobs_with_2.get("NextToken") is not None | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     processing_jobs_with_2_next = sagemaker_client.list_processing_jobs( | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  |         NameContains="2", | 
					
						
							|  |  |  |         MaxResults=1, | 
					
						
							|  |  |  |         NextToken=processing_jobs_with_2.get("NextToken"), | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(processing_jobs_with_2_next["ProcessingJobSummaries"]) == 0 | 
					
						
							|  |  |  |     assert processing_jobs_with_2_next.get("NextToken") is not None | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     processing_jobs_with_2_next_next = sagemaker_client.list_processing_jobs( | 
					
						
							| 
									
										
										
										
											2021-11-06 13:47:42 +01:00
										 |  |  |         NameContains="2", | 
					
						
							|  |  |  |         MaxResults=1, | 
					
						
							|  |  |  |         NextToken=processing_jobs_with_2_next.get("NextToken"), | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert len(processing_jobs_with_2_next_next["ProcessingJobSummaries"]) == 0 | 
					
						
							|  |  |  |     assert processing_jobs_with_2_next_next.get("NextToken") is None | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_add_and_delete_tags_in_training_job(sagemaker_client): | 
					
						
							|  |  |  |     processing_job_name = "MyProcessingJob" | 
					
						
							| 
									
										
										
										
											2022-11-17 21:41:08 -01:00
										 |  |  |     role_arn = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole" | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     container = "382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1" | 
					
						
							|  |  |  |     bucket = "my-bucket" | 
					
						
							|  |  |  |     prefix = "my-prefix" | 
					
						
							|  |  |  |     app_specification = { | 
					
						
							|  |  |  |         "ImageUri": container, | 
					
						
							|  |  |  |         "ContainerEntrypoint": ["python3", "app.py"], | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     processing_resources = { | 
					
						
							|  |  |  |         "ClusterConfig": { | 
					
						
							|  |  |  |             "InstanceCount": 2, | 
					
						
							|  |  |  |             "InstanceType": "ml.m5.xlarge", | 
					
						
							|  |  |  |             "VolumeSizeInGB": 20, | 
					
						
							|  |  |  |         }, | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     stopping_condition = {"MaxRuntimeInSeconds": 60 * 60} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     job = MyProcessingJobModel( | 
					
						
							|  |  |  |         processing_job_name, | 
					
						
							|  |  |  |         role_arn, | 
					
						
							|  |  |  |         container=container, | 
					
						
							|  |  |  |         bucket=bucket, | 
					
						
							|  |  |  |         prefix=prefix, | 
					
						
							|  |  |  |         app_specification=app_specification, | 
					
						
							|  |  |  |         processing_resources=processing_resources, | 
					
						
							|  |  |  |         stopping_condition=stopping_condition, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     resp = job.save(sagemaker_client) | 
					
						
							|  |  |  |     resource_arn = resp["ProcessingJobArn"] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     tags = [ | 
					
						
							|  |  |  |         {"Key": "myKey", "Value": "myValue"}, | 
					
						
							|  |  |  |     ] | 
					
						
							|  |  |  |     response = sagemaker_client.add_tags(ResourceArn=resource_arn, Tags=tags) | 
					
						
							|  |  |  |     assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     response = sagemaker_client.list_tags(ResourceArn=resource_arn) | 
					
						
							|  |  |  |     assert response["Tags"] == tags | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     tag_keys = [tag["Key"] for tag in tags] | 
					
						
							|  |  |  |     response = sagemaker_client.delete_tags(ResourceArn=resource_arn, TagKeys=tag_keys) | 
					
						
							|  |  |  |     assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     response = sagemaker_client.list_tags(ResourceArn=resource_arn) | 
					
						
							|  |  |  |     assert response["Tags"] == [] |