Sagemaker: Add support for Serverless Endpoint Configurations (#5445)
This commit is contained in:
parent
47a358fc35
commit
6f3c391812
@ -398,7 +398,25 @@ class FakeEndpointConfig(BaseObject, CloudFormationModel):
|
|||||||
|
|
||||||
def validate_production_variants(self, production_variants):
|
def validate_production_variants(self, production_variants):
|
||||||
for production_variant in production_variants:
|
for production_variant in production_variants:
|
||||||
self.validate_instance_type(production_variant["InstanceType"])
|
if "InstanceType" in production_variant.keys():
|
||||||
|
self.validate_instance_type(production_variant["InstanceType"])
|
||||||
|
elif "ServerlessConfig" in production_variant.keys():
|
||||||
|
self.validate_serverless_config(production_variant["ServerlessConfig"])
|
||||||
|
else:
|
||||||
|
message = "Invalid Keys for ProductionVariant: received {} but expected it to contain one of {}".format(
|
||||||
|
production_variant.keys(), ["InstanceType", "ServerlessConfig"]
|
||||||
|
)
|
||||||
|
raise ValidationError(message=message)
|
||||||
|
|
||||||
|
def validate_serverless_config(self, serverless_config):
|
||||||
|
VALID_SERVERLESS_MEMORY_SIZE = [1024, 2048, 3072, 4096, 5120, 6144]
|
||||||
|
if not validators.is_one_of(
|
||||||
|
serverless_config["MemorySizeInMB"], VALID_SERVERLESS_MEMORY_SIZE
|
||||||
|
):
|
||||||
|
message = "Value '{}' at 'MemorySizeInMB' failed to satisfy constraint: Member must satisfy enum value set: {}".format(
|
||||||
|
serverless_config["MemorySizeInMB"], VALID_SERVERLESS_MEMORY_SIZE
|
||||||
|
)
|
||||||
|
raise ValidationError(message=message)
|
||||||
|
|
||||||
def validate_instance_type(self, instance_type):
|
def validate_instance_type(self, instance_type):
|
||||||
VALID_INSTANCE_TYPES = [
|
VALID_INSTANCE_TYPES = [
|
||||||
|
@ -20,6 +20,8 @@ TEST_ENDPOINT_NAME = "MyEndpoint"
|
|||||||
TEST_ENDPOINT_CONFIG_NAME = "MyEndpointConfig"
|
TEST_ENDPOINT_CONFIG_NAME = "MyEndpointConfig"
|
||||||
TEST_VARIANT_NAME = "MyProductionVariant"
|
TEST_VARIANT_NAME = "MyProductionVariant"
|
||||||
TEST_INSTANCE_TYPE = "ml.t2.medium"
|
TEST_INSTANCE_TYPE = "ml.t2.medium"
|
||||||
|
TEST_MEMORY_SIZE = 1024
|
||||||
|
TEST_CONCURRENCY = 10
|
||||||
TEST_PRODUCTION_VARIANTS = [
|
TEST_PRODUCTION_VARIANTS = [
|
||||||
{
|
{
|
||||||
"VariantName": TEST_VARIANT_NAME,
|
"VariantName": TEST_VARIANT_NAME,
|
||||||
@ -28,6 +30,16 @@ TEST_PRODUCTION_VARIANTS = [
|
|||||||
"InstanceType": TEST_INSTANCE_TYPE,
|
"InstanceType": TEST_INSTANCE_TYPE,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
TEST_SERVERLESS_PRODUCTION_VARIANTS = [
|
||||||
|
{
|
||||||
|
"VariantName": TEST_VARIANT_NAME,
|
||||||
|
"ModelName": TEST_MODEL_NAME,
|
||||||
|
"ServerlessConfig": {
|
||||||
|
"MemorySizeInMB": TEST_MEMORY_SIZE,
|
||||||
|
"MaxConcurrency": TEST_CONCURRENCY,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -35,19 +47,12 @@ def sagemaker_client():
|
|||||||
return boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
return boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||||
|
|
||||||
|
|
||||||
@mock_sagemaker
|
def create_endpoint_config_helper(sagemaker_client, production_variants):
|
||||||
def test_create_endpoint_config(sagemaker_client):
|
|
||||||
with pytest.raises(ClientError) as e:
|
|
||||||
sagemaker_client.create_endpoint_config(
|
|
||||||
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME,
|
|
||||||
ProductionVariants=TEST_PRODUCTION_VARIANTS,
|
|
||||||
)
|
|
||||||
assert e.value.response["Error"]["Message"].startswith("Could not find model")
|
|
||||||
|
|
||||||
_create_model(sagemaker_client, TEST_MODEL_NAME)
|
_create_model(sagemaker_client, TEST_MODEL_NAME)
|
||||||
|
|
||||||
resp = sagemaker_client.create_endpoint_config(
|
resp = sagemaker_client.create_endpoint_config(
|
||||||
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME,
|
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME,
|
||||||
ProductionVariants=TEST_PRODUCTION_VARIANTS,
|
ProductionVariants=production_variants,
|
||||||
)
|
)
|
||||||
resp["EndpointConfigArn"].should.match(
|
resp["EndpointConfigArn"].should.match(
|
||||||
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(
|
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(
|
||||||
@ -64,7 +69,33 @@ def test_create_endpoint_config(sagemaker_client):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
resp["EndpointConfigName"].should.equal(TEST_ENDPOINT_CONFIG_NAME)
|
resp["EndpointConfigName"].should.equal(TEST_ENDPOINT_CONFIG_NAME)
|
||||||
resp["ProductionVariants"].should.equal(TEST_PRODUCTION_VARIANTS)
|
resp["ProductionVariants"].should.equal(production_variants)
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_create_endpoint_config(sagemaker_client):
|
||||||
|
with pytest.raises(ClientError) as e:
|
||||||
|
sagemaker_client.create_endpoint_config(
|
||||||
|
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME,
|
||||||
|
ProductionVariants=TEST_PRODUCTION_VARIANTS,
|
||||||
|
)
|
||||||
|
assert e.value.response["Error"]["Message"].startswith("Could not find model")
|
||||||
|
|
||||||
|
# Testing instance-based endpoint configuration
|
||||||
|
create_endpoint_config_helper(sagemaker_client, TEST_PRODUCTION_VARIANTS)
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_create_endpoint_config_serverless(sagemaker_client):
|
||||||
|
with pytest.raises(ClientError) as e:
|
||||||
|
sagemaker_client.create_endpoint_config(
|
||||||
|
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME,
|
||||||
|
ProductionVariants=TEST_SERVERLESS_PRODUCTION_VARIANTS,
|
||||||
|
)
|
||||||
|
assert e.value.response["Error"]["Message"].startswith("Could not find model")
|
||||||
|
|
||||||
|
# Testing serverless endpoint configuration
|
||||||
|
create_endpoint_config_helper(sagemaker_client, TEST_SERVERLESS_PRODUCTION_VARIANTS)
|
||||||
|
|
||||||
|
|
||||||
@mock_sagemaker
|
@mock_sagemaker
|
||||||
@ -129,6 +160,26 @@ def test_create_endpoint_invalid_instance_type(sagemaker_client):
|
|||||||
assert expected_message in e.value.response["Error"]["Message"]
|
assert expected_message in e.value.response["Error"]["Message"]
|
||||||
|
|
||||||
|
|
||||||
|
@mock_sagemaker
|
||||||
|
def test_create_endpoint_invalid_memory_size(sagemaker_client):
|
||||||
|
_create_model(sagemaker_client, TEST_MODEL_NAME)
|
||||||
|
|
||||||
|
memory_size = 1111
|
||||||
|
production_variants = TEST_SERVERLESS_PRODUCTION_VARIANTS
|
||||||
|
production_variants[0]["ServerlessConfig"]["MemorySizeInMB"] = memory_size
|
||||||
|
|
||||||
|
with pytest.raises(ClientError) as e:
|
||||||
|
sagemaker_client.create_endpoint_config(
|
||||||
|
EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME,
|
||||||
|
ProductionVariants=production_variants,
|
||||||
|
)
|
||||||
|
assert e.value.response["Error"]["Code"] == "ValidationException"
|
||||||
|
expected_message = "Value '{}' at 'MemorySizeInMB' failed to satisfy constraint: Member must satisfy enum value set: [".format(
|
||||||
|
memory_size
|
||||||
|
)
|
||||||
|
assert expected_message in e.value.response["Error"]["Message"]
|
||||||
|
|
||||||
|
|
||||||
@mock_sagemaker
|
@mock_sagemaker
|
||||||
def test_create_endpoint(sagemaker_client):
|
def test_create_endpoint(sagemaker_client):
|
||||||
with pytest.raises(ClientError) as e:
|
with pytest.raises(ClientError) as e:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user