diff --git a/moto/sagemaker/models.py b/moto/sagemaker/models.py index 3e0dce87b..6ff36249f 100644 --- a/moto/sagemaker/models.py +++ b/moto/sagemaker/models.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals +import os from copy import deepcopy from datetime import datetime @@ -32,6 +33,288 @@ class BaseObject(BaseModel): return self.gen_response_object() +class FakeTrainingJob(BaseObject): + def __init__( + self, + region_name, + training_job_name, + hyper_parameters, + algorithm_specification, + role_arn, + input_data_config, + output_data_config, + resource_config, + vpc_config, + stopping_condition, + tags, + enable_network_isolation, + enable_inter_container_traffic_encryption, + enable_managed_spot_training, + checkpoint_config, + debug_hook_config, + debug_rule_configurations, + tensor_board_output_config, + experiment_config, + ): + self.training_job_name = training_job_name + self.hyper_parameters = hyper_parameters + self.algorithm_specification = algorithm_specification + self.role_arn = role_arn + self.input_data_config = input_data_config + self.output_data_config = output_data_config + self.resource_config = resource_config + self.vpc_config = vpc_config + self.stopping_condition = stopping_condition + self.tags = tags + self.enable_network_isolation = enable_network_isolation + self.enable_inter_container_traffic_encryption = ( + enable_inter_container_traffic_encryption + ) + self.enable_managed_spot_training = enable_managed_spot_training + self.checkpoint_config = checkpoint_config + self.debug_hook_config = debug_hook_config + self.debug_rule_configurations = debug_rule_configurations + self.tensor_board_output_config = tensor_board_output_config + self.experiment_config = experiment_config + self.training_job_arn = FakeTrainingJob.arn_formatter( + training_job_name, region_name + ) + self.creation_time = self.last_modified_time = datetime.now().strftime( + "%Y-%m-%d %H:%M:%S" + ) + self.model_artifacts = { + "S3ModelArtifacts": os.path.join( + self.output_data_config["S3OutputPath"], + self.training_job_name, + "output", + "model.tar.gz", + ) + } + self.training_job_status = "Completed" + self.secondary_status = "Completed" + self.algorithm_specification["MetricDefinitions"] = [ + { + "Name": "test:dcg", + "Regex": "#quality_metric: host=\\S+, test dcg =(\\S+)", + } + ] + now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + self.creation_time = now_string + self.last_modified_time = now_string + self.training_start_time = now_string + self.training_end_time = now_string + self.secondary_status_transitions = [ + { + "Status": "Starting", + "StartTime": self.creation_time, + "EndTime": self.creation_time, + "StatusMessage": "Preparing the instances for training", + } + ] + self.final_metric_data_list = [ + { + "MetricName": "train:progress", + "Value": 100.0, + "Timestamp": self.creation_time, + } + ] + + @property + def response_object(self): + response_object = self.gen_response_object() + return { + k: v for k, v in response_object.items() if v is not None and v != [None] + } + + @property + def response_create(self): + return {"TrainingJobArn": self.training_job_arn} + + @staticmethod + def arn_formatter(endpoint_name, region_name): + return ( + "arn:aws:sagemaker:" + + region_name + + ":" + + str(ACCOUNT_ID) + + ":training-job/" + + endpoint_name + ) + + +class FakeEndpoint(BaseObject): + def __init__( + self, + region_name, + endpoint_name, + endpoint_config_name, + production_variants, + data_capture_config, + tags, + ): + self.endpoint_name = endpoint_name + self.endpoint_arn = FakeEndpoint.arn_formatter(endpoint_name, region_name) + self.endpoint_config_name = endpoint_config_name + self.production_variants = production_variants + self.data_capture_config = data_capture_config + self.tags = tags or [] + self.endpoint_status = "InService" + self.failure_reason = None + self.creation_time = self.last_modified_time = datetime.now().strftime( + "%Y-%m-%d %H:%M:%S" + ) + + @property + def response_object(self): + response_object = self.gen_response_object() + return { + k: v for k, v in response_object.items() if v is not None and v != [None] + } + + @property + def response_create(self): + return {"EndpointArn": self.endpoint_arn} + + @staticmethod + def arn_formatter(endpoint_name, region_name): + return ( + "arn:aws:sagemaker:" + + region_name + + ":" + + str(ACCOUNT_ID) + + ":endpoint/" + + endpoint_name + ) + + +class FakeEndpointConfig(BaseObject): + def __init__( + self, + region_name, + endpoint_config_name, + production_variants, + data_capture_config, + tags, + kms_key_id, + ): + self.validate_production_variants(production_variants) + + self.endpoint_config_name = endpoint_config_name + self.endpoint_config_arn = FakeEndpointConfig.arn_formatter( + endpoint_config_name, region_name + ) + self.production_variants = production_variants or [] + self.data_capture_config = data_capture_config or {} + self.tags = tags or [] + self.kms_key_id = kms_key_id + self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + def validate_production_variants(self, production_variants): + for production_variant in production_variants: + self.validate_instance_type(production_variant["InstanceType"]) + + def validate_instance_type(self, instance_type): + VALID_INSTANCE_TYPES = [ + "ml.r5d.12xlarge", + "ml.r5.12xlarge", + "ml.p2.xlarge", + "ml.m5.4xlarge", + "ml.m4.16xlarge", + "ml.r5d.24xlarge", + "ml.r5.24xlarge", + "ml.p3.16xlarge", + "ml.m5d.xlarge", + "ml.m5.large", + "ml.t2.xlarge", + "ml.p2.16xlarge", + "ml.m5d.12xlarge", + "ml.inf1.2xlarge", + "ml.m5d.24xlarge", + "ml.c4.2xlarge", + "ml.c5.2xlarge", + "ml.c4.4xlarge", + "ml.inf1.6xlarge", + "ml.c5d.2xlarge", + "ml.c5.4xlarge", + "ml.g4dn.xlarge", + "ml.g4dn.12xlarge", + "ml.c5d.4xlarge", + "ml.g4dn.2xlarge", + "ml.c4.8xlarge", + "ml.c4.large", + "ml.c5d.xlarge", + "ml.c5.large", + "ml.g4dn.4xlarge", + "ml.c5.9xlarge", + "ml.g4dn.16xlarge", + "ml.c5d.large", + "ml.c5.xlarge", + "ml.c5d.9xlarge", + "ml.c4.xlarge", + "ml.inf1.xlarge", + "ml.g4dn.8xlarge", + "ml.inf1.24xlarge", + "ml.m5d.2xlarge", + "ml.t2.2xlarge", + "ml.c5d.18xlarge", + "ml.m5d.4xlarge", + "ml.t2.medium", + "ml.c5.18xlarge", + "ml.r5d.2xlarge", + "ml.r5.2xlarge", + "ml.p3.2xlarge", + "ml.m5d.large", + "ml.m5.xlarge", + "ml.m4.10xlarge", + "ml.t2.large", + "ml.r5d.4xlarge", + "ml.r5.4xlarge", + "ml.m5.12xlarge", + "ml.m4.xlarge", + "ml.m5.24xlarge", + "ml.m4.2xlarge", + "ml.p2.8xlarge", + "ml.m5.2xlarge", + "ml.r5d.xlarge", + "ml.r5d.large", + "ml.r5.xlarge", + "ml.r5.large", + "ml.p3.8xlarge", + "ml.m4.4xlarge", + ] + if not validators.is_one_of(instance_type, VALID_INSTANCE_TYPES): + message = "Value '{}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: {}".format( + instance_type, VALID_INSTANCE_TYPES + ) + raise RESTError( + error_type="ValidationException", + message=message, + template="error_json", + ) + + @property + def response_object(self): + response_object = self.gen_response_object() + return { + k: v for k, v in response_object.items() if v is not None and v != [None] + } + + @property + def response_create(self): + return {"EndpointConfigArn": self.endpoint_config_arn} + + @staticmethod + def arn_formatter(model_name, region_name): + return ( + "arn:aws:sagemaker:" + + region_name + + ":" + + str(ACCOUNT_ID) + + ":endpoint-config/" + + model_name + ) + + class Model(BaseObject): def __init__( self, @@ -238,6 +521,9 @@ class SageMakerModelBackend(BaseBackend): def __init__(self, region_name=None): self._models = {} self.notebook_instances = {} + self.endpoint_configs = {} + self.endpoints = {} + self.training_jobs = {} self.region_name = region_name def reset(self): @@ -305,10 +591,10 @@ class SageMakerModelBackend(BaseBackend): self._validate_unique_notebook_instance_name(notebook_instance_name) notebook_instance = FakeSagemakerNotebookInstance( - self.region_name, - notebook_instance_name, - instance_type, - role_arn, + region_name=self.region_name, + notebook_instance_name=notebook_instance_name, + instance_type=instance_type, + role_arn=role_arn, subnet_id=subnet_id, security_group_ids=security_group_ids, kms_key_id=kms_key_id, @@ -392,6 +678,235 @@ class SageMakerModelBackend(BaseBackend): except RESTError: return [] + def create_endpoint_config( + self, + endpoint_config_name, + production_variants, + data_capture_config, + tags, + kms_key_id, + ): + endpoint_config = FakeEndpointConfig( + region_name=self.region_name, + endpoint_config_name=endpoint_config_name, + production_variants=production_variants, + data_capture_config=data_capture_config, + tags=tags, + kms_key_id=kms_key_id, + ) + self.validate_production_variants(production_variants) + + self.endpoint_configs[endpoint_config_name] = endpoint_config + return endpoint_config + + def validate_production_variants(self, production_variants): + for production_variant in production_variants: + if production_variant["ModelName"] not in self._models: + message = "Could not find model '{}'.".format( + Model.arn_for_model_name( + production_variant["ModelName"], self.region_name + ) + ) + raise RESTError( + error_type="ValidationException", + message=message, + template="error_json", + ) + + def describe_endpoint_config(self, endpoint_config_name): + try: + return self.endpoint_configs[endpoint_config_name].response_object + except KeyError: + message = "Could not find endpoint configuration '{}'.".format( + FakeEndpointConfig.arn_formatter(endpoint_config_name, self.region_name) + ) + raise RESTError( + error_type="ValidationException", + message=message, + template="error_json", + ) + + def delete_endpoint_config(self, endpoint_config_name): + try: + del self.endpoint_configs[endpoint_config_name] + except KeyError: + message = "Could not find endpoint configuration '{}'.".format( + FakeEndpointConfig.arn_formatter(endpoint_config_name, self.region_name) + ) + raise RESTError( + error_type="ValidationException", + message=message, + template="error_json", + ) + + def create_endpoint( + self, endpoint_name, endpoint_config_name, tags, + ): + try: + endpoint_config = self.describe_endpoint_config(endpoint_config_name) + except KeyError: + message = "Could not find endpoint_config '{}'.".format( + FakeEndpointConfig.arn_formatter(endpoint_config_name, self.region_name) + ) + raise RESTError( + error_type="ValidationException", + message=message, + template="error_json", + ) + + endpoint = FakeEndpoint( + region_name=self.region_name, + endpoint_name=endpoint_name, + endpoint_config_name=endpoint_config_name, + production_variants=endpoint_config["ProductionVariants"], + data_capture_config=endpoint_config["DataCaptureConfig"], + tags=tags, + ) + + self.endpoints[endpoint_name] = endpoint + return endpoint + + def describe_endpoint(self, endpoint_name): + try: + return self.endpoints[endpoint_name].response_object + except KeyError: + message = "Could not find endpoint configuration '{}'.".format( + FakeEndpoint.arn_formatter(endpoint_name, self.region_name) + ) + raise RESTError( + error_type="ValidationException", + message=message, + template="error_json", + ) + + def delete_endpoint(self, endpoint_name): + try: + del self.endpoints[endpoint_name] + except KeyError: + message = "Could not find endpoint configuration '{}'.".format( + FakeEndpoint.arn_formatter(endpoint_name, self.region_name) + ) + raise RESTError( + error_type="ValidationException", + message=message, + template="error_json", + ) + + def get_endpoint_by_arn(self, arn): + endpoints = [ + endpoint + for endpoint in self.endpoints.values() + if endpoint.endpoint_arn == arn + ] + if len(endpoints) == 0: + message = "RecordNotFound" + raise RESTError( + error_type="ValidationException", + message=message, + template="error_json", + ) + return endpoints[0] + + def get_endpoint_tags(self, arn): + try: + endpoint = self.get_endpoint_by_arn(arn) + return endpoint.tags or [] + except RESTError: + return [] + + def create_training_job( + self, + training_job_name, + hyper_parameters, + algorithm_specification, + role_arn, + input_data_config, + output_data_config, + resource_config, + vpc_config, + stopping_condition, + tags, + enable_network_isolation, + enable_inter_container_traffic_encryption, + enable_managed_spot_training, + checkpoint_config, + debug_hook_config, + debug_rule_configurations, + tensor_board_output_config, + experiment_config, + ): + training_job = FakeTrainingJob( + region_name=self.region_name, + training_job_name=training_job_name, + hyper_parameters=hyper_parameters, + algorithm_specification=algorithm_specification, + role_arn=role_arn, + input_data_config=input_data_config, + output_data_config=output_data_config, + resource_config=resource_config, + vpc_config=vpc_config, + stopping_condition=stopping_condition, + tags=tags, + enable_network_isolation=enable_network_isolation, + enable_inter_container_traffic_encryption=enable_inter_container_traffic_encryption, + enable_managed_spot_training=enable_managed_spot_training, + checkpoint_config=checkpoint_config, + debug_hook_config=debug_hook_config, + debug_rule_configurations=debug_rule_configurations, + tensor_board_output_config=tensor_board_output_config, + experiment_config=experiment_config, + ) + self.training_jobs[training_job_name] = training_job + return training_job + + def describe_training_job(self, training_job_name): + try: + return self.training_jobs[training_job_name].response_object + except KeyError: + message = "Could not find training job '{}'.".format( + FakeTrainingJob.arn_formatter(training_job_name, self.region_name) + ) + raise RESTError( + error_type="ValidationException", + message=message, + template="error_json", + ) + + def delete_training_job(self, training_job_name): + try: + del self.training_jobs[training_job_name] + except KeyError: + message = "Could not find endpoint configuration '{}'.".format( + FakeTrainingJob.arn_formatter(training_job_name, self.region_name) + ) + raise RESTError( + error_type="ValidationException", + message=message, + template="error_json", + ) + + def get_training_job_by_arn(self, arn): + training_jobs = [ + training_job + for training_job in self.training_jobs.values() + if training_job.training_job_arn == arn + ] + if len(training_jobs) == 0: + message = "RecordNotFound" + raise RESTError( + error_type="ValidationException", + message=message, + template="error_json", + ) + return training_jobs[0] + + def get_training_job_tags(self, arn): + try: + training_job = self.get_training_job_by_arn(arn) + return training_job.tags or [] + except RESTError: + return [] + sagemaker_backends = {} for region, ec2_backend in ec2_backends.items(): diff --git a/moto/sagemaker/responses.py b/moto/sagemaker/responses.py index 58e28ef01..48a3a6432 100644 --- a/moto/sagemaker/responses.py +++ b/moto/sagemaker/responses.py @@ -122,6 +122,120 @@ class SageMakerResponse(BaseResponse): @amzn_request_id def list_tags(self): arn = self._get_param("ResourceArn") - tags = self.sagemaker_backend.get_notebook_instance_tags(arn) + try: + if ":notebook-instance/" in arn: + tags = self.sagemaker_backend.get_notebook_instance_tags(arn) + elif ":endpoint/" in arn: + tags = self.sagemaker_backend.get_endpoint_tags(arn) + elif ":training-job/" in arn: + tags = self.sagemaker_backend.get_training_job_tags(arn) + else: + tags = [] + except AWSError: + tags = [] response = {"Tags": tags} return 200, {}, json.dumps(response) + + @amzn_request_id + def create_endpoint_config(self): + try: + endpoint_config = self.sagemaker_backend.create_endpoint_config( + endpoint_config_name=self._get_param("EndpointConfigName"), + production_variants=self._get_param("ProductionVariants"), + data_capture_config=self._get_param("DataCaptureConfig"), + tags=self._get_param("Tags"), + kms_key_id=self._get_param("KmsKeyId"), + ) + response = { + "EndpointConfigArn": endpoint_config.endpoint_config_arn, + } + return 200, {}, json.dumps(response) + except AWSError as err: + return err.response() + + @amzn_request_id + def describe_endpoint_config(self): + endpoint_config_name = self._get_param("EndpointConfigName") + response = self.sagemaker_backend.describe_endpoint_config(endpoint_config_name) + return json.dumps(response) + + @amzn_request_id + def delete_endpoint_config(self): + endpoint_config_name = self._get_param("EndpointConfigName") + self.sagemaker_backend.delete_endpoint_config(endpoint_config_name) + return 200, {}, json.dumps("{}") + + @amzn_request_id + def create_endpoint(self): + try: + endpoint = self.sagemaker_backend.create_endpoint( + endpoint_name=self._get_param("EndpointName"), + endpoint_config_name=self._get_param("EndpointConfigName"), + tags=self._get_param("Tags"), + ) + response = { + "EndpointArn": endpoint.endpoint_arn, + } + return 200, {}, json.dumps(response) + except AWSError as err: + return err.response() + + @amzn_request_id + def describe_endpoint(self): + endpoint_name = self._get_param("EndpointName") + response = self.sagemaker_backend.describe_endpoint(endpoint_name) + return json.dumps(response) + + @amzn_request_id + def delete_endpoint(self): + endpoint_name = self._get_param("EndpointName") + self.sagemaker_backend.delete_endpoint(endpoint_name) + return 200, {}, json.dumps("{}") + + @amzn_request_id + def create_training_job(self): + try: + training_job = self.sagemaker_backend.create_training_job( + training_job_name=self._get_param("TrainingJobName"), + hyper_parameters=self._get_param("HyperParameters"), + algorithm_specification=self._get_param("AlgorithmSpecification"), + role_arn=self._get_param("RoleArn"), + input_data_config=self._get_param("InputDataConfig"), + output_data_config=self._get_param("OutputDataConfig"), + resource_config=self._get_param("ResourceConfig"), + vpc_config=self._get_param("VpcConfig"), + stopping_condition=self._get_param("StoppingCondition"), + tags=self._get_param("Tags"), + enable_network_isolation=self._get_param( + "EnableNetworkIsolation", False + ), + enable_inter_container_traffic_encryption=self._get_param( + "EnableInterContainerTrafficEncryption", False + ), + enable_managed_spot_training=self._get_param( + "EnableManagedSpotTraining", False + ), + checkpoint_config=self._get_param("CheckpointConfig"), + debug_hook_config=self._get_param("DebugHookConfig"), + debug_rule_configurations=self._get_param("DebugRuleConfigurations"), + tensor_board_output_config=self._get_param("TensorBoardOutputConfig"), + experiment_config=self._get_param("ExperimentConfig"), + ) + response = { + "TrainingJobArn": training_job.training_job_arn, + } + return 200, {}, json.dumps(response) + except AWSError as err: + return err.response() + + @amzn_request_id + def describe_training_job(self): + training_job_name = self._get_param("TrainingJobName") + response = self.sagemaker_backend.describe_training_job(training_job_name) + return json.dumps(response) + + @amzn_request_id + def delete_training_job(self): + training_job_name = self._get_param("TrainingJobName") + self.sagemaker_backend.delete_training_job(training_job_name) + return 200, {}, json.dumps("{}") diff --git a/moto/sagemaker/urls.py b/moto/sagemaker/urls.py index 224342ce5..9c039d899 100644 --- a/moto/sagemaker/urls.py +++ b/moto/sagemaker/urls.py @@ -3,7 +3,6 @@ from .responses import SageMakerResponse url_bases = [ "https?://api.sagemaker.(.+).amazonaws.com", - "https?://api-fips.sagemaker.(.+).amazonaws.com", ] url_paths = { diff --git a/tests/test_sagemaker/test_sagemaker_endpoint.py b/tests/test_sagemaker/test_sagemaker_endpoint.py new file mode 100644 index 000000000..b048439ff --- /dev/null +++ b/tests/test_sagemaker/test_sagemaker_endpoint.py @@ -0,0 +1,246 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + +import datetime +import boto3 +from botocore.exceptions import ClientError, ParamValidationError +import sure # noqa + +from moto import mock_sagemaker +from moto.sts.models import ACCOUNT_ID +from nose.tools import assert_true, assert_equal, assert_raises + +TEST_REGION_NAME = "us-east-1" +FAKE_ROLE_ARN = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID) +GENERIC_TAGS_PARAM = [ + {"Key": "newkey1", "Value": "newval1"}, + {"Key": "newkey2", "Value": "newval2"}, +] + + +@mock_sagemaker +def test_create_endpoint_config(): + sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + model_name = "MyModel" + production_variants = [ + { + "VariantName": "MyProductionVariant", + "ModelName": model_name, + "InitialInstanceCount": 1, + "InstanceType": "ml.t2.medium", + }, + ] + + endpoint_config_name = "MyEndpointConfig" + with assert_raises(ClientError) as e: + sagemaker.create_endpoint_config( + EndpointConfigName=endpoint_config_name, + ProductionVariants=production_variants, + ) + assert_true( + e.exception.response["Error"]["Message"].startswith("Could not find model") + ) + + _create_model(sagemaker, model_name) + resp = sagemaker.create_endpoint_config( + EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants + ) + resp["EndpointConfigArn"].should.match( + r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name) + ) + + resp = sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name) + resp["EndpointConfigArn"].should.match( + r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name) + ) + resp["EndpointConfigName"].should.equal(endpoint_config_name) + resp["ProductionVariants"].should.equal(production_variants) + + +@mock_sagemaker +def test_delete_endpoint_config(): + sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + model_name = "MyModel" + _create_model(sagemaker, model_name) + + endpoint_config_name = "MyEndpointConfig" + production_variants = [ + { + "VariantName": "MyProductionVariant", + "ModelName": model_name, + "InitialInstanceCount": 1, + "InstanceType": "ml.t2.medium", + }, + ] + + resp = sagemaker.create_endpoint_config( + EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants + ) + resp["EndpointConfigArn"].should.match( + r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name) + ) + + resp = sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name) + resp["EndpointConfigArn"].should.match( + r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name) + ) + + resp = sagemaker.delete_endpoint_config(EndpointConfigName=endpoint_config_name) + with assert_raises(ClientError) as e: + sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name) + assert_true( + e.exception.response["Error"]["Message"].startswith( + "Could not find endpoint configuration" + ) + ) + + with assert_raises(ClientError) as e: + sagemaker.delete_endpoint_config(EndpointConfigName=endpoint_config_name) + assert_true( + e.exception.response["Error"]["Message"].startswith( + "Could not find endpoint configuration" + ) + ) + pass + + +@mock_sagemaker +def test_create_endpoint_invalid_instance_type(): + sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + model_name = "MyModel" + _create_model(sagemaker, model_name) + + instance_type = "InvalidInstanceType" + production_variants = [ + { + "VariantName": "MyProductionVariant", + "ModelName": model_name, + "InitialInstanceCount": 1, + "InstanceType": instance_type, + }, + ] + + endpoint_config_name = "MyEndpointConfig" + with assert_raises(ClientError) as e: + sagemaker.create_endpoint_config( + EndpointConfigName=endpoint_config_name, + ProductionVariants=production_variants, + ) + assert_equal(e.exception.response["Error"]["Code"], "ValidationException") + expected_message = "Value '{}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: [".format( + instance_type + ) + assert_true(expected_message in e.exception.response["Error"]["Message"]) + + +@mock_sagemaker +def test_create_endpoint(): + sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + endpoint_name = "MyEndpoint" + with assert_raises(ClientError) as e: + sagemaker.create_endpoint( + EndpointName=endpoint_name, EndpointConfigName="NonexistentEndpointConfig" + ) + assert_true( + e.exception.response["Error"]["Message"].startswith( + "Could not find endpoint configuration" + ) + ) + + model_name = "MyModel" + _create_model(sagemaker, model_name) + + endpoint_config_name = "MyEndpointConfig" + _create_endpoint_config(sagemaker, endpoint_config_name, model_name) + + resp = sagemaker.create_endpoint( + EndpointName=endpoint_name, + EndpointConfigName=endpoint_config_name, + Tags=GENERIC_TAGS_PARAM, + ) + resp["EndpointArn"].should.match( + r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(endpoint_name) + ) + + resp = sagemaker.describe_endpoint(EndpointName=endpoint_name) + resp["EndpointArn"].should.match( + r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(endpoint_name) + ) + resp["EndpointName"].should.equal(endpoint_name) + resp["EndpointConfigName"].should.equal(endpoint_config_name) + resp["EndpointStatus"].should.equal("InService") + assert_true(isinstance(resp["CreationTime"], datetime.datetime)) + assert_true(isinstance(resp["LastModifiedTime"], datetime.datetime)) + resp["ProductionVariants"][0]["VariantName"].should.equal("MyProductionVariant") + + resp = sagemaker.list_tags(ResourceArn=resp["EndpointArn"]) + assert_equal(resp["Tags"], GENERIC_TAGS_PARAM) + + +@mock_sagemaker +def test_delete_endpoint(): + sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + model_name = "MyModel" + _create_model(sagemaker, model_name) + + endpoint_config_name = "MyEndpointConfig" + _create_endpoint_config(sagemaker, endpoint_config_name, model_name) + + endpoint_name = "MyEndpoint" + _create_endpoint(sagemaker, endpoint_name, endpoint_config_name) + + sagemaker.delete_endpoint(EndpointName=endpoint_name) + with assert_raises(ClientError) as e: + sagemaker.describe_endpoint(EndpointName=endpoint_name) + assert_true( + e.exception.response["Error"]["Message"].startswith("Could not find endpoint") + ) + + with assert_raises(ClientError) as e: + sagemaker.delete_endpoint(EndpointName=endpoint_name) + assert_true( + e.exception.response["Error"]["Message"].startswith("Could not find endpoint") + ) + + +def _create_model(boto_client, model_name): + resp = boto_client.create_model( + ModelName=model_name, + PrimaryContainer={ + "Image": "382416733822.dkr.ecr.us-east-1.amazonaws.com/factorization-machines:1", + "ModelDataUrl": "s3://MyBucket/model.tar.gz", + }, + ExecutionRoleArn=FAKE_ROLE_ARN, + ) + assert_equal(resp["ResponseMetadata"]["HTTPStatusCode"], 200) + + +def _create_endpoint_config(boto_client, endpoint_config_name, model_name): + production_variants = [ + { + "VariantName": "MyProductionVariant", + "ModelName": model_name, + "InitialInstanceCount": 1, + "InstanceType": "ml.t2.medium", + }, + ] + resp = boto_client.create_endpoint_config( + EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants + ) + resp["EndpointConfigArn"].should.match( + r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name) + ) + + +def _create_endpoint(boto_client, endpoint_name, endpoint_config_name): + resp = boto_client.create_endpoint( + EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name + ) + resp["EndpointArn"].should.match( + r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(endpoint_name) + ) diff --git a/tests/test_sagemaker/test_sagemaker_training.py b/tests/test_sagemaker/test_sagemaker_training.py new file mode 100644 index 000000000..feaf9f713 --- /dev/null +++ b/tests/test_sagemaker/test_sagemaker_training.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + +import boto3 +import datetime +import sure # noqa + +from moto import mock_sagemaker +from moto.sts.models import ACCOUNT_ID +from nose.tools import assert_true, assert_equal, assert_raises, assert_regexp_matches + +FAKE_ROLE_ARN = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID) +TEST_REGION_NAME = "us-east-1" + + +@mock_sagemaker +def test_create_training_job(): + sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) + + training_job_name = "MyTrainingJob" + container = "382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1" + bucket = "my-bucket" + prefix = "sagemaker/DEMO-breast-cancer-prediction/" + + params = { + "RoleArn": FAKE_ROLE_ARN, + "TrainingJobName": training_job_name, + "AlgorithmSpecification": { + "TrainingImage": container, + "TrainingInputMode": "File", + }, + "ResourceConfig": { + "InstanceCount": 1, + "InstanceType": "ml.c4.2xlarge", + "VolumeSizeInGB": 10, + }, + "InputDataConfig": [ + { + "ChannelName": "train", + "DataSource": { + "S3DataSource": { + "S3DataType": "S3Prefix", + "S3Uri": "s3://{}/{}/train/".format(bucket, prefix), + "S3DataDistributionType": "ShardedByS3Key", + } + }, + "CompressionType": "None", + "RecordWrapperType": "None", + }, + { + "ChannelName": "validation", + "DataSource": { + "S3DataSource": { + "S3DataType": "S3Prefix", + "S3Uri": "s3://{}/{}/validation/".format(bucket, prefix), + "S3DataDistributionType": "FullyReplicated", + } + }, + "CompressionType": "None", + "RecordWrapperType": "None", + }, + ], + "OutputDataConfig": {"S3OutputPath": "s3://{}/{}/".format(bucket, prefix)}, + "HyperParameters": { + "feature_dim": "30", + "mini_batch_size": "100", + "predictor_type": "regressor", + "epochs": "10", + "num_models": "32", + "loss": "absolute_loss", + }, + "StoppingCondition": {"MaxRuntimeInSeconds": 60 * 60}, + } + + resp = sagemaker.create_training_job(**params) + resp["TrainingJobArn"].should.match( + r"^arn:aws:sagemaker:.*:.*:training-job/{}$".format(training_job_name) + ) + + resp = sagemaker.describe_training_job(TrainingJobName=training_job_name) + resp["TrainingJobName"].should.equal(training_job_name) + resp["TrainingJobArn"].should.match( + r"^arn:aws:sagemaker:.*:.*:training-job/{}$".format(training_job_name) + ) + assert_true( + resp["ModelArtifacts"]["S3ModelArtifacts"].startswith( + params["OutputDataConfig"]["S3OutputPath"] + ) + ) + assert_true(training_job_name in (resp["ModelArtifacts"]["S3ModelArtifacts"])) + assert_true( + resp["ModelArtifacts"]["S3ModelArtifacts"].endswith("output/model.tar.gz") + ) + assert_equal(resp["TrainingJobStatus"], "Completed") + assert_equal(resp["SecondaryStatus"], "Completed") + assert_equal(resp["HyperParameters"], params["HyperParameters"]) + assert_equal( + resp["AlgorithmSpecification"]["TrainingImage"], + params["AlgorithmSpecification"]["TrainingImage"], + ) + assert_equal( + resp["AlgorithmSpecification"]["TrainingInputMode"], + params["AlgorithmSpecification"]["TrainingInputMode"], + ) + assert_true("MetricDefinitions" in resp["AlgorithmSpecification"]) + assert_true("Name" in resp["AlgorithmSpecification"]["MetricDefinitions"][0]) + assert_true("Regex" in resp["AlgorithmSpecification"]["MetricDefinitions"][0]) + assert_equal(resp["RoleArn"], FAKE_ROLE_ARN) + assert_equal(resp["InputDataConfig"], params["InputDataConfig"]) + assert_equal(resp["OutputDataConfig"], params["OutputDataConfig"]) + assert_equal(resp["ResourceConfig"], params["ResourceConfig"]) + assert_equal(resp["StoppingCondition"], params["StoppingCondition"]) + assert_true(isinstance(resp["CreationTime"], datetime.datetime)) + assert_true(isinstance(resp["TrainingStartTime"], datetime.datetime)) + assert_true(isinstance(resp["TrainingEndTime"], datetime.datetime)) + assert_true(isinstance(resp["LastModifiedTime"], datetime.datetime)) + assert_true("SecondaryStatusTransitions" in resp) + assert_true("Status" in resp["SecondaryStatusTransitions"][0]) + assert_true("StartTime" in resp["SecondaryStatusTransitions"][0]) + assert_true("EndTime" in resp["SecondaryStatusTransitions"][0]) + assert_true("StatusMessage" in resp["SecondaryStatusTransitions"][0]) + assert_true("FinalMetricDataList" in resp) + assert_true("MetricName" in resp["FinalMetricDataList"][0]) + assert_true("Value" in resp["FinalMetricDataList"][0]) + assert_true("Timestamp" in resp["FinalMetricDataList"][0]) + + pass