From eb4989111817015dd24e9e8742387f3cca31ab15 Mon Sep 17 00:00:00 2001 From: Killian O'Daly Date: Wed, 4 May 2022 10:36:46 +0100 Subject: [PATCH] Sagemaker: add support for update_endpoint_weights_and_capacities (#5082) --- IMPLEMENTATION_COVERAGE.md | 2 +- docs/docs/services/sagemaker.rst | 2 +- moto/sagemaker/models.py | 92 +++++- moto/sagemaker/responses.py | 11 + .../test_sagemaker/test_sagemaker_endpoint.py | 278 +++++++++++++++++- 5 files changed, 366 insertions(+), 19 deletions(-) diff --git a/IMPLEMENTATION_COVERAGE.md b/IMPLEMENTATION_COVERAGE.md index 303eeddaf..ea4bf0c41 100644 --- a/IMPLEMENTATION_COVERAGE.md +++ b/IMPLEMENTATION_COVERAGE.md @@ -5142,7 +5142,7 @@ - [ ] update_devices - [ ] update_domain - [ ] update_endpoint -- [ ] update_endpoint_weights_and_capacities +- [X] update_endpoint_weights_and_capacities - [ ] update_experiment - [ ] update_image - [ ] update_model_package diff --git a/docs/docs/services/sagemaker.rst b/docs/docs/services/sagemaker.rst index 681aeb497..a5089f1a8 100644 --- a/docs/docs/services/sagemaker.rst +++ b/docs/docs/services/sagemaker.rst @@ -260,7 +260,7 @@ sagemaker - [ ] update_devices - [ ] update_domain - [ ] update_endpoint -- [ ] update_endpoint_weights_and_capacities +- [X] update_endpoint_weights_and_capacities - [ ] update_experiment - [ ] update_image - [ ] update_model_package diff --git a/moto/sagemaker/models.py b/moto/sagemaker/models.py index 8adb8779f..59b8f5052 100644 --- a/moto/sagemaker/models.py +++ b/moto/sagemaker/models.py @@ -249,7 +249,9 @@ class FakeEndpoint(BaseObject, CloudFormationModel): self.endpoint_name = endpoint_name self.endpoint_arn = FakeEndpoint.arn_formatter(endpoint_name, region_name) self.endpoint_config_name = endpoint_config_name - self.production_variants = production_variants + self.production_variants = self._process_production_variants( + production_variants + ) self.data_capture_config = data_capture_config self.tags = tags or [] self.endpoint_status = "InService" @@ -258,6 +260,42 @@ class FakeEndpoint(BaseObject, CloudFormationModel): "%Y-%m-%d %H:%M:%S" ) + def _process_production_variants(self, production_variants): + endpoint_variants = [] + for production_variant in production_variants: + temp_variant = {} + + # VariantName is the only required param + temp_variant["VariantName"] = production_variant["VariantName"] + + if production_variant.get("InitialInstanceCount", None): + temp_variant["CurrentInstanceCount"] = production_variant[ + "InitialInstanceCount" + ] + temp_variant["DesiredInstanceCount"] = production_variant[ + "InitialInstanceCount" + ] + + if production_variant.get("InitialVariantWeight", None): + temp_variant["CurrentWeight"] = production_variant[ + "InitialVariantWeight" + ] + temp_variant["DesiredWeight"] = production_variant[ + "InitialVariantWeight" + ] + + if production_variant.get("ServerlessConfig", None): + temp_variant["CurrentServerlessConfig"] = production_variant[ + "ServerlessConfig" + ] + temp_variant["DesiredServerlessConfig"] = production_variant[ + "ServerlessConfig" + ] + + endpoint_variants.append(temp_variant) + + return endpoint_variants + @property def response_object(self): response_object = self.gen_response_object() @@ -1607,7 +1645,7 @@ class SageMakerModelBackend(BaseBackend): try: return self.endpoints[endpoint_name].response_object except KeyError: - message = "Could not find endpoint configuration '{}'.".format( + message = "Could not find endpoint '{}'.".format( FakeEndpoint.arn_formatter(endpoint_name, self.region_name) ) raise ValidationError(message=message) @@ -1616,7 +1654,7 @@ class SageMakerModelBackend(BaseBackend): try: del self.endpoints[endpoint_name] except KeyError: - message = "Could not find endpoint configuration '{}'.".format( + message = "Could not find endpoint '{}'.".format( FakeEndpoint.arn_formatter(endpoint_name, self.region_name) ) raise ValidationError(message=message) @@ -1890,6 +1928,54 @@ class SageMakerModelBackend(BaseBackend): "NextToken": str(next_index) if next_index is not None else None, } + def update_endpoint_weights_and_capacities( + self, endpoint_name, desired_weights_and_capacities + ): + # Validate inputs + endpoint = self.endpoints.get(endpoint_name, None) + if not endpoint: + raise AWSValidationException( + f'Could not find endpoint "{FakeEndpoint.arn_formatter(endpoint_name, self.region_name)}".' + ) + + names_checked = [] + for variant_config in desired_weights_and_capacities: + name = variant_config.get("VariantName") + + if name in names_checked: + raise AWSValidationException( + f'The variant name "{name}" was non-unique within the request.' + ) + + if not any( + variant["VariantName"] == name + for variant in endpoint.production_variants + ): + raise AWSValidationException( + f'The variant name(s) "{name}" is/are not present within endpoint configuration "{endpoint.endpoint_config_name}".' + ) + + names_checked.append(name) + + # Update endpoint variants + endpoint.endpoint_status = "Updating" + + for variant_config in desired_weights_and_capacities: + name = variant_config.get("VariantName") + desired_weight = variant_config.get("DesiredWeight") + desired_instance_count = variant_config.get("DesiredInstanceCount") + + for variant in endpoint.production_variants: + if variant.get("VariantName") == name: + variant["DesiredWeight"] = desired_weight + variant["CurrentWeight"] = desired_weight + variant["DesiredInstanceCount"] = desired_instance_count + variant["CurrentInstanceCount"] = desired_instance_count + break + + endpoint.endpoint_status = "InService" + return endpoint.endpoint_arn + class FakeExperiment(BaseObject): def __init__(self, region_name, experiment_name, tags): diff --git a/moto/sagemaker/responses.py b/moto/sagemaker/responses.py index 78f268e53..b4be4cd0e 100644 --- a/moto/sagemaker/responses.py +++ b/moto/sagemaker/responses.py @@ -1,4 +1,5 @@ import json + from moto.sagemaker.exceptions import AWSValidationException from moto.core.responses import BaseResponse @@ -576,3 +577,13 @@ class SageMakerResponse(BaseResponse): status_equals=status_equals, ) return 200, {}, json.dumps(response) + + def update_endpoint_weights_and_capacities(self): + endpoint_name = self._get_param("EndpointName") + desired_weights_and_capacities = self._get_param("DesiredWeightsAndCapacities") + endpoint_arn = self.sagemaker_backend.update_endpoint_weights_and_capacities( + endpoint_name=endpoint_name, + desired_weights_and_capacities=desired_weights_and_capacities, + ) + response = {"EndpointArn": endpoint_arn} + return 200, {}, json.dumps(response) diff --git a/tests/test_sagemaker/test_sagemaker_endpoint.py b/tests/test_sagemaker/test_sagemaker_endpoint.py index 4a1b30218..dd68f86ee 100644 --- a/tests/test_sagemaker/test_sagemaker_endpoint.py +++ b/tests/test_sagemaker/test_sagemaker_endpoint.py @@ -18,12 +18,14 @@ GENERIC_TAGS_PARAM = [ TEST_MODEL_NAME = "MyModel" TEST_ENDPOINT_NAME = "MyEndpoint" TEST_ENDPOINT_CONFIG_NAME = "MyEndpointConfig" +TEST_VARIANT_NAME = "MyProductionVariant" +TEST_INSTANCE_TYPE = "ml.t2.medium" TEST_PRODUCTION_VARIANTS = [ { - "VariantName": "MyProductionVariant", + "VariantName": TEST_VARIANT_NAME, "ModelName": TEST_MODEL_NAME, "InitialInstanceCount": 1, - "InstanceType": "ml.t2.medium", + "InstanceType": TEST_INSTANCE_TYPE, }, ] @@ -162,7 +164,7 @@ def test_create_endpoint(sagemaker_client): resp["EndpointStatus"].should.equal("InService") assert isinstance(resp["CreationTime"], datetime.datetime) assert isinstance(resp["LastModifiedTime"], datetime.datetime) - resp["ProductionVariants"][0]["VariantName"].should.equal("MyProductionVariant") + resp["ProductionVariants"][0]["VariantName"].should.equal(TEST_VARIANT_NAME) resp = sagemaker_client.list_tags(ResourceArn=resp["EndpointArn"]) assert resp["Tags"] == GENERIC_TAGS_PARAM @@ -245,11 +247,256 @@ def test_list_tags_endpoint(sagemaker_client): assert response["Tags"] == tags[50:] +@mock_sagemaker +def test_update_endpoint_weights_and_capacities_one_variant(sagemaker_client): + _set_up_sagemaker_resources( + sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME + ) + + new_desired_weight = 1.5 + new_desired_instance_count = 123 + + response = sagemaker_client.update_endpoint_weights_and_capacities( + EndpointName=TEST_ENDPOINT_NAME, + DesiredWeightsAndCapacities=[ + { + "VariantName": TEST_VARIANT_NAME, + "DesiredWeight": new_desired_weight, + "DesiredInstanceCount": new_desired_instance_count, + }, + ], + ) + response["EndpointArn"].should.match( + r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(TEST_ENDPOINT_NAME) + ) + + resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME) + resp["EndpointArn"].should.match( + r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(TEST_ENDPOINT_NAME) + ) + resp["EndpointName"].should.equal(TEST_ENDPOINT_NAME) + resp["EndpointConfigName"].should.equal(TEST_ENDPOINT_CONFIG_NAME) + resp["EndpointStatus"].should.equal("InService") + assert isinstance(resp["CreationTime"], datetime.datetime) + assert isinstance(resp["LastModifiedTime"], datetime.datetime) + + resp["ProductionVariants"][0]["VariantName"].should.equal(TEST_VARIANT_NAME) + resp["ProductionVariants"][0]["DesiredInstanceCount"].should.equal( + new_desired_instance_count + ) + resp["ProductionVariants"][0]["CurrentInstanceCount"].should.equal( + new_desired_instance_count + ) + resp["ProductionVariants"][0]["DesiredWeight"].should.equal(new_desired_weight) + resp["ProductionVariants"][0]["CurrentWeight"].should.equal(new_desired_weight) + + +@mock_sagemaker +def test_update_endpoint_weights_and_capacities_two_variants(sagemaker_client): + production_variants = [ + { + "VariantName": "MyProductionVariant1", + "ModelName": TEST_MODEL_NAME, + "InitialInstanceCount": 1, + "InstanceType": TEST_INSTANCE_TYPE, + }, + { + "VariantName": "MyProductionVariant2", + "ModelName": TEST_MODEL_NAME, + "InitialInstanceCount": 1, + "InstanceType": TEST_INSTANCE_TYPE, + }, + ] + + _set_up_sagemaker_resources( + sagemaker_client, + TEST_ENDPOINT_NAME, + TEST_ENDPOINT_CONFIG_NAME, + TEST_MODEL_NAME, + production_variants, + ) + + desired_weights_and_capacities = [ + { + "VariantName": "MyProductionVariant1", + "DesiredWeight": 1.5, + "DesiredInstanceCount": 123, + }, + { + "VariantName": "MyProductionVariant2", + "DesiredWeight": 1.5, + "DesiredInstanceCount": 123, + }, + ] + + new_desired_weight = 1.5 + new_desired_instance_count = 123 + + response = sagemaker_client.update_endpoint_weights_and_capacities( + EndpointName=TEST_ENDPOINT_NAME, + DesiredWeightsAndCapacities=desired_weights_and_capacities, + ) + response["EndpointArn"].should.match( + r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(TEST_ENDPOINT_NAME) + ) + + resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME) + resp["EndpointArn"].should.match( + r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(TEST_ENDPOINT_NAME) + ) + resp["EndpointName"].should.equal(TEST_ENDPOINT_NAME) + resp["EndpointConfigName"].should.equal(TEST_ENDPOINT_CONFIG_NAME) + resp["EndpointStatus"].should.equal("InService") + assert isinstance(resp["CreationTime"], datetime.datetime) + assert isinstance(resp["LastModifiedTime"], datetime.datetime) + + resp["ProductionVariants"][0]["VariantName"].should.equal("MyProductionVariant1") + resp["ProductionVariants"][0]["DesiredInstanceCount"].should.equal( + new_desired_instance_count + ) + resp["ProductionVariants"][0]["CurrentInstanceCount"].should.equal( + new_desired_instance_count + ) + resp["ProductionVariants"][0]["DesiredWeight"].should.equal(new_desired_weight) + resp["ProductionVariants"][0]["CurrentWeight"].should.equal(new_desired_weight) + + resp["ProductionVariants"][1]["VariantName"].should.equal("MyProductionVariant2") + resp["ProductionVariants"][1]["DesiredInstanceCount"].should.equal( + new_desired_instance_count + ) + resp["ProductionVariants"][1]["CurrentInstanceCount"].should.equal( + new_desired_instance_count + ) + resp["ProductionVariants"][1]["DesiredWeight"].should.equal(new_desired_weight) + resp["ProductionVariants"][1]["CurrentWeight"].should.equal(new_desired_weight) + + +@mock_sagemaker +def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_variant( + sagemaker_client, +): + _set_up_sagemaker_resources( + sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME + ) + + old_resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME) + del old_resp["ResponseMetadata"] + + variant_name = "SillyNotCorrectName" + new_desired_weight = 1.5 + new_desired_instance_count = 123 + + with pytest.raises(ClientError) as exc: + sagemaker_client.update_endpoint_weights_and_capacities( + EndpointName=TEST_ENDPOINT_NAME, + DesiredWeightsAndCapacities=[ + { + "VariantName": variant_name, + "DesiredWeight": new_desired_weight, + "DesiredInstanceCount": new_desired_instance_count, + }, + ], + ) + + err = exc.value.response["Error"] + err["Message"].should.equal( + f'The variant name(s) "{variant_name}" is/are not present within endpoint configuration "{TEST_ENDPOINT_CONFIG_NAME}".' + ) + + resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME) + del resp["ResponseMetadata"] + resp.should.equal(old_resp) + + +@mock_sagemaker +def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_endpoint( + sagemaker_client, +): + _set_up_sagemaker_resources( + sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME + ) + + old_resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME) + del old_resp["ResponseMetadata"] + + endpoint_name = "SillyEndpointName" + variant_name = "SillyNotCorrectName" + new_desired_weight = 1.5 + new_desired_instance_count = 123 + + with pytest.raises(ClientError) as exc: + sagemaker_client.update_endpoint_weights_and_capacities( + EndpointName=endpoint_name, + DesiredWeightsAndCapacities=[ + { + "VariantName": variant_name, + "DesiredWeight": new_desired_weight, + "DesiredInstanceCount": new_desired_instance_count, + }, + ], + ) + + err = exc.value.response["Error"] + err["Message"].should.equal( + f'Could not find endpoint "arn:aws:sagemaker:us-east-1:{ACCOUNT_ID}:endpoint/{endpoint_name}".' + ) + + resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME) + del resp["ResponseMetadata"] + resp.should.equal(old_resp) + + +@mock_sagemaker +def test_update_endpoint_weights_and_capacities_should_throw_clienterror_nonunique_variant( + sagemaker_client, +): + _set_up_sagemaker_resources( + sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME + ) + + old_resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME) + del old_resp["ResponseMetadata"] + + desired_weights_and_capacities = [ + { + "VariantName": TEST_VARIANT_NAME, + "DesiredWeight": 1.5, + "DesiredInstanceCount": 123, + }, + { + "VariantName": TEST_VARIANT_NAME, + "DesiredWeight": 1.5, + "DesiredInstanceCount": 123, + }, + ] + + with pytest.raises(ClientError) as exc: + sagemaker_client.update_endpoint_weights_and_capacities( + EndpointName=TEST_ENDPOINT_NAME, + DesiredWeightsAndCapacities=desired_weights_and_capacities, + ) + + err = exc.value.response["Error"] + err["Message"].should.equal( + f'The variant name "{TEST_VARIANT_NAME}" was non-unique within the request.' + ) + + resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME) + del resp["ResponseMetadata"] + resp.should.equal(old_resp) + + def _set_up_sagemaker_resources( - boto_client, endpoint_name, endpoint_config_name, model_name + boto_client, + endpoint_name, + endpoint_config_name, + model_name, + production_variants=None, ): _create_model(boto_client, model_name) - _create_endpoint_config(boto_client, endpoint_config_name, model_name) + _create_endpoint_config( + boto_client, endpoint_config_name, model_name, production_variants + ) _create_endpoint(boto_client, endpoint_name, endpoint_config_name) @@ -265,15 +512,18 @@ def _create_model(boto_client, model_name): assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 -def _create_endpoint_config(boto_client, endpoint_config_name, model_name): - production_variants = [ - { - "VariantName": "MyProductionVariant", - "ModelName": model_name, - "InitialInstanceCount": 1, - "InstanceType": "ml.t2.medium", - }, - ] +def _create_endpoint_config( + boto_client, endpoint_config_name, model_name, production_variants=None +): + if not production_variants: + production_variants = [ + { + "VariantName": TEST_VARIANT_NAME, + "ModelName": model_name, + "InitialInstanceCount": 1, + "InstanceType": TEST_INSTANCE_TYPE, + }, + ] resp = boto_client.create_endpoint_config( EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants )