Co-authored-by: nivla1 <keongalvin@gmail.com>
This commit is contained in:
parent
cfc793014f
commit
5044df98bc
@ -1,5 +1,5 @@
|
||||
from __future__ import unicode_literals
|
||||
from moto.core.exceptions import RESTError, JsonRESTError
|
||||
from moto.core.exceptions import RESTError, JsonRESTError, AWSError
|
||||
|
||||
ERROR_WITH_MODEL_NAME = """{% extends 'single_error' %}
|
||||
{% block extra %}<ModelName>{{ model }}</ModelName>{% endblock %}
|
||||
@ -32,3 +32,7 @@ class MissingModel(ModelError):
|
||||
class ValidationError(JsonRESTError):
|
||||
def __init__(self, message, **kwargs):
|
||||
super(ValidationError, self).__init__("ValidationException", message, **kwargs)
|
||||
|
||||
|
||||
class AWSValidationException(AWSError):
|
||||
TYPE = "ValidationException"
|
||||
|
@ -3,11 +3,10 @@ from __future__ import unicode_literals
|
||||
import os
|
||||
from boto3 import Session
|
||||
from datetime import datetime
|
||||
|
||||
from moto.core import ACCOUNT_ID, BaseBackend, BaseModel, CloudFormationModel
|
||||
from moto.core.exceptions import RESTError
|
||||
from moto.sagemaker import validators
|
||||
from .exceptions import MissingModel, ValidationError
|
||||
from .exceptions import MissingModel, ValidationError, AWSValidationException
|
||||
|
||||
|
||||
class BaseObject(BaseModel):
|
||||
@ -1243,6 +1242,90 @@ class SageMakerModelBackend(BaseBackend):
|
||||
except RESTError:
|
||||
return []
|
||||
|
||||
def list_training_jobs(
|
||||
self,
|
||||
next_token,
|
||||
max_results,
|
||||
creation_time_after,
|
||||
creation_time_before,
|
||||
last_modified_time_after,
|
||||
last_modified_time_before,
|
||||
name_contains,
|
||||
status_equals,
|
||||
sort_by,
|
||||
sort_order,
|
||||
):
|
||||
if next_token:
|
||||
try:
|
||||
starting_index = int(next_token)
|
||||
if starting_index > len(self.training_jobs):
|
||||
raise ValueError # invalid next_token
|
||||
except ValueError:
|
||||
raise AWSValidationException('Invalid pagination token because "{0}".')
|
||||
else:
|
||||
starting_index = 0
|
||||
|
||||
if max_results:
|
||||
end_index = max_results + starting_index
|
||||
training_jobs_fetched = list(self.training_jobs.values())[
|
||||
starting_index:end_index
|
||||
]
|
||||
if end_index >= len(self.training_jobs):
|
||||
next_index = None
|
||||
else:
|
||||
next_index = end_index
|
||||
else:
|
||||
training_jobs_fetched = list(self.training_jobs.values())
|
||||
next_index = None
|
||||
|
||||
if name_contains is not None:
|
||||
training_jobs_fetched = filter(
|
||||
lambda x: name_contains in x.training_job_name, training_jobs_fetched
|
||||
)
|
||||
|
||||
if creation_time_after is not None:
|
||||
training_jobs_fetched = filter(
|
||||
lambda x: x.creation_time > creation_time_after, training_jobs_fetched
|
||||
)
|
||||
|
||||
if creation_time_before is not None:
|
||||
training_jobs_fetched = filter(
|
||||
lambda x: x.creation_time < creation_time_before, training_jobs_fetched
|
||||
)
|
||||
|
||||
if last_modified_time_after is not None:
|
||||
training_jobs_fetched = filter(
|
||||
lambda x: x.last_modified_time > last_modified_time_after,
|
||||
training_jobs_fetched,
|
||||
)
|
||||
|
||||
if last_modified_time_before is not None:
|
||||
training_jobs_fetched = filter(
|
||||
lambda x: x.last_modified_time < last_modified_time_before,
|
||||
training_jobs_fetched,
|
||||
)
|
||||
if status_equals is not None:
|
||||
training_jobs_fetched = filter(
|
||||
lambda x: x.training_job_status == status_equals, training_jobs_fetched
|
||||
)
|
||||
|
||||
training_job_summaries = [
|
||||
{
|
||||
"TrainingJobName": training_job_data.training_job_name,
|
||||
"TrainingJobArn": training_job_data.training_job_arn,
|
||||
"CreationTime": training_job_data.creation_time,
|
||||
"TrainingEndTime": training_job_data.training_end_time,
|
||||
"LastModifiedTime": training_job_data.last_modified_time,
|
||||
"TrainingJobStatus": training_job_data.training_job_status,
|
||||
}
|
||||
for training_job_data in training_jobs_fetched
|
||||
]
|
||||
|
||||
return {
|
||||
"TrainingJobSummaries": training_job_summaries,
|
||||
"NextToken": str(next_index) if next_index is not None else None,
|
||||
}
|
||||
|
||||
|
||||
sagemaker_backends = {}
|
||||
for region in Session().get_available_regions("sagemaker"):
|
||||
|
@ -1,6 +1,7 @@
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import json
|
||||
from moto.sagemaker.exceptions import AWSValidationException
|
||||
|
||||
from moto.core.exceptions import AWSError
|
||||
from moto.core.responses import BaseResponse
|
||||
@ -8,6 +9,10 @@ from moto.core.utils import amzn_request_id
|
||||
from .models import sagemaker_backends
|
||||
|
||||
|
||||
def format_enum_error(value, attribute, allowed):
|
||||
return f"Value '{value}' at '{attribute}' failed to satisfy constraint: Member must satisfy enum value set: {allowed}"
|
||||
|
||||
|
||||
class SageMakerResponse(BaseResponse):
|
||||
@property
|
||||
def sagemaker_backend(self):
|
||||
@ -274,3 +279,65 @@ class SageMakerResponse(BaseResponse):
|
||||
)
|
||||
)
|
||||
return 200, {}, json.dumps("{}")
|
||||
|
||||
@amzn_request_id
|
||||
def list_training_jobs(self):
|
||||
max_results_range = range(1, 101)
|
||||
allowed_sort_by = ["Name", "CreationTime", "Status"]
|
||||
allowed_sort_order = ["Ascending", "Descending"]
|
||||
allowed_status_equals = [
|
||||
"Completed",
|
||||
"Stopped",
|
||||
"InProgress",
|
||||
"Stopping",
|
||||
"Failed",
|
||||
]
|
||||
|
||||
try:
|
||||
max_results = self._get_int_param("MaxResults")
|
||||
sort_by = self._get_param("SortBy", "CreationTime")
|
||||
sort_order = self._get_param("SortOrder", "Ascending")
|
||||
status_equals = self._get_param("StatusEquals")
|
||||
next_token = self._get_param("NextToken")
|
||||
errors = []
|
||||
if max_results and max_results not in max_results_range:
|
||||
errors.append(
|
||||
"Value '%s' at 'maxResults' failed to satisfy constraint: Member must have value less than or equal to %s".format(
|
||||
max_results, max_results_range[-1]
|
||||
)
|
||||
)
|
||||
|
||||
if sort_by not in allowed_sort_by:
|
||||
errors.append(format_enum_error(sort_by, "sortBy", allowed_sort_by))
|
||||
if sort_order not in allowed_sort_order:
|
||||
errors.append(
|
||||
format_enum_error(sort_order, "sortOrder", allowed_sort_order)
|
||||
)
|
||||
|
||||
if status_equals and status_equals not in allowed_status_equals:
|
||||
errors.append(
|
||||
format_enum_error(
|
||||
status_equals, "statusEquals", allowed_status_equals
|
||||
)
|
||||
)
|
||||
|
||||
if errors != []:
|
||||
raise AWSValidationException(
|
||||
f"{len(errors)} validation errors detected: {';'.join(errors)}"
|
||||
)
|
||||
|
||||
response = self.sagemaker_backend.list_training_jobs(
|
||||
next_token=next_token,
|
||||
max_results=max_results,
|
||||
creation_time_after=self._get_param("CreationTimeAfter"),
|
||||
creation_time_before=self._get_param("CreationTimeBefore"),
|
||||
last_modified_time_after=self._get_param("LastModifiedTimeAfter"),
|
||||
last_modified_time_before=self._get_param("LastModifiedTimeBefore"),
|
||||
name_contains=self._get_param("NameContains"),
|
||||
status_equals=status_equals,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
return 200, {}, json.dumps(response)
|
||||
except AWSError as err:
|
||||
return err.response()
|
||||
|
@ -9,4 +9,4 @@ click
|
||||
inflection==0.3.1
|
||||
lxml
|
||||
packaging
|
||||
|
||||
prompt_toolkit
|
||||
|
@ -15,10 +15,8 @@ class MySageMakerModel(object):
|
||||
def __init__(self, name, arn, container=None, vpc_config=None):
|
||||
self.name = name
|
||||
self.arn = arn
|
||||
self.container = container if container else {}
|
||||
self.vpc_config = (
|
||||
vpc_config if vpc_config else {"sg-groups": ["sg-123"], "subnets": ["123"]}
|
||||
)
|
||||
self.container = container or {}
|
||||
self.vpc_config = vpc_config or {"sg-groups": ["sg-123"], "subnets": ["123"]}
|
||||
|
||||
def save(self):
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
|
@ -1,8 +1,15 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import unicode_literals
|
||||
from moto.core.exceptions import JsonRESTError
|
||||
|
||||
from re import M
|
||||
from moto.core import responses
|
||||
from os import O_DSYNC, scandir
|
||||
import pytest
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
import datetime
|
||||
from botocore.configloader import raw_config_parse
|
||||
import sure # noqa
|
||||
|
||||
from moto import mock_sagemaker
|
||||
@ -12,28 +19,114 @@ FAKE_ROLE_ARN = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)
|
||||
TEST_REGION_NAME = "us-east-1"
|
||||
|
||||
|
||||
class MyTrainingJobModel(object):
|
||||
def __init__(
|
||||
self,
|
||||
training_job_name,
|
||||
role_arn,
|
||||
container=None,
|
||||
bucket=None,
|
||||
prefix=None,
|
||||
algorithm_specification=None,
|
||||
resource_config=None,
|
||||
input_data_config=None,
|
||||
output_data_config=None,
|
||||
hyper_parameters=None,
|
||||
stopping_condition=None,
|
||||
):
|
||||
self.training_job_name = training_job_name
|
||||
self.role_arn = role_arn
|
||||
self.container = (
|
||||
container or "382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1"
|
||||
)
|
||||
self.bucket = bucket or "my-bucket"
|
||||
self.prefix = prefix or "sagemaker/DEMO-breast-cancer-prediction/"
|
||||
self.algorithm_specification = algorithm_specification or {
|
||||
"TrainingImage": self.container,
|
||||
"TrainingInputMode": "File",
|
||||
}
|
||||
self.resource_config = resource_config or {
|
||||
"InstanceCount": 1,
|
||||
"InstanceType": "ml.c4.2xlarge",
|
||||
"VolumeSizeInGB": 10,
|
||||
}
|
||||
self.input_data_config = input_data_config or [
|
||||
{
|
||||
"ChannelName": "train",
|
||||
"DataSource": {
|
||||
"S3DataSource": {
|
||||
"S3DataType": "S3Prefix",
|
||||
"S3Uri": "s3://{}/{}/train/".format(self.bucket, self.prefix),
|
||||
"S3DataDistributionType": "ShardedByS3Key",
|
||||
}
|
||||
},
|
||||
"CompressionType": "None",
|
||||
"RecordWrapperType": "None",
|
||||
},
|
||||
{
|
||||
"ChannelName": "validation",
|
||||
"DataSource": {
|
||||
"S3DataSource": {
|
||||
"S3DataType": "S3Prefix",
|
||||
"S3Uri": "s3://{}/{}/validation/".format(
|
||||
self.bucket, self.prefix
|
||||
),
|
||||
"S3DataDistributionType": "FullyReplicated",
|
||||
}
|
||||
},
|
||||
"CompressionType": "None",
|
||||
"RecordWrapperType": "None",
|
||||
},
|
||||
]
|
||||
self.output_data_config = output_data_config or {
|
||||
"S3OutputPath": "s3://{}/{}/".format(self.bucket, self.prefix)
|
||||
}
|
||||
self.hyper_parameters = hyper_parameters or {
|
||||
"feature_dim": "30",
|
||||
"mini_batch_size": "100",
|
||||
"predictor_type": "regressor",
|
||||
"epochs": "10",
|
||||
"num_models": "32",
|
||||
"loss": "absolute_loss",
|
||||
}
|
||||
|
||||
self.stopping_condition = stopping_condition or {"MaxRuntimeInSeconds": 60 * 60}
|
||||
|
||||
def save(self):
|
||||
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
params = {
|
||||
"RoleArn": self.role_arn,
|
||||
"TrainingJobName": self.training_job_name,
|
||||
"AlgorithmSpecification": self.algorithm_specification,
|
||||
"ResourceConfig": self.resource_config,
|
||||
"InputDataConfig": self.input_data_config,
|
||||
"OutputDataConfig": self.output_data_config,
|
||||
"HyperParameters": self.hyper_parameters,
|
||||
"StoppingCondition": self.stopping_condition,
|
||||
}
|
||||
return sagemaker.create_training_job(**params)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_training_job():
|
||||
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
training_job_name = "MyTrainingJob"
|
||||
role_arn = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)
|
||||
container = "382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1"
|
||||
bucket = "my-bucket"
|
||||
prefix = "sagemaker/DEMO-breast-cancer-prediction/"
|
||||
|
||||
params = {
|
||||
"RoleArn": FAKE_ROLE_ARN,
|
||||
"TrainingJobName": training_job_name,
|
||||
"AlgorithmSpecification": {
|
||||
algorithm_specification = {
|
||||
"TrainingImage": container,
|
||||
"TrainingInputMode": "File",
|
||||
},
|
||||
"ResourceConfig": {
|
||||
}
|
||||
resource_config = {
|
||||
"InstanceCount": 1,
|
||||
"InstanceType": "ml.c4.2xlarge",
|
||||
"VolumeSizeInGB": 10,
|
||||
},
|
||||
"InputDataConfig": [
|
||||
}
|
||||
input_data_config = [
|
||||
{
|
||||
"ChannelName": "train",
|
||||
"DataSource": {
|
||||
@ -58,20 +151,32 @@ def test_create_training_job():
|
||||
"CompressionType": "None",
|
||||
"RecordWrapperType": "None",
|
||||
},
|
||||
],
|
||||
"OutputDataConfig": {"S3OutputPath": "s3://{}/{}/".format(bucket, prefix)},
|
||||
"HyperParameters": {
|
||||
]
|
||||
output_data_config = {"S3OutputPath": "s3://{}/{}/".format(bucket, prefix)}
|
||||
hyper_parameters = {
|
||||
"feature_dim": "30",
|
||||
"mini_batch_size": "100",
|
||||
"predictor_type": "regressor",
|
||||
"epochs": "10",
|
||||
"num_models": "32",
|
||||
"loss": "absolute_loss",
|
||||
},
|
||||
"StoppingCondition": {"MaxRuntimeInSeconds": 60 * 60},
|
||||
}
|
||||
stopping_condition = {"MaxRuntimeInSeconds": 60 * 60}
|
||||
|
||||
resp = sagemaker.create_training_job(**params)
|
||||
job = MyTrainingJobModel(
|
||||
training_job_name,
|
||||
role_arn,
|
||||
container=container,
|
||||
bucket=bucket,
|
||||
prefix=prefix,
|
||||
algorithm_specification=algorithm_specification,
|
||||
resource_config=resource_config,
|
||||
input_data_config=input_data_config,
|
||||
output_data_config=output_data_config,
|
||||
hyper_parameters=hyper_parameters,
|
||||
stopping_condition=stopping_condition,
|
||||
)
|
||||
resp = job.save()
|
||||
resp["TrainingJobArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:training-job/{}$".format(training_job_name)
|
||||
)
|
||||
@ -82,29 +187,29 @@ def test_create_training_job():
|
||||
r"^arn:aws:sagemaker:.*:.*:training-job/{}$".format(training_job_name)
|
||||
)
|
||||
assert resp["ModelArtifacts"]["S3ModelArtifacts"].startswith(
|
||||
params["OutputDataConfig"]["S3OutputPath"]
|
||||
output_data_config["S3OutputPath"]
|
||||
)
|
||||
assert training_job_name in (resp["ModelArtifacts"]["S3ModelArtifacts"])
|
||||
assert resp["ModelArtifacts"]["S3ModelArtifacts"].endswith("output/model.tar.gz")
|
||||
assert resp["TrainingJobStatus"] == "Completed"
|
||||
assert resp["SecondaryStatus"] == "Completed"
|
||||
assert resp["HyperParameters"] == params["HyperParameters"]
|
||||
assert resp["HyperParameters"] == hyper_parameters
|
||||
assert (
|
||||
resp["AlgorithmSpecification"]["TrainingImage"]
|
||||
== params["AlgorithmSpecification"]["TrainingImage"]
|
||||
== algorithm_specification["TrainingImage"]
|
||||
)
|
||||
assert (
|
||||
resp["AlgorithmSpecification"]["TrainingInputMode"]
|
||||
== params["AlgorithmSpecification"]["TrainingInputMode"]
|
||||
== algorithm_specification["TrainingInputMode"]
|
||||
)
|
||||
assert "MetricDefinitions" in resp["AlgorithmSpecification"]
|
||||
assert "Name" in resp["AlgorithmSpecification"]["MetricDefinitions"][0]
|
||||
assert "Regex" in resp["AlgorithmSpecification"]["MetricDefinitions"][0]
|
||||
assert resp["RoleArn"] == FAKE_ROLE_ARN
|
||||
assert resp["InputDataConfig"] == params["InputDataConfig"]
|
||||
assert resp["OutputDataConfig"] == params["OutputDataConfig"]
|
||||
assert resp["ResourceConfig"] == params["ResourceConfig"]
|
||||
assert resp["StoppingCondition"] == params["StoppingCondition"]
|
||||
assert resp["RoleArn"] == role_arn
|
||||
assert resp["InputDataConfig"] == input_data_config
|
||||
assert resp["OutputDataConfig"] == output_data_config
|
||||
assert resp["ResourceConfig"] == resource_config
|
||||
assert resp["StoppingCondition"] == stopping_condition
|
||||
assert isinstance(resp["CreationTime"], datetime.datetime)
|
||||
assert isinstance(resp["TrainingStartTime"], datetime.datetime)
|
||||
assert isinstance(resp["TrainingEndTime"], datetime.datetime)
|
||||
@ -120,3 +225,183 @@ def test_create_training_job():
|
||||
assert "Timestamp" in resp["FinalMetricDataList"][0]
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_training_jobs():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
name = "blah"
|
||||
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar"
|
||||
test_training_job = MyTrainingJobModel(training_job_name=name, role_arn=arn)
|
||||
test_training_job.save()
|
||||
training_jobs = client.list_training_jobs()
|
||||
assert len(training_jobs["TrainingJobSummaries"]).should.equal(1)
|
||||
assert training_jobs["TrainingJobSummaries"][0]["TrainingJobName"].should.equal(
|
||||
name
|
||||
)
|
||||
|
||||
assert training_jobs["TrainingJobSummaries"][0]["TrainingJobArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:training-job/{}$".format(name)
|
||||
)
|
||||
assert training_jobs.get("NextToken") is None
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_training_jobs_multiple():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
name_job_1 = "blah"
|
||||
arn_job_1 = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar"
|
||||
test_training_job_1 = MyTrainingJobModel(
|
||||
training_job_name=name_job_1, role_arn=arn_job_1
|
||||
)
|
||||
test_training_job_1.save()
|
||||
|
||||
name_job_2 = "blah2"
|
||||
arn_job_2 = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar2"
|
||||
test_training_job_2 = MyTrainingJobModel(
|
||||
training_job_name=name_job_2, role_arn=arn_job_2
|
||||
)
|
||||
test_training_job_2.save()
|
||||
training_jobs_limit = client.list_training_jobs(MaxResults=1)
|
||||
assert len(training_jobs_limit["TrainingJobSummaries"]).should.equal(1)
|
||||
|
||||
training_jobs = client.list_training_jobs()
|
||||
assert len(training_jobs["TrainingJobSummaries"]).should.equal(2)
|
||||
assert training_jobs.get("NextToken").should.be.none
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_training_jobs_none():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
training_jobs = client.list_training_jobs()
|
||||
assert len(training_jobs["TrainingJobSummaries"]).should.equal(0)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_training_jobs_should_validate_input():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
junk_status_equals = "blah"
|
||||
with pytest.raises(ClientError) as ex:
|
||||
client.list_training_jobs(StatusEquals=junk_status_equals)
|
||||
expected_error = f"1 validation errors detected: Value '{junk_status_equals}' at 'statusEquals' failed to satisfy constraint: Member must satisfy enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', 'Failed']"
|
||||
assert ex.value.response["Error"]["Code"] == "ValidationException"
|
||||
assert ex.value.response["Error"]["Message"] == expected_error
|
||||
|
||||
junk_next_token = "asdf"
|
||||
with pytest.raises(ClientError) as ex:
|
||||
client.list_training_jobs(NextToken=junk_next_token)
|
||||
assert ex.value.response["Error"]["Code"] == "ValidationException"
|
||||
assert (
|
||||
ex.value.response["Error"]["Message"]
|
||||
== 'Invalid pagination token because "{0}".'
|
||||
)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_training_jobs_with_name_filters():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
for i in range(5):
|
||||
name = "xgboost-{}".format(i)
|
||||
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i)
|
||||
MyTrainingJobModel(training_job_name=name, role_arn=arn).save()
|
||||
for i in range(5):
|
||||
name = "vgg-{}".format(i)
|
||||
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{}".format(i)
|
||||
MyTrainingJobModel(training_job_name=name, role_arn=arn).save()
|
||||
xgboost_training_jobs = client.list_training_jobs(NameContains="xgboost")
|
||||
assert len(xgboost_training_jobs["TrainingJobSummaries"]).should.equal(5)
|
||||
|
||||
training_jobs_with_2 = client.list_training_jobs(NameContains="2")
|
||||
assert len(training_jobs_with_2["TrainingJobSummaries"]).should.equal(2)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_training_jobs_paginated():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
for i in range(5):
|
||||
name = "xgboost-{}".format(i)
|
||||
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i)
|
||||
MyTrainingJobModel(training_job_name=name, role_arn=arn).save()
|
||||
xgboost_training_job_1 = client.list_training_jobs(
|
||||
NameContains="xgboost", MaxResults=1
|
||||
)
|
||||
assert len(xgboost_training_job_1["TrainingJobSummaries"]).should.equal(1)
|
||||
assert xgboost_training_job_1["TrainingJobSummaries"][0][
|
||||
"TrainingJobName"
|
||||
].should.equal("xgboost-0")
|
||||
assert xgboost_training_job_1.get("NextToken").should_not.be.none
|
||||
|
||||
xgboost_training_job_next = client.list_training_jobs(
|
||||
NameContains="xgboost",
|
||||
MaxResults=1,
|
||||
NextToken=xgboost_training_job_1.get("NextToken"),
|
||||
)
|
||||
assert len(xgboost_training_job_next["TrainingJobSummaries"]).should.equal(1)
|
||||
assert xgboost_training_job_next["TrainingJobSummaries"][0][
|
||||
"TrainingJobName"
|
||||
].should.equal("xgboost-1")
|
||||
assert xgboost_training_job_next.get("NextToken").should_not.be.none
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_training_jobs_paginated_with_target_in_middle():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
for i in range(5):
|
||||
name = "xgboost-{}".format(i)
|
||||
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i)
|
||||
MyTrainingJobModel(training_job_name=name, role_arn=arn).save()
|
||||
for i in range(5):
|
||||
name = "vgg-{}".format(i)
|
||||
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{}".format(i)
|
||||
MyTrainingJobModel(training_job_name=name, role_arn=arn).save()
|
||||
|
||||
vgg_training_job_1 = client.list_training_jobs(NameContains="vgg", MaxResults=1)
|
||||
assert len(vgg_training_job_1["TrainingJobSummaries"]).should.equal(0)
|
||||
assert vgg_training_job_1.get("NextToken").should_not.be.none
|
||||
|
||||
vgg_training_job_6 = client.list_training_jobs(NameContains="vgg", MaxResults=6)
|
||||
|
||||
assert len(vgg_training_job_6["TrainingJobSummaries"]).should.equal(1)
|
||||
assert vgg_training_job_6["TrainingJobSummaries"][0][
|
||||
"TrainingJobName"
|
||||
].should.equal("vgg-0")
|
||||
assert vgg_training_job_6.get("NextToken").should_not.be.none
|
||||
|
||||
vgg_training_job_10 = client.list_training_jobs(NameContains="vgg", MaxResults=10)
|
||||
|
||||
assert len(vgg_training_job_10["TrainingJobSummaries"]).should.equal(5)
|
||||
assert vgg_training_job_10["TrainingJobSummaries"][-1][
|
||||
"TrainingJobName"
|
||||
].should.equal("vgg-4")
|
||||
assert vgg_training_job_10.get("NextToken").should.be.none
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_list_training_jobs_paginated_with_fragmented_targets():
|
||||
client = boto3.client("sagemaker", region_name="us-east-1")
|
||||
for i in range(5):
|
||||
name = "xgboost-{}".format(i)
|
||||
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{}".format(i)
|
||||
MyTrainingJobModel(training_job_name=name, role_arn=arn).save()
|
||||
for i in range(5):
|
||||
name = "vgg-{}".format(i)
|
||||
arn = "arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{}".format(i)
|
||||
MyTrainingJobModel(training_job_name=name, role_arn=arn).save()
|
||||
|
||||
training_jobs_with_2 = client.list_training_jobs(NameContains="2", MaxResults=8)
|
||||
assert len(training_jobs_with_2["TrainingJobSummaries"]).should.equal(2)
|
||||
assert training_jobs_with_2.get("NextToken").should_not.be.none
|
||||
|
||||
training_jobs_with_2_next = client.list_training_jobs(
|
||||
NameContains="2", MaxResults=1, NextToken=training_jobs_with_2.get("NextToken"),
|
||||
)
|
||||
assert len(training_jobs_with_2_next["TrainingJobSummaries"]).should.equal(0)
|
||||
assert training_jobs_with_2_next.get("NextToken").should_not.be.none
|
||||
|
||||
training_jobs_with_2_next_next = client.list_training_jobs(
|
||||
NameContains="2",
|
||||
MaxResults=1,
|
||||
NextToken=training_jobs_with_2_next.get("NextToken"),
|
||||
)
|
||||
assert len(training_jobs_with_2_next_next["TrainingJobSummaries"]).should.equal(0)
|
||||
assert training_jobs_with_2_next_next.get("NextToken").should.be.none
|
||||
|
Loading…
Reference in New Issue
Block a user