moto/tests/test_sagemaker/test_sagemaker_endpoint.py

292 lines
10 KiB
Python
Raw Normal View History

import datetime
import uuid
import boto3
from botocore.exceptions import ClientError
2021-10-18 19:44:29 +00:00
import sure # noqa # pylint: disable=unused-import
from moto import mock_sagemaker
from moto.sts.models import ACCOUNT_ID
import pytest
TEST_REGION_NAME = "us-east-1"
TEST_ROLE_ARN = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)
GENERIC_TAGS_PARAM = [
{"Key": "newkey1", "Value": "newval1"},
{"Key": "newkey2", "Value": "newval2"},
]
TEST_MODEL_NAME = "MyModel"
TEST_ENDPOINT_NAME = "MyEndpoint"
TEST_ENDPOINT_CONFIG_NAME = "MyEndpointConfig"
TEST_PRODUCTION_VARIANTS = [
{
"VariantName": "MyProductionVariant",
"ModelName": TEST_MODEL_NAME,
"InitialInstanceCount": 1,
"InstanceType": "ml.t2.medium",
},
]
@pytest.fixture
def sagemaker_client():
return boto3.client("sagemaker", region_name=TEST_REGION_NAME)
@mock_sagemaker
def test_create_endpoint_config(sagemaker_client):
with pytest.raises(ClientError) as e:
sagemaker_client.create_endpoint_config(
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME,
ProductionVariants=TEST_PRODUCTION_VARIANTS,
)
2020-10-06 06:04:09 +00:00
assert e.value.response["Error"]["Message"].startswith("Could not find model")
_create_model(sagemaker_client, TEST_MODEL_NAME)
resp = sagemaker_client.create_endpoint_config(
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME,
ProductionVariants=TEST_PRODUCTION_VARIANTS,
)
resp["EndpointConfigArn"].should.match(
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(
TEST_ENDPOINT_CONFIG_NAME
)
)
resp = sagemaker_client.describe_endpoint_config(
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME
)
resp["EndpointConfigArn"].should.match(
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(
TEST_ENDPOINT_CONFIG_NAME
)
)
resp["EndpointConfigName"].should.equal(TEST_ENDPOINT_CONFIG_NAME)
resp["ProductionVariants"].should.equal(TEST_PRODUCTION_VARIANTS)
@mock_sagemaker
def test_delete_endpoint_config(sagemaker_client):
_create_model(sagemaker_client, TEST_MODEL_NAME)
resp = sagemaker_client.create_endpoint_config(
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME,
ProductionVariants=TEST_PRODUCTION_VARIANTS,
)
resp["EndpointConfigArn"].should.match(
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(
TEST_ENDPOINT_CONFIG_NAME
)
)
resp = sagemaker_client.describe_endpoint_config(
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME
)
resp["EndpointConfigArn"].should.match(
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(
TEST_ENDPOINT_CONFIG_NAME
)
)
sagemaker_client.delete_endpoint_config(
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME
)
with pytest.raises(ClientError) as e:
sagemaker_client.describe_endpoint_config(
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME
)
2020-10-06 06:46:05 +00:00
assert e.value.response["Error"]["Message"].startswith(
"Could not find endpoint configuration"
)
with pytest.raises(ClientError) as e:
sagemaker_client.delete_endpoint_config(
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME
)
2020-10-06 06:46:05 +00:00
assert e.value.response["Error"]["Message"].startswith(
"Could not find endpoint configuration"
)
@mock_sagemaker
def test_create_endpoint_invalid_instance_type(sagemaker_client):
_create_model(sagemaker_client, TEST_MODEL_NAME)
instance_type = "InvalidInstanceType"
production_variants = TEST_PRODUCTION_VARIANTS
production_variants[0]["InstanceType"] = instance_type
with pytest.raises(ClientError) as e:
sagemaker_client.create_endpoint_config(
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME,
ProductionVariants=production_variants,
)
2020-10-06 06:04:09 +00:00
assert e.value.response["Error"]["Code"] == "ValidationException"
expected_message = "Value '{}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: [".format(
instance_type
)
2020-10-06 06:04:09 +00:00
assert expected_message in e.value.response["Error"]["Message"]
@mock_sagemaker
def test_create_endpoint(sagemaker_client):
with pytest.raises(ClientError) as e:
sagemaker_client.create_endpoint(
EndpointName=TEST_ENDPOINT_NAME,
EndpointConfigName="NonexistentEndpointConfig",
)
2020-10-06 06:46:05 +00:00
assert e.value.response["Error"]["Message"].startswith(
"Could not find endpoint configuration"
)
_create_model(sagemaker_client, TEST_MODEL_NAME)
_create_endpoint_config(
sagemaker_client, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
)
resp = sagemaker_client.create_endpoint(
EndpointName=TEST_ENDPOINT_NAME,
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME,
Tags=GENERIC_TAGS_PARAM,
)
resp["EndpointArn"].should.match(
r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(TEST_ENDPOINT_NAME)
)
resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME)
resp["EndpointArn"].should.match(
r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(TEST_ENDPOINT_NAME)
)
resp["EndpointName"].should.equal(TEST_ENDPOINT_NAME)
resp["EndpointConfigName"].should.equal(TEST_ENDPOINT_CONFIG_NAME)
resp["EndpointStatus"].should.equal("InService")
assert isinstance(resp["CreationTime"], datetime.datetime)
assert isinstance(resp["LastModifiedTime"], datetime.datetime)
resp["ProductionVariants"][0]["VariantName"].should.equal("MyProductionVariant")
resp = sagemaker_client.list_tags(ResourceArn=resp["EndpointArn"])
assert resp["Tags"] == GENERIC_TAGS_PARAM
@mock_sagemaker
def test_delete_endpoint(sagemaker_client):
_set_up_sagemaker_resources(
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
)
sagemaker_client.delete_endpoint(EndpointName=TEST_ENDPOINT_NAME)
with pytest.raises(ClientError) as e:
sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME)
2020-10-06 06:04:09 +00:00
assert e.value.response["Error"]["Message"].startswith("Could not find endpoint")
with pytest.raises(ClientError) as e:
sagemaker_client.delete_endpoint(EndpointName=TEST_ENDPOINT_NAME)
2020-10-06 06:04:09 +00:00
assert e.value.response["Error"]["Message"].startswith("Could not find endpoint")
@mock_sagemaker
def test_add_tags_endpoint(sagemaker_client):
_set_up_sagemaker_resources(
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
)
resource_arn = f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:endpoint/{TEST_ENDPOINT_NAME}"
response = sagemaker_client.add_tags(
ResourceArn=resource_arn, Tags=GENERIC_TAGS_PARAM
)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
response = sagemaker_client.list_tags(ResourceArn=resource_arn)
assert response["Tags"] == GENERIC_TAGS_PARAM
@mock_sagemaker
def test_delete_tags_endpoint(sagemaker_client):
_set_up_sagemaker_resources(
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
)
resource_arn = f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:endpoint/{TEST_ENDPOINT_NAME}"
response = sagemaker_client.add_tags(
ResourceArn=resource_arn, Tags=GENERIC_TAGS_PARAM
)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
tag_keys = [tag["Key"] for tag in GENERIC_TAGS_PARAM]
response = sagemaker_client.delete_tags(ResourceArn=resource_arn, TagKeys=tag_keys)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
response = sagemaker_client.list_tags(ResourceArn=resource_arn)
assert response["Tags"] == []
@mock_sagemaker
def test_list_tags_endpoint(sagemaker_client):
_set_up_sagemaker_resources(
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
)
tags = []
for _ in range(80):
tags.append({"Key": str(uuid.uuid4()), "Value": "myValue"})
resource_arn = f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}:endpoint/{TEST_ENDPOINT_NAME}"
response = sagemaker_client.add_tags(ResourceArn=resource_arn, Tags=tags)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
response = sagemaker_client.list_tags(ResourceArn=resource_arn)
assert len(response["Tags"]) == 50
assert response["Tags"] == tags[:50]
response = sagemaker_client.list_tags(
ResourceArn=resource_arn, NextToken=response["NextToken"]
)
assert len(response["Tags"]) == 30
assert response["Tags"] == tags[50:]
def _set_up_sagemaker_resources(
boto_client, endpoint_name, endpoint_config_name, model_name
):
_create_model(boto_client, model_name)
_create_endpoint_config(boto_client, endpoint_config_name, model_name)
_create_endpoint(boto_client, endpoint_name, endpoint_config_name)
def _create_model(boto_client, model_name):
resp = boto_client.create_model(
ModelName=model_name,
PrimaryContainer={
"Image": "382416733822.dkr.ecr.us-east-1.amazonaws.com/factorization-machines:1",
"ModelDataUrl": "s3://MyBucket/model.tar.gz",
},
ExecutionRoleArn=TEST_ROLE_ARN,
)
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
def _create_endpoint_config(boto_client, endpoint_config_name, model_name):
production_variants = [
{
"VariantName": "MyProductionVariant",
"ModelName": model_name,
"InitialInstanceCount": 1,
"InstanceType": "ml.t2.medium",
},
]
resp = boto_client.create_endpoint_config(
EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants
)
resp["EndpointConfigArn"].should.match(
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
)
def _create_endpoint(boto_client, endpoint_name, endpoint_config_name):
resp = boto_client.create_endpoint(
EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)
resp["EndpointArn"].should.match(
r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(endpoint_name)
)