moto/tests/test_sagemaker/test_sagemaker_models.py

149 lines
4.8 KiB
Python
Raw Normal View History

import boto3
from botocore.exceptions import ClientError
import pytest
from moto import mock_sagemaker
2021-10-18 19:44:29 +00:00
import sure # noqa # pylint: disable=unused-import
from moto.sagemaker.models import VpcConfig
TEST_REGION_NAME = "us-east-1"
TEST_ARN = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar"
TEST_MODEL_NAME = "MyModelName"
@pytest.fixture
def sagemaker_client():
return boto3.client("sagemaker", region_name=TEST_REGION_NAME)
class MySageMakerModel(object):
def __init__(self, name=None, arn=None, container=None, vpc_config=None):
self.name = name or TEST_MODEL_NAME
self.arn = arn or TEST_ARN
self.container = container or {}
self.vpc_config = vpc_config or {"sg-groups": ["sg-123"], "subnets": ["123"]}
def save(self, sagemaker_client):
vpc_config = VpcConfig(
self.vpc_config.get("sg-groups"), self.vpc_config.get("subnets")
)
resp = sagemaker_client.create_model(
ModelName=self.name,
ExecutionRoleArn=self.arn,
VpcConfig=vpc_config.response_object,
)
return resp
@mock_sagemaker
def test_describe_model(sagemaker_client):
test_model = MySageMakerModel()
test_model.save(sagemaker_client)
model = sagemaker_client.describe_model(ModelName=TEST_MODEL_NAME)
assert model.get("ModelName").should.equal(TEST_MODEL_NAME)
2021-02-02 16:31:26 +00:00
@mock_sagemaker
def test_describe_model_not_found(sagemaker_client):
2021-02-02 16:31:26 +00:00
with pytest.raises(ClientError) as err:
sagemaker_client.describe_model(ModelName="unknown")
2021-02-02 16:31:26 +00:00
assert err.value.response["Error"]["Message"].should.contain("Could not find model")
@mock_sagemaker
def test_create_model(sagemaker_client):
vpc_config = VpcConfig(["sg-foobar"], ["subnet-xxx"])
model = sagemaker_client.create_model(
ModelName=TEST_MODEL_NAME,
ExecutionRoleArn=TEST_ARN,
VpcConfig=vpc_config.response_object,
)
model["ModelArn"].should.match(
r"^arn:aws:sagemaker:.*:.*:model/{}$".format(TEST_MODEL_NAME)
)
@mock_sagemaker
def test_delete_model(sagemaker_client):
test_model = MySageMakerModel()
test_model.save(sagemaker_client)
assert len(sagemaker_client.list_models()["Models"]).should.equal(1)
sagemaker_client.delete_model(ModelName=TEST_MODEL_NAME)
assert len(sagemaker_client.list_models()["Models"]).should.equal(0)
@mock_sagemaker
def test_delete_model_not_found(sagemaker_client):
with pytest.raises(ClientError) as err:
sagemaker_client.delete_model(ModelName="blah")
2020-10-06 06:04:09 +00:00
assert err.value.response["Error"]["Code"].should.equal("404")
@mock_sagemaker
def test_list_models(sagemaker_client):
test_model = MySageMakerModel()
test_model.save(sagemaker_client)
models = sagemaker_client.list_models()
assert len(models["Models"]).should.equal(1)
assert models["Models"][0]["ModelName"].should.equal(TEST_MODEL_NAME)
assert models["Models"][0]["ModelArn"].should.match(
r"^arn:aws:sagemaker:.*:.*:model/{}$".format(TEST_MODEL_NAME)
)
@mock_sagemaker
def test_list_models_multiple(sagemaker_client):
name_model_1 = "blah"
arn_model_1 = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar"
test_model_1 = MySageMakerModel(name=name_model_1, arn=arn_model_1)
test_model_1.save(sagemaker_client)
name_model_2 = "blah2"
arn_model_2 = "arn:aws:sagemaker:eu-west-1:000000000000:x-x/foobar2"
test_model_2 = MySageMakerModel(name=name_model_2, arn=arn_model_2)
test_model_2.save(sagemaker_client)
models = sagemaker_client.list_models()
assert len(models["Models"]).should.equal(2)
@mock_sagemaker
def test_list_models_none(sagemaker_client):
models = sagemaker_client.list_models()
assert len(models["Models"]).should.equal(0)
@mock_sagemaker
def test_add_tags_to_model(sagemaker_client):
model = MySageMakerModel().save(sagemaker_client)
resource_arn = model["ModelArn"]
tags = [
{"Key": "myKey", "Value": "myValue"},
]
response = sagemaker_client.add_tags(ResourceArn=resource_arn, Tags=tags)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
response = sagemaker_client.list_tags(ResourceArn=resource_arn)
assert response["Tags"] == tags
@mock_sagemaker
def test_delete_tags_from_model(sagemaker_client):
model = MySageMakerModel().save(sagemaker_client)
resource_arn = model["ModelArn"]
tags = [
{"Key": "myKey", "Value": "myValue"},
]
response = sagemaker_client.add_tags(ResourceArn=resource_arn, Tags=tags)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
tag_keys = [tag["Key"] for tag in tags]
response = sagemaker_client.delete_tags(ResourceArn=resource_arn, TagKeys=tag_keys)
assert response["ResponseMetadata"]["HTTPStatusCode"] == 200
response = sagemaker_client.list_tags(ResourceArn=resource_arn)
assert response["Tags"] == []