| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  | import datetime | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  | import re | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  | import uuid | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  | import boto3 | 
					
						
							| 
									
										
										
										
											2023-11-30 07:55:51 -08:00
										 |  |  | import pytest | 
					
						
							| 
									
										
										
										
											2020-10-06 07:54:49 +02:00
										 |  |  | from botocore.exceptions import ClientError | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | from moto import mock_sagemaker | 
					
						
							| 
									
										
										
										
											2022-08-13 09:49:43 +00:00
										 |  |  | from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | TEST_REGION_NAME = "us-east-1" | 
					
						
							| 
									
										
										
										
											2022-11-17 21:41:08 -01:00
										 |  |  | TEST_ROLE_ARN = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole" | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  | GENERIC_TAGS_PARAM = [ | 
					
						
							|  |  |  |     {"Key": "newkey1", "Value": "newval1"}, | 
					
						
							|  |  |  |     {"Key": "newkey2", "Value": "newval2"}, | 
					
						
							|  |  |  | ] | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  | TEST_MODEL_NAME = "MyModel" | 
					
						
							|  |  |  | TEST_ENDPOINT_NAME = "MyEndpoint" | 
					
						
							|  |  |  | TEST_ENDPOINT_CONFIG_NAME = "MyEndpointConfig" | 
					
						
							| 
									
										
										
										
											2022-05-04 10:36:46 +01:00
										 |  |  | TEST_VARIANT_NAME = "MyProductionVariant" | 
					
						
							|  |  |  | TEST_INSTANCE_TYPE = "ml.t2.medium" | 
					
						
							| 
									
										
										
										
											2022-09-02 14:42:50 -04:00
										 |  |  | TEST_MEMORY_SIZE = 1024 | 
					
						
							|  |  |  | TEST_CONCURRENCY = 10 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  | TEST_PRODUCTION_VARIANTS = [ | 
					
						
							|  |  |  |     { | 
					
						
							| 
									
										
										
										
											2022-05-04 10:36:46 +01:00
										 |  |  |         "VariantName": TEST_VARIANT_NAME, | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |         "ModelName": TEST_MODEL_NAME, | 
					
						
							|  |  |  |         "InitialInstanceCount": 1, | 
					
						
							| 
									
										
										
										
											2022-05-04 10:36:46 +01:00
										 |  |  |         "InstanceType": TEST_INSTANCE_TYPE, | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     }, | 
					
						
							|  |  |  | ] | 
					
						
							| 
									
										
										
										
											2022-09-02 14:42:50 -04:00
										 |  |  | TEST_SERVERLESS_PRODUCTION_VARIANTS = [ | 
					
						
							|  |  |  |     { | 
					
						
							|  |  |  |         "VariantName": TEST_VARIANT_NAME, | 
					
						
							|  |  |  |         "ModelName": TEST_MODEL_NAME, | 
					
						
							|  |  |  |         "ServerlessConfig": { | 
					
						
							|  |  |  |             "MemorySizeInMB": TEST_MEMORY_SIZE, | 
					
						
							|  |  |  |             "MaxConcurrency": TEST_CONCURRENCY, | 
					
						
							|  |  |  |         }, | 
					
						
							|  |  |  |     }, | 
					
						
							|  |  |  | ] | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-04 16:28:30 +00:00
										 |  |  | @pytest.fixture(name="sagemaker_client") | 
					
						
							|  |  |  | def fixture_sagemaker_client(): | 
					
						
							|  |  |  |     with mock_sagemaker(): | 
					
						
							|  |  |  |         yield boto3.client("sagemaker", region_name=TEST_REGION_NAME) | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-02 14:42:50 -04:00
										 |  |  | def create_endpoint_config_helper(sagemaker_client, production_variants): | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     _create_model(sagemaker_client, TEST_MODEL_NAME) | 
					
						
							| 
									
										
										
										
											2022-09-02 14:42:50 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     resp = sagemaker_client.create_endpoint_config( | 
					
						
							|  |  |  |         EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME, | 
					
						
							| 
									
										
										
										
											2022-09-02 14:42:50 -04:00
										 |  |  |         ProductionVariants=production_variants, | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert re.match( | 
					
						
							|  |  |  |         rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{TEST_ENDPOINT_CONFIG_NAME}$", | 
					
						
							|  |  |  |         resp["EndpointConfigArn"], | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     resp = sagemaker_client.describe_endpoint_config( | 
					
						
							|  |  |  |         EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert re.match( | 
					
						
							|  |  |  |         rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{TEST_ENDPOINT_CONFIG_NAME}$", | 
					
						
							|  |  |  |         resp["EndpointConfigArn"], | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert resp["EndpointConfigName"] == TEST_ENDPOINT_CONFIG_NAME | 
					
						
							|  |  |  |     assert resp["ProductionVariants"] == production_variants | 
					
						
							| 
									
										
										
										
											2022-09-02 14:42:50 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_create_endpoint_config(sagemaker_client): | 
					
						
							|  |  |  |     with pytest.raises(ClientError) as e: | 
					
						
							|  |  |  |         sagemaker_client.create_endpoint_config( | 
					
						
							|  |  |  |             EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME, | 
					
						
							|  |  |  |             ProductionVariants=TEST_PRODUCTION_VARIANTS, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     assert e.value.response["Error"]["Message"].startswith("Could not find model") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Testing instance-based endpoint configuration | 
					
						
							|  |  |  |     create_endpoint_config_helper(sagemaker_client, TEST_PRODUCTION_VARIANTS) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_create_endpoint_config_serverless(sagemaker_client): | 
					
						
							|  |  |  |     with pytest.raises(ClientError) as e: | 
					
						
							|  |  |  |         sagemaker_client.create_endpoint_config( | 
					
						
							|  |  |  |             EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME, | 
					
						
							|  |  |  |             ProductionVariants=TEST_SERVERLESS_PRODUCTION_VARIANTS, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     assert e.value.response["Error"]["Message"].startswith("Could not find model") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Testing serverless endpoint configuration | 
					
						
							|  |  |  |     create_endpoint_config_helper(sagemaker_client, TEST_SERVERLESS_PRODUCTION_VARIANTS) | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  | def test_delete_endpoint_config(sagemaker_client): | 
					
						
							|  |  |  |     _create_model(sagemaker_client, TEST_MODEL_NAME) | 
					
						
							|  |  |  |     resp = sagemaker_client.create_endpoint_config( | 
					
						
							|  |  |  |         EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME, | 
					
						
							|  |  |  |         ProductionVariants=TEST_PRODUCTION_VARIANTS, | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert re.match( | 
					
						
							|  |  |  |         rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{TEST_ENDPOINT_CONFIG_NAME}$", | 
					
						
							|  |  |  |         resp["EndpointConfigArn"], | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     resp = sagemaker_client.describe_endpoint_config( | 
					
						
							|  |  |  |         EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert re.match( | 
					
						
							|  |  |  |         rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{TEST_ENDPOINT_CONFIG_NAME}$", | 
					
						
							|  |  |  |         resp["EndpointConfigArn"], | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     sagemaker_client.delete_endpoint_config( | 
					
						
							|  |  |  |         EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2020-10-06 07:54:49 +02:00
										 |  |  |     with pytest.raises(ClientError) as e: | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |         sagemaker_client.describe_endpoint_config( | 
					
						
							|  |  |  |             EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2020-10-06 08:46:05 +02:00
										 |  |  |     assert e.value.response["Error"]["Message"].startswith( | 
					
						
							|  |  |  |         "Could not find endpoint configuration" | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-06 07:54:49 +02:00
										 |  |  |     with pytest.raises(ClientError) as e: | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |         sagemaker_client.delete_endpoint_config( | 
					
						
							|  |  |  |             EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2020-10-06 08:46:05 +02:00
										 |  |  |     assert e.value.response["Error"]["Message"].startswith( | 
					
						
							|  |  |  |         "Could not find endpoint configuration" | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  | def test_create_endpoint_invalid_instance_type(sagemaker_client): | 
					
						
							|  |  |  |     _create_model(sagemaker_client, TEST_MODEL_NAME) | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  |     instance_type = "InvalidInstanceType" | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     production_variants = TEST_PRODUCTION_VARIANTS | 
					
						
							|  |  |  |     production_variants[0]["InstanceType"] = instance_type | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-06 07:54:49 +02:00
										 |  |  |     with pytest.raises(ClientError) as e: | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |         sagemaker_client.create_endpoint_config( | 
					
						
							|  |  |  |             EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME, | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  |             ProductionVariants=production_variants, | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2020-10-06 08:04:09 +02:00
										 |  |  |     assert e.value.response["Error"]["Code"] == "ValidationException" | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     expected_message = ( | 
					
						
							|  |  |  |         f"Value '{instance_type}' at 'instanceType' failed to satisfy " | 
					
						
							|  |  |  |         "constraint: Member must satisfy enum value set: [" | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2020-10-06 08:04:09 +02:00
										 |  |  |     assert expected_message in e.value.response["Error"]["Message"] | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-02 14:42:50 -04:00
										 |  |  | def test_create_endpoint_invalid_memory_size(sagemaker_client): | 
					
						
							|  |  |  |     _create_model(sagemaker_client, TEST_MODEL_NAME) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     memory_size = 1111 | 
					
						
							|  |  |  |     production_variants = TEST_SERVERLESS_PRODUCTION_VARIANTS | 
					
						
							|  |  |  |     production_variants[0]["ServerlessConfig"]["MemorySizeInMB"] = memory_size | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     with pytest.raises(ClientError) as e: | 
					
						
							|  |  |  |         sagemaker_client.create_endpoint_config( | 
					
						
							|  |  |  |             EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME, | 
					
						
							|  |  |  |             ProductionVariants=production_variants, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     assert e.value.response["Error"]["Code"] == "ValidationException" | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     expected_message = ( | 
					
						
							|  |  |  |         f"Value '{memory_size}' at 'MemorySizeInMB' failed to satisfy " | 
					
						
							|  |  |  |         "constraint: Member must satisfy enum value set: [" | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2022-09-02 14:42:50 -04:00
										 |  |  |     assert expected_message in e.value.response["Error"]["Message"] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  | def test_create_endpoint(sagemaker_client): | 
					
						
							| 
									
										
										
										
											2020-10-06 07:54:49 +02:00
										 |  |  |     with pytest.raises(ClientError) as e: | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |         sagemaker_client.create_endpoint( | 
					
						
							|  |  |  |             EndpointName=TEST_ENDPOINT_NAME, | 
					
						
							|  |  |  |             EndpointConfigName="NonexistentEndpointConfig", | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2020-10-06 08:46:05 +02:00
										 |  |  |     assert e.value.response["Error"]["Message"].startswith( | 
					
						
							|  |  |  |         "Could not find endpoint configuration" | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     _create_model(sagemaker_client, TEST_MODEL_NAME) | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     _create_endpoint_config( | 
					
						
							|  |  |  |         sagemaker_client, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     resp = sagemaker_client.create_endpoint( | 
					
						
							|  |  |  |         EndpointName=TEST_ENDPOINT_NAME, | 
					
						
							|  |  |  |         EndpointConfigName=TEST_ENDPOINT_CONFIG_NAME, | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  |         Tags=GENERIC_TAGS_PARAM, | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert re.match( | 
					
						
							|  |  |  |         rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$", resp["EndpointArn"] | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert re.match( | 
					
						
							|  |  |  |         rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$", resp["EndpointArn"] | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert resp["EndpointName"] == TEST_ENDPOINT_NAME | 
					
						
							|  |  |  |     assert resp["EndpointConfigName"] == TEST_ENDPOINT_CONFIG_NAME | 
					
						
							|  |  |  |     assert resp["EndpointStatus"] == "InService" | 
					
						
							| 
									
										
										
										
											2020-10-06 07:54:49 +02:00
										 |  |  |     assert isinstance(resp["CreationTime"], datetime.datetime) | 
					
						
							|  |  |  |     assert isinstance(resp["LastModifiedTime"], datetime.datetime) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert resp["ProductionVariants"][0]["VariantName"] == TEST_VARIANT_NAME | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     resp = sagemaker_client.list_tags(ResourceArn=resp["EndpointArn"]) | 
					
						
							| 
									
										
										
										
											2020-10-06 07:54:49 +02:00
										 |  |  |     assert resp["Tags"] == GENERIC_TAGS_PARAM | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  | def test_delete_endpoint(sagemaker_client): | 
					
						
							|  |  |  |     _set_up_sagemaker_resources( | 
					
						
							|  |  |  |         sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     sagemaker_client.delete_endpoint(EndpointName=TEST_ENDPOINT_NAME) | 
					
						
							| 
									
										
										
										
											2020-10-06 07:54:49 +02:00
										 |  |  |     with pytest.raises(ClientError) as e: | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |         sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME) | 
					
						
							| 
									
										
										
										
											2020-10-06 08:04:09 +02:00
										 |  |  |     assert e.value.response["Error"]["Message"].startswith("Could not find endpoint") | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-10-06 07:54:49 +02:00
										 |  |  |     with pytest.raises(ClientError) as e: | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |         sagemaker_client.delete_endpoint(EndpointName=TEST_ENDPOINT_NAME) | 
					
						
							| 
									
										
										
										
											2020-10-06 08:04:09 +02:00
										 |  |  |     assert e.value.response["Error"]["Message"].startswith("Could not find endpoint") | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  | def test_add_tags_endpoint(sagemaker_client): | 
					
						
							|  |  |  |     _set_up_sagemaker_resources( | 
					
						
							|  |  |  |         sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     resource_arn = ( | 
					
						
							|  |  |  |         f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}" | 
					
						
							|  |  |  |         f":endpoint/{TEST_ENDPOINT_NAME}" | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     response = sagemaker_client.add_tags( | 
					
						
							|  |  |  |         ResourceArn=resource_arn, Tags=GENERIC_TAGS_PARAM | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     response = sagemaker_client.list_tags(ResourceArn=resource_arn) | 
					
						
							|  |  |  |     assert response["Tags"] == GENERIC_TAGS_PARAM | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_delete_tags_endpoint(sagemaker_client): | 
					
						
							|  |  |  |     _set_up_sagemaker_resources( | 
					
						
							|  |  |  |         sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     resource_arn = ( | 
					
						
							|  |  |  |         f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}" | 
					
						
							|  |  |  |         f":endpoint/{TEST_ENDPOINT_NAME}" | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     response = sagemaker_client.add_tags( | 
					
						
							|  |  |  |         ResourceArn=resource_arn, Tags=GENERIC_TAGS_PARAM | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     tag_keys = [tag["Key"] for tag in GENERIC_TAGS_PARAM] | 
					
						
							|  |  |  |     response = sagemaker_client.delete_tags(ResourceArn=resource_arn, TagKeys=tag_keys) | 
					
						
							|  |  |  |     assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     response = sagemaker_client.list_tags(ResourceArn=resource_arn) | 
					
						
							|  |  |  |     assert response["Tags"] == [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_list_tags_endpoint(sagemaker_client): | 
					
						
							|  |  |  |     _set_up_sagemaker_resources( | 
					
						
							|  |  |  |         sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     tags = [] | 
					
						
							|  |  |  |     for _ in range(80): | 
					
						
							|  |  |  |         tags.append({"Key": str(uuid.uuid4()), "Value": "myValue"}) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     resource_arn = ( | 
					
						
							|  |  |  |         f"arn:aws:sagemaker:{TEST_REGION_NAME}:{ACCOUNT_ID}" | 
					
						
							|  |  |  |         f":endpoint/{TEST_ENDPOINT_NAME}" | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     response = sagemaker_client.add_tags(ResourceArn=resource_arn, Tags=tags) | 
					
						
							|  |  |  |     assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     response = sagemaker_client.list_tags(ResourceArn=resource_arn) | 
					
						
							|  |  |  |     assert len(response["Tags"]) == 50 | 
					
						
							|  |  |  |     assert response["Tags"] == tags[:50] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     response = sagemaker_client.list_tags( | 
					
						
							|  |  |  |         ResourceArn=resource_arn, NextToken=response["NextToken"] | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     assert len(response["Tags"]) == 30 | 
					
						
							|  |  |  |     assert response["Tags"] == tags[50:] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-05-04 10:36:46 +01:00
										 |  |  | def test_update_endpoint_weights_and_capacities_one_variant(sagemaker_client): | 
					
						
							|  |  |  |     _set_up_sagemaker_resources( | 
					
						
							|  |  |  |         sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     new_desired_weight = 1.5 | 
					
						
							|  |  |  |     new_desired_instance_count = 123 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     response = sagemaker_client.update_endpoint_weights_and_capacities( | 
					
						
							|  |  |  |         EndpointName=TEST_ENDPOINT_NAME, | 
					
						
							|  |  |  |         DesiredWeightsAndCapacities=[ | 
					
						
							|  |  |  |             { | 
					
						
							|  |  |  |                 "VariantName": TEST_VARIANT_NAME, | 
					
						
							|  |  |  |                 "DesiredWeight": new_desired_weight, | 
					
						
							|  |  |  |                 "DesiredInstanceCount": new_desired_instance_count, | 
					
						
							|  |  |  |             }, | 
					
						
							|  |  |  |         ], | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert re.match( | 
					
						
							|  |  |  |         rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$", | 
					
						
							|  |  |  |         response["EndpointArn"], | 
					
						
							| 
									
										
										
										
											2022-05-04 10:36:46 +01:00
										 |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert re.match( | 
					
						
							|  |  |  |         rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$", resp["EndpointArn"] | 
					
						
							| 
									
										
										
										
											2022-05-04 10:36:46 +01:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert resp["EndpointName"] == TEST_ENDPOINT_NAME | 
					
						
							|  |  |  |     assert resp["EndpointConfigName"] == TEST_ENDPOINT_CONFIG_NAME | 
					
						
							|  |  |  |     assert resp["EndpointStatus"] == "InService" | 
					
						
							| 
									
										
										
										
											2022-05-04 10:36:46 +01:00
										 |  |  |     assert isinstance(resp["CreationTime"], datetime.datetime) | 
					
						
							|  |  |  |     assert isinstance(resp["LastModifiedTime"], datetime.datetime) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert resp["ProductionVariants"][0]["VariantName"] == TEST_VARIANT_NAME | 
					
						
							|  |  |  |     assert ( | 
					
						
							|  |  |  |         resp["ProductionVariants"][0]["DesiredInstanceCount"] | 
					
						
							|  |  |  |         == new_desired_instance_count | 
					
						
							| 
									
										
										
										
											2022-05-04 10:36:46 +01:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert ( | 
					
						
							|  |  |  |         resp["ProductionVariants"][0]["CurrentInstanceCount"] | 
					
						
							|  |  |  |         == new_desired_instance_count | 
					
						
							| 
									
										
										
										
											2022-05-04 10:36:46 +01:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert resp["ProductionVariants"][0]["DesiredWeight"] == new_desired_weight | 
					
						
							|  |  |  |     assert resp["ProductionVariants"][0]["CurrentWeight"] == new_desired_weight | 
					
						
							| 
									
										
										
										
											2022-05-04 10:36:46 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_update_endpoint_weights_and_capacities_two_variants(sagemaker_client): | 
					
						
							|  |  |  |     production_variants = [ | 
					
						
							|  |  |  |         { | 
					
						
							|  |  |  |             "VariantName": "MyProductionVariant1", | 
					
						
							|  |  |  |             "ModelName": TEST_MODEL_NAME, | 
					
						
							|  |  |  |             "InitialInstanceCount": 1, | 
					
						
							|  |  |  |             "InstanceType": TEST_INSTANCE_TYPE, | 
					
						
							|  |  |  |         }, | 
					
						
							|  |  |  |         { | 
					
						
							|  |  |  |             "VariantName": "MyProductionVariant2", | 
					
						
							|  |  |  |             "ModelName": TEST_MODEL_NAME, | 
					
						
							|  |  |  |             "InitialInstanceCount": 1, | 
					
						
							|  |  |  |             "InstanceType": TEST_INSTANCE_TYPE, | 
					
						
							|  |  |  |         }, | 
					
						
							|  |  |  |     ] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     _set_up_sagemaker_resources( | 
					
						
							|  |  |  |         sagemaker_client, | 
					
						
							|  |  |  |         TEST_ENDPOINT_NAME, | 
					
						
							|  |  |  |         TEST_ENDPOINT_CONFIG_NAME, | 
					
						
							|  |  |  |         TEST_MODEL_NAME, | 
					
						
							|  |  |  |         production_variants, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     desired_weights_and_capacities = [ | 
					
						
							|  |  |  |         { | 
					
						
							|  |  |  |             "VariantName": "MyProductionVariant1", | 
					
						
							|  |  |  |             "DesiredWeight": 1.5, | 
					
						
							|  |  |  |             "DesiredInstanceCount": 123, | 
					
						
							|  |  |  |         }, | 
					
						
							|  |  |  |         { | 
					
						
							|  |  |  |             "VariantName": "MyProductionVariant2", | 
					
						
							|  |  |  |             "DesiredWeight": 1.5, | 
					
						
							|  |  |  |             "DesiredInstanceCount": 123, | 
					
						
							|  |  |  |         }, | 
					
						
							|  |  |  |     ] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     new_desired_weight = 1.5 | 
					
						
							|  |  |  |     new_desired_instance_count = 123 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     response = sagemaker_client.update_endpoint_weights_and_capacities( | 
					
						
							|  |  |  |         EndpointName=TEST_ENDPOINT_NAME, | 
					
						
							|  |  |  |         DesiredWeightsAndCapacities=desired_weights_and_capacities, | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert re.match( | 
					
						
							|  |  |  |         rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$", | 
					
						
							|  |  |  |         response["EndpointArn"], | 
					
						
							| 
									
										
										
										
											2022-05-04 10:36:46 +01:00
										 |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert re.match( | 
					
						
							|  |  |  |         rf"^arn:aws:sagemaker:.*:.*:endpoint/{TEST_ENDPOINT_NAME}$", resp["EndpointArn"] | 
					
						
							| 
									
										
										
										
											2022-05-04 10:36:46 +01:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert resp["EndpointName"] == TEST_ENDPOINT_NAME | 
					
						
							|  |  |  |     assert resp["EndpointConfigName"] == TEST_ENDPOINT_CONFIG_NAME | 
					
						
							|  |  |  |     assert resp["EndpointStatus"] == "InService" | 
					
						
							| 
									
										
										
										
											2022-05-04 10:36:46 +01:00
										 |  |  |     assert isinstance(resp["CreationTime"], datetime.datetime) | 
					
						
							|  |  |  |     assert isinstance(resp["LastModifiedTime"], datetime.datetime) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert resp["ProductionVariants"][0]["VariantName"] == "MyProductionVariant1" | 
					
						
							|  |  |  |     assert ( | 
					
						
							|  |  |  |         resp["ProductionVariants"][0]["DesiredInstanceCount"] | 
					
						
							|  |  |  |         == new_desired_instance_count | 
					
						
							| 
									
										
										
										
											2022-05-04 10:36:46 +01:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert ( | 
					
						
							|  |  |  |         resp["ProductionVariants"][0]["CurrentInstanceCount"] | 
					
						
							|  |  |  |         == new_desired_instance_count | 
					
						
							| 
									
										
										
										
											2022-05-04 10:36:46 +01:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert resp["ProductionVariants"][0]["DesiredWeight"] == new_desired_weight | 
					
						
							|  |  |  |     assert resp["ProductionVariants"][0]["CurrentWeight"] == new_desired_weight | 
					
						
							| 
									
										
										
										
											2022-05-04 10:36:46 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert resp["ProductionVariants"][1]["VariantName"] == "MyProductionVariant2" | 
					
						
							|  |  |  |     assert ( | 
					
						
							|  |  |  |         resp["ProductionVariants"][1]["DesiredInstanceCount"] | 
					
						
							|  |  |  |         == new_desired_instance_count | 
					
						
							| 
									
										
										
										
											2022-05-04 10:36:46 +01:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert ( | 
					
						
							|  |  |  |         resp["ProductionVariants"][1]["CurrentInstanceCount"] | 
					
						
							|  |  |  |         == new_desired_instance_count | 
					
						
							| 
									
										
										
										
											2022-05-04 10:36:46 +01:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert resp["ProductionVariants"][1]["DesiredWeight"] == new_desired_weight | 
					
						
							|  |  |  |     assert resp["ProductionVariants"][1]["CurrentWeight"] == new_desired_weight | 
					
						
							| 
									
										
										
										
											2022-05-04 10:36:46 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_variant( | 
					
						
							|  |  |  |     sagemaker_client, | 
					
						
							|  |  |  | ): | 
					
						
							|  |  |  |     _set_up_sagemaker_resources( | 
					
						
							|  |  |  |         sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     old_resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME) | 
					
						
							|  |  |  |     del old_resp["ResponseMetadata"] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     variant_name = "SillyNotCorrectName" | 
					
						
							|  |  |  |     new_desired_weight = 1.5 | 
					
						
							|  |  |  |     new_desired_instance_count = 123 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     with pytest.raises(ClientError) as exc: | 
					
						
							|  |  |  |         sagemaker_client.update_endpoint_weights_and_capacities( | 
					
						
							|  |  |  |             EndpointName=TEST_ENDPOINT_NAME, | 
					
						
							|  |  |  |             DesiredWeightsAndCapacities=[ | 
					
						
							|  |  |  |                 { | 
					
						
							|  |  |  |                     "VariantName": variant_name, | 
					
						
							|  |  |  |                     "DesiredWeight": new_desired_weight, | 
					
						
							|  |  |  |                     "DesiredInstanceCount": new_desired_instance_count, | 
					
						
							|  |  |  |                 }, | 
					
						
							|  |  |  |             ], | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     err = exc.value.response["Error"] | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert err["Message"] == ( | 
					
						
							|  |  |  |         f'The variant name(s) "{variant_name}" is/are not present within ' | 
					
						
							|  |  |  |         f'endpoint configuration "{TEST_ENDPOINT_CONFIG_NAME}".' | 
					
						
							| 
									
										
										
										
											2022-05-04 10:36:46 +01:00
										 |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME) | 
					
						
							|  |  |  |     del resp["ResponseMetadata"] | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert resp == old_resp | 
					
						
							| 
									
										
										
										
											2022-05-04 10:36:46 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_update_endpoint_weights_and_capacities_should_throw_clienterror_no_endpoint( | 
					
						
							|  |  |  |     sagemaker_client, | 
					
						
							|  |  |  | ): | 
					
						
							|  |  |  |     _set_up_sagemaker_resources( | 
					
						
							|  |  |  |         sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     old_resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME) | 
					
						
							|  |  |  |     del old_resp["ResponseMetadata"] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     endpoint_name = "SillyEndpointName" | 
					
						
							|  |  |  |     variant_name = "SillyNotCorrectName" | 
					
						
							|  |  |  |     new_desired_weight = 1.5 | 
					
						
							|  |  |  |     new_desired_instance_count = 123 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     with pytest.raises(ClientError) as exc: | 
					
						
							|  |  |  |         sagemaker_client.update_endpoint_weights_and_capacities( | 
					
						
							|  |  |  |             EndpointName=endpoint_name, | 
					
						
							|  |  |  |             DesiredWeightsAndCapacities=[ | 
					
						
							|  |  |  |                 { | 
					
						
							|  |  |  |                     "VariantName": variant_name, | 
					
						
							|  |  |  |                     "DesiredWeight": new_desired_weight, | 
					
						
							|  |  |  |                     "DesiredInstanceCount": new_desired_instance_count, | 
					
						
							|  |  |  |                 }, | 
					
						
							|  |  |  |             ], | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     err = exc.value.response["Error"] | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert err["Message"] == ( | 
					
						
							|  |  |  |         f'Could not find endpoint "arn:aws:sagemaker:us-east-1:' | 
					
						
							|  |  |  |         f'{ACCOUNT_ID}:endpoint/{endpoint_name}".' | 
					
						
							| 
									
										
										
										
											2022-05-04 10:36:46 +01:00
										 |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME) | 
					
						
							|  |  |  |     del resp["ResponseMetadata"] | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert resp == old_resp | 
					
						
							| 
									
										
										
										
											2022-05-04 10:36:46 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_update_endpoint_weights_and_capacities_should_throw_clienterror_nonunique_variant( | 
					
						
							|  |  |  |     sagemaker_client, | 
					
						
							|  |  |  | ): | 
					
						
							|  |  |  |     _set_up_sagemaker_resources( | 
					
						
							|  |  |  |         sagemaker_client, TEST_ENDPOINT_NAME, TEST_ENDPOINT_CONFIG_NAME, TEST_MODEL_NAME | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     old_resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME) | 
					
						
							|  |  |  |     del old_resp["ResponseMetadata"] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     desired_weights_and_capacities = [ | 
					
						
							|  |  |  |         { | 
					
						
							|  |  |  |             "VariantName": TEST_VARIANT_NAME, | 
					
						
							|  |  |  |             "DesiredWeight": 1.5, | 
					
						
							|  |  |  |             "DesiredInstanceCount": 123, | 
					
						
							|  |  |  |         }, | 
					
						
							|  |  |  |         { | 
					
						
							|  |  |  |             "VariantName": TEST_VARIANT_NAME, | 
					
						
							|  |  |  |             "DesiredWeight": 1.5, | 
					
						
							|  |  |  |             "DesiredInstanceCount": 123, | 
					
						
							|  |  |  |         }, | 
					
						
							|  |  |  |     ] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     with pytest.raises(ClientError) as exc: | 
					
						
							|  |  |  |         sagemaker_client.update_endpoint_weights_and_capacities( | 
					
						
							|  |  |  |             EndpointName=TEST_ENDPOINT_NAME, | 
					
						
							|  |  |  |             DesiredWeightsAndCapacities=desired_weights_and_capacities, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     err = exc.value.response["Error"] | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert err["Message"] == ( | 
					
						
							| 
									
										
										
										
											2022-05-04 10:36:46 +01:00
										 |  |  |         f'The variant name "{TEST_VARIANT_NAME}" was non-unique within the request.' | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     resp = sagemaker_client.describe_endpoint(EndpointName=TEST_ENDPOINT_NAME) | 
					
						
							|  |  |  |     del resp["ResponseMetadata"] | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert resp == old_resp | 
					
						
							| 
									
										
										
										
											2022-05-04 10:36:46 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  | def _set_up_sagemaker_resources( | 
					
						
							| 
									
										
										
										
											2022-05-04 10:36:46 +01:00
										 |  |  |     boto_client, | 
					
						
							|  |  |  |     endpoint_name, | 
					
						
							|  |  |  |     endpoint_config_name, | 
					
						
							|  |  |  |     model_name, | 
					
						
							|  |  |  |     production_variants=None, | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  | ): | 
					
						
							|  |  |  |     _create_model(boto_client, model_name) | 
					
						
							| 
									
										
										
										
											2022-05-04 10:36:46 +01:00
										 |  |  |     _create_endpoint_config( | 
					
						
							|  |  |  |         boto_client, endpoint_config_name, model_name, production_variants | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |     _create_endpoint(boto_client, endpoint_name, endpoint_config_name) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  | def _create_model(boto_client, model_name): | 
					
						
							|  |  |  |     resp = boto_client.create_model( | 
					
						
							|  |  |  |         ModelName=model_name, | 
					
						
							|  |  |  |         PrimaryContainer={ | 
					
						
							|  |  |  |             "Image": "382416733822.dkr.ecr.us-east-1.amazonaws.com/factorization-machines:1", | 
					
						
							|  |  |  |             "ModelDataUrl": "s3://MyBucket/model.tar.gz", | 
					
						
							|  |  |  |         }, | 
					
						
							| 
									
										
										
										
											2022-04-27 12:56:08 +01:00
										 |  |  |         ExecutionRoleArn=TEST_ROLE_ARN, | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2020-10-06 07:54:49 +02:00
										 |  |  |     assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-05-04 10:36:46 +01:00
										 |  |  | def _create_endpoint_config( | 
					
						
							|  |  |  |     boto_client, endpoint_config_name, model_name, production_variants=None | 
					
						
							|  |  |  | ): | 
					
						
							|  |  |  |     if not production_variants: | 
					
						
							|  |  |  |         production_variants = [ | 
					
						
							|  |  |  |             { | 
					
						
							|  |  |  |                 "VariantName": TEST_VARIANT_NAME, | 
					
						
							|  |  |  |                 "ModelName": model_name, | 
					
						
							|  |  |  |                 "InitialInstanceCount": 1, | 
					
						
							|  |  |  |                 "InstanceType": TEST_INSTANCE_TYPE, | 
					
						
							|  |  |  |             }, | 
					
						
							|  |  |  |         ] | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  |     resp = boto_client.create_endpoint_config( | 
					
						
							|  |  |  |         EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert re.match( | 
					
						
							|  |  |  |         rf"^arn:aws:sagemaker:.*:.*:endpoint-config/{endpoint_config_name}$", | 
					
						
							|  |  |  |         resp["EndpointConfigArn"], | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _create_endpoint(boto_client, endpoint_name, endpoint_config_name): | 
					
						
							|  |  |  |     resp = boto_client.create_endpoint( | 
					
						
							|  |  |  |         EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2023-08-08 06:06:51 -04:00
										 |  |  |     assert re.match( | 
					
						
							|  |  |  |         rf"^arn:aws:sagemaker:.*:.*:endpoint/{endpoint_name}$", resp["EndpointArn"] | 
					
						
							| 
									
										
										
										
											2020-07-19 10:06:48 -04:00
										 |  |  |     ) |