Sagemaker: Add support for Serverless Endpoint Configurations (#5445)
This commit is contained in:
parent
47a358fc35
commit
6f3c391812
@ -398,7 +398,25 @@ class FakeEndpointConfig(BaseObject, CloudFormationModel):
|
||||
|
||||
def validate_production_variants(self, production_variants):
|
||||
for production_variant in production_variants:
|
||||
self.validate_instance_type(production_variant["InstanceType"])
|
||||
if "InstanceType" in production_variant.keys():
|
||||
self.validate_instance_type(production_variant["InstanceType"])
|
||||
elif "ServerlessConfig" in production_variant.keys():
|
||||
self.validate_serverless_config(production_variant["ServerlessConfig"])
|
||||
else:
|
||||
message = "Invalid Keys for ProductionVariant: received {} but expected it to contain one of {}".format(
|
||||
production_variant.keys(), ["InstanceType", "ServerlessConfig"]
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
|
||||
def validate_serverless_config(self, serverless_config):
|
||||
VALID_SERVERLESS_MEMORY_SIZE = [1024, 2048, 3072, 4096, 5120, 6144]
|
||||
if not validators.is_one_of(
|
||||
serverless_config["MemorySizeInMB"], VALID_SERVERLESS_MEMORY_SIZE
|
||||
):
|
||||
message = "Value '{}' at 'MemorySizeInMB' failed to satisfy constraint: Member must satisfy enum value set: {}".format(
|
||||
serverless_config["MemorySizeInMB"], VALID_SERVERLESS_MEMORY_SIZE
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
|
||||
def validate_instance_type(self, instance_type):
|
||||
VALID_INSTANCE_TYPES = [
|
||||
|
@ -20,6 +20,8 @@ TEST_ENDPOINT_NAME = "MyEndpoint"
|
||||
TEST_ENDPOINT_CONFIG_NAME = "MyEndpointConfig"
|
||||
TEST_VARIANT_NAME = "MyProductionVariant"
|
||||
TEST_INSTANCE_TYPE = "ml.t2.medium"
|
||||
TEST_MEMORY_SIZE = 1024
|
||||
TEST_CONCURRENCY = 10
|
||||
TEST_PRODUCTION_VARIANTS = [
|
||||
{
|
||||
"VariantName": TEST_VARIANT_NAME,
|
||||
@ -28,6 +30,16 @@ TEST_PRODUCTION_VARIANTS = [
|
||||
"InstanceType": TEST_INSTANCE_TYPE,
|
||||
},
|
||||
]
|
||||
TEST_SERVERLESS_PRODUCTION_VARIANTS = [
|
||||
{
|
||||
"VariantName": TEST_VARIANT_NAME,
|
||||
"ModelName": TEST_MODEL_NAME,
|
||||
"ServerlessConfig": {
|
||||
"MemorySizeInMB": TEST_MEMORY_SIZE,
|
||||
"MaxConcurrency": TEST_CONCURRENCY,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -35,19 +47,12 @@ def sagemaker_client():
|
||||
return boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_endpoint_config(sagemaker_client):
|
||||
with pytest.raises(ClientError) as e:
|
||||
sagemaker_client.create_endpoint_config(
|
||||
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME,
|
||||
ProductionVariants=TEST_PRODUCTION_VARIANTS,
|
||||
)
|
||||
assert e.value.response["Error"]["Message"].startswith("Could not find model")
|
||||
|
||||
def create_endpoint_config_helper(sagemaker_client, production_variants):
|
||||
_create_model(sagemaker_client, TEST_MODEL_NAME)
|
||||
|
||||
resp = sagemaker_client.create_endpoint_config(
|
||||
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME,
|
||||
ProductionVariants=TEST_PRODUCTION_VARIANTS,
|
||||
ProductionVariants=production_variants,
|
||||
)
|
||||
resp["EndpointConfigArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(
|
||||
@ -64,7 +69,33 @@ def test_create_endpoint_config(sagemaker_client):
|
||||
)
|
||||
)
|
||||
resp["EndpointConfigName"].should.equal(TEST_ENDPOINT_CONFIG_NAME)
|
||||
resp["ProductionVariants"].should.equal(TEST_PRODUCTION_VARIANTS)
|
||||
resp["ProductionVariants"].should.equal(production_variants)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_endpoint_config(sagemaker_client):
|
||||
with pytest.raises(ClientError) as e:
|
||||
sagemaker_client.create_endpoint_config(
|
||||
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME,
|
||||
ProductionVariants=TEST_PRODUCTION_VARIANTS,
|
||||
)
|
||||
assert e.value.response["Error"]["Message"].startswith("Could not find model")
|
||||
|
||||
# Testing instance-based endpoint configuration
|
||||
create_endpoint_config_helper(sagemaker_client, TEST_PRODUCTION_VARIANTS)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_endpoint_config_serverless(sagemaker_client):
|
||||
with pytest.raises(ClientError) as e:
|
||||
sagemaker_client.create_endpoint_config(
|
||||
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME,
|
||||
ProductionVariants=TEST_SERVERLESS_PRODUCTION_VARIANTS,
|
||||
)
|
||||
assert e.value.response["Error"]["Message"].startswith("Could not find model")
|
||||
|
||||
# Testing serverless endpoint configuration
|
||||
create_endpoint_config_helper(sagemaker_client, TEST_SERVERLESS_PRODUCTION_VARIANTS)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
@ -129,6 +160,26 @@ def test_create_endpoint_invalid_instance_type(sagemaker_client):
|
||||
assert expected_message in e.value.response["Error"]["Message"]
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_endpoint_invalid_memory_size(sagemaker_client):
|
||||
_create_model(sagemaker_client, TEST_MODEL_NAME)
|
||||
|
||||
memory_size = 1111
|
||||
production_variants = TEST_SERVERLESS_PRODUCTION_VARIANTS
|
||||
production_variants[0]["ServerlessConfig"]["MemorySizeInMB"] = memory_size
|
||||
|
||||
with pytest.raises(ClientError) as e:
|
||||
sagemaker_client.create_endpoint_config(
|
||||
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME,
|
||||
ProductionVariants=production_variants,
|
||||
)
|
||||
assert e.value.response["Error"]["Code"] == "ValidationException"
|
||||
expected_message = "Value '{}' at 'MemorySizeInMB' failed to satisfy constraint: Member must satisfy enum value set: [".format(
|
||||
memory_size
|
||||
)
|
||||
assert expected_message in e.value.response["Error"]["Message"]
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_endpoint(sagemaker_client):
|
||||
with pytest.raises(ClientError) as e:
|
||||
|
Loading…
Reference in New Issue
Block a user