Sagemaker: add support for update_endpoint_weights_and_capacities (#5082)
This commit is contained in:
parent
406ca1d8f9
commit
eb49891118
@ -5142,7 +5142,7 @@
|
||||
- [ ] update_devices
|
||||
- [ ] update_domain
|
||||
- [ ] update_endpoint
|
||||
- [ ] update_endpoint_weights_and_capacities
|
||||
- [X] update_endpoint_weights_and_capacities
|
||||
- [ ] update_experiment
|
||||
- [ ] update_image
|
||||
- [ ] update_model_package
|
||||
|
@ -260,7 +260,7 @@ sagemaker
|
||||
- [ ] update_devices
|
||||
- [ ] update_domain
|
||||
- [ ] update_endpoint
|
||||
- [ ] update_endpoint_weights_and_capacities
|
||||
- [X] update_endpoint_weights_and_capacities
|
||||
- [ ] update_experiment
|
||||
- [ ] update_image
|
||||
- [ ] update_model_package
|
||||
|
@ -249,7 +249,9 @@ class FakeEndpoint(BaseObject, CloudFormationModel):
|
||||
self.endpoint_name = endpoint_name
|
||||
self.endpoint_arn = FakeEndpoint.arn_formatter(endpoint_name, region_name)
|
||||
self.endpoint_config_name = endpoint_config_name
|
||||
self.production_variants = production_variants
|
||||
self.production_variants = self._process_production_variants(
|
||||
production_variants
|
||||
)
|
||||
self.data_capture_config = data_capture_config
|
||||
self.tags = tags or []
|
||||
self.endpoint_status = "InService"
|
||||
@ -258,6 +260,42 @@ class FakeEndpoint(BaseObject, CloudFormationModel):
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
|
||||
def _process_production_variants(self, production_variants):
|
||||
endpoint_variants = []
|
||||
for production_variant in production_variants:
|
||||
temp_variant = {}
|
||||
|
||||
# VariantName is the only required param
|
||||
temp_variant["VariantName"] = production_variant["VariantName"]
|
||||
|
||||
if production_variant.get("InitialInstanceCount", None):
|
||||
temp_variant["CurrentInstanceCount"] = production_variant[
|
||||
"InitialInstanceCount"
|
||||
]
|
||||
temp_variant["DesiredInstanceCount"] = production_variant[
|
||||
"InitialInstanceCount"
|
||||
]
|
||||
|
||||
if production_variant.get("InitialVariantWeight", None):
|
||||
temp_variant["CurrentWeight"] = production_variant[
|
||||
"InitialVariantWeight"
|
||||
]
|
||||
temp_variant["DesiredWeight"] = production_variant[
|
||||
"InitialVariantWeight"
|
||||
]
|
||||
|
||||
if production_variant.get("ServerlessConfig", None):
|
||||
temp_variant["CurrentServerlessConfig"] = production_variant[
|
||||
"ServerlessConfig"
|
||||
]
|
||||
temp_variant["DesiredServerlessConfig"] = production_variant[
|
||||
"ServerlessConfig"
|
||||
]
|
||||
|
||||
endpoint_variants.append(temp_variant)
|
||||
|
||||
return endpoint_variants
|
||||
|
||||
@property
|
||||
def response_object(self):
|
||||
response_object = self.gen_response_object()
|
||||
@ -1607,7 +1645,7 @@ class SageMakerModelBackend(BaseBackend):
|
||||
try:
|
||||
return self.endpoints[endpoint_name].response_object
|
||||
except KeyError:
|
||||
message = "Could not find endpoint configuration '{}'.".format(
|
||||
message = "Could not find endpoint '{}'.".format(
|
||||
FakeEndpoint.arn_formatter(endpoint_name, self.region_name)
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
@ -1616,7 +1654,7 @@ class SageMakerModelBackend(BaseBackend):
|
||||
try:
|
||||
del self.endpoints[endpoint_name]
|
||||
except KeyError:
|
||||
message = "Could not find endpoint configuration '{}'.".format(
|
||||
message = "Could not find endpoint '{}'.".format(
|
||||
FakeEndpoint.arn_formatter(endpoint_name, self.region_name)
|
||||
)
|
||||
raise ValidationError(message=message)
|
||||
@ -1890,6 +1928,54 @@ class SageMakerModelBackend(BaseBackend):
|
||||
"NextToken": str(next_index) if next_index is not None else None,
|
||||
}
|
||||
|
||||
def update_endpoint_weights_and_capacities(
|
||||
self, endpoint_name, desired_weights_and_capacities
|
||||
):
|
||||
# Validate inputs
|
||||
endpoint = self.endpoints.get(endpoint_name, None)
|
||||
if not endpoint:
|
||||
raise AWSValidationException(
|
||||
f'Could not find endpoint "{FakeEndpoint.arn_formatter(endpoint_name, self.region_name)}".'
|
||||
)
|
||||
|
||||
names_checked = []
|
||||
for variant_config in desired_weights_and_capacities:
|
||||
name = variant_config.get("VariantName")
|
||||
|
||||
if name in names_checked:
|
||||
raise AWSValidationException(
|
||||
f'The variant name "{name}" was non-unique within the request.'
|
||||
)
|
||||
|
||||
if not any(
|
||||
variant["VariantName"] == name
|
||||
for variant in endpoint.production_variants
|
||||
):
|
||||
raise AWSValidationException(
|
||||
f'The variant name(s) "{name}" is/are not present within endpoint configuration "{endpoint.endpoint_config_name}".'
|
||||
)
|
||||
|
||||
names_checked.append(name)
|
||||
|
||||
# Update endpoint variants
|
||||
endpoint.endpoint_status = "Updating"
|
||||
|
||||
for variant_config in desired_weights_and_capacities:
|
||||
name = variant_config.get("VariantName")
|
||||
desired_weight = variant_config.get("DesiredWeight")
|
||||
desired_instance_count = variant_config.get("DesiredInstanceCount")
|
||||
|
||||
for variant in endpoint.production_variants:
|
||||
if variant.get("VariantName") == name:
|
||||
variant["DesiredWeight"] = desired_weight
|
||||
variant["CurrentWeight"] = desired_weight
|
||||
variant["DesiredInstanceCount"] = desired_instance_count
|
||||
variant["CurrentInstanceCount"] = desired_instance_count
|
||||
break
|
||||
|
||||
endpoint.endpoint_status = "InService"
|
||||
return endpoint.endpoint_arn
|
||||
|
||||
|
||||
class FakeExperiment(BaseObject):
|
||||
def __init__(self, region_name, experiment_name, tags):
|
||||
|
@ -1,4 +1,5 @@
|
||||
import json
|
||||
|
||||
from moto.sagemaker.exceptions import AWSValidationException
|
||||
|
||||
from moto.core.responses import BaseResponse
|
||||
@ -576,3 +577,13 @@ class SageMakerResponse(BaseResponse):
|
||||
status_equals=status_equals,
|
||||
)
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
def update_endpoint_weights_and_capacities(self):
|
||||
endpoint_name = self._get_param("EndpointName")
|
||||
desired_weights_and_capacities = self._get_param("DesiredWeightsAndCapacities")
|
||||
endpoint_arn = self.sagemaker_backend.update_endpoint_weights_and_capacities(
|
||||
endpoint_name=endpoint_name,
|
||||
desired_weights_and_capacities=desired_weights_and_capacities,
|
||||
)
|
||||
response = {"EndpointArn": endpoint_arn}
|
||||
return 200, {}, json.dumps(response)
|
||||
|
@ -18,12 +18,14 @@ GENERIC_TAGS_PARAM = [
|
||||
TEST_MODEL_NAME = "MyModel"
|
||||
TEST_ENDPOINT_NAME = "MyEndpoint"
|
||||
TEST_ENDPOINT_CONFIG_NAME = "MyEndpointConfig"
|
||||
TEST_VARIANT_NAME = "MyProductionVariant"
|
||||
TEST_INSTANCE_TYPE = "ml.t2.medium"
|
||||
TEST_PRODUCTION_VARIANTS = [
|
||||
{
|
||||
"VariantName": "MyProductionVariant",
|
||||
"VariantName": TEST_VARIANT_NAME,
|
||||
"ModelName": TEST_MODEL_NAME,
|
||||
"InitialInstanceCount": 1,
|
||||
"InstanceType": "ml.t2.medium",
|
||||
"InstanceType": TEST_INSTANCE_TYPE,
|
||||
},
|
||||
]
|
||||
|
||||
@ -162,7 +164,7 @@ def test_create_endpoint(sagemaker_client):
|
||||
resp["EndpointStatus"].should.equal("InService")
|
||||
assert isinstance(resp["CreationTime"], datetime.datetime)
|
||||
assert isinstance(resp["LastModifiedTime"], datetime.datetime)
|
||||
resp["ProductionVariants"][0]["VariantName"].should.equal("MyProductionVariant")
|
||||
resp["ProductionVariants"][0]["VariantName"].should.equal(TEST_VARIANT_NAME)
|
||||
|
||||
resp = sagemaker_client.list_tags(ResourceArn=resp["EndpointArn"])
|
||||
assert resp["Tags"] == GENERIC_TAGS_PARAM
|
||||
@ -245,11 +247,256 @@ def test_list_tags_endpoint(sagemaker_client):
|
||||
assert response["Tags"] == tags[50:]
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_update_endpoint_weights_and_capacities_one_variant(sagemaker_client):
|
||||
_set_up_sagemaker_resources(
|
||||
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
|
||||
)
|
||||
|
||||
new_desired_weight = 1.5
|
||||
new_desired_instance_count = 123
|
||||
|
||||
response = sagemaker_client.update_endpoint_weights_and_capacities(
|
||||
EndpointName=TEST_ENDPOINT_NAME,
|
||||
DesiredWeightsAndCapacities=[
|
||||
{
|
||||
"VariantName": TEST_VARIANT_NAME,
|
||||
"DesiredWeight": new_desired_weight,
|
||||
"DesiredInstanceCount": new_desired_instance_count,
|
||||
},
|
||||
],
|
||||
)
|
||||
response["EndpointArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(TEST_ENDPOINT_NAME)
|
||||
)
|
||||
|
||||
resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME)
|
||||
resp["EndpointArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(TEST_ENDPOINT_NAME)
|
||||
)
|
||||
resp["EndpointName"].should.equal(TEST_ENDPOINT_NAME)
|
||||
resp["EndpointConfigName"].should.equal(TEST_ENDPOINT_CONFIG_NAME)
|
||||
resp["EndpointStatus"].should.equal("InService")
|
||||
assert isinstance(resp["CreationTime"], datetime.datetime)
|
||||
assert isinstance(resp["LastModifiedTime"], datetime.datetime)
|
||||
|
||||
resp["ProductionVariants"][0]["VariantName"].should.equal(TEST_VARIANT_NAME)
|
||||
resp["ProductionVariants"][0]["DesiredInstanceCount"].should.equal(
|
||||
new_desired_instance_count
|
||||
)
|
||||
resp["ProductionVariants"][0]["CurrentInstanceCount"].should.equal(
|
||||
new_desired_instance_count
|
||||
)
|
||||
resp["ProductionVariants"][0]["DesiredWeight"].should.equal(new_desired_weight)
|
||||
resp["ProductionVariants"][0]["CurrentWeight"].should.equal(new_desired_weight)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_update_endpoint_weights_and_capacities_two_variants(sagemaker_client):
|
||||
production_variants = [
|
||||
{
|
||||
"VariantName": "MyProductionVariant1",
|
||||
"ModelName": TEST_MODEL_NAME,
|
||||
"InitialInstanceCount": 1,
|
||||
"InstanceType": TEST_INSTANCE_TYPE,
|
||||
},
|
||||
{
|
||||
"VariantName": "MyProductionVariant2",
|
||||
"ModelName": TEST_MODEL_NAME,
|
||||
"InitialInstanceCount": 1,
|
||||
"InstanceType": TEST_INSTANCE_TYPE,
|
||||
},
|
||||
]
|
||||
|
||||
_set_up_sagemaker_resources(
|
||||
sagemaker_client,
|
||||
TEST_ENDPOINT_NAME,
|
||||
TEST_ENDPOINT_CONFIG_NAME,
|
||||
TEST_MODEL_NAME,
|
||||
production_variants,
|
||||
)
|
||||
|
||||
desired_weights_and_capacities = [
|
||||
{
|
||||
"VariantName": "MyProductionVariant1",
|
||||
"DesiredWeight": 1.5,
|
||||
"DesiredInstanceCount": 123,
|
||||
},
|
||||
{
|
||||
"VariantName": "MyProductionVariant2",
|
||||
"DesiredWeight": 1.5,
|
||||
"DesiredInstanceCount": 123,
|
||||
},
|
||||
]
|
||||
|
||||
new_desired_weight = 1.5
|
||||
new_desired_instance_count = 123
|
||||
|
||||
response = sagemaker_client.update_endpoint_weights_and_capacities(
|
||||
EndpointName=TEST_ENDPOINT_NAME,
|
||||
DesiredWeightsAndCapacities=desired_weights_and_capacities,
|
||||
)
|
||||
response["EndpointArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(TEST_ENDPOINT_NAME)
|
||||
)
|
||||
|
||||
resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME)
|
||||
resp["EndpointArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(TEST_ENDPOINT_NAME)
|
||||
)
|
||||
resp["EndpointName"].should.equal(TEST_ENDPOINT_NAME)
|
||||
resp["EndpointConfigName"].should.equal(TEST_ENDPOINT_CONFIG_NAME)
|
||||
resp["EndpointStatus"].should.equal("InService")
|
||||
assert isinstance(resp["CreationTime"], datetime.datetime)
|
||||
assert isinstance(resp["LastModifiedTime"], datetime.datetime)
|
||||
|
||||
resp["ProductionVariants"][0]["VariantName"].should.equal("MyProductionVariant1")
|
||||
resp["ProductionVariants"][0]["DesiredInstanceCount"].should.equal(
|
||||
new_desired_instance_count
|
||||
)
|
||||
resp["ProductionVariants"][0]["CurrentInstanceCount"].should.equal(
|
||||
new_desired_instance_count
|
||||
)
|
||||
resp["ProductionVariants"][0]["DesiredWeight"].should.equal(new_desired_weight)
|
||||
resp["ProductionVariants"][0]["CurrentWeight"].should.equal(new_desired_weight)
|
||||
|
||||
resp["ProductionVariants"][1]["VariantName"].should.equal("MyProductionVariant2")
|
||||
resp["ProductionVariants"][1]["DesiredInstanceCount"].should.equal(
|
||||
new_desired_instance_count
|
||||
)
|
||||
resp["ProductionVariants"][1]["CurrentInstanceCount"].should.equal(
|
||||
new_desired_instance_count
|
||||
)
|
||||
resp["ProductionVariants"][1]["DesiredWeight"].should.equal(new_desired_weight)
|
||||
resp["ProductionVariants"][1]["CurrentWeight"].should.equal(new_desired_weight)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_variant(
|
||||
sagemaker_client,
|
||||
):
|
||||
_set_up_sagemaker_resources(
|
||||
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
|
||||
)
|
||||
|
||||
old_resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME)
|
||||
del old_resp["ResponseMetadata"]
|
||||
|
||||
variant_name = "SillyNotCorrectName"
|
||||
new_desired_weight = 1.5
|
||||
new_desired_instance_count = 123
|
||||
|
||||
with pytest.raises(ClientError) as exc:
|
||||
sagemaker_client.update_endpoint_weights_and_capacities(
|
||||
EndpointName=TEST_ENDPOINT_NAME,
|
||||
DesiredWeightsAndCapacities=[
|
||||
{
|
||||
"VariantName": variant_name,
|
||||
"DesiredWeight": new_desired_weight,
|
||||
"DesiredInstanceCount": new_desired_instance_count,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
err = exc.value.response["Error"]
|
||||
err["Message"].should.equal(
|
||||
f'The variant name(s) "{variant_name}" is/are not present within endpoint configuration "{TEST_ENDPOINT_CONFIG_NAME}".'
|
||||
)
|
||||
|
||||
resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME)
|
||||
del resp["ResponseMetadata"]
|
||||
resp.should.equal(old_resp)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_endpoint(
|
||||
sagemaker_client,
|
||||
):
|
||||
_set_up_sagemaker_resources(
|
||||
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
|
||||
)
|
||||
|
||||
old_resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME)
|
||||
del old_resp["ResponseMetadata"]
|
||||
|
||||
endpoint_name = "SillyEndpointName"
|
||||
variant_name = "SillyNotCorrectName"
|
||||
new_desired_weight = 1.5
|
||||
new_desired_instance_count = 123
|
||||
|
||||
with pytest.raises(ClientError) as exc:
|
||||
sagemaker_client.update_endpoint_weights_and_capacities(
|
||||
EndpointName=endpoint_name,
|
||||
DesiredWeightsAndCapacities=[
|
||||
{
|
||||
"VariantName": variant_name,
|
||||
"DesiredWeight": new_desired_weight,
|
||||
"DesiredInstanceCount": new_desired_instance_count,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
err = exc.value.response["Error"]
|
||||
err["Message"].should.equal(
|
||||
f'Could not find endpoint "arn:aws:sagemaker:us-east-1:{ACCOUNT_ID}:endpoint/{endpoint_name}".'
|
||||
)
|
||||
|
||||
resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME)
|
||||
del resp["ResponseMetadata"]
|
||||
resp.should.equal(old_resp)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_update_endpoint_weights_and_capacities_should_throw_clienterror_nonunique_variant(
|
||||
sagemaker_client,
|
||||
):
|
||||
_set_up_sagemaker_resources(
|
||||
sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME
|
||||
)
|
||||
|
||||
old_resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME)
|
||||
del old_resp["ResponseMetadata"]
|
||||
|
||||
desired_weights_and_capacities = [
|
||||
{
|
||||
"VariantName": TEST_VARIANT_NAME,
|
||||
"DesiredWeight": 1.5,
|
||||
"DesiredInstanceCount": 123,
|
||||
},
|
||||
{
|
||||
"VariantName": TEST_VARIANT_NAME,
|
||||
"DesiredWeight": 1.5,
|
||||
"DesiredInstanceCount": 123,
|
||||
},
|
||||
]
|
||||
|
||||
with pytest.raises(ClientError) as exc:
|
||||
sagemaker_client.update_endpoint_weights_and_capacities(
|
||||
EndpointName=TEST_ENDPOINT_NAME,
|
||||
DesiredWeightsAndCapacities=desired_weights_and_capacities,
|
||||
)
|
||||
|
||||
err = exc.value.response["Error"]
|
||||
err["Message"].should.equal(
|
||||
f'The variant name "{TEST_VARIANT_NAME}" was non-unique within the request.'
|
||||
)
|
||||
|
||||
resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME)
|
||||
del resp["ResponseMetadata"]
|
||||
resp.should.equal(old_resp)
|
||||
|
||||
|
||||
def _set_up_sagemaker_resources(
|
||||
boto_client, endpoint_name, endpoint_config_name, model_name
|
||||
boto_client,
|
||||
endpoint_name,
|
||||
endpoint_config_name,
|
||||
model_name,
|
||||
production_variants=None,
|
||||
):
|
||||
_create_model(boto_client, model_name)
|
||||
_create_endpoint_config(boto_client, endpoint_config_name, model_name)
|
||||
_create_endpoint_config(
|
||||
boto_client, endpoint_config_name, model_name, production_variants
|
||||
)
|
||||
_create_endpoint(boto_client, endpoint_name, endpoint_config_name)
|
||||
|
||||
|
||||
@ -265,15 +512,18 @@ def _create_model(boto_client, model_name):
|
||||
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
|
||||
|
||||
|
||||
def _create_endpoint_config(boto_client, endpoint_config_name, model_name):
|
||||
production_variants = [
|
||||
{
|
||||
"VariantName": "MyProductionVariant",
|
||||
"ModelName": model_name,
|
||||
"InitialInstanceCount": 1,
|
||||
"InstanceType": "ml.t2.medium",
|
||||
},
|
||||
]
|
||||
def _create_endpoint_config(
|
||||
boto_client, endpoint_config_name, model_name, production_variants=None
|
||||
):
|
||||
if not production_variants:
|
||||
production_variants = [
|
||||
{
|
||||
"VariantName": TEST_VARIANT_NAME,
|
||||
"ModelName": model_name,
|
||||
"InitialInstanceCount": 1,
|
||||
"InstanceType": TEST_INSTANCE_TYPE,
|
||||
},
|
||||
]
|
||||
resp = boto_client.create_endpoint_config(
|
||||
EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user