import boto3 from botocore.exceptions import ClientError import datetime import sure # noqa # pylint: disable=unused-import import pytest from moto import mock_sagemaker from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID FAKE_ROLE_ARN = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole" TEST_REGION_NAME = "us-east-1" class MyTransformJobModel(object): def __init__( self, transform_job_name, model_name, max_concurrent_transforms=None, model_client_config=None, max_payload_in_mb=None, batch_strategy=None, environment=None, transform_input=None, transform_output=None, data_capture_config=None, transform_resources=None, data_processing=None, tags=None, experiment_config=None, ): self.transform_job_name = transform_job_name self.model_name = model_name self.max_concurrent_transforms = max_concurrent_transforms or 1 self.model_client_config = model_client_config or {} self.max_payload_in_mb = max_payload_in_mb or 1 self.batch_strategy = batch_strategy or "SingleRecord" self.environment = environment or {} self.transform_input = transform_input or { "DataSource": { "S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": "input"} }, "ContentType": "application/json", "CompressionType": "None", "SplitType": "None", } self.transform_output = transform_output or { "S3OutputPath": "some-bucket", "Accept": "application/json", "AssembleWith": "None", "KmsKeyId": "None", } self.data_capture_config = data_capture_config or { "DestinationS3Uri": "data_capture", "KmsKeyId": "None", "GenerateInferenceId": False, } self.transform_resources = transform_resources or { "InstanceType": "ml.m5.2xlarge", "InstanceCount": 1, } self.data_processing = data_processing or {} self.tags = tags or [] self.experiment_config = experiment_config or {} def save(self): sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) params = { "TransformJobName": self.transform_job_name, "ModelName": self.model_name, "MaxConcurrentTransforms": self.max_concurrent_transforms, "ModelClientConfig": self.model_client_config, "MaxPayloadInMB": self.max_payload_in_mb, "BatchStrategy": self.batch_strategy, "Environment": self.environment, "TransformInput": self.transform_input, "TransformOutput": self.transform_output, "DataCaptureConfig": self.data_capture_config, "TransformResources": self.transform_resources, "DataProcessing": self.data_processing, "Tags": self.tags, "ExperimentConfig": self.experiment_config, } return sagemaker.create_transform_job(**params) @mock_sagemaker def test_create_transform_job(): sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) transform_job_name = "MyTransformJob" model_name = "MyModelName" bucket = "my-bucket" transform_input = { "DataSource": {"S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": "input"}}, "ContentType": "application/json", "CompressionType": "None", "SplitType": "None", } transform_output = { "S3OutputPath": bucket, "Accept": "application/json", "AssembleWith": "None", "KmsKeyId": "None", } model_client_config = { "InvocationsTimeoutInSeconds": 60, "InvocationsMaxRetries": 1, } max_payload_in_mb = 1 data_capture_config = { "DestinationS3Uri": "data_capture", "KmsKeyId": "None", "GenerateInferenceId": False, } transform_resources = { "InstanceType": "ml.m5.2xlarge", "InstanceCount": 1, } data_processing = { "InputFilter": "$.features", "OutputFilter": "$['id','SageMakerOutput']", "JoinSource": "None", } experiment_config = { "ExperimentName": "MyExperiment", "TrialName": "MyTrial", "TrialComponentDisplayName": "MyTrialDisplay", "RunName": "MyRun", } job = MyTransformJobModel( transform_job_name=transform_job_name, model_name=model_name, transform_output=transform_output, model_client_config=model_client_config, max_payload_in_mb=max_payload_in_mb, data_capture_config=data_capture_config, transform_resources=transform_resources, data_processing=data_processing, experiment_config=experiment_config, ) resp = job.save() resp["TransformJobArn"].should.match( rf"^arn:aws:sagemaker:.*:.*:transform-job/{transform_job_name}$" ) resp = sagemaker.describe_transform_job(TransformJobName=transform_job_name) resp["TransformJobName"].should.equal(transform_job_name) resp["TransformJobStatus"].should.equal("Completed") resp["ModelName"].should.equal(model_name) resp["MaxConcurrentTransforms"].should.equal(1) resp["ModelClientConfig"].should.equal(model_client_config) resp["MaxPayloadInMB"].should.equal(max_payload_in_mb) resp["BatchStrategy"].should.equal("SingleRecord") resp["TransformInput"].should.equal(transform_input) resp["TransformOutput"].should.equal(transform_output) resp["DataCaptureConfig"].should.equal(data_capture_config) resp["TransformResources"].should.equal(transform_resources) resp["DataProcessing"].should.equal(data_processing) resp["ExperimentConfig"].should.equal(experiment_config) assert isinstance(resp["CreationTime"], datetime.datetime) assert isinstance(resp["TransformStartTime"], datetime.datetime) assert isinstance(resp["TransformEndTime"], datetime.datetime) @mock_sagemaker def test_list_transform_jobs(): client = boto3.client("sagemaker", region_name="us-east-1") name = "blah" model_name = "blah_model" test_transform_job = MyTransformJobModel( transform_job_name=name, model_name=model_name ) test_transform_job.save() transform_jobs = client.list_transform_jobs() assert len(transform_jobs["TransformJobSummaries"]).should.equal(1) assert transform_jobs["TransformJobSummaries"][0]["TransformJobName"].should.equal( name ) assert transform_jobs["TransformJobSummaries"][0]["TransformJobArn"].should.match( rf"^arn:aws:sagemaker:.*:.*:transform-job/{name}$" ) assert transform_jobs.get("NextToken") is None @mock_sagemaker def test_list_transform_jobs_multiple(): client = boto3.client("sagemaker", region_name="us-east-1") name_job_1 = "blah" model_name1 = "blah_model" test_transform_job_1 = MyTransformJobModel( transform_job_name=name_job_1, model_name=model_name1 ) test_transform_job_1.save() name_job_2 = "blah2" model_name2 = "blah_model2" test_transform_job_2 = MyTransformJobModel( transform_job_name=name_job_2, model_name=model_name2 ) test_transform_job_2.save() transform_jobs_limit = client.list_transform_jobs(MaxResults=1) assert len(transform_jobs_limit["TransformJobSummaries"]).should.equal(1) transform_jobs = client.list_transform_jobs() assert len(transform_jobs["TransformJobSummaries"]).should.equal(2) assert transform_jobs.get("NextToken").should.be.none @mock_sagemaker def test_list_transform_jobs_none(): client = boto3.client("sagemaker", region_name="us-east-1") transform_jobs = client.list_transform_jobs() assert len(transform_jobs["TransformJobSummaries"]).should.equal(0) @mock_sagemaker def test_list_transform_jobs_should_validate_input(): client = boto3.client("sagemaker", region_name="us-east-1") junk_status_equals = "blah" with pytest.raises(ClientError) as ex: client.list_transform_jobs(StatusEquals=junk_status_equals) expected_error = f"1 validation errors detected: Value '{junk_status_equals}' at 'statusEquals' failed to satisfy constraint: Member must satisfy enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', 'Failed']" assert ex.value.response["Error"]["Code"] == "ValidationException" assert ex.value.response["Error"]["Message"] == expected_error junk_next_token = "asdf" with pytest.raises(ClientError) as ex: client.list_transform_jobs(NextToken=junk_next_token) assert ex.value.response["Error"]["Code"] == "ValidationException" assert ( ex.value.response["Error"]["Message"] == 'Invalid pagination token because "{0}".' ) @mock_sagemaker def test_list_transform_jobs_with_name_filters(): client = boto3.client("sagemaker", region_name="us-east-1") for i in range(5): name = f"xgboost-{i}" model_name = f"blah_model-{i}" MyTransformJobModel(transform_job_name=name, model_name=model_name).save() for i in range(5): name = f"vgg-{i}" model_name = f"blah_model-{i}" MyTransformJobModel(transform_job_name=name, model_name=model_name).save() xgboost_transform_jobs = client.list_transform_jobs(NameContains="xgboost") assert len(xgboost_transform_jobs["TransformJobSummaries"]).should.equal(5) transform_jobs_with_2 = client.list_transform_jobs(NameContains="2") assert len(transform_jobs_with_2["TransformJobSummaries"]).should.equal(2) @mock_sagemaker def test_list_transform_jobs_paginated(): client = boto3.client("sagemaker", region_name="us-east-1") for i in range(5): name = f"xgboost-{i}" model_name = f"my-model-{i}" MyTransformJobModel(transform_job_name=name, model_name=model_name).save() xgboost_transform_job_1 = client.list_transform_jobs( NameContains="xgboost", MaxResults=1 ) assert len(xgboost_transform_job_1["TransformJobSummaries"]).should.equal(1) assert xgboost_transform_job_1["TransformJobSummaries"][0][ "TransformJobName" ].should.equal("xgboost-0") assert xgboost_transform_job_1.get("NextToken").should_not.be.none xgboost_transform_job_next = client.list_transform_jobs( NameContains="xgboost", MaxResults=1, NextToken=xgboost_transform_job_1.get("NextToken"), ) assert len(xgboost_transform_job_next["TransformJobSummaries"]).should.equal(1) assert xgboost_transform_job_next["TransformJobSummaries"][0][ "TransformJobName" ].should.equal("xgboost-1") assert xgboost_transform_job_next.get("NextToken").should_not.be.none @mock_sagemaker def test_list_transform_jobs_paginated_with_target_in_middle(): client = boto3.client("sagemaker", region_name="us-east-1") for i in range(5): name = f"xgboost-{i}" model_name = f"my-model-{i}" MyTransformJobModel(transform_job_name=name, model_name=model_name).save() for i in range(5): name = f"vgg-{i}" MyTransformJobModel(transform_job_name=name, model_name=model_name).save() vgg_transform_job_1 = client.list_transform_jobs(NameContains="vgg", MaxResults=1) assert len(vgg_transform_job_1["TransformJobSummaries"]).should.equal(0) assert vgg_transform_job_1.get("NextToken").should_not.be.none vgg_transform_job_6 = client.list_transform_jobs(NameContains="vgg", MaxResults=6) assert len(vgg_transform_job_6["TransformJobSummaries"]).should.equal(1) assert vgg_transform_job_6["TransformJobSummaries"][0][ "TransformJobName" ].should.equal("vgg-0") assert vgg_transform_job_6.get("NextToken").should_not.be.none vgg_transform_job_10 = client.list_transform_jobs(NameContains="vgg", MaxResults=10) assert len(vgg_transform_job_10["TransformJobSummaries"]).should.equal(5) assert vgg_transform_job_10["TransformJobSummaries"][-1][ "TransformJobName" ].should.equal("vgg-4") assert vgg_transform_job_10.get("NextToken").should.be.none @mock_sagemaker def test_list_transform_jobs_paginated_with_fragmented_targets(): client = boto3.client("sagemaker", region_name="us-east-1") for i in range(5): name = f"xgboost-{i}" model_name = f"my-model-{i}" MyTransformJobModel(transform_job_name=name, model_name=model_name).save() for i in range(5): name = f"vgg-{i}" MyTransformJobModel(transform_job_name=name, model_name=model_name).save() transform_jobs_with_2 = client.list_transform_jobs(NameContains="2", MaxResults=8) assert len(transform_jobs_with_2["TransformJobSummaries"]).should.equal(2) assert transform_jobs_with_2.get("NextToken").should_not.be.none transform_jobs_with_2_next = client.list_transform_jobs( NameContains="2", MaxResults=1, NextToken=transform_jobs_with_2.get("NextToken") ) assert len(transform_jobs_with_2_next["TransformJobSummaries"]).should.equal(0) assert transform_jobs_with_2_next.get("NextToken").should_not.be.none transform_jobs_with_2_next_next = client.list_transform_jobs( NameContains="2", MaxResults=1, NextToken=transform_jobs_with_2_next.get("NextToken"), ) assert len(transform_jobs_with_2_next_next["TransformJobSummaries"]).should.equal(0) assert transform_jobs_with_2_next_next.get("NextToken").should.be.none @mock_sagemaker def test_add_tags_to_transform_job(): client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) name = "blah" model_name = "my-model" resource_arn = "arn:aws:sagemaker:us-east-1:123456789012:transform-job/blah" test_transform_job = MyTransformJobModel( transform_job_name=name, model_name=model_name ) test_transform_job.save() tags = [ {"Key": "myKey", "Value": "myValue"}, ] response = client.add_tags(ResourceArn=resource_arn, Tags=tags) assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 response = client.list_tags(ResourceArn=resource_arn) assert response["Tags"] == tags @mock_sagemaker def test_delete_tags_from_transform_job(): client = boto3.client("sagemaker", region_name=TEST_REGION_NAME) name = "blah" model_name = "my-model" resource_arn = "arn:aws:sagemaker:us-east-1:123456789012:transform-job/blah" test_transform_job = MyTransformJobModel( transform_job_name=name, model_name=model_name ) test_transform_job.save() tags = [ {"Key": "myKey", "Value": "myValue"}, ] response = client.add_tags(ResourceArn=resource_arn, Tags=tags) assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 tag_keys = [tag["Key"] for tag in tags] response = client.delete_tags(ResourceArn=resource_arn, TagKeys=tag_keys) assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 response = client.list_tags(ResourceArn=resource_arn) assert response["Tags"] == [] @mock_sagemaker def test_describe_unknown_transform_job(): client = boto3.client("sagemaker", region_name="us-east-1") with pytest.raises(ClientError) as exc: client.describe_transform_job(TransformJobName="unknown") err = exc.value.response["Error"] err["Code"].should.equal("ValidationException") err["Message"].should.equal( f"Could not find transform job 'arn:aws:sagemaker:us-east-1:{ACCOUNT_ID}:transform-job/unknown'." )