2020-07-19 14:06:48 +00:00
|
|
|
import datetime
|
|
|
|
import boto3
|
2020-10-06 05:54:49 +00:00
|
|
|
from botocore.exceptions import ClientError
|
2021-10-18 19:44:29 +00:00
|
|
|
import sure # noqa # pylint: disable=unused-import
|
2020-07-19 14:06:48 +00:00
|
|
|
|
|
|
|
from moto import mock_sagemaker
|
|
|
|
from moto.sts.models import ACCOUNT_ID
|
2020-10-06 05:54:49 +00:00
|
|
|
import pytest
|
2020-07-19 14:06:48 +00:00
|
|
|
|
|
|
|
TEST_REGION_NAME = "us-east-1"
|
|
|
|
FAKE_ROLE_ARN = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)
|
|
|
|
GENERIC_TAGS_PARAM = [
|
|
|
|
{"Key": "newkey1", "Value": "newval1"},
|
|
|
|
{"Key": "newkey2", "Value": "newval2"},
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
@mock_sagemaker
|
|
|
|
def test_create_endpoint_config():
|
|
|
|
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
|
|
|
|
|
|
|
model_name = "MyModel"
|
|
|
|
production_variants = [
|
|
|
|
{
|
|
|
|
"VariantName": "MyProductionVariant",
|
|
|
|
"ModelName": model_name,
|
|
|
|
"InitialInstanceCount": 1,
|
|
|
|
"InstanceType": "ml.t2.medium",
|
|
|
|
},
|
|
|
|
]
|
|
|
|
|
|
|
|
endpoint_config_name = "MyEndpointConfig"
|
2020-10-06 05:54:49 +00:00
|
|
|
with pytest.raises(ClientError) as e:
|
2020-07-19 14:06:48 +00:00
|
|
|
sagemaker.create_endpoint_config(
|
|
|
|
EndpointConfigName=endpoint_config_name,
|
|
|
|
ProductionVariants=production_variants,
|
|
|
|
)
|
2020-10-06 06:04:09 +00:00
|
|
|
assert e.value.response["Error"]["Message"].startswith("Could not find model")
|
2020-07-19 14:06:48 +00:00
|
|
|
|
|
|
|
_create_model(sagemaker, model_name)
|
|
|
|
resp = sagemaker.create_endpoint_config(
|
|
|
|
EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants
|
|
|
|
)
|
|
|
|
resp["EndpointConfigArn"].should.match(
|
|
|
|
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
|
|
|
|
)
|
|
|
|
|
|
|
|
resp = sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
|
|
|
|
resp["EndpointConfigArn"].should.match(
|
|
|
|
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
|
|
|
|
)
|
|
|
|
resp["EndpointConfigName"].should.equal(endpoint_config_name)
|
|
|
|
resp["ProductionVariants"].should.equal(production_variants)
|
|
|
|
|
|
|
|
|
|
|
|
@mock_sagemaker
|
|
|
|
def test_delete_endpoint_config():
|
|
|
|
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
|
|
|
|
|
|
|
model_name = "MyModel"
|
|
|
|
_create_model(sagemaker, model_name)
|
|
|
|
|
|
|
|
endpoint_config_name = "MyEndpointConfig"
|
|
|
|
production_variants = [
|
|
|
|
{
|
|
|
|
"VariantName": "MyProductionVariant",
|
|
|
|
"ModelName": model_name,
|
|
|
|
"InitialInstanceCount": 1,
|
|
|
|
"InstanceType": "ml.t2.medium",
|
|
|
|
},
|
|
|
|
]
|
|
|
|
|
|
|
|
resp = sagemaker.create_endpoint_config(
|
|
|
|
EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants
|
|
|
|
)
|
|
|
|
resp["EndpointConfigArn"].should.match(
|
|
|
|
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
|
|
|
|
)
|
|
|
|
|
|
|
|
resp = sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
|
|
|
|
resp["EndpointConfigArn"].should.match(
|
|
|
|
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
|
|
|
|
)
|
|
|
|
|
|
|
|
resp = sagemaker.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
|
2020-10-06 05:54:49 +00:00
|
|
|
with pytest.raises(ClientError) as e:
|
2020-07-19 14:06:48 +00:00
|
|
|
sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
|
2020-10-06 06:46:05 +00:00
|
|
|
assert e.value.response["Error"]["Message"].startswith(
|
|
|
|
"Could not find endpoint configuration"
|
|
|
|
)
|
2020-07-19 14:06:48 +00:00
|
|
|
|
2020-10-06 05:54:49 +00:00
|
|
|
with pytest.raises(ClientError) as e:
|
2020-07-19 14:06:48 +00:00
|
|
|
sagemaker.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
|
2020-10-06 06:46:05 +00:00
|
|
|
assert e.value.response["Error"]["Message"].startswith(
|
|
|
|
"Could not find endpoint configuration"
|
|
|
|
)
|
2020-07-19 14:06:48 +00:00
|
|
|
|
|
|
|
|
|
|
|
@mock_sagemaker
|
|
|
|
def test_create_endpoint_invalid_instance_type():
|
|
|
|
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
|
|
|
|
|
|
|
model_name = "MyModel"
|
|
|
|
_create_model(sagemaker, model_name)
|
|
|
|
|
|
|
|
instance_type = "InvalidInstanceType"
|
|
|
|
production_variants = [
|
|
|
|
{
|
|
|
|
"VariantName": "MyProductionVariant",
|
|
|
|
"ModelName": model_name,
|
|
|
|
"InitialInstanceCount": 1,
|
|
|
|
"InstanceType": instance_type,
|
|
|
|
},
|
|
|
|
]
|
|
|
|
|
|
|
|
endpoint_config_name = "MyEndpointConfig"
|
2020-10-06 05:54:49 +00:00
|
|
|
with pytest.raises(ClientError) as e:
|
2020-07-19 14:06:48 +00:00
|
|
|
sagemaker.create_endpoint_config(
|
|
|
|
EndpointConfigName=endpoint_config_name,
|
|
|
|
ProductionVariants=production_variants,
|
|
|
|
)
|
2020-10-06 06:04:09 +00:00
|
|
|
assert e.value.response["Error"]["Code"] == "ValidationException"
|
2020-07-19 14:06:48 +00:00
|
|
|
expected_message = "Value '{}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: [".format(
|
|
|
|
instance_type
|
|
|
|
)
|
2020-10-06 06:04:09 +00:00
|
|
|
assert expected_message in e.value.response["Error"]["Message"]
|
2020-07-19 14:06:48 +00:00
|
|
|
|
|
|
|
|
|
|
|
@mock_sagemaker
|
|
|
|
def test_create_endpoint():
|
|
|
|
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
|
|
|
|
|
|
|
endpoint_name = "MyEndpoint"
|
2020-10-06 05:54:49 +00:00
|
|
|
with pytest.raises(ClientError) as e:
|
2020-07-19 14:06:48 +00:00
|
|
|
sagemaker.create_endpoint(
|
|
|
|
EndpointName=endpoint_name, EndpointConfigName="NonexistentEndpointConfig"
|
|
|
|
)
|
2020-10-06 06:46:05 +00:00
|
|
|
assert e.value.response["Error"]["Message"].startswith(
|
|
|
|
"Could not find endpoint configuration"
|
|
|
|
)
|
2020-07-19 14:06:48 +00:00
|
|
|
|
|
|
|
model_name = "MyModel"
|
|
|
|
_create_model(sagemaker, model_name)
|
|
|
|
|
|
|
|
endpoint_config_name = "MyEndpointConfig"
|
|
|
|
_create_endpoint_config(sagemaker, endpoint_config_name, model_name)
|
|
|
|
|
|
|
|
resp = sagemaker.create_endpoint(
|
|
|
|
EndpointName=endpoint_name,
|
|
|
|
EndpointConfigName=endpoint_config_name,
|
|
|
|
Tags=GENERIC_TAGS_PARAM,
|
|
|
|
)
|
|
|
|
resp["EndpointArn"].should.match(
|
|
|
|
r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(endpoint_name)
|
|
|
|
)
|
|
|
|
|
|
|
|
resp = sagemaker.describe_endpoint(EndpointName=endpoint_name)
|
|
|
|
resp["EndpointArn"].should.match(
|
|
|
|
r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(endpoint_name)
|
|
|
|
)
|
|
|
|
resp["EndpointName"].should.equal(endpoint_name)
|
|
|
|
resp["EndpointConfigName"].should.equal(endpoint_config_name)
|
|
|
|
resp["EndpointStatus"].should.equal("InService")
|
2020-10-06 05:54:49 +00:00
|
|
|
assert isinstance(resp["CreationTime"], datetime.datetime)
|
|
|
|
assert isinstance(resp["LastModifiedTime"], datetime.datetime)
|
2020-07-19 14:06:48 +00:00
|
|
|
resp["ProductionVariants"][0]["VariantName"].should.equal("MyProductionVariant")
|
|
|
|
|
|
|
|
resp = sagemaker.list_tags(ResourceArn=resp["EndpointArn"])
|
2020-10-06 05:54:49 +00:00
|
|
|
assert resp["Tags"] == GENERIC_TAGS_PARAM
|
2020-07-19 14:06:48 +00:00
|
|
|
|
|
|
|
|
|
|
|
@mock_sagemaker
|
|
|
|
def test_delete_endpoint():
|
|
|
|
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
|
|
|
|
|
|
|
model_name = "MyModel"
|
|
|
|
_create_model(sagemaker, model_name)
|
|
|
|
|
|
|
|
endpoint_config_name = "MyEndpointConfig"
|
|
|
|
_create_endpoint_config(sagemaker, endpoint_config_name, model_name)
|
|
|
|
|
|
|
|
endpoint_name = "MyEndpoint"
|
|
|
|
_create_endpoint(sagemaker, endpoint_name, endpoint_config_name)
|
|
|
|
|
|
|
|
sagemaker.delete_endpoint(EndpointName=endpoint_name)
|
2020-10-06 05:54:49 +00:00
|
|
|
with pytest.raises(ClientError) as e:
|
2020-07-19 14:06:48 +00:00
|
|
|
sagemaker.describe_endpoint(EndpointName=endpoint_name)
|
2020-10-06 06:04:09 +00:00
|
|
|
assert e.value.response["Error"]["Message"].startswith("Could not find endpoint")
|
2020-07-19 14:06:48 +00:00
|
|
|
|
2020-10-06 05:54:49 +00:00
|
|
|
with pytest.raises(ClientError) as e:
|
2020-07-19 14:06:48 +00:00
|
|
|
sagemaker.delete_endpoint(EndpointName=endpoint_name)
|
2020-10-06 06:04:09 +00:00
|
|
|
assert e.value.response["Error"]["Message"].startswith("Could not find endpoint")
|
2020-07-19 14:06:48 +00:00
|
|
|
|
|
|
|
|
|
|
|
def _create_model(boto_client, model_name):
|
|
|
|
resp = boto_client.create_model(
|
|
|
|
ModelName=model_name,
|
|
|
|
PrimaryContainer={
|
|
|
|
"Image": "382416733822.dkr.ecr.us-east-1.amazonaws.com/factorization-machines:1",
|
|
|
|
"ModelDataUrl": "s3://MyBucket/model.tar.gz",
|
|
|
|
},
|
|
|
|
ExecutionRoleArn=FAKE_ROLE_ARN,
|
|
|
|
)
|
2020-10-06 05:54:49 +00:00
|
|
|
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
2020-07-19 14:06:48 +00:00
|
|
|
|
|
|
|
|
|
|
|
def _create_endpoint_config(boto_client, endpoint_config_name, model_name):
|
|
|
|
production_variants = [
|
|
|
|
{
|
|
|
|
"VariantName": "MyProductionVariant",
|
|
|
|
"ModelName": model_name,
|
|
|
|
"InitialInstanceCount": 1,
|
|
|
|
"InstanceType": "ml.t2.medium",
|
|
|
|
},
|
|
|
|
]
|
|
|
|
resp = boto_client.create_endpoint_config(
|
|
|
|
EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants
|
|
|
|
)
|
|
|
|
resp["EndpointConfigArn"].should.match(
|
|
|
|
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def _create_endpoint(boto_client, endpoint_name, endpoint_config_name):
|
|
|
|
resp = boto_client.create_endpoint(
|
|
|
|
EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
|
|
|
|
)
|
|
|
|
resp["EndpointArn"].should.match(
|
|
|
|
r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(endpoint_name)
|
|
|
|
)
|