diff --git a/moto/sagemaker/models.py b/moto/sagemaker/models.py index e4c9c2523..0f2778c9c 100644 --- a/moto/sagemaker/models.py +++ b/moto/sagemaker/models.py @@ -398,7 +398,25 @@ class FakeEndpointConfig(BaseObject, CloudFormationModel): def validate_production_variants(self, production_variants): for production_variant in production_variants: - self.validate_instance_type(production_variant["InstanceType"]) + if "InstanceType" in production_variant.keys(): + self.validate_instance_type(production_variant["InstanceType"]) + elif "ServerlessConfig" in production_variant.keys(): + self.validate_serverless_config(production_variant["ServerlessConfig"]) + else: + message = "Invalid Keys for ProductionVariant: received {} but expected it to contain one of {}".format( + production_variant.keys(), ["InstanceType", "ServerlessConfig"] + ) + raise ValidationError(message=message) + + def validate_serverless_config(self, serverless_config): + VALID_SERVERLESS_MEMORY_SIZE = [1024, 2048, 3072, 4096, 5120, 6144] + if not validators.is_one_of( + serverless_config["MemorySizeInMB"], VALID_SERVERLESS_MEMORY_SIZE + ): + message = "Value '{}' at 'MemorySizeInMB' failed to satisfy constraint: Member must satisfy enum value set: {}".format( + serverless_config["MemorySizeInMB"], VALID_SERVERLESS_MEMORY_SIZE + ) + raise ValidationError(message=message) def validate_instance_type(self, instance_type): VALID_INSTANCE_TYPES = [ diff --git a/tests/test_sagemaker/test_sagemaker_endpoint.py b/tests/test_sagemaker/test_sagemaker_endpoint.py index 5f73ab819..f49918b14 100644 --- a/tests/test_sagemaker/test_sagemaker_endpoint.py +++ b/tests/test_sagemaker/test_sagemaker_endpoint.py @@ -20,6 +20,8 @@ TEST_ENDPOINT_NAME = "MyEndpoint" TEST_ENDPOINT_CONFIG_NAME = "MyEndpointConfig" TEST_VARIANT_NAME = "MyProductionVariant" TEST_INSTANCE_TYPE = "ml.t2.medium" +TEST_MEMORY_SIZE = 1024 +TEST_CONCURRENCY = 10 TEST_PRODUCTION_VARIANTS = [ { "VariantName": TEST_VARIANT_NAME, @@ -28,6 +30,16 @@ TEST_PRODUCTION_VARIANTS = [ "InstanceType": TEST_INSTANCE_TYPE, }, ] +TEST_SERVERLESS_PRODUCTION_VARIANTS = [ + { + "VariantName": TEST_VARIANT_NAME, + "ModelName": TEST_MODEL_NAME, + "ServerlessConfig": { + "MemorySizeInMB": TEST_MEMORY_SIZE, + "MaxConcurrency": TEST_CONCURRENCY, + }, + }, +] @pytest.fixture @@ -35,19 +47,12 @@ def sagemaker_client(): return boto3.client("sagemaker", region_name=TEST_REGION_NAME) -@mock_sagemaker -def test_create_endpoint_config(sagemaker_client): - with pytest.raises(ClientError) as e: - sagemaker_client.create_endpoint_config( - EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME, - ProductionVariants=TEST_PRODUCTION_VARIANTS, - ) - assert e.value.response["Error"]["Message"].startswith("Could not find model") - +def create_endpoint_config_helper(sagemaker_client, production_variants): _create_model(sagemaker_client, TEST_MODEL_NAME) + resp = sagemaker_client.create_endpoint_config( EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME, - ProductionVariants=TEST_PRODUCTION_VARIANTS, + ProductionVariants=production_variants, ) resp["EndpointConfigArn"].should.match( r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format( @@ -64,7 +69,33 @@ def test_create_endpoint_config(sagemaker_client): ) ) resp["EndpointConfigName"].should.equal(TEST_ENDPOINT_CONFIG_NAME) - resp["ProductionVariants"].should.equal(TEST_PRODUCTION_VARIANTS) + resp["ProductionVariants"].should.equal(production_variants) + + +@mock_sagemaker +def test_create_endpoint_config(sagemaker_client): + with pytest.raises(ClientError) as e: + sagemaker_client.create_endpoint_config( + EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME, + ProductionVariants=TEST_PRODUCTION_VARIANTS, + ) + assert e.value.response["Error"]["Message"].startswith("Could not find model") + + # Testing instance-based endpoint configuration + create_endpoint_config_helper(sagemaker_client, TEST_PRODUCTION_VARIANTS) + + +@mock_sagemaker +def test_create_endpoint_config_serverless(sagemaker_client): + with pytest.raises(ClientError) as e: + sagemaker_client.create_endpoint_config( + EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME, + ProductionVariants=TEST_SERVERLESS_PRODUCTION_VARIANTS, + ) + assert e.value.response["Error"]["Message"].startswith("Could not find model") + + # Testing serverless endpoint configuration + create_endpoint_config_helper(sagemaker_client, TEST_SERVERLESS_PRODUCTION_VARIANTS) @mock_sagemaker @@ -129,6 +160,26 @@ def test_create_endpoint_invalid_instance_type(sagemaker_client): assert expected_message in e.value.response["Error"]["Message"] +@mock_sagemaker +def test_create_endpoint_invalid_memory_size(sagemaker_client): + _create_model(sagemaker_client, TEST_MODEL_NAME) + + memory_size = 1111 + production_variants = TEST_SERVERLESS_PRODUCTION_VARIANTS + production_variants[0]["ServerlessConfig"]["MemorySizeInMB"] = memory_size + + with pytest.raises(ClientError) as e: + sagemaker_client.create_endpoint_config( + EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME, + ProductionVariants=production_variants, + ) + assert e.value.response["Error"]["Code"] == "ValidationException" + expected_message = "Value '{}' at 'MemorySizeInMB' failed to satisfy constraint: Member must satisfy enum value set: [".format( + memory_size + ) + assert expected_message in e.value.response["Error"]["Message"] + + @mock_sagemaker def test_create_endpoint(sagemaker_client): with pytest.raises(ClientError) as e: