* Basic upport for Endpoints, EndpointConfigs and TrainingJobs * Dropped extraneous pass statement. Co-authored-by: Joseph Weitekamp <jweite@amazon.com>
		
			
				
	
	
		
			247 lines
		
	
	
		
			8.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			247 lines
		
	
	
		
			8.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # -*- coding: utf-8 -*-
 | |
| from __future__ import unicode_literals
 | |
| 
 | |
| import datetime
 | |
| import boto3
 | |
| from botocore.exceptions import ClientError, ParamValidationError
 | |
| import sure  # noqa
 | |
| 
 | |
| from moto import mock_sagemaker
 | |
| from moto.sts.models import ACCOUNT_ID
 | |
| from nose.tools import assert_true, assert_equal, assert_raises
 | |
| 
 | |
| TEST_REGION_NAME = "us-east-1"
 | |
| FAKE_ROLE_ARN = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)
 | |
| GENERIC_TAGS_PARAM = [
 | |
|     {"Key": "newkey1", "Value": "newval1"},
 | |
|     {"Key": "newkey2", "Value": "newval2"},
 | |
| ]
 | |
| 
 | |
| 
 | |
| @mock_sagemaker
 | |
| def test_create_endpoint_config():
 | |
|     sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
 | |
| 
 | |
|     model_name = "MyModel"
 | |
|     production_variants = [
 | |
|         {
 | |
|             "VariantName": "MyProductionVariant",
 | |
|             "ModelName": model_name,
 | |
|             "InitialInstanceCount": 1,
 | |
|             "InstanceType": "ml.t2.medium",
 | |
|         },
 | |
|     ]
 | |
| 
 | |
|     endpoint_config_name = "MyEndpointConfig"
 | |
|     with assert_raises(ClientError) as e:
 | |
|         sagemaker.create_endpoint_config(
 | |
|             EndpointConfigName=endpoint_config_name,
 | |
|             ProductionVariants=production_variants,
 | |
|         )
 | |
|     assert_true(
 | |
|         e.exception.response["Error"]["Message"].startswith("Could not find model")
 | |
|     )
 | |
| 
 | |
|     _create_model(sagemaker, model_name)
 | |
|     resp = sagemaker.create_endpoint_config(
 | |
|         EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants
 | |
|     )
 | |
|     resp["EndpointConfigArn"].should.match(
 | |
|         r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
 | |
|     )
 | |
| 
 | |
|     resp = sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
 | |
|     resp["EndpointConfigArn"].should.match(
 | |
|         r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
 | |
|     )
 | |
|     resp["EndpointConfigName"].should.equal(endpoint_config_name)
 | |
|     resp["ProductionVariants"].should.equal(production_variants)
 | |
| 
 | |
| 
 | |
| @mock_sagemaker
 | |
| def test_delete_endpoint_config():
 | |
|     sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
 | |
| 
 | |
|     model_name = "MyModel"
 | |
|     _create_model(sagemaker, model_name)
 | |
| 
 | |
|     endpoint_config_name = "MyEndpointConfig"
 | |
|     production_variants = [
 | |
|         {
 | |
|             "VariantName": "MyProductionVariant",
 | |
|             "ModelName": model_name,
 | |
|             "InitialInstanceCount": 1,
 | |
|             "InstanceType": "ml.t2.medium",
 | |
|         },
 | |
|     ]
 | |
| 
 | |
|     resp = sagemaker.create_endpoint_config(
 | |
|         EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants
 | |
|     )
 | |
|     resp["EndpointConfigArn"].should.match(
 | |
|         r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
 | |
|     )
 | |
| 
 | |
|     resp = sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
 | |
|     resp["EndpointConfigArn"].should.match(
 | |
|         r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
 | |
|     )
 | |
| 
 | |
|     resp = sagemaker.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
 | |
|     with assert_raises(ClientError) as e:
 | |
|         sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
 | |
|     assert_true(
 | |
|         e.exception.response["Error"]["Message"].startswith(
 | |
|             "Could not find endpoint configuration"
 | |
|         )
 | |
|     )
 | |
| 
 | |
|     with assert_raises(ClientError) as e:
 | |
|         sagemaker.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
 | |
|     assert_true(
 | |
|         e.exception.response["Error"]["Message"].startswith(
 | |
|             "Could not find endpoint configuration"
 | |
|         )
 | |
|     )
 | |
|     pass
 | |
| 
 | |
| 
 | |
| @mock_sagemaker
 | |
| def test_create_endpoint_invalid_instance_type():
 | |
|     sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
 | |
| 
 | |
|     model_name = "MyModel"
 | |
|     _create_model(sagemaker, model_name)
 | |
| 
 | |
|     instance_type = "InvalidInstanceType"
 | |
|     production_variants = [
 | |
|         {
 | |
|             "VariantName": "MyProductionVariant",
 | |
|             "ModelName": model_name,
 | |
|             "InitialInstanceCount": 1,
 | |
|             "InstanceType": instance_type,
 | |
|         },
 | |
|     ]
 | |
| 
 | |
|     endpoint_config_name = "MyEndpointConfig"
 | |
|     with assert_raises(ClientError) as e:
 | |
|         sagemaker.create_endpoint_config(
 | |
|             EndpointConfigName=endpoint_config_name,
 | |
|             ProductionVariants=production_variants,
 | |
|         )
 | |
|     assert_equal(e.exception.response["Error"]["Code"], "ValidationException")
 | |
|     expected_message = "Value '{}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: [".format(
 | |
|         instance_type
 | |
|     )
 | |
|     assert_true(expected_message in e.exception.response["Error"]["Message"])
 | |
| 
 | |
| 
 | |
| @mock_sagemaker
 | |
| def test_create_endpoint():
 | |
|     sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
 | |
| 
 | |
|     endpoint_name = "MyEndpoint"
 | |
|     with assert_raises(ClientError) as e:
 | |
|         sagemaker.create_endpoint(
 | |
|             EndpointName=endpoint_name, EndpointConfigName="NonexistentEndpointConfig"
 | |
|         )
 | |
|     assert_true(
 | |
|         e.exception.response["Error"]["Message"].startswith(
 | |
|             "Could not find endpoint configuration"
 | |
|         )
 | |
|     )
 | |
| 
 | |
|     model_name = "MyModel"
 | |
|     _create_model(sagemaker, model_name)
 | |
| 
 | |
|     endpoint_config_name = "MyEndpointConfig"
 | |
|     _create_endpoint_config(sagemaker, endpoint_config_name, model_name)
 | |
| 
 | |
|     resp = sagemaker.create_endpoint(
 | |
|         EndpointName=endpoint_name,
 | |
|         EndpointConfigName=endpoint_config_name,
 | |
|         Tags=GENERIC_TAGS_PARAM,
 | |
|     )
 | |
|     resp["EndpointArn"].should.match(
 | |
|         r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(endpoint_name)
 | |
|     )
 | |
| 
 | |
|     resp = sagemaker.describe_endpoint(EndpointName=endpoint_name)
 | |
|     resp["EndpointArn"].should.match(
 | |
|         r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(endpoint_name)
 | |
|     )
 | |
|     resp["EndpointName"].should.equal(endpoint_name)
 | |
|     resp["EndpointConfigName"].should.equal(endpoint_config_name)
 | |
|     resp["EndpointStatus"].should.equal("InService")
 | |
|     assert_true(isinstance(resp["CreationTime"], datetime.datetime))
 | |
|     assert_true(isinstance(resp["LastModifiedTime"], datetime.datetime))
 | |
|     resp["ProductionVariants"][0]["VariantName"].should.equal("MyProductionVariant")
 | |
| 
 | |
|     resp = sagemaker.list_tags(ResourceArn=resp["EndpointArn"])
 | |
|     assert_equal(resp["Tags"], GENERIC_TAGS_PARAM)
 | |
| 
 | |
| 
 | |
| @mock_sagemaker
 | |
| def test_delete_endpoint():
 | |
|     sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
 | |
| 
 | |
|     model_name = "MyModel"
 | |
|     _create_model(sagemaker, model_name)
 | |
| 
 | |
|     endpoint_config_name = "MyEndpointConfig"
 | |
|     _create_endpoint_config(sagemaker, endpoint_config_name, model_name)
 | |
| 
 | |
|     endpoint_name = "MyEndpoint"
 | |
|     _create_endpoint(sagemaker, endpoint_name, endpoint_config_name)
 | |
| 
 | |
|     sagemaker.delete_endpoint(EndpointName=endpoint_name)
 | |
|     with assert_raises(ClientError) as e:
 | |
|         sagemaker.describe_endpoint(EndpointName=endpoint_name)
 | |
|     assert_true(
 | |
|         e.exception.response["Error"]["Message"].startswith("Could not find endpoint")
 | |
|     )
 | |
| 
 | |
|     with assert_raises(ClientError) as e:
 | |
|         sagemaker.delete_endpoint(EndpointName=endpoint_name)
 | |
|     assert_true(
 | |
|         e.exception.response["Error"]["Message"].startswith("Could not find endpoint")
 | |
|     )
 | |
| 
 | |
| 
 | |
| def _create_model(boto_client, model_name):
 | |
|     resp = boto_client.create_model(
 | |
|         ModelName=model_name,
 | |
|         PrimaryContainer={
 | |
|             "Image": "382416733822.dkr.ecr.us-east-1.amazonaws.com/factorization-machines:1",
 | |
|             "ModelDataUrl": "s3://MyBucket/model.tar.gz",
 | |
|         },
 | |
|         ExecutionRoleArn=FAKE_ROLE_ARN,
 | |
|     )
 | |
|     assert_equal(resp["ResponseMetadata"]["HTTPStatusCode"], 200)
 | |
| 
 | |
| 
 | |
| def _create_endpoint_config(boto_client, endpoint_config_name, model_name):
 | |
|     production_variants = [
 | |
|         {
 | |
|             "VariantName": "MyProductionVariant",
 | |
|             "ModelName": model_name,
 | |
|             "InitialInstanceCount": 1,
 | |
|             "InstanceType": "ml.t2.medium",
 | |
|         },
 | |
|     ]
 | |
|     resp = boto_client.create_endpoint_config(
 | |
|         EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants
 | |
|     )
 | |
|     resp["EndpointConfigArn"].should.match(
 | |
|         r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
 | |
|     )
 | |
| 
 | |
| 
 | |
| def _create_endpoint(boto_client, endpoint_name, endpoint_config_name):
 | |
|     resp = boto_client.create_endpoint(
 | |
|         EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
 | |
|     )
 | |
|     resp["EndpointArn"].should.match(
 | |
|         r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(endpoint_name)
 | |
|     )
 |