Sagemaker: Add support for Serverless Endpoint Configurations (#5445)

This commit is contained in:
Arnaud Stiegler 2022-09-02 14:42:50 -04:00 committed by GitHub
parent 47a358fc35
commit 6f3c391812
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 81 additions and 12 deletions

View File

@ -398,7 +398,25 @@ class FakeEndpointConfig(BaseObject, CloudFormationModel):
def validate_production_variants(self, production_variants):
for production_variant in production_variants:
self.validate_instance_type(production_variant["InstanceType"])
if "InstanceType" in production_variant.keys():
self.validate_instance_type(production_variant["InstanceType"])
elif "ServerlessConfig" in production_variant.keys():
self.validate_serverless_config(production_variant["ServerlessConfig"])
else:
message = "Invalid Keys for ProductionVariant: received {} but expected it to contain one of {}".format(
production_variant.keys(), ["InstanceType", "ServerlessConfig"]
)
raise ValidationError(message=message)
def validate_serverless_config(self, serverless_config):
VALID_SERVERLESS_MEMORY_SIZE = [1024, 2048, 3072, 4096, 5120, 6144]
if not validators.is_one_of(
serverless_config["MemorySizeInMB"], VALID_SERVERLESS_MEMORY_SIZE
):
message = "Value '{}' at 'MemorySizeInMB' failed to satisfy constraint: Member must satisfy enum value set: {}".format(
serverless_config["MemorySizeInMB"], VALID_SERVERLESS_MEMORY_SIZE
)
raise ValidationError(message=message)
def validate_instance_type(self, instance_type):
VALID_INSTANCE_TYPES = [

View File

@ -20,6 +20,8 @@ TEST_ENDPOINT_NAME = "MyEndpoint"
TEST_ENDPOINT_CONFIG_NAME = "MyEndpointConfig"
TEST_VARIANT_NAME = "MyProductionVariant"
TEST_INSTANCE_TYPE = "ml.t2.medium"
TEST_MEMORY_SIZE = 1024
TEST_CONCURRENCY = 10
TEST_PRODUCTION_VARIANTS = [
{
"VariantName": TEST_VARIANT_NAME,
@ -28,6 +30,16 @@ TEST_PRODUCTION_VARIANTS = [
"InstanceType": TEST_INSTANCE_TYPE,
},
]
TEST_SERVERLESS_PRODUCTION_VARIANTS = [
{
"VariantName": TEST_VARIANT_NAME,
"ModelName": TEST_MODEL_NAME,
"ServerlessConfig": {
"MemorySizeInMB": TEST_MEMORY_SIZE,
"MaxConcurrency": TEST_CONCURRENCY,
},
},
]
@pytest.fixture
@ -35,19 +47,12 @@ def sagemaker_client():
return boto3.client("sagemaker", region_name=TEST_REGION_NAME)
@mock_sagemaker
def test_create_endpoint_config(sagemaker_client):
with pytest.raises(ClientError) as e:
sagemaker_client.create_endpoint_config(
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME,
ProductionVariants=TEST_PRODUCTION_VARIANTS,
)
assert e.value.response["Error"]["Message"].startswith("Could not find model")
def create_endpoint_config_helper(sagemaker_client, production_variants):
_create_model(sagemaker_client, TEST_MODEL_NAME)
resp = sagemaker_client.create_endpoint_config(
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME,
ProductionVariants=TEST_PRODUCTION_VARIANTS,
ProductionVariants=production_variants,
)
resp["EndpointConfigArn"].should.match(
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(
@ -64,7 +69,33 @@ def test_create_endpoint_config(sagemaker_client):
)
)
resp["EndpointConfigName"].should.equal(TEST_ENDPOINT_CONFIG_NAME)
resp["ProductionVariants"].should.equal(TEST_PRODUCTION_VARIANTS)
resp["ProductionVariants"].should.equal(production_variants)
@mock_sagemaker
def test_create_endpoint_config(sagemaker_client):
with pytest.raises(ClientError) as e:
sagemaker_client.create_endpoint_config(
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME,
ProductionVariants=TEST_PRODUCTION_VARIANTS,
)
assert e.value.response["Error"]["Message"].startswith("Could not find model")
# Testing instance-based endpoint configuration
create_endpoint_config_helper(sagemaker_client, TEST_PRODUCTION_VARIANTS)
@mock_sagemaker
def test_create_endpoint_config_serverless(sagemaker_client):
with pytest.raises(ClientError) as e:
sagemaker_client.create_endpoint_config(
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME,
ProductionVariants=TEST_SERVERLESS_PRODUCTION_VARIANTS,
)
assert e.value.response["Error"]["Message"].startswith("Could not find model")
# Testing serverless endpoint configuration
create_endpoint_config_helper(sagemaker_client, TEST_SERVERLESS_PRODUCTION_VARIANTS)
@mock_sagemaker
@ -129,6 +160,26 @@ def test_create_endpoint_invalid_instance_type(sagemaker_client):
assert expected_message in e.value.response["Error"]["Message"]
@mock_sagemaker
def test_create_endpoint_invalid_memory_size(sagemaker_client):
_create_model(sagemaker_client, TEST_MODEL_NAME)
memory_size = 1111
production_variants = TEST_SERVERLESS_PRODUCTION_VARIANTS
production_variants[0]["ServerlessConfig"]["MemorySizeInMB"] = memory_size
with pytest.raises(ClientError) as e:
sagemaker_client.create_endpoint_config(
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME,
ProductionVariants=production_variants,
)
assert e.value.response["Error"]["Code"] == "ValidationException"
expected_message = "Value '{}' at 'MemorySizeInMB' failed to satisfy constraint: Member must satisfy enum value set: [".format(
memory_size
)
assert expected_message in e.value.response["Error"]["Message"]
@mock_sagemaker
def test_create_endpoint(sagemaker_client):
with pytest.raises(ClientError) as e: