import datetime import boto3 from botocore.exceptions import ClientError import sure # noqa # pylint: disable=unused-import from moto import mock_sagemaker from moto.sts.models import ACCOUNT_ID import pytest TEST_REGION_NAME = "us-east-1" FAKE_ROLE_ARN = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID) GENERIC_TAGS_PARAM = [ {"Key": "newkey1", "Value": "newval1"}, {"Key": "newkey2", "Value": "newval2"}, ] @mock_sagemaker def test_create_endpoint_config(): sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) model_name = "MyModel" production_variants = [ { "VariantName": "MyProductionVariant", "ModelName": model_name, "InitialInstanceCount": 1, "InstanceType": "ml.t2.medium", }, ] endpoint_config_name = "MyEndpointConfig" with pytest.raises(ClientError) as e: sagemaker.create_endpoint_config( EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants, ) assert e.value.response["Error"]["Message"].startswith("Could not find model") _create_model(sagemaker, model_name) resp = sagemaker.create_endpoint_config( EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants ) resp["EndpointConfigArn"].should.match( r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name) ) resp = sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name) resp["EndpointConfigArn"].should.match( r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name) ) resp["EndpointConfigName"].should.equal(endpoint_config_name) resp["ProductionVariants"].should.equal(production_variants) @mock_sagemaker def test_delete_endpoint_config(): sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) model_name = "MyModel" _create_model(sagemaker, model_name) endpoint_config_name = "MyEndpointConfig" production_variants = [ { "VariantName": "MyProductionVariant", "ModelName": model_name, "InitialInstanceCount": 1, "InstanceType": "ml.t2.medium", }, ] resp = sagemaker.create_endpoint_config( EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants ) resp["EndpointConfigArn"].should.match( r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name) ) resp = sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name) resp["EndpointConfigArn"].should.match( r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name) ) resp = sagemaker.delete_endpoint_config(EndpointConfigName=endpoint_config_name) with pytest.raises(ClientError) as e: sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name) assert e.value.response["Error"]["Message"].startswith( "Could not find endpoint configuration" ) with pytest.raises(ClientError) as e: sagemaker.delete_endpoint_config(EndpointConfigName=endpoint_config_name) assert e.value.response["Error"]["Message"].startswith( "Could not find endpoint configuration" ) @mock_sagemaker def test_create_endpoint_invalid_instance_type(): sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) model_name = "MyModel" _create_model(sagemaker, model_name) instance_type = "InvalidInstanceType" production_variants = [ { "VariantName": "MyProductionVariant", "ModelName": model_name, "InitialInstanceCount": 1, "InstanceType": instance_type, }, ] endpoint_config_name = "MyEndpointConfig" with pytest.raises(ClientError) as e: sagemaker.create_endpoint_config( EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants, ) assert e.value.response["Error"]["Code"] == "ValidationException" expected_message = "Value '{}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: [".format( instance_type ) assert expected_message in e.value.response["Error"]["Message"] @mock_sagemaker def test_create_endpoint(): sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) endpoint_name = "MyEndpoint" with pytest.raises(ClientError) as e: sagemaker.create_endpoint( EndpointName=endpoint_name, EndpointConfigName="NonexistentEndpointConfig" ) assert e.value.response["Error"]["Message"].startswith( "Could not find endpoint configuration" ) model_name = "MyModel" _create_model(sagemaker, model_name) endpoint_config_name = "MyEndpointConfig" _create_endpoint_config(sagemaker, endpoint_config_name, model_name) resp = sagemaker.create_endpoint( EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name, Tags=GENERIC_TAGS_PARAM, ) resp["EndpointArn"].should.match( r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(endpoint_name) ) resp = sagemaker.describe_endpoint(EndpointName=endpoint_name) resp["EndpointArn"].should.match( r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(endpoint_name) ) resp["EndpointName"].should.equal(endpoint_name) resp["EndpointConfigName"].should.equal(endpoint_config_name) resp["EndpointStatus"].should.equal("InService") assert isinstance(resp["CreationTime"], datetime.datetime) assert isinstance(resp["LastModifiedTime"], datetime.datetime) resp["ProductionVariants"][0]["VariantName"].should.equal("MyProductionVariant") resp = sagemaker.list_tags(ResourceArn=resp["EndpointArn"]) assert resp["Tags"] == GENERIC_TAGS_PARAM @mock_sagemaker def test_delete_endpoint(): sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) model_name = "MyModel" _create_model(sagemaker, model_name) endpoint_config_name = "MyEndpointConfig" _create_endpoint_config(sagemaker, endpoint_config_name, model_name) endpoint_name = "MyEndpoint" _create_endpoint(sagemaker, endpoint_name, endpoint_config_name) sagemaker.delete_endpoint(EndpointName=endpoint_name) with pytest.raises(ClientError) as e: sagemaker.describe_endpoint(EndpointName=endpoint_name) assert e.value.response["Error"]["Message"].startswith("Could not find endpoint") with pytest.raises(ClientError) as e: sagemaker.delete_endpoint(EndpointName=endpoint_name) assert e.value.response["Error"]["Message"].startswith("Could not find endpoint") def _create_model(boto_client, model_name): resp = boto_client.create_model( ModelName=model_name, PrimaryContainer={ "Image": "382416733822.dkr.ecr.us-east-1.amazonaws.com/factorization-machines:1", "ModelDataUrl": "s3://MyBucket/model.tar.gz", }, ExecutionRoleArn=FAKE_ROLE_ARN, ) assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 def _create_endpoint_config(boto_client, endpoint_config_name, model_name): production_variants = [ { "VariantName": "MyProductionVariant", "ModelName": model_name, "InitialInstanceCount": 1, "InstanceType": "ml.t2.medium", }, ] resp = boto_client.create_endpoint_config( EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants ) resp["EndpointConfigArn"].should.match( r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name) ) def _create_endpoint(boto_client, endpoint_name, endpoint_config_name): resp = boto_client.create_endpoint( EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name ) resp["EndpointArn"].should.match( r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(endpoint_name) )