Sagemaker: add support for update_endpoint_weights_and_capacities (#5082)

This commit is contained in:
Killian O'Daly 2022-05-04 10:36:46 +01:00 committed by GitHub
parent 406ca1d8f9
commit eb49891118
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 366 additions and 19 deletions

View File

@ -5142,7 +5142,7 @@
- [ ] update_devices
- [ ] update_domain
- [ ] update_endpoint
- [ ] update_endpoint_weights_and_capacities
- [X] update_endpoint_weights_and_capacities
- [ ] update_experiment
- [ ] update_image
- [ ] update_model_package

View File

@ -260,7 +260,7 @@ sagemaker
- [ ] update_devices
- [ ] update_domain
- [ ] update_endpoint
- [ ] update_endpoint_weights_and_capacities
- [X] update_endpoint_weights_and_capacities
- [ ] update_experiment
- [ ] update_image
- [ ] update_model_package

View File

@ -249,7 +249,9 @@ class FakeEndpoint(BaseObject, CloudFormationModel):
self.endpoint_name = endpoint_name
self.endpoint_arn = FakeEndpoint.arn_formatter(endpoint_name, region_name)
self.endpoint_config_name = endpoint_config_name
self.production_variants = production_variants
self.production_variants = self._process_production_variants(
production_variants
)
self.data_capture_config = data_capture_config
self.tags = tags or []
self.endpoint_status = "InService"
@ -258,6 +260,42 @@ class FakeEndpoint(BaseObject, CloudFormationModel):
"%Y-%m-%d %H:%M:%S"
)
def _process_production_variants(self, production_variants):
endpoint_variants = []
for production_variant in production_variants:
temp_variant = {}
# VariantName is the only required param
temp_variant["VariantName"] = production_variant["VariantName"]
if production_variant.get("InitialInstanceCount", None):
temp_variant["CurrentInstanceCount"] = production_variant[
"InitialInstanceCount"
]
temp_variant["DesiredInstanceCount"] = production_variant[
"InitialInstanceCount"
]
if production_variant.get("InitialVariantWeight", None):
temp_variant["CurrentWeight"] = production_variant[
"InitialVariantWeight"
]
temp_variant["DesiredWeight"] = production_variant[
"InitialVariantWeight"
]
if production_variant.get("ServerlessConfig", None):
temp_variant["CurrentServerlessConfig"] = production_variant[
"ServerlessConfig"
]
temp_variant["DesiredServerlessConfig"] = production_variant[
"ServerlessConfig"
]
endpoint_variants.append(temp_variant)
return endpoint_variants
@property
def response_object(self):
response_object = self.gen_response_object()
@ -1607,7 +1645,7 @@ class SageMakerModelBackend(BaseBackend):
try:
return self.endpoints[endpoint_name].response_object
except KeyError:
message = "Could not find endpoint configuration '{}'.".format(
message = "Could not find endpoint '{}'.".format(
FakeEndpoint.arn_formatter(endpoint_name, self.region_name)
)
raise ValidationError(message=message)
@ -1616,7 +1654,7 @@ class SageMakerModelBackend(BaseBackend):
try:
del self.endpoints[endpoint_name]
except KeyError:
message = "Could not find endpoint configuration '{}'.".format(
message = "Could not find endpoint '{}'.".format(
FakeEndpoint.arn_formatter(endpoint_name, self.region_name)
)
raise ValidationError(message=message)
@ -1890,6 +1928,54 @@ class SageMakerModelBackend(BaseBackend):
"NextToken": str(next_index) if next_index is not None else None,
}
def update_endpoint_weights_and_capacities(
self, endpoint_name, desired_weights_and_capacities
):
# Validate inputs
endpoint = self.endpoints.get(endpoint_name, None)
if not endpoint:
raise AWSValidationException(
f'Could not find endpoint "{FakeEndpoint.arn_formatter(endpoint_name, self.region_name)}".'
)
names_checked = []
for variant_config in desired_weights_and_capacities:
name = variant_config.get("VariantName")
if name in names_checked:
raise AWSValidationException(
f'The variant name "{name}" was non-unique within the request.'
)
if not any(
variant["VariantName"] == name
for variant in endpoint.production_variants
):
raise AWSValidationException(
f'The variant name(s) "{name}" is/are not present within endpoint configuration "{endpoint.endpoint_config_name}".'
)
names_checked.append(name)
# Update endpoint variants
endpoint.endpoint_status = "Updating"
for variant_config in desired_weights_and_capacities:
name = variant_config.get("VariantName")
desired_weight = variant_config.get("DesiredWeight")
desired_instance_count = variant_config.get("DesiredInstanceCount")
for variant in endpoint.production_variants:
if variant.get("VariantName") == name:
variant["DesiredWeight"] = desired_weight
variant["CurrentWeight"] = desired_weight
variant["DesiredInstanceCount"] = desired_instance_count
variant["CurrentInstanceCount"] = desired_instance_count
break
endpoint.endpoint_status = "InService"
return endpoint.endpoint_arn
class FakeExperiment(BaseObject):
def __init__(self, region_name, experiment_name, tags):

View File

@ -1,4 +1,5 @@
import json
from moto.sagemaker.exceptions import AWSValidationException
from moto.core.responses import BaseResponse
@ -576,3 +577,13 @@ class SageMakerResponse(BaseResponse):
status_equals=status_equals,
)
return 200, {}, json.dumps(response)
def update_endpoint_weights_and_capacities(self):
endpoint_name = self._get_param("EndpointName")
desired_weights_and_capacities = self._get_param("DesiredWeightsAndCapacities")
endpoint_arn = self.sagemaker_backend.update_endpoint_weights_and_capacities(
endpoint_name=endpoint_name,
desired_weights_and_capacities=desired_weights_and_capacities,
)
response = {"EndpointArn": endpoint_arn}
return 200, {}, json.dumps(response)

View File

@ -18,12 +18,14 @@ GENERIC_TAGS_PARAM = [
TEST_MODEL_NAME = "MyModel"
TEST_ENDPOINT_NAME = "MyEndpoint"
TEST_ENDPOINT_CONFIG_NAME = "MyEndpointConfig"
TEST_VARIANT_NAME = "MyProductionVariant"
TEST_INSTANCE_TYPE = "ml.t2.medium"
TEST_PRODUCTION_VARIANTS = [
{
"VariantName": "MyProductionVariant",
"VariantName": TEST_VARIANT_NAME,
"ModelName": TEST_MODEL_NAME,
"InitialInstanceCount": 1,
"InstanceType": "ml.t2.medium",
"InstanceType": TEST_INSTANCE_TYPE,
},
]
@ -162,7 +164,7 @@ def test_create_endpoint(sagemaker_client):
resp["EndpointStatus"].should.equal("InService")
assert isinstance(resp["CreationTime"], datetime.datetime)
assert isinstance(resp["LastModifiedTime"], datetime.datetime)
resp["ProductionVariants"][0]["VariantName"].should.equal("MyProductionVariant")
resp["ProductionVariants"][0]["VariantName"].should.equal(TEST_VARIANT_NAME)
resp = sagemaker_client.list_tags(ResourceArn=resp["EndpointArn"])
assert resp["Tags"] == GENERIC_TAGS_PARAM
@ -245,11 +247,256 @@ def test_list_tags_endpoint(sagemaker_client):
assert response["Tags"] == tags[50:]
@mock_sagemaker
def test_update_endpoint_weights_and_capacities_one_variant(sagemaker_client):
_set_up_sagemaker_resources(
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
)
new_desired_weight = 1.5
new_desired_instance_count = 123
response = sagemaker_client.update_endpoint_weights_and_capacities(
EndpointName=TEST_ENDPOINT_NAME,
DesiredWeightsAndCapacities=[
{
"VariantName": TEST_VARIANT_NAME,
"DesiredWeight": new_desired_weight,
"DesiredInstanceCount": new_desired_instance_count,
},
],
)
response["EndpointArn"].should.match(
r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(TEST_ENDPOINT_NAME)
)
resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME)
resp["EndpointArn"].should.match(
r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(TEST_ENDPOINT_NAME)
)
resp["EndpointName"].should.equal(TEST_ENDPOINT_NAME)
resp["EndpointConfigName"].should.equal(TEST_ENDPOINT_CONFIG_NAME)
resp["EndpointStatus"].should.equal("InService")
assert isinstance(resp["CreationTime"], datetime.datetime)
assert isinstance(resp["LastModifiedTime"], datetime.datetime)
resp["ProductionVariants"][0]["VariantName"].should.equal(TEST_VARIANT_NAME)
resp["ProductionVariants"][0]["DesiredInstanceCount"].should.equal(
new_desired_instance_count
)
resp["ProductionVariants"][0]["CurrentInstanceCount"].should.equal(
new_desired_instance_count
)
resp["ProductionVariants"][0]["DesiredWeight"].should.equal(new_desired_weight)
resp["ProductionVariants"][0]["CurrentWeight"].should.equal(new_desired_weight)
@mock_sagemaker
def test_update_endpoint_weights_and_capacities_two_variants(sagemaker_client):
production_variants = [
{
"VariantName": "MyProductionVariant1",
"ModelName": TEST_MODEL_NAME,
"InitialInstanceCount": 1,
"InstanceType": TEST_INSTANCE_TYPE,
},
{
"VariantName": "MyProductionVariant2",
"ModelName": TEST_MODEL_NAME,
"InitialInstanceCount": 1,
"InstanceType": TEST_INSTANCE_TYPE,
},
]
_set_up_sagemaker_resources(
sagemaker_client,
TEST_ENDPOINT_NAME,
TEST_ENDPOINT_CONFIG_NAME,
TEST_MODEL_NAME,
production_variants,
)
desired_weights_and_capacities = [
{
"VariantName": "MyProductionVariant1",
"DesiredWeight": 1.5,
"DesiredInstanceCount": 123,
},
{
"VariantName": "MyProductionVariant2",
"DesiredWeight": 1.5,
"DesiredInstanceCount": 123,
},
]
new_desired_weight = 1.5
new_desired_instance_count = 123
response = sagemaker_client.update_endpoint_weights_and_capacities(
EndpointName=TEST_ENDPOINT_NAME,
DesiredWeightsAndCapacities=desired_weights_and_capacities,
)
response["EndpointArn"].should.match(
r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(TEST_ENDPOINT_NAME)
)
resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME)
resp["EndpointArn"].should.match(
r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(TEST_ENDPOINT_NAME)
)
resp["EndpointName"].should.equal(TEST_ENDPOINT_NAME)
resp["EndpointConfigName"].should.equal(TEST_ENDPOINT_CONFIG_NAME)
resp["EndpointStatus"].should.equal("InService")
assert isinstance(resp["CreationTime"], datetime.datetime)
assert isinstance(resp["LastModifiedTime"], datetime.datetime)
resp["ProductionVariants"][0]["VariantName"].should.equal("MyProductionVariant1")
resp["ProductionVariants"][0]["DesiredInstanceCount"].should.equal(
new_desired_instance_count
)
resp["ProductionVariants"][0]["CurrentInstanceCount"].should.equal(
new_desired_instance_count
)
resp["ProductionVariants"][0]["DesiredWeight"].should.equal(new_desired_weight)
resp["ProductionVariants"][0]["CurrentWeight"].should.equal(new_desired_weight)
resp["ProductionVariants"][1]["VariantName"].should.equal("MyProductionVariant2")
resp["ProductionVariants"][1]["DesiredInstanceCount"].should.equal(
new_desired_instance_count
)
resp["ProductionVariants"][1]["CurrentInstanceCount"].should.equal(
new_desired_instance_count
)
resp["ProductionVariants"][1]["DesiredWeight"].should.equal(new_desired_weight)
resp["ProductionVariants"][1]["CurrentWeight"].should.equal(new_desired_weight)
@mock_sagemaker
def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_variant(
sagemaker_client,
):
_set_up_sagemaker_resources(
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
)
old_resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME)
del old_resp["ResponseMetadata"]
variant_name = "SillyNotCorrectName"
new_desired_weight = 1.5
new_desired_instance_count = 123
with pytest.raises(ClientError) as exc:
sagemaker_client.update_endpoint_weights_and_capacities(
EndpointName=TEST_ENDPOINT_NAME,
DesiredWeightsAndCapacities=[
{
"VariantName": variant_name,
"DesiredWeight": new_desired_weight,
"DesiredInstanceCount": new_desired_instance_count,
},
],
)
err = exc.value.response["Error"]
err["Message"].should.equal(
f'The variant name(s) "{variant_name}" is/are not present within endpoint configuration "{TEST_ENDPOINT_CONFIG_NAME}".'
)
resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME)
del resp["ResponseMetadata"]
resp.should.equal(old_resp)
@mock_sagemaker
def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_endpoint(
sagemaker_client,
):
_set_up_sagemaker_resources(
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
)
old_resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME)
del old_resp["ResponseMetadata"]
endpoint_name = "SillyEndpointName"
variant_name = "SillyNotCorrectName"
new_desired_weight = 1.5
new_desired_instance_count = 123
with pytest.raises(ClientError) as exc:
sagemaker_client.update_endpoint_weights_and_capacities(
EndpointName=endpoint_name,
DesiredWeightsAndCapacities=[
{
"VariantName": variant_name,
"DesiredWeight": new_desired_weight,
"DesiredInstanceCount": new_desired_instance_count,
},
],
)
err = exc.value.response["Error"]
err["Message"].should.equal(
f'Could not find endpoint "arn:aws:sagemaker:us-east-1:{ACCOUNT_ID}:endpoint/{endpoint_name}".'
)
resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME)
del resp["ResponseMetadata"]
resp.should.equal(old_resp)
@mock_sagemaker
def test_update_endpoint_weights_and_capacities_should_throw_clienterror_nonunique_variant(
sagemaker_client,
):
_set_up_sagemaker_resources(
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
)
old_resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME)
del old_resp["ResponseMetadata"]
desired_weights_and_capacities = [
{
"VariantName": TEST_VARIANT_NAME,
"DesiredWeight": 1.5,
"DesiredInstanceCount": 123,
},
{
"VariantName": TEST_VARIANT_NAME,
"DesiredWeight": 1.5,
"DesiredInstanceCount": 123,
},
]
with pytest.raises(ClientError) as exc:
sagemaker_client.update_endpoint_weights_and_capacities(
EndpointName=TEST_ENDPOINT_NAME,
DesiredWeightsAndCapacities=desired_weights_and_capacities,
)
err = exc.value.response["Error"]
err["Message"].should.equal(
f'The variant name "{TEST_VARIANT_NAME}" was non-unique within the request.'
)
resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME)
del resp["ResponseMetadata"]
resp.should.equal(old_resp)
def _set_up_sagemaker_resources(
boto_client, endpoint_name, endpoint_config_name, model_name
boto_client,
endpoint_name,
endpoint_config_name,
model_name,
production_variants=None,
):
_create_model(boto_client, model_name)
_create_endpoint_config(boto_client, endpoint_config_name, model_name)
_create_endpoint_config(
boto_client, endpoint_config_name, model_name, production_variants
)
_create_endpoint(boto_client, endpoint_name, endpoint_config_name)
@ -265,15 +512,18 @@ def _create_model(boto_client, model_name):
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
def _create_endpoint_config(boto_client, endpoint_config_name, model_name):
production_variants = [
{
"VariantName": "MyProductionVariant",
"ModelName": model_name,
"InitialInstanceCount": 1,
"InstanceType": "ml.t2.medium",
},
]
def _create_endpoint_config(
boto_client, endpoint_config_name, model_name, production_variants=None
):
if not production_variants:
production_variants = [
{
"VariantName": TEST_VARIANT_NAME,
"ModelName": model_name,
"InitialInstanceCount": 1,
"InstanceType": TEST_INSTANCE_TYPE,
},
]
resp = boto_client.create_endpoint_config(
EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants
)