2021-11-06 12:47:42 +00:00
import boto3
from botocore . exceptions import ClientError
import datetime
import pytest
from moto import mock_sagemaker
2022-08-13 09:49:43 +00:00
from moto . core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID
2021-11-06 12:47:42 +00:00
2022-04-27 11:56:08 +00:00
FAKE_ROLE_ARN = f " arn:aws:iam:: { ACCOUNT_ID } :role/FakeRole "
FAKE_PROCESSING_JOB_NAME = " MyProcessingJob "
FAKE_CONTAINER = " 382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1 "
2021-11-06 12:47:42 +00:00
TEST_REGION_NAME = " us-east-1 "
2022-10-04 16:28:30 +00:00
@pytest.fixture ( name = " sagemaker_client " )
def fixture_sagemaker_client ( ) :
with mock_sagemaker ( ) :
yield boto3 . client ( " sagemaker " , region_name = TEST_REGION_NAME )
2022-04-27 11:56:08 +00:00
2021-11-06 12:47:42 +00:00
class MyProcessingJobModel ( object ) :
def __init__ (
self ,
processing_job_name ,
role_arn ,
container = None ,
bucket = None ,
prefix = None ,
app_specification = None ,
network_config = None ,
processing_inputs = None ,
processing_output_config = None ,
processing_resources = None ,
stopping_condition = None ,
) :
self . processing_job_name = processing_job_name
self . role_arn = role_arn
self . container = (
container
or " 683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn:0.23-1-cpu-py3 "
)
self . bucket = bucket or " my-bucket "
self . prefix = prefix or " sagemaker "
self . app_specification = app_specification or {
" ImageUri " : self . container ,
2022-03-10 14:39:59 +00:00
" ContainerEntrypoint " : [ " python3 " ] ,
2021-11-06 12:47:42 +00:00
}
self . network_config = network_config or {
" EnableInterContainerTrafficEncryption " : False ,
" EnableNetworkIsolation " : False ,
}
self . processing_inputs = processing_inputs or [
{
" InputName " : " input " ,
" AppManaged " : False ,
" S3Input " : {
2022-11-17 22:41:08 +00:00
" S3Uri " : f " s3:// { self . bucket } / { self . prefix } /processing/ " ,
2021-11-06 12:47:42 +00:00
" LocalPath " : " /opt/ml/processing/input " ,
" S3DataType " : " S3Prefix " ,
" S3InputMode " : " File " ,
" S3DataDistributionType " : " FullyReplicated " ,
" S3CompressionType " : " None " ,
} ,
}
]
self . processing_output_config = processing_output_config or {
" Outputs " : [
{
" OutputName " : " output " ,
" S3Output " : {
2022-11-17 22:41:08 +00:00
" S3Uri " : f " s3:// { self . bucket } / { self . prefix } /processing/ " ,
2021-11-06 12:47:42 +00:00
" LocalPath " : " /opt/ml/processing/output " ,
" S3UploadMode " : " EndOfJob " ,
} ,
" AppManaged " : False ,
}
]
}
self . processing_resources = processing_resources or {
" ClusterConfig " : {
" InstanceCount " : 1 ,
" InstanceType " : " ml.m5.large " ,
" VolumeSizeInGB " : 10 ,
} ,
}
self . stopping_condition = stopping_condition or {
" MaxRuntimeInSeconds " : 3600 ,
}
2022-04-27 11:56:08 +00:00
def save ( self , sagemaker_client ) :
2021-11-06 12:47:42 +00:00
params = {
" AppSpecification " : self . app_specification ,
" NetworkConfig " : self . network_config ,
" ProcessingInputs " : self . processing_inputs ,
" ProcessingJobName " : self . processing_job_name ,
" ProcessingOutputConfig " : self . processing_output_config ,
" ProcessingResources " : self . processing_resources ,
" RoleArn " : self . role_arn ,
" StoppingCondition " : self . stopping_condition ,
}
2022-04-27 11:56:08 +00:00
return sagemaker_client . create_processing_job ( * * params )
2021-11-06 12:47:42 +00:00
2022-04-27 11:56:08 +00:00
def test_create_processing_job ( sagemaker_client ) :
2021-11-06 12:47:42 +00:00
bucket = " my-bucket "
prefix = " my-prefix "
app_specification = {
2022-04-27 11:56:08 +00:00
" ImageUri " : FAKE_CONTAINER ,
2021-11-06 12:47:42 +00:00
" ContainerEntrypoint " : [ " python3 " , " app.py " ] ,
}
processing_resources = {
" ClusterConfig " : {
" InstanceCount " : 2 ,
" InstanceType " : " ml.m5.xlarge " ,
" VolumeSizeInGB " : 20 ,
} ,
}
stopping_condition = { " MaxRuntimeInSeconds " : 60 * 60 }
job = MyProcessingJobModel (
2022-04-27 11:56:08 +00:00
processing_job_name = FAKE_PROCESSING_JOB_NAME ,
role_arn = FAKE_ROLE_ARN ,
container = FAKE_CONTAINER ,
2021-11-06 12:47:42 +00:00
bucket = bucket ,
prefix = prefix ,
app_specification = app_specification ,
processing_resources = processing_resources ,
stopping_condition = stopping_condition ,
)
2022-04-27 11:56:08 +00:00
resp = job . save ( sagemaker_client )
2021-11-06 12:47:42 +00:00
resp [ " ProcessingJobArn " ] . should . match (
2022-11-17 22:41:08 +00:00
rf " ^arn:aws:sagemaker:.*:.*:processing-job/ { FAKE_PROCESSING_JOB_NAME } $ "
2021-11-06 12:47:42 +00:00
)
2022-04-27 11:56:08 +00:00
resp = sagemaker_client . describe_processing_job (
ProcessingJobName = FAKE_PROCESSING_JOB_NAME
)
resp [ " ProcessingJobName " ] . should . equal ( FAKE_PROCESSING_JOB_NAME )
2021-11-06 12:47:42 +00:00
resp [ " ProcessingJobArn " ] . should . match (
2022-11-17 22:41:08 +00:00
rf " ^arn:aws:sagemaker:.*:.*:processing-job/ { FAKE_PROCESSING_JOB_NAME } $ "
2021-11-06 12:47:42 +00:00
)
assert " python3 " in resp [ " AppSpecification " ] [ " ContainerEntrypoint " ]
assert " app.py " in resp [ " AppSpecification " ] [ " ContainerEntrypoint " ]
2022-04-27 11:56:08 +00:00
assert resp [ " RoleArn " ] == FAKE_ROLE_ARN
2021-11-06 12:47:42 +00:00
assert resp [ " ProcessingJobStatus " ] == " Completed "
assert isinstance ( resp [ " CreationTime " ] , datetime . datetime )
assert isinstance ( resp [ " LastModifiedTime " ] , datetime . datetime )
2022-04-27 11:56:08 +00:00
def test_list_processing_jobs ( sagemaker_client ) :
test_processing_job = MyProcessingJobModel (
processing_job_name = FAKE_PROCESSING_JOB_NAME , role_arn = FAKE_ROLE_ARN
)
test_processing_job . save ( sagemaker_client )
processing_jobs = sagemaker_client . list_processing_jobs ( )
2021-11-06 12:47:42 +00:00
assert len ( processing_jobs [ " ProcessingJobSummaries " ] ) . should . equal ( 1 )
assert processing_jobs [ " ProcessingJobSummaries " ] [ 0 ] [
" ProcessingJobName "
2022-04-27 11:56:08 +00:00
] . should . equal ( FAKE_PROCESSING_JOB_NAME )
2021-11-06 12:47:42 +00:00
assert processing_jobs [ " ProcessingJobSummaries " ] [ 0 ] [
" ProcessingJobArn "
2022-04-27 11:56:08 +00:00
] . should . match (
2022-11-17 22:41:08 +00:00
rf " ^arn:aws:sagemaker:.*:.*:processing-job/ { FAKE_PROCESSING_JOB_NAME } $ "
2022-04-27 11:56:08 +00:00
)
2021-11-06 12:47:42 +00:00
assert processing_jobs . get ( " NextToken " ) is None
2022-04-27 11:56:08 +00:00
def test_list_processing_jobs_multiple ( sagemaker_client ) :
2021-11-06 12:47:42 +00:00
name_job_1 = " blah "
arn_job_1 = " arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar "
test_processing_job_1 = MyProcessingJobModel (
processing_job_name = name_job_1 , role_arn = arn_job_1
)
2022-04-27 11:56:08 +00:00
test_processing_job_1 . save ( sagemaker_client )
2021-11-06 12:47:42 +00:00
name_job_2 = " blah2 "
arn_job_2 = " arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar2 "
test_processing_job_2 = MyProcessingJobModel (
processing_job_name = name_job_2 , role_arn = arn_job_2
)
2022-04-27 11:56:08 +00:00
test_processing_job_2 . save ( sagemaker_client )
processing_jobs_limit = sagemaker_client . list_processing_jobs ( MaxResults = 1 )
2021-11-06 12:47:42 +00:00
assert len ( processing_jobs_limit [ " ProcessingJobSummaries " ] ) . should . equal ( 1 )
2022-04-27 11:56:08 +00:00
processing_jobs = sagemaker_client . list_processing_jobs ( )
2021-11-06 12:47:42 +00:00
assert len ( processing_jobs [ " ProcessingJobSummaries " ] ) . should . equal ( 2 )
assert processing_jobs . get ( " NextToken " ) . should . be . none
2022-04-27 11:56:08 +00:00
def test_list_processing_jobs_none ( sagemaker_client ) :
processing_jobs = sagemaker_client . list_processing_jobs ( )
2021-11-06 12:47:42 +00:00
assert len ( processing_jobs [ " ProcessingJobSummaries " ] ) . should . equal ( 0 )
2022-04-27 11:56:08 +00:00
def test_list_processing_jobs_should_validate_input ( sagemaker_client ) :
2021-11-06 12:47:42 +00:00
junk_status_equals = " blah "
with pytest . raises ( ClientError ) as ex :
2022-04-27 11:56:08 +00:00
sagemaker_client . list_processing_jobs ( StatusEquals = junk_status_equals )
2021-11-06 12:47:42 +00:00
expected_error = f " 1 validation errors detected: Value ' { junk_status_equals } ' at ' statusEquals ' failed to satisfy constraint: Member must satisfy enum value set: [ ' Completed ' , ' Stopped ' , ' InProgress ' , ' Stopping ' , ' Failed ' ] "
assert ex . value . response [ " Error " ] [ " Code " ] == " ValidationException "
assert ex . value . response [ " Error " ] [ " Message " ] == expected_error
junk_next_token = " asdf "
with pytest . raises ( ClientError ) as ex :
2022-04-27 11:56:08 +00:00
sagemaker_client . list_processing_jobs ( NextToken = junk_next_token )
2021-11-06 12:47:42 +00:00
assert ex . value . response [ " Error " ] [ " Code " ] == " ValidationException "
assert (
ex . value . response [ " Error " ] [ " Message " ]
== ' Invalid pagination token because " {0} " . '
)
2022-04-27 11:56:08 +00:00
def test_list_processing_jobs_with_name_filters ( sagemaker_client ) :
2021-11-06 12:47:42 +00:00
for i in range ( 5 ) :
2022-11-17 22:41:08 +00:00
name = f " xgboost- { i } "
arn = f " arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar- { i } "
2022-04-27 11:56:08 +00:00
MyProcessingJobModel ( processing_job_name = name , role_arn = arn ) . save (
sagemaker_client
)
2021-11-06 12:47:42 +00:00
for i in range ( 5 ) :
2022-11-17 22:41:08 +00:00
name = f " vgg- { i } "
arn = f " arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo- { i } "
2022-04-27 11:56:08 +00:00
MyProcessingJobModel ( processing_job_name = name , role_arn = arn ) . save (
sagemaker_client
)
xgboost_processing_jobs = sagemaker_client . list_processing_jobs (
NameContains = " xgboost "
)
2021-11-06 12:47:42 +00:00
assert len ( xgboost_processing_jobs [ " ProcessingJobSummaries " ] ) . should . equal ( 5 )
2022-04-27 11:56:08 +00:00
processing_jobs_with_2 = sagemaker_client . list_processing_jobs ( NameContains = " 2 " )
2021-11-06 12:47:42 +00:00
assert len ( processing_jobs_with_2 [ " ProcessingJobSummaries " ] ) . should . equal ( 2 )
2022-04-27 11:56:08 +00:00
def test_list_processing_jobs_paginated ( sagemaker_client ) :
2021-11-06 12:47:42 +00:00
for i in range ( 5 ) :
2022-11-17 22:41:08 +00:00
name = f " xgboost- { i } "
arn = f " arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar- { i } "
2022-04-27 11:56:08 +00:00
MyProcessingJobModel ( processing_job_name = name , role_arn = arn ) . save (
sagemaker_client
)
xgboost_processing_job_1 = sagemaker_client . list_processing_jobs (
2021-11-06 12:47:42 +00:00
NameContains = " xgboost " , MaxResults = 1
)
assert len ( xgboost_processing_job_1 [ " ProcessingJobSummaries " ] ) . should . equal ( 1 )
assert xgboost_processing_job_1 [ " ProcessingJobSummaries " ] [ 0 ] [
" ProcessingJobName "
] . should . equal ( " xgboost-0 " )
assert xgboost_processing_job_1 . get ( " NextToken " ) . should_not . be . none
2022-04-27 11:56:08 +00:00
xgboost_processing_job_next = sagemaker_client . list_processing_jobs (
2021-11-06 12:47:42 +00:00
NameContains = " xgboost " ,
MaxResults = 1 ,
NextToken = xgboost_processing_job_1 . get ( " NextToken " ) ,
)
assert len ( xgboost_processing_job_next [ " ProcessingJobSummaries " ] ) . should . equal ( 1 )
assert xgboost_processing_job_next [ " ProcessingJobSummaries " ] [ 0 ] [
" ProcessingJobName "
] . should . equal ( " xgboost-1 " )
assert xgboost_processing_job_next . get ( " NextToken " ) . should_not . be . none
2022-04-27 11:56:08 +00:00
def test_list_processing_jobs_paginated_with_target_in_middle ( sagemaker_client ) :
2021-11-06 12:47:42 +00:00
for i in range ( 5 ) :
2022-11-17 22:41:08 +00:00
name = f " xgboost- { i } "
arn = f " arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar- { i } "
2022-04-27 11:56:08 +00:00
MyProcessingJobModel ( processing_job_name = name , role_arn = arn ) . save (
sagemaker_client
)
2021-11-06 12:47:42 +00:00
for i in range ( 5 ) :
2022-11-17 22:41:08 +00:00
name = f " vgg- { i } "
arn = f " arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo- { i } "
2022-04-27 11:56:08 +00:00
MyProcessingJobModel ( processing_job_name = name , role_arn = arn ) . save (
sagemaker_client
)
2021-11-06 12:47:42 +00:00
2022-04-27 11:56:08 +00:00
vgg_processing_job_1 = sagemaker_client . list_processing_jobs (
NameContains = " vgg " , MaxResults = 1
)
2021-11-06 12:47:42 +00:00
assert len ( vgg_processing_job_1 [ " ProcessingJobSummaries " ] ) . should . equal ( 0 )
assert vgg_processing_job_1 . get ( " NextToken " ) . should_not . be . none
2022-04-27 11:56:08 +00:00
vgg_processing_job_6 = sagemaker_client . list_processing_jobs (
NameContains = " vgg " , MaxResults = 6
)
2021-11-06 12:47:42 +00:00
assert len ( vgg_processing_job_6 [ " ProcessingJobSummaries " ] ) . should . equal ( 1 )
assert vgg_processing_job_6 [ " ProcessingJobSummaries " ] [ 0 ] [
" ProcessingJobName "
] . should . equal ( " vgg-0 " )
assert vgg_processing_job_6 . get ( " NextToken " ) . should_not . be . none
2022-04-27 11:56:08 +00:00
vgg_processing_job_10 = sagemaker_client . list_processing_jobs (
2021-11-06 12:47:42 +00:00
NameContains = " vgg " , MaxResults = 10
)
assert len ( vgg_processing_job_10 [ " ProcessingJobSummaries " ] ) . should . equal ( 5 )
assert vgg_processing_job_10 [ " ProcessingJobSummaries " ] [ - 1 ] [
" ProcessingJobName "
] . should . equal ( " vgg-4 " )
assert vgg_processing_job_10 . get ( " NextToken " ) . should . be . none
2022-04-27 11:56:08 +00:00
def test_list_processing_jobs_paginated_with_fragmented_targets ( sagemaker_client ) :
2021-11-06 12:47:42 +00:00
for i in range ( 5 ) :
2022-11-17 22:41:08 +00:00
name = f " xgboost- { i } "
arn = f " arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar- { i } "
2022-04-27 11:56:08 +00:00
MyProcessingJobModel ( processing_job_name = name , role_arn = arn ) . save (
sagemaker_client
)
2021-11-06 12:47:42 +00:00
for i in range ( 5 ) :
2022-11-17 22:41:08 +00:00
name = f " vgg- { i } "
arn = f " arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo- { i } "
2022-04-27 11:56:08 +00:00
MyProcessingJobModel ( processing_job_name = name , role_arn = arn ) . save (
sagemaker_client
)
2021-11-06 12:47:42 +00:00
2022-04-27 11:56:08 +00:00
processing_jobs_with_2 = sagemaker_client . list_processing_jobs (
NameContains = " 2 " , MaxResults = 8
)
2021-11-06 12:47:42 +00:00
assert len ( processing_jobs_with_2 [ " ProcessingJobSummaries " ] ) . should . equal ( 2 )
assert processing_jobs_with_2 . get ( " NextToken " ) . should_not . be . none
2022-04-27 11:56:08 +00:00
processing_jobs_with_2_next = sagemaker_client . list_processing_jobs (
2021-11-06 12:47:42 +00:00
NameContains = " 2 " ,
MaxResults = 1 ,
NextToken = processing_jobs_with_2 . get ( " NextToken " ) ,
)
assert len ( processing_jobs_with_2_next [ " ProcessingJobSummaries " ] ) . should . equal ( 0 )
assert processing_jobs_with_2_next . get ( " NextToken " ) . should_not . be . none
2022-04-27 11:56:08 +00:00
processing_jobs_with_2_next_next = sagemaker_client . list_processing_jobs (
2021-11-06 12:47:42 +00:00
NameContains = " 2 " ,
MaxResults = 1 ,
NextToken = processing_jobs_with_2_next . get ( " NextToken " ) ,
)
assert len ( processing_jobs_with_2_next_next [ " ProcessingJobSummaries " ] ) . should . equal (
0
)
assert processing_jobs_with_2_next_next . get ( " NextToken " ) . should . be . none
2022-04-27 11:56:08 +00:00
def test_add_and_delete_tags_in_training_job ( sagemaker_client ) :
processing_job_name = " MyProcessingJob "
2022-11-17 22:41:08 +00:00
role_arn = f " arn:aws:iam:: { ACCOUNT_ID } :role/FakeRole "
2022-04-27 11:56:08 +00:00
container = " 382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1 "
bucket = " my-bucket "
prefix = " my-prefix "
app_specification = {
" ImageUri " : container ,
" ContainerEntrypoint " : [ " python3 " , " app.py " ] ,
}
processing_resources = {
" ClusterConfig " : {
" InstanceCount " : 2 ,
" InstanceType " : " ml.m5.xlarge " ,
" VolumeSizeInGB " : 20 ,
} ,
}
stopping_condition = { " MaxRuntimeInSeconds " : 60 * 60 }
job = MyProcessingJobModel (
processing_job_name ,
role_arn ,
container = container ,
bucket = bucket ,
prefix = prefix ,
app_specification = app_specification ,
processing_resources = processing_resources ,
stopping_condition = stopping_condition ,
)
resp = job . save ( sagemaker_client )
resource_arn = resp [ " ProcessingJobArn " ]
tags = [
{ " Key " : " myKey " , " Value " : " myValue " } ,
]
response = sagemaker_client . add_tags ( ResourceArn = resource_arn , Tags = tags )
assert response [ " ResponseMetadata " ] [ " HTTPStatusCode " ] == 200
response = sagemaker_client . list_tags ( ResourceArn = resource_arn )
assert response [ " Tags " ] == tags
tag_keys = [ tag [ " Key " ] for tag in tags ]
response = sagemaker_client . delete_tags ( ResourceArn = resource_arn , TagKeys = tag_keys )
assert response [ " ResponseMetadata " ] [ " HTTPStatusCode " ] == 200
response = sagemaker_client . list_tags ( ResourceArn = resource_arn )
assert response [ " Tags " ] == [ ]