diff --git a/moto/sagemaker/exceptions.py b/moto/sagemaker/exceptions.py index ba183a0f5..86dcf505a 100644 --- a/moto/sagemaker/exceptions.py +++ b/moto/sagemaker/exceptions.py @@ -1,3 +1,4 @@ +from typing import Any from moto.core.exceptions import RESTError, JsonRESTError, AWSError ERROR_WITH_MODEL_NAME = """{% extends 'single_error' %} @@ -6,14 +7,14 @@ ERROR_WITH_MODEL_NAME = """{% extends 'single_error' %} class SagemakerClientError(RESTError): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): kwargs.setdefault("template", "single_error") self.templates["model_error"] = ERROR_WITH_MODEL_NAME super().__init__(*args, **kwargs) class ModelError(RESTError): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): kwargs.setdefault("template", "model_error") self.templates["model_error"] = ERROR_WITH_MODEL_NAME super().__init__(*args, **kwargs) @@ -22,13 +23,13 @@ class ModelError(RESTError): class MissingModel(ModelError): code = 404 - def __init__(self, *args, **kwargs): - super().__init__("NoSuchModel", "Could not find model", *args, **kwargs) + def __init__(self, model: str): + super().__init__("NoSuchModel", "Could not find model", model=model) class ValidationError(JsonRESTError): - def __init__(self, message, **kwargs): - super().__init__("ValidationException", message, **kwargs) + def __init__(self, message: str): + super().__init__("ValidationException", message) class AWSValidationException(AWSError): @@ -36,5 +37,5 @@ class AWSValidationException(AWSError): class ResourceNotFound(JsonRESTError): - def __init__(self, message, **kwargs): - super().__init__(__class__.__name__, message, **kwargs) + def __init__(self, message: str): + super().__init__(__class__.__name__, message) # type: ignore diff --git a/moto/sagemaker/models.py b/moto/sagemaker/models.py index d1ea299a4..cab49ff18 100644 --- a/moto/sagemaker/models.py +++ b/moto/sagemaker/models.py @@ -3,6 +3,7 @@ import os import random import string from datetime import datetime +from typing import Any, Dict, List, Optional, Iterable from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel from moto.sagemaker import validators @@ -54,19 +55,19 @@ PAGINATION_MODEL = { class BaseObject(BaseModel): - def camelCase(self, key): + def camelCase(self, key: str) -> str: words = [] for word in key.split("_"): words.append(word.title()) return "".join(words) - def update(self, details_json): + def update(self, details_json: str) -> None: details = json.loads(details_json) for k in details.keys(): setattr(self, k, details[k]) - def gen_response_object(self): - response_object = dict() + def gen_response_object(self) -> Dict[str, Any]: + response_object: Dict[str, Any] = dict() for key, value in self.__dict__.items(): if "_" in key: response_object[self.camelCase(key)] = value @@ -75,20 +76,20 @@ class BaseObject(BaseModel): return response_object @property - def response_object(self): + def response_object(self) -> Dict[str, Any]: # type: ignore[misc] return self.gen_response_object() class FakePipelineExecution(BaseObject): def __init__( self, - pipeline_execution_arn, - pipeline_execution_display_name, - pipeline_parameters, - pipeline_execution_description, - parallelism_configuration, - pipeline_definition, - client_request_token, + pipeline_execution_arn: str, + pipeline_execution_display_name: str, + pipeline_parameters: List[Dict[str, str]], + pipeline_execution_description: str, + parallelism_configuration: Dict[str, int], + pipeline_definition: str, + client_request_token: str, ): self.pipeline_execution_arn = pipeline_execution_arn self.pipeline_execution_display_name = pipeline_execution_display_name @@ -128,15 +129,15 @@ class FakePipelineExecution(BaseObject): class FakePipeline(BaseObject): def __init__( self, - pipeline_name, - pipeline_display_name, - pipeline_definition, - pipeline_description, - role_arn, - tags, - account_id, - region_name, - parallelism_configuration, + pipeline_name: str, + pipeline_display_name: str, + pipeline_definition: str, + pipeline_description: str, + role_arn: str, + tags: List[Dict[str, str]], + account_id: str, + region_name: str, + parallelism_configuration: Dict[str, int], ): self.pipeline_name = pipeline_name self.pipeline_arn = arn_formatter( @@ -145,7 +146,7 @@ class FakePipeline(BaseObject): self.pipeline_display_name = pipeline_display_name or pipeline_name self.pipeline_definition = pipeline_definition self.pipeline_description = pipeline_description - self.pipeline_executions = dict() + self.pipeline_executions: Dict[str, FakePipelineExecution] = dict() self.role_arn = role_arn self.tags = tags or [] self.parallelism_configuration = parallelism_configuration @@ -153,7 +154,7 @@ class FakePipeline(BaseObject): now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.creation_time = now_string self.last_modified_time = now_string - self.last_execution_time = None + self.last_execution_time: Optional[str] = None self.pipeline_status = "Active" fake_user_profile_name = "fake-user-profile-name" @@ -179,21 +180,21 @@ class FakePipeline(BaseObject): class FakeProcessingJob(BaseObject): def __init__( self, - app_specification, - experiment_config, - network_config, - processing_inputs, - processing_job_name, - processing_output_config, - account_id, - region_name, - role_arn, - tags, - stopping_condition, + app_specification: Dict[str, Any], + experiment_config: Dict[str, str], + network_config: Dict[str, Any], + processing_inputs: List[Dict[str, Any]], + processing_job_name: str, + processing_output_config: Dict[str, Any], + account_id: str, + region_name: str, + role_arn: str, + tags: List[Dict[str, str]], + stopping_condition: Dict[str, int], ): self.processing_job_name = processing_job_name - self.processing_job_arn = arn_formatter( - "processing-job", processing_job_name, account_id, region_name + self.processing_job_arn = FakeProcessingJob.arn_formatter( + processing_job_name, account_id, region_name ) now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S") @@ -211,40 +212,44 @@ class FakeProcessingJob(BaseObject): self.stopping_condition = stopping_condition @property - def response_object(self): + def response_object(self) -> Dict[str, Any]: # type: ignore[misc] response_object = self.gen_response_object() return { k: v for k, v in response_object.items() if v is not None and v != [None] } @property - def response_create(self): + def response_create(self) -> Dict[str, str]: return {"ProcessingJobArn": self.processing_job_arn} + @staticmethod + def arn_formatter(name: str, account_id: str, region: str) -> str: + return arn_formatter("processing-job", name, account_id, region) + class FakeTrainingJob(BaseObject): def __init__( self, - account_id, - region_name, - training_job_name, - hyper_parameters, - algorithm_specification, - role_arn, - input_data_config, - output_data_config, - resource_config, - vpc_config, - stopping_condition, - tags, - enable_network_isolation, - enable_inter_container_traffic_encryption, - enable_managed_spot_training, - checkpoint_config, - debug_hook_config, - debug_rule_configurations, - tensor_board_output_config, - experiment_config, + account_id: str, + region_name: str, + training_job_name: str, + hyper_parameters: Dict[str, str], + algorithm_specification: Dict[str, Any], + role_arn: str, + input_data_config: List[Dict[str, Any]], + output_data_config: Dict[str, str], + resource_config: Dict[str, Any], + vpc_config: Dict[str, List[str]], + stopping_condition: Dict[str, int], + tags: List[Dict[str, str]], + enable_network_isolation: bool, + enable_inter_container_traffic_encryption: bool, + enable_managed_spot_training: bool, + checkpoint_config: Dict[str, str], + debug_hook_config: Dict[str, Any], + debug_rule_configurations: List[Dict[str, Any]], + tensor_board_output_config: Dict[str, str], + experiment_config: Dict[str, str], ): self.training_job_name = training_job_name self.hyper_parameters = hyper_parameters @@ -310,31 +315,31 @@ class FakeTrainingJob(BaseObject): ] @property - def response_object(self): + def response_object(self) -> Dict[str, Any]: # type: ignore[misc] response_object = self.gen_response_object() return { k: v for k, v in response_object.items() if v is not None and v != [None] } @property - def response_create(self): + def response_create(self) -> Dict[str, str]: return {"TrainingJobArn": self.training_job_arn} @staticmethod - def arn_formatter(name, account_id, region_name): + def arn_formatter(name: str, account_id: str, region_name: str) -> str: return arn_formatter("training-job", name, account_id, region_name) class FakeEndpoint(BaseObject, CloudFormationModel): def __init__( self, - account_id, - region_name, - endpoint_name, - endpoint_config_name, - production_variants, - data_capture_config, - tags, + account_id: str, + region_name: str, + endpoint_name: str, + endpoint_config_name: str, + production_variants: List[Dict[str, Any]], + data_capture_config: Dict[str, Any], + tags: List[Dict[str, str]], ): self.endpoint_name = endpoint_name self.endpoint_arn = FakeEndpoint.arn_formatter( @@ -352,7 +357,9 @@ class FakeEndpoint(BaseObject, CloudFormationModel): "%Y-%m-%d %H:%M:%S" ) - def _process_production_variants(self, production_variants): + def _process_production_variants( + self, production_variants: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: endpoint_variants = [] for production_variant in production_variants: temp_variant = {} @@ -389,29 +396,29 @@ class FakeEndpoint(BaseObject, CloudFormationModel): return endpoint_variants @property - def response_object(self): + def response_object(self) -> Dict[str, Any]: # type: ignore[misc] response_object = self.gen_response_object() return { k: v for k, v in response_object.items() if v is not None and v != [None] } @property - def response_create(self): + def response_create(self) -> Dict[str, str]: return {"EndpointArn": self.endpoint_arn} @staticmethod - def arn_formatter(endpoint_name, account_id, region_name): + def arn_formatter(endpoint_name: str, account_id: str, region_name: str) -> str: return arn_formatter("endpoint", endpoint_name, account_id, region_name) @property - def physical_resource_id(self): + def physical_resource_id(self) -> str: return self.endpoint_arn @classmethod - def has_cfn_attr(cls, attr): + def has_cfn_attr(cls, attr: str) -> bool: return attr in ["EndpointName"] - def get_cfn_attribute(self, attribute_name): + def get_cfn_attribute(self, attribute_name: str) -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-endpoint.html#aws-resource-sagemaker-endpoint-return-values from moto.cloudformation.exceptions import UnformattedGetAttTemplateException @@ -420,18 +427,23 @@ class FakeEndpoint(BaseObject, CloudFormationModel): raise UnformattedGetAttTemplateException() @staticmethod - def cloudformation_name_type(): - return None + def cloudformation_name_type() -> str: + return "" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-endpoint.html return "AWS::SageMaker::Endpoint" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "FakeEndpoint": sagemaker_backend = sagemaker_backends[account_id][region_name] # Get required properties from provided CloudFormation template @@ -446,14 +458,14 @@ class FakeEndpoint(BaseObject, CloudFormationModel): return endpoint @classmethod - def update_from_cloudformation_json( + def update_from_cloudformation_json( # type: ignore[misc] cls, - original_resource, - new_resource_name, - cloudformation_json, - account_id, - region_name, - ): + original_resource: Any, + new_resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + ) -> "FakeEndpoint": # Changes to the Endpoint will not change resource name cls.delete_from_cloudformation_json( original_resource.endpoint_arn, cloudformation_json, account_id, region_name @@ -467,9 +479,13 @@ class FakeEndpoint(BaseObject, CloudFormationModel): return new_resource @classmethod - def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name - ): + def delete_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + ) -> None: # Get actual name because resource_name actually provides the ARN # since the Physical Resource ID is the ARN despite SageMaker # using the name for most of its operations. @@ -481,13 +497,13 @@ class FakeEndpoint(BaseObject, CloudFormationModel): class FakeEndpointConfig(BaseObject, CloudFormationModel): def __init__( self, - account_id, - region_name, - endpoint_config_name, - production_variants, - data_capture_config, - tags, - kms_key_id, + account_id: str, + region_name: str, + endpoint_config_name: str, + production_variants: List[Dict[str, Any]], + data_capture_config: Dict[str, Any], + tags: List[Dict[str, Any]], + kms_key_id: str, ): self.validate_production_variants(production_variants) @@ -501,7 +517,9 @@ class FakeEndpointConfig(BaseObject, CloudFormationModel): self.kms_key_id = kms_key_id self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - def validate_production_variants(self, production_variants): + def validate_production_variants( + self, production_variants: List[Dict[str, Any]] + ) -> None: for production_variant in production_variants: if "InstanceType" in production_variant.keys(): self.validate_instance_type(production_variant["InstanceType"]) @@ -511,7 +529,7 @@ class FakeEndpointConfig(BaseObject, CloudFormationModel): message = f"Invalid Keys for ProductionVariant: received {production_variant.keys()} but expected it to contain one of {['InstanceType', 'ServerlessConfig']}" raise ValidationError(message=message) - def validate_serverless_config(self, serverless_config): + def validate_serverless_config(self, serverless_config: Dict[str, Any]) -> None: VALID_SERVERLESS_MEMORY_SIZE = [1024, 2048, 3072, 4096, 5120, 6144] if not validators.is_one_of( serverless_config["MemorySizeInMB"], VALID_SERVERLESS_MEMORY_SIZE @@ -519,7 +537,7 @@ class FakeEndpointConfig(BaseObject, CloudFormationModel): message = f"Value '{serverless_config['MemorySizeInMB']}' at 'MemorySizeInMB' failed to satisfy constraint: Member must satisfy enum value set: {VALID_SERVERLESS_MEMORY_SIZE}" raise ValidationError(message=message) - def validate_instance_type(self, instance_type): + def validate_instance_type(self, instance_type: str) -> None: VALID_INSTANCE_TYPES = [ "ml.r5d.12xlarge", "ml.r5.12xlarge", @@ -593,31 +611,33 @@ class FakeEndpointConfig(BaseObject, CloudFormationModel): raise ValidationError(message=message) @property - def response_object(self): + def response_object(self) -> Dict[str, Any]: # type: ignore[misc] response_object = self.gen_response_object() return { k: v for k, v in response_object.items() if v is not None and v != [None] } @property - def response_create(self): + def response_create(self) -> Dict[str, str]: return {"EndpointConfigArn": self.endpoint_config_arn} @staticmethod - def arn_formatter(endpoint_config_name, account_id, region_name): + def arn_formatter( + endpoint_config_name: str, account_id: str, region_name: str + ) -> str: return arn_formatter( "endpoint-config", endpoint_config_name, account_id, region_name ) @property - def physical_resource_id(self): + def physical_resource_id(self) -> str: return self.endpoint_config_arn @classmethod - def has_cfn_attr(cls, attr): + def has_cfn_attr(cls, attr: str) -> bool: return attr in ["EndpointConfigName"] - def get_cfn_attribute(self, attribute_name): + def get_cfn_attribute(self, attribute_name: str) -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-endpointconfig.html#aws-resource-sagemaker-endpointconfig-return-values from moto.cloudformation.exceptions import UnformattedGetAttTemplateException @@ -626,18 +646,23 @@ class FakeEndpointConfig(BaseObject, CloudFormationModel): raise UnformattedGetAttTemplateException() @staticmethod - def cloudformation_name_type(): - return None + def cloudformation_name_type() -> str: + return "" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-endpointconfig.html return "AWS::SageMaker::EndpointConfig" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "FakeEndpointConfig": sagemaker_backend = sagemaker_backends[account_id][region_name] # Get required properties from provided CloudFormation template @@ -654,14 +679,14 @@ class FakeEndpointConfig(BaseObject, CloudFormationModel): return endpoint_config @classmethod - def update_from_cloudformation_json( + def update_from_cloudformation_json( # type: ignore[misc] cls, - original_resource, - new_resource_name, - cloudformation_json, - account_id, - region_name, - ): + original_resource: Any, + new_resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + ) -> "FakeEndpointConfig": # Most changes to the endpoint config will change resource name for EndpointConfigs cls.delete_from_cloudformation_json( original_resource.endpoint_config_arn, @@ -675,9 +700,13 @@ class FakeEndpointConfig(BaseObject, CloudFormationModel): return new_resource @classmethod - def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name - ): + def delete_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + ) -> None: # Get actual name because resource_name actually provides the ARN # since the Physical Resource ID is the ARN despite SageMaker # using the name for most of its operations. @@ -691,14 +720,14 @@ class FakeEndpointConfig(BaseObject, CloudFormationModel): class Model(BaseObject, CloudFormationModel): def __init__( self, - account_id, - region_name, - model_name, - execution_role_arn, - primary_container, - vpc_config, - containers=None, - tags=None, + account_id: str, + region_name: str, + model_name: str, + execution_role_arn: str, + primary_container: Dict[str, Any], + vpc_config: Dict[str, Any], + containers: Optional[List[Dict[str, Any]]] = None, + tags: Optional[List[Dict[str, str]]] = None, ): self.model_name = model_name self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") @@ -713,25 +742,25 @@ class Model(BaseObject, CloudFormationModel): ) @property - def response_object(self): + def response_object(self) -> Dict[str, Any]: # type: ignore[misc] response_object = self.gen_response_object() return { k: v for k, v in response_object.items() if v is not None and v != [None] } @property - def response_create(self): + def response_create(self) -> Dict[str, str]: return {"ModelArn": self.model_arn} @property - def physical_resource_id(self): + def physical_resource_id(self) -> str: return self.model_arn @classmethod - def has_cfn_attr(cls, attr): + def has_cfn_attr(cls, attr: str) -> bool: return attr in ["ModelName"] - def get_cfn_attribute(self, attribute_name): + def get_cfn_attribute(self, attribute_name: str) -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-model.html#aws-resource-sagemaker-model-return-values from moto.cloudformation.exceptions import UnformattedGetAttTemplateException @@ -740,18 +769,23 @@ class Model(BaseObject, CloudFormationModel): raise UnformattedGetAttTemplateException() @staticmethod - def cloudformation_name_type(): - return None + def cloudformation_name_type() -> str: + return "" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-model.html return "AWS::SageMaker::Model" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "Model": sagemaker_backend = sagemaker_backends[account_id][region_name] # Get required properties from provided CloudFormation template @@ -760,24 +794,24 @@ class Model(BaseObject, CloudFormationModel): primary_container = properties["PrimaryContainer"] model = sagemaker_backend.create_model( - ModelName=resource_name, - ExecutionRoleArn=execution_role_arn, - PrimaryContainer=primary_container, - VpcConfig=properties.get("VpcConfig", {}), - Containers=properties.get("Containers", []), - Tags=properties.get("Tags", []), + model_name=resource_name, + execution_role_arn=execution_role_arn, + primary_container=primary_container, + vpc_config=properties.get("VpcConfig", {}), + containers=properties.get("Containers", []), + tags=properties.get("Tags", []), ) return model @classmethod - def update_from_cloudformation_json( + def update_from_cloudformation_json( # type: ignore[misc] cls, - original_resource, - new_resource_name, - cloudformation_json, - account_id, - region_name, - ): + original_resource: Any, + new_resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + ) -> "Model": # Most changes to the model will change resource name for Models cls.delete_from_cloudformation_json( original_resource.model_arn, cloudformation_json, account_id, region_name @@ -788,9 +822,13 @@ class Model(BaseObject, CloudFormationModel): return new_resource @classmethod - def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name - ): + def delete_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + ) -> None: # Get actual name because resource_name actually provides the ARN # since the Physical Resource ID is the ARN despite SageMaker # using the name for most of its operations. @@ -800,12 +838,12 @@ class Model(BaseObject, CloudFormationModel): class VpcConfig(BaseObject): - def __init__(self, security_group_ids, subnets): + def __init__(self, security_group_ids: List[str], subnets: List[str]): self.security_group_ids = security_group_ids self.subnets = subnets @property - def response_object(self): + def response_object(self) -> Dict[str, List[str]]: response_object = self.gen_response_object() return { k: v for k, v in response_object.items() if v is not None and v != [None] @@ -813,7 +851,7 @@ class VpcConfig(BaseObject): class Container(BaseObject): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): self.container_hostname = kwargs.get("container_hostname", "localhost") self.model_data_url = kwargs.get("data_url", "") self.model_package_name = kwargs.get("package_name", "pkg") @@ -821,7 +859,7 @@ class Container(BaseObject): self.environment = kwargs.get("environment", {}) @property - def response_object(self): + def response_object(self) -> Dict[str, Any]: # type: ignore[misc] response_object = self.gen_response_object() return { k: v for k, v in response_object.items() if v is not None and v != [None] @@ -831,22 +869,22 @@ class Container(BaseObject): class FakeSagemakerNotebookInstance(CloudFormationModel): def __init__( self, - account_id, - region_name, - notebook_instance_name, - instance_type, - role_arn, - subnet_id, - security_group_ids, - kms_key_id, - tags, - lifecycle_config_name, - direct_internet_access, - volume_size_in_gb, - accelerator_types, - default_code_repository, - additional_code_repositories, - root_access, + account_id: str, + region_name: str, + notebook_instance_name: str, + instance_type: str, + role_arn: str, + subnet_id: Optional[str], + security_group_ids: Optional[List[str]], + kms_key_id: Optional[str], + tags: Optional[List[Dict[str, str]]], + lifecycle_config_name: Optional[str], + direct_internet_access: str, + volume_size_in_gb: int, + accelerator_types: Optional[List[str]], + default_code_repository: Optional[str], + additional_code_repositories: Optional[List[str]], + root_access: Optional[str], ): self.validate_volume_size_in_gb(volume_size_in_gb) self.validate_instance_type(instance_type) @@ -866,19 +904,19 @@ class FakeSagemakerNotebookInstance(CloudFormationModel): self.default_code_repository = default_code_repository self.additional_code_repositories = additional_code_repositories self.root_access = root_access - self.status = None + self.status: Optional[str] = None self.creation_time = self.last_modified_time = datetime.now() self.arn = arn_formatter( "notebook-instance", notebook_instance_name, account_id, region_name ) self.start() - def validate_volume_size_in_gb(self, volume_size_in_gb): + def validate_volume_size_in_gb(self, volume_size_in_gb: int) -> None: if not validators.is_integer_between(volume_size_in_gb, mn=5, optional=True): message = "Invalid range for parameter VolumeSizeInGB, value: {}, valid range: 5-inf" raise ValidationError(message=message) - def validate_instance_type(self, instance_type): + def validate_instance_type(self, instance_type: str) -> None: VALID_INSTANCE_TYPES = [ "ml.p2.xlarge", "ml.m5.4xlarge", @@ -924,30 +962,30 @@ class FakeSagemakerNotebookInstance(CloudFormationModel): raise ValidationError(message=message) @property - def url(self): + def url(self) -> str: return ( f"{self.notebook_instance_name}.notebook.{self.region_name}.sagemaker.aws" ) - def start(self): + def start(self) -> None: self.status = "InService" @property - def is_deletable(self): + def is_deletable(self) -> bool: return self.status in ["Stopped", "Failed"] - def stop(self): + def stop(self) -> None: self.status = "Stopped" @property - def physical_resource_id(self): + def physical_resource_id(self) -> str: return self.arn @classmethod - def has_cfn_attr(cls, attr): + def has_cfn_attr(cls, attr: str) -> bool: return attr in ["NotebookInstanceName"] - def get_cfn_attribute(self, attribute_name): + def get_cfn_attribute(self, attribute_name: str) -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-notebookinstance.html#aws-resource-sagemaker-notebookinstance-return-values from moto.cloudformation.exceptions import UnformattedGetAttTemplateException @@ -956,18 +994,23 @@ class FakeSagemakerNotebookInstance(CloudFormationModel): raise UnformattedGetAttTemplateException() @staticmethod - def cloudformation_name_type(): - return None + def cloudformation_name_type() -> str: + return "" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-notebookinstance.html return "AWS::SageMaker::NotebookInstance" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "FakeSagemakerNotebookInstance": # Get required properties from provided CloudFormation template properties = cloudformation_json["Properties"] instance_type = properties["InstanceType"] @@ -981,14 +1024,14 @@ class FakeSagemakerNotebookInstance(CloudFormationModel): return notebook @classmethod - def update_from_cloudformation_json( + def update_from_cloudformation_json( # type: ignore[misc] cls, - original_resource, - new_resource_name, - cloudformation_json, - account_id, - region_name, - ): + original_resource: Any, + new_resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + ) -> "FakeSagemakerNotebookInstance": # Operations keep same resource name so delete old and create new to mimic update cls.delete_from_cloudformation_json( original_resource.arn, cloudformation_json, account_id, region_name @@ -1002,9 +1045,13 @@ class FakeSagemakerNotebookInstance(CloudFormationModel): return new_resource @classmethod - def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name - ): + def delete_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + ) -> None: # Get actual name because resource_name actually provides the ARN # since the Physical Resource ID is the ARN despite SageMaker # using the name for most of its operations. @@ -1018,11 +1065,11 @@ class FakeSagemakerNotebookInstance(CloudFormationModel): class FakeSageMakerNotebookInstanceLifecycleConfig(BaseObject, CloudFormationModel): def __init__( self, - account_id, - region_name, - notebook_instance_lifecycle_config_name, - on_create, - on_start, + account_id: str, + region_name: str, + notebook_instance_lifecycle_config_name: str, + on_create: List[Dict[str, str]], + on_start: List[Dict[str, str]], ): self.region_name = region_name self.notebook_instance_lifecycle_config_name = ( @@ -1040,31 +1087,27 @@ class FakeSageMakerNotebookInstanceLifecycleConfig(BaseObject, CloudFormationMod ) @staticmethod - def arn_formatter(name, account_id, region_name): + def arn_formatter(name: str, account_id: str, region_name: str) -> str: return arn_formatter( "notebook-instance-lifecycle-configuration", name, account_id, region_name ) @property - def response_object(self): + def response_object(self) -> Dict[str, Any]: # type: ignore[misc] response_object = self.gen_response_object() return { k: v for k, v in response_object.items() if v is not None and v != [None] } @property - def response_create(self): - return {"TrainingJobArn": self.training_job_arn} - - @property - def physical_resource_id(self): + def physical_resource_id(self) -> str: return self.notebook_instance_lifecycle_config_arn @classmethod - def has_cfn_attr(cls, attr): + def has_cfn_attr(cls, attr: str) -> bool: return attr in ["NotebookInstanceLifecycleConfigName"] - def get_cfn_attribute(self, attribute_name): + def get_cfn_attribute(self, attribute_name: str) -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-notebookinstancelifecycleconfig.html#aws-resource-sagemaker-notebookinstancelifecycleconfig-return-values from moto.cloudformation.exceptions import UnformattedGetAttTemplateException @@ -1073,18 +1116,23 @@ class FakeSageMakerNotebookInstanceLifecycleConfig(BaseObject, CloudFormationMod raise UnformattedGetAttTemplateException() @staticmethod - def cloudformation_name_type(): - return None + def cloudformation_name_type() -> str: + return "" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sagemaker-notebookinstancelifecycleconfig.html return "AWS::SageMaker::NotebookInstanceLifecycleConfig" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "FakeSageMakerNotebookInstanceLifecycleConfig": properties = cloudformation_json["Properties"] config = sagemaker_backends[account_id][ @@ -1097,14 +1145,14 @@ class FakeSageMakerNotebookInstanceLifecycleConfig(BaseObject, CloudFormationMod return config @classmethod - def update_from_cloudformation_json( + def update_from_cloudformation_json( # type: ignore[misc] cls, - original_resource, - new_resource_name, - cloudformation_json, - account_id, - region_name, - ): + original_resource: Any, + new_resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + ) -> "FakeSageMakerNotebookInstanceLifecycleConfig": # Operations keep same resource name so delete old and create new to mimic update cls.delete_from_cloudformation_json( original_resource.notebook_instance_lifecycle_config_arn, @@ -1121,9 +1169,13 @@ class FakeSageMakerNotebookInstanceLifecycleConfig(BaseObject, CloudFormationMod return new_resource @classmethod - def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name - ): + def delete_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + ) -> None: # Get actual name because resource_name actually provides the ARN # since the Physical Resource ID is the ARN despite SageMaker # using the name for most of its operations. @@ -1134,23 +1186,27 @@ class FakeSageMakerNotebookInstanceLifecycleConfig(BaseObject, CloudFormationMod class SageMakerModelBackend(BaseBackend): - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self._models = {} - self.notebook_instances = {} - self.endpoint_configs = {} - self.endpoints = {} - self.experiments = {} - self.pipelines = {} - self.pipeline_executions = {} - self.processing_jobs = {} - self.trials = {} - self.trial_components = {} - self.training_jobs = {} - self.notebook_instance_lifecycle_configurations = {} + self._models: Dict[str, Model] = {} + self.notebook_instances: Dict[str, FakeSagemakerNotebookInstance] = {} + self.endpoint_configs: Dict[str, FakeEndpointConfig] = {} + self.endpoints: Dict[str, FakeEndpoint] = {} + self.experiments: Dict[str, FakeExperiment] = {} + self.pipelines: Dict[str, FakePipeline] = {} + self.pipeline_executions: Dict[str, FakePipelineExecution] = {} + self.processing_jobs: Dict[str, FakeProcessingJob] = {} + self.trials: Dict[str, FakeTrial] = {} + self.trial_components: Dict[str, FakeTrialComponent] = {} + self.training_jobs: Dict[str, FakeTrainingJob] = {} + self.notebook_instance_lifecycle_configurations: Dict[ + str, FakeSageMakerNotebookInstanceLifecycleConfig + ] = {} @staticmethod - def default_vpc_endpoint_service(service_region, zones): + def default_vpc_endpoint_service( + service_region: str, zones: List[str] + ) -> List[Dict[str, str]]: """Default VPC endpoint services.""" api_service = BaseBackend.default_vpc_endpoint_service_factory( service_region, zones, "api.sagemaker", special_service_name="sagemaker.api" @@ -1201,32 +1257,40 @@ class SageMakerModelBackend(BaseBackend): } return api_service + [notebook_service, studio_service] - def create_model(self, **kwargs): + def create_model( + self, + model_name: str, + execution_role_arn: str, + primary_container: Optional[Dict[str, Any]], + vpc_config: Optional[Dict[str, Any]], + containers: Optional[List[Dict[str, Any]]], + tags: Optional[List[Dict[str, str]]], + ) -> Model: model_obj = Model( account_id=self.account_id, region_name=self.region_name, - model_name=kwargs.get("ModelName"), - execution_role_arn=kwargs.get("ExecutionRoleArn"), - primary_container=kwargs.get("PrimaryContainer", {}), - vpc_config=kwargs.get("VpcConfig", {}), - containers=kwargs.get("Containers", []), - tags=kwargs.get("Tags", []), + model_name=model_name, + execution_role_arn=execution_role_arn, + primary_container=primary_container or {}, + vpc_config=vpc_config or {}, + containers=containers or [], + tags=tags or [], ) - self._models[kwargs.get("ModelName")] = model_obj + self._models[model_name] = model_obj return model_obj - def describe_model(self, model_name=None): + def describe_model(self, model_name: str) -> Model: model = self._models.get(model_name) if model: return model arn = arn_formatter("model", model_name, self.account_id, self.region_name) raise ValidationError(message=f"Could not find model '{arn}'.") - def list_models(self): + def list_models(self) -> Iterable[Model]: return self._models.values() - def delete_model(self, model_name=None): + def delete_model(self, model_name: str) -> None: for model in self._models.values(): if model.model_name == model_name: self._models.pop(model.model_name) @@ -1234,7 +1298,7 @@ class SageMakerModelBackend(BaseBackend): else: raise MissingModel(model=model_name) - def create_experiment(self, experiment_name): + def create_experiment(self, experiment_name: str) -> Dict[str, str]: experiment = FakeExperiment( account_id=self.account_id, region_name=self.region_name, @@ -1244,7 +1308,7 @@ class SageMakerModelBackend(BaseBackend): self.experiments[experiment_name] = experiment return experiment.response_create - def describe_experiment(self, experiment_name): + def describe_experiment(self, experiment_name: str) -> Dict[str, Any]: experiment_data = self.experiments[experiment_name] return { "ExperimentName": experiment_data.experiment_name, @@ -1253,7 +1317,7 @@ class SageMakerModelBackend(BaseBackend): "LastModifiedTime": experiment_data.last_modified_time, } - def _get_resource_from_arn(self, arn): + def _get_resource_from_arn(self, arn: str) -> Any: resources = { "model": self._models, "notebook-instance": self.notebook_instances, @@ -1268,30 +1332,31 @@ class SageMakerModelBackend(BaseBackend): } target_resource, target_name = arn.split(":")[-1].split("/") try: - resource = resources.get(target_resource).get(target_name) + resource = resources.get(target_resource).get(target_name) # type: ignore except KeyError: message = f"Could not find {target_resource} with name {target_name}" raise ValidationError(message=message) return resource - def add_tags(self, arn, tags): + def add_tags(self, arn: str, tags: List[Dict[str, str]]) -> List[Dict[str, str]]: resource = self._get_resource_from_arn(arn) resource.tags.extend(tags) + return resource.tags @paginate(pagination_model=PAGINATION_MODEL) - def list_tags(self, arn): + def list_tags(self, arn: str) -> List[Dict[str, str]]: # type: ignore[misc] resource = self._get_resource_from_arn(arn) return resource.tags - def delete_tags(self, arn, tag_keys): + def delete_tags(self, arn: str, tag_keys: List[str]) -> None: resource = self._get_resource_from_arn(arn) resource.tags = [tag for tag in resource.tags if tag["Key"] not in tag_keys] @paginate(pagination_model=PAGINATION_MODEL) - def list_experiments(self): + def list_experiments(self) -> List["FakeExperiment"]: # type: ignore[misc] return list(self.experiments.values()) - def search(self, resource=None, search_expression=None): + def search(self, resource: Any = None, search_expression: Any = None) -> Any: next_index = None valid_resources = [ @@ -1315,7 +1380,7 @@ class SageMakerModelBackend(BaseBackend): f"An error occurred (ValidationException) when calling the Search operation: 1 validation error detected: Value '{resource}' at 'resource' failed to satisfy constraint: Member must satisfy enum value set: {valid_resources}" ) - def evaluate_search_expression(item): + def evaluate_search_expression(item: Any) -> bool: filters = None if search_expression is not None: filters = search_expression.get("Filters") @@ -1362,7 +1427,7 @@ class SageMakerModelBackend(BaseBackend): return True - result = { + result: Dict[str, Any] = { "Results": [], "NextToken": str(next_index) if next_index is not None else None, } @@ -1418,16 +1483,18 @@ class SageMakerModelBackend(BaseBackend): result["Results"].append({"TrialComponent": trial_component_summary}) return result - def delete_experiment(self, experiment_name): + def delete_experiment(self, experiment_name: str) -> None: try: del self.experiments[experiment_name] except KeyError: - arn = FakeTrial.arn_formatter(experiment_name, self.region_name) + arn = FakeTrial.arn_formatter( + experiment_name, self.account_id, self.region_name + ) raise ValidationError( message=f"Could not find experiment configuration '{arn}'." ) - def create_trial(self, trial_name, experiment_name): + def create_trial(self, trial_name: str, experiment_name: str) -> Dict[str, str]: trial = FakeTrial( account_id=self.account_id, region_name=self.region_name, @@ -1439,27 +1506,27 @@ class SageMakerModelBackend(BaseBackend): self.trials[trial_name] = trial return trial.response_create - def describe_trial(self, trial_name): + def describe_trial(self, trial_name: str) -> Dict[str, Any]: try: return self.trials[trial_name].response_object except KeyError: - arn = FakeTrial.arn_formatter(trial_name, self.region_name) + arn = FakeTrial.arn_formatter(trial_name, self.account_id, self.region_name) raise ValidationError(message=f"Could not find trial '{arn}'.") - def delete_trial(self, trial_name): + def delete_trial(self, trial_name: str) -> None: try: del self.trials[trial_name] except KeyError: - arn = FakeTrial.arn_formatter(trial_name, self.region_name) + arn = FakeTrial.arn_formatter(trial_name, self.account_id, self.region_name) raise ValidationError( message=f"Could not find trial configuration '{arn}'." ) @paginate(pagination_model=PAGINATION_MODEL) - def list_trials(self, experiment_name=None, trial_component_name=None): + def list_trials(self, experiment_name: Optional[str] = None, trial_component_name: Optional[str] = None) -> List["FakeTrial"]: # type: ignore[misc] trials_fetched = list(self.trials.values()) - def evaluate_filter_expression(trial_data): + def evaluate_filter_expression(trial_data: FakeTrial) -> bool: if experiment_name is not None: if trial_data.experiment_name != experiment_name: return False @@ -1476,7 +1543,9 @@ class SageMakerModelBackend(BaseBackend): if evaluate_filter_expression(trial_data) ] - def create_trial_component(self, trial_component_name, trial_name): + def create_trial_component( + self, trial_component_name: str, trial_name: str + ) -> Dict[str, Any]: trial_component = FakeTrialComponent( account_id=self.account_id, region_name=self.region_name, @@ -1487,16 +1556,18 @@ class SageMakerModelBackend(BaseBackend): self.trial_components[trial_component_name] = trial_component return trial_component.response_create - def delete_trial_component(self, trial_component_name): + def delete_trial_component(self, trial_component_name: str) -> None: try: del self.trial_components[trial_component_name] except KeyError: - arn = FakeTrial.arn_formatter(trial_component_name, self.region_name) + arn = FakeTrial.arn_formatter( + trial_component_name, self.account_id, self.region_name + ) raise ValidationError( message=f"Could not find trial-component configuration '{arn}'." ) - def describe_trial_component(self, trial_component_name): + def describe_trial_component(self, trial_component_name: str) -> Dict[str, Any]: try: return self.trial_components[trial_component_name].response_object except KeyError: @@ -1505,11 +1576,13 @@ class SageMakerModelBackend(BaseBackend): ) raise ValidationError(message=f"Could not find trial component '{arn}'.") - def _update_trial_component_details(self, trial_component_name, details_json): + def _update_trial_component_details( + self, trial_component_name: str, details_json: str + ) -> None: self.trial_components[trial_component_name].update(details_json) @paginate(pagination_model=PAGINATION_MODEL) - def list_trial_components(self, trial_name=None): + def list_trial_components(self, trial_name: Optional[str] = None) -> List["FakeTrialComponent"]: # type: ignore[misc] trial_components_fetched = list(self.trial_components.values()) return [ @@ -1518,10 +1591,9 @@ class SageMakerModelBackend(BaseBackend): if trial_name is None or trial_component_data.trial_name == trial_name ] - def associate_trial_component(self, params): - trial_name = params["TrialName"] - trial_component_name = params["TrialComponentName"] - + def associate_trial_component( + self, trial_name: str, trial_component_name: str + ) -> Dict[str, str]: if trial_name in self.trials.keys(): self.trials[trial_name].trial_components.extend([trial_component_name]) else: @@ -1539,17 +1611,16 @@ class SageMakerModelBackend(BaseBackend): "TrialArn": self.trials[trial_name].trial_arn, } - def disassociate_trial_component(self, params): - trial_component_name = params["TrialComponentName"] - trial_name = params["TrialName"] - + def disassociate_trial_component( + self, trial_name: str, trial_component_name: str + ) -> Dict[str, str]: if trial_component_name in self.trial_components.keys(): self.trial_components[trial_component_name].trial_name = None if trial_name in self.trials.keys(): self.trials[trial_name].trial_components = list( filter( - lambda x: x != trial_component_name, + lambda x: x != trial_component_name, # type: ignore self.trials[trial_name].trial_components, ) ) @@ -1561,21 +1632,21 @@ class SageMakerModelBackend(BaseBackend): def create_notebook_instance( self, - notebook_instance_name, - instance_type, - role_arn, - subnet_id=None, - security_group_ids=None, - kms_key_id=None, - tags=None, - lifecycle_config_name=None, - direct_internet_access="Enabled", - volume_size_in_gb=5, - accelerator_types=None, - default_code_repository=None, - additional_code_repositories=None, - root_access=None, - ): + notebook_instance_name: str, + instance_type: str, + role_arn: str, + subnet_id: Optional[str] = None, + security_group_ids: Optional[List[str]] = None, + kms_key_id: Optional[str] = None, + tags: Optional[List[Dict[str, str]]] = None, + lifecycle_config_name: Optional[str] = None, + direct_internet_access: str = "Enabled", + volume_size_in_gb: int = 5, + accelerator_types: Optional[List[str]] = None, + default_code_repository: Optional[str] = None, + additional_code_repositories: Optional[List[str]] = None, + root_access: Optional[str] = None, + ) -> FakeSagemakerNotebookInstance: self._validate_unique_notebook_instance_name(notebook_instance_name) notebook_instance = FakeSagemakerNotebookInstance( @@ -1601,27 +1672,31 @@ class SageMakerModelBackend(BaseBackend): self.notebook_instances[notebook_instance_name] = notebook_instance return notebook_instance - def _validate_unique_notebook_instance_name(self, notebook_instance_name): + def _validate_unique_notebook_instance_name( + self, notebook_instance_name: str + ) -> None: if notebook_instance_name in self.notebook_instances: duplicate_arn = self.notebook_instances[notebook_instance_name].arn message = f"Cannot create a duplicate Notebook Instance ({duplicate_arn})" raise ValidationError(message=message) - def get_notebook_instance(self, notebook_instance_name): + def get_notebook_instance( + self, notebook_instance_name: str + ) -> FakeSagemakerNotebookInstance: try: return self.notebook_instances[notebook_instance_name] except KeyError: raise ValidationError(message="RecordNotFound") - def start_notebook_instance(self, notebook_instance_name): + def start_notebook_instance(self, notebook_instance_name: str) -> None: notebook_instance = self.get_notebook_instance(notebook_instance_name) notebook_instance.start() - def stop_notebook_instance(self, notebook_instance_name): + def stop_notebook_instance(self, notebook_instance_name: str) -> None: notebook_instance = self.get_notebook_instance(notebook_instance_name) notebook_instance.stop() - def delete_notebook_instance(self, notebook_instance_name): + def delete_notebook_instance(self, notebook_instance_name: str) -> None: notebook_instance = self.get_notebook_instance(notebook_instance_name) if not notebook_instance.is_deletable: message = f"Status ({notebook_instance.status}) not in ([Stopped, Failed]). Unable to transition to (Deleting) for Notebook Instance ({notebook_instance.arn})" @@ -1629,8 +1704,11 @@ class SageMakerModelBackend(BaseBackend): del self.notebook_instances[notebook_instance_name] def create_notebook_instance_lifecycle_config( - self, notebook_instance_lifecycle_config_name, on_create, on_start - ): + self, + notebook_instance_lifecycle_config_name: str, + on_create: List[Dict[str, str]], + on_start: List[Dict[str, str]], + ) -> FakeSageMakerNotebookInstanceLifecycleConfig: if ( notebook_instance_lifecycle_config_name in self.notebook_instance_lifecycle_configurations @@ -1655,8 +1733,8 @@ class SageMakerModelBackend(BaseBackend): return lifecycle_config def describe_notebook_instance_lifecycle_config( - self, notebook_instance_lifecycle_config_name - ): + self, notebook_instance_lifecycle_config_name: str + ) -> Dict[str, Any]: try: return self.notebook_instance_lifecycle_configurations[ notebook_instance_lifecycle_config_name @@ -1671,8 +1749,8 @@ class SageMakerModelBackend(BaseBackend): raise ValidationError(message=message) def delete_notebook_instance_lifecycle_config( - self, notebook_instance_lifecycle_config_name - ): + self, notebook_instance_lifecycle_config_name: str + ) -> None: try: del self.notebook_instance_lifecycle_configurations[ notebook_instance_lifecycle_config_name @@ -1688,12 +1766,12 @@ class SageMakerModelBackend(BaseBackend): def create_endpoint_config( self, - endpoint_config_name, - production_variants, - data_capture_config, - tags, - kms_key_id, - ): + endpoint_config_name: str, + production_variants: List[Dict[str, Any]], + data_capture_config: Dict[str, Any], + tags: List[Dict[str, str]], + kms_key_id: str, + ) -> FakeEndpointConfig: endpoint_config = FakeEndpointConfig( account_id=self.account_id, region_name=self.region_name, @@ -1708,7 +1786,9 @@ class SageMakerModelBackend(BaseBackend): self.endpoint_configs[endpoint_config_name] = endpoint_config return endpoint_config - def validate_production_variants(self, production_variants): + def validate_production_variants( + self, production_variants: List[Dict[str, Any]] + ) -> None: for production_variant in production_variants: if production_variant["ModelName"] not in self._models: arn = arn_formatter( @@ -1719,7 +1799,7 @@ class SageMakerModelBackend(BaseBackend): ) raise ValidationError(message=f"Could not find model '{arn}'.") - def describe_endpoint_config(self, endpoint_config_name): + def describe_endpoint_config(self, endpoint_config_name: str) -> Dict[str, Any]: try: return self.endpoint_configs[endpoint_config_name].response_object except KeyError: @@ -1730,7 +1810,7 @@ class SageMakerModelBackend(BaseBackend): message=f"Could not find endpoint configuration '{arn}'." ) - def delete_endpoint_config(self, endpoint_config_name): + def delete_endpoint_config(self, endpoint_config_name: str) -> None: try: del self.endpoint_configs[endpoint_config_name] except KeyError: @@ -1741,7 +1821,9 @@ class SageMakerModelBackend(BaseBackend): message=f"Could not find endpoint configuration '{arn}'." ) - def create_endpoint(self, endpoint_name, endpoint_config_name, tags): + def create_endpoint( + self, endpoint_name: str, endpoint_config_name: str, tags: List[Dict[str, str]] + ) -> FakeEndpoint: try: endpoint_config = self.describe_endpoint_config(endpoint_config_name) except KeyError: @@ -1763,7 +1845,7 @@ class SageMakerModelBackend(BaseBackend): self.endpoints[endpoint_name] = endpoint return endpoint - def describe_endpoint(self, endpoint_name): + def describe_endpoint(self, endpoint_name: str) -> Dict[str, Any]: try: return self.endpoints[endpoint_name].response_object except KeyError: @@ -1772,7 +1854,7 @@ class SageMakerModelBackend(BaseBackend): ) raise ValidationError(message=f"Could not find endpoint '{arn}'.") - def delete_endpoint(self, endpoint_name): + def delete_endpoint(self, endpoint_name: str) -> None: try: del self.endpoints[endpoint_name] except KeyError: @@ -1783,16 +1865,16 @@ class SageMakerModelBackend(BaseBackend): def create_processing_job( self, - app_specification, - experiment_config, - network_config, - processing_inputs, - processing_job_name, - processing_output_config, - role_arn, - tags, - stopping_condition, - ): + app_specification: Dict[str, Any], + experiment_config: Dict[str, str], + network_config: Dict[str, Any], + processing_inputs: List[Dict[str, Any]], + processing_job_name: str, + processing_output_config: Dict[str, Any], + role_arn: str, + tags: List[Dict[str, str]], + stopping_condition: Dict[str, int], + ) -> FakeProcessingJob: processing_job = FakeProcessingJob( app_specification=app_specification, experiment_config=experiment_config, @@ -1809,7 +1891,7 @@ class SageMakerModelBackend(BaseBackend): self.processing_jobs[processing_job_name] = processing_job return processing_job - def describe_processing_job(self, processing_job_name): + def describe_processing_job(self, processing_job_name: str) -> Dict[str, Any]: try: return self.processing_jobs[processing_job_name].response_object except KeyError: @@ -1820,15 +1902,15 @@ class SageMakerModelBackend(BaseBackend): def create_pipeline( self, - pipeline_name, - pipeline_display_name, - pipeline_definition, - pipeline_definition_s3_location, - pipeline_description, - role_arn, - tags, - parallelism_configuration, - ): + pipeline_name: str, + pipeline_display_name: str, + pipeline_definition: str, + pipeline_definition_s3_location: Dict[str, Any], + pipeline_description: str, + role_arn: str, + tags: List[Dict[str, str]], + parallelism_configuration: Dict[str, int], + ) -> FakePipeline: if not any([pipeline_definition, pipeline_definition_s3_location]): raise ValidationError( "An error occurred (ValidationException) when calling the CreatePipeline operation: Either " @@ -1847,7 +1929,7 @@ class SageMakerModelBackend(BaseBackend): ) if pipeline_definition_s3_location: - pipeline_definition = load_pipeline_definition_from_s3( + pipeline_definition = load_pipeline_definition_from_s3( # type: ignore pipeline_definition_s3_location, self.account_id ) @@ -1866,19 +1948,12 @@ class SageMakerModelBackend(BaseBackend): self.pipelines[pipeline_name] = pipeline return pipeline - def delete_pipeline( - self, - pipeline_name, - ): + def delete_pipeline(self, pipeline_name: str) -> str: pipeline = get_pipeline_from_name(self.pipelines, pipeline_name) del self.pipelines[pipeline.pipeline_name] return pipeline.pipeline_arn - def update_pipeline( - self, - pipeline_name, - **kwargs, - ): + def update_pipeline(self, pipeline_name: str, **kwargs: Any) -> str: pipeline = get_pipeline_from_name(self.pipelines, pipeline_name) if all( [ @@ -1896,7 +1971,7 @@ class SageMakerModelBackend(BaseBackend): if attr_key == "pipeline_definition_s3_location": self.pipelines[ pipeline_name - ].pipeline_definition = load_pipeline_definition_from_s3( + ].pipeline_definition = load_pipeline_definition_from_s3( # type: ignore attr_value, self.account_id ) continue @@ -1906,13 +1981,13 @@ class SageMakerModelBackend(BaseBackend): def start_pipeline_execution( self, - pipeline_name, - pipeline_execution_display_name, - pipeline_parameters, - pipeline_execution_description, - parallelism_configuration, - client_request_token, - ): + pipeline_name: str, + pipeline_execution_display_name: str, + pipeline_parameters: List[Dict[str, Any]], + pipeline_execution_description: str, + parallelism_configuration: Dict[str, int], + client_request_token: str, + ) -> Dict[str, str]: pipeline = get_pipeline_from_name(self.pipelines, pipeline_name) execution_id = "".join( random.choices(string.ascii_lowercase + string.digits, k=12) @@ -1942,15 +2017,11 @@ class SageMakerModelBackend(BaseBackend): pipeline_name ].last_execution_time = fake_pipeline_execution.start_time - response = {"PipelineExecutionArn": pipeline_execution_arn} - return response + return {"PipelineExecutionArn": pipeline_execution_arn} - def list_pipeline_executions( - self, - pipeline_name, - ): + def list_pipeline_executions(self, pipeline_name: str) -> Dict[str, Any]: pipeline = get_pipeline_from_name(self.pipelines, pipeline_name) - response = { + return { "PipelineExecutionSummaries": [ { "PipelineExecutionArn": pipeline_execution_arn, @@ -1965,46 +2036,40 @@ class SageMakerModelBackend(BaseBackend): for pipeline_execution_arn, pipeline_execution in pipeline.pipeline_executions.items() ] } - return response def describe_pipeline_definition_for_execution( - self, - pipeline_execution_arn, - ): + self, pipeline_execution_arn: str + ) -> Dict[str, Any]: pipeline_execution = get_pipeline_execution_from_arn( self.pipelines, pipeline_execution_arn ) - response = { + return { "PipelineDefinition": str( pipeline_execution.pipeline_definition_for_execution ), "CreationTime": pipeline_execution.creation_time, } - return response def list_pipeline_parameters_for_execution( - self, - pipeline_execution_arn, - ): + self, pipeline_execution_arn: str + ) -> Dict[str, Any]: pipeline_execution = get_pipeline_execution_from_arn( self.pipelines, pipeline_execution_arn ) - response = { + return { "PipelineParameters": pipeline_execution.pipeline_parameters, } - return response def describe_pipeline_execution( - self, - pipeline_execution_arn, - ): + self, pipeline_execution_arn: str + ) -> Dict[str, Any]: pipeline_execution = get_pipeline_execution_from_arn( self.pipelines, pipeline_execution_arn ) pipeline_name = get_pipeline_name_from_execution_arn(pipeline_execution_arn) pipeline = get_pipeline_from_name(self.pipelines, pipeline_name) - pipeline_execution_summaries = { + return { "PipelineArn": pipeline.pipeline_arn, "PipelineExecutionArn": pipeline_execution.pipeline_execution_arn, "PipelineExecutionDisplayName": pipeline_execution.pipeline_execution_display_name, @@ -2018,14 +2083,10 @@ class SageMakerModelBackend(BaseBackend): "LastModifiedBy": pipeline_execution.last_modified_by, "ParallelismConfiguration": pipeline_execution.parallelism_configuration, } - return pipeline_execution_summaries - def describe_pipeline( - self, - pipeline_name, - ): + def describe_pipeline(self, pipeline_name: str) -> Dict[str, Any]: pipeline = get_pipeline_from_name(self.pipelines, pipeline_name) - response = { + return { "PipelineArn": pipeline.pipeline_arn, "PipelineName": pipeline.pipeline_name, "PipelineDisplayName": pipeline.pipeline_display_name, @@ -2041,18 +2102,16 @@ class SageMakerModelBackend(BaseBackend): "ParallelismConfiguration": pipeline.parallelism_configuration, } - return response - def list_pipelines( self, - pipeline_name_prefix, - created_after, - created_before, - next_token, - max_results, - sort_by, - sort_order, - ): + pipeline_name_prefix: str, + created_after: str, + created_before: str, + next_token: str, + max_results: int, + sort_by: str, + sort_order: str, + ) -> Dict[str, Any]: if next_token: try: starting_index = int(next_token) @@ -2065,7 +2124,9 @@ class SageMakerModelBackend(BaseBackend): if max_results: end_index = max_results + starting_index - pipelines_fetched = list(self.pipelines.values())[starting_index:end_index] + pipelines_fetched: Iterable[FakePipeline] = list(self.pipelines.values())[ + starting_index:end_index + ] if end_index >= len(self.pipelines): next_index = None else: @@ -2080,7 +2141,7 @@ class SageMakerModelBackend(BaseBackend): pipelines_fetched, ) - def format_time(x): + def format_time(x: Any) -> str: return ( x if isinstance(x, str) @@ -2100,11 +2161,10 @@ class SageMakerModelBackend(BaseBackend): ) sort_key = "pipeline_name" if sort_by == "Name" else "creation_time" - sort_order = False if sort_order == "Ascending" else True pipelines_fetched = sorted( pipelines_fetched, key=lambda pipeline_fetched: getattr(pipeline_fetched, sort_key), - reverse=sort_order, + reverse=sort_order != "Ascending", ) pipeline_summaries = [ @@ -2128,15 +2188,15 @@ class SageMakerModelBackend(BaseBackend): def list_processing_jobs( self, - next_token, - max_results, - creation_time_after, - creation_time_before, - last_modified_time_after, - last_modified_time_before, - name_contains, - status_equals, - ): + next_token: str, + max_results: int, + creation_time_after: str, + creation_time_before: str, + last_modified_time_after: str, + last_modified_time_before: str, + name_contains: str, + status_equals: str, + ) -> Dict[str, Any]: if next_token: try: starting_index = int(next_token) @@ -2149,9 +2209,9 @@ class SageMakerModelBackend(BaseBackend): if max_results: end_index = max_results + starting_index - processing_jobs_fetched = list(self.processing_jobs.values())[ - starting_index:end_index - ] + processing_jobs_fetched: Iterable[FakeProcessingJob] = list( + self.processing_jobs.values() + )[starting_index:end_index] if end_index >= len(self.processing_jobs): next_index = None else: @@ -2190,7 +2250,7 @@ class SageMakerModelBackend(BaseBackend): ) if status_equals is not None: processing_jobs_fetched = filter( - lambda x: x.training_job_status == status_equals, + lambda x: x.processing_job_status == status_equals, processing_jobs_fetched, ) @@ -2213,25 +2273,25 @@ class SageMakerModelBackend(BaseBackend): def create_training_job( self, - training_job_name, - hyper_parameters, - algorithm_specification, - role_arn, - input_data_config, - output_data_config, - resource_config, - vpc_config, - stopping_condition, - tags, - enable_network_isolation, - enable_inter_container_traffic_encryption, - enable_managed_spot_training, - checkpoint_config, - debug_hook_config, - debug_rule_configurations, - tensor_board_output_config, - experiment_config, - ): + training_job_name: str, + hyper_parameters: Dict[str, str], + algorithm_specification: Dict[str, Any], + role_arn: str, + input_data_config: List[Dict[str, Any]], + output_data_config: Dict[str, str], + resource_config: Dict[str, Any], + vpc_config: Dict[str, List[str]], + stopping_condition: Dict[str, int], + tags: List[Dict[str, str]], + enable_network_isolation: bool, + enable_inter_container_traffic_encryption: bool, + enable_managed_spot_training: bool, + checkpoint_config: Dict[str, str], + debug_hook_config: Dict[str, Any], + debug_rule_configurations: List[Dict[str, Any]], + tensor_board_output_config: Dict[str, str], + experiment_config: Dict[str, str], + ) -> FakeTrainingJob: training_job = FakeTrainingJob( account_id=self.account_id, region_name=self.region_name, @@ -2257,7 +2317,7 @@ class SageMakerModelBackend(BaseBackend): self.training_jobs[training_job_name] = training_job return training_job - def describe_training_job(self, training_job_name): + def describe_training_job(self, training_job_name: str) -> Dict[str, Any]: try: return self.training_jobs[training_job_name].response_object except KeyError: @@ -2269,15 +2329,15 @@ class SageMakerModelBackend(BaseBackend): def list_training_jobs( self, - next_token, - max_results, - creation_time_after, - creation_time_before, - last_modified_time_after, - last_modified_time_before, - name_contains, - status_equals, - ): + next_token: str, + max_results: int, + creation_time_after: str, + creation_time_before: str, + last_modified_time_after: str, + last_modified_time_before: str, + name_contains: str, + status_equals: str, + ) -> Dict[str, Any]: if next_token: try: starting_index = int(next_token) @@ -2290,9 +2350,9 @@ class SageMakerModelBackend(BaseBackend): if max_results: end_index = max_results + starting_index - training_jobs_fetched = list(self.training_jobs.values())[ - starting_index:end_index - ] + training_jobs_fetched: Iterable[FakeTrainingJob] = list( + self.training_jobs.values() + )[starting_index:end_index] if end_index >= len(self.training_jobs): next_index = None else: @@ -2350,8 +2410,8 @@ class SageMakerModelBackend(BaseBackend): } def update_endpoint_weights_and_capacities( - self, endpoint_name, desired_weights_and_capacities - ): + self, endpoint_name: str, desired_weights_and_capacities: List[Dict[str, Any]] + ) -> str: # Validate inputs endpoint = self.endpoints.get(endpoint_name, None) if not endpoint: @@ -2400,7 +2460,13 @@ class SageMakerModelBackend(BaseBackend): class FakeExperiment(BaseObject): - def __init__(self, account_id, region_name, experiment_name, tags): + def __init__( + self, + account_id: str, + region_name: str, + experiment_name: str, + tags: List[Dict[str, str]], + ): self.experiment_name = experiment_name self.experiment_arn = arn_formatter( "experiment", experiment_name, account_id, region_name @@ -2411,31 +2477,29 @@ class FakeExperiment(BaseObject): ) @property - def response_object(self): + def response_object(self) -> Dict[str, Any]: # type: ignore[misc] response_object = self.gen_response_object() return { k: v for k, v in response_object.items() if v is not None and v != [None] } @property - def response_create(self): + def response_create(self) -> Dict[str, str]: return {"ExperimentArn": self.experiment_arn} class FakeTrial(BaseObject): def __init__( self, - account_id, - region_name, - trial_name, - experiment_name, - tags, - trial_components, + account_id: str, + region_name: str, + trial_name: str, + experiment_name: str, + tags: List[Dict[str, str]], + trial_components: List[str], ): self.trial_name = trial_name - self.trial_arn = arn_formatter( - "experiment-trial", trial_name, account_id, region_name - ) + self.trial_arn = FakeTrial.arn_formatter(trial_name, account_id, region_name) self.tags = tags self.trial_components = trial_components self.experiment_name = experiment_name @@ -2444,19 +2508,30 @@ class FakeTrial(BaseObject): ) @property - def response_object(self): + def response_object(self) -> Dict[str, Any]: # type: ignore[misc] response_object = self.gen_response_object() return { k: v for k, v in response_object.items() if v is not None and v != [None] } @property - def response_create(self): + def response_create(self) -> Dict[str, str]: return {"TrialArn": self.trial_arn} + @staticmethod + def arn_formatter(name: str, account_id: str, region: str) -> str: + return arn_formatter("experiment-trial", name, account_id, region) + class FakeTrialComponent(BaseObject): - def __init__(self, account_id, region_name, trial_component_name, trial_name, tags): + def __init__( + self, + account_id: str, + region_name: str, + trial_component_name: str, + trial_name: Optional[str], + tags: List[Dict[str, str]], + ): self.trial_component_name = trial_component_name self.trial_component_arn = FakeTrialComponent.arn_formatter( trial_component_name, account_id, region_name @@ -2467,18 +2542,20 @@ class FakeTrialComponent(BaseObject): self.creation_time = self.last_modified_time = now_string @property - def response_object(self): + def response_object(self) -> Dict[str, Any]: # type: ignore[misc] response_object = self.gen_response_object() return { k: v for k, v in response_object.items() if v is not None and v != [None] } @property - def response_create(self): + def response_create(self) -> Dict[str, str]: return {"TrialComponentArn": self.trial_component_arn} @staticmethod - def arn_formatter(trial_component_name, account_id, region_name): + def arn_formatter( + trial_component_name: str, account_id: str, region_name: str + ) -> str: return arn_formatter( "experiment-trial-component", trial_component_name, account_id, region_name ) diff --git a/moto/sagemaker/responses.py b/moto/sagemaker/responses.py index c679cb5dd..a294e3e3b 100644 --- a/moto/sagemaker/responses.py +++ b/moto/sagemaker/responses.py @@ -1,54 +1,59 @@ import json +from typing import Any from moto.sagemaker.exceptions import AWSValidationException +from moto.core.common_types import TYPE_RESPONSE from moto.core.responses import BaseResponse from moto.utilities.aws_headers import amzn_request_id -from .models import sagemaker_backends +from .models import sagemaker_backends, SageMakerModelBackend -def format_enum_error(value, attribute, allowed): +def format_enum_error(value: str, attribute: str, allowed: Any) -> str: return f"Value '{value}' at '{attribute}' failed to satisfy constraint: Member must satisfy enum value set: {allowed}" class SageMakerResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="sagemaker") @property - def sagemaker_backend(self): + def sagemaker_backend(self) -> SageMakerModelBackend: return sagemaker_backends[self.current_account][self.region] - @property - def request_params(self): - try: - return json.loads(self.body) - except ValueError: - return {} - - def describe_model(self): + def describe_model(self) -> str: model_name = self._get_param("ModelName") model = self.sagemaker_backend.describe_model(model_name) return json.dumps(model.response_object) - def create_model(self): - model = self.sagemaker_backend.create_model(**self.request_params) + def create_model(self) -> str: + model_name = self._get_param("ModelName") + execution_role_arn = self._get_param("ExecutionRoleArn") + primary_container = self._get_param("PrimaryContainer") + vpc_config = self._get_param("VpcConfig") + containers = self._get_param("Containers") + tags = self._get_param("Tags") + model = self.sagemaker_backend.create_model( + model_name=model_name, + execution_role_arn=execution_role_arn, + primary_container=primary_container, + vpc_config=vpc_config, + containers=containers, + tags=tags, + ) return json.dumps(model.response_create) - def delete_model(self): + def delete_model(self) -> str: model_name = self._get_param("ModelName") - response = self.sagemaker_backend.delete_model(model_name) - return json.dumps(response) + self.sagemaker_backend.delete_model(model_name) + return "{}" - def list_models(self): - models = self.sagemaker_backend.list_models(**self.request_params) + def list_models(self) -> str: + models = self.sagemaker_backend.list_models() return json.dumps({"Models": [model.response_object for model in models]}) - def _get_param(self, param_name, if_none=None): - return self.request_params.get(param_name, if_none) - @amzn_request_id - def create_notebook_instance(self): + def create_notebook_instance(self) -> TYPE_RESPONSE: sagemaker_notebook = self.sagemaker_backend.create_notebook_instance( notebook_instance_name=self._get_param("NotebookInstanceName"), instance_type=self._get_param("InstanceType"), @@ -65,13 +70,10 @@ class SageMakerResponse(BaseResponse): additional_code_repositories=self._get_param("AdditionalCodeRepositories"), root_access=self._get_param("RootAccess"), ) - response = { - "NotebookInstanceArn": sagemaker_notebook.arn, - } - return 200, {}, json.dumps(response) + return 200, {}, json.dumps({"NotebookInstanceArn": sagemaker_notebook.arn}) @amzn_request_id - def describe_notebook_instance(self): + def describe_notebook_instance(self) -> TYPE_RESPONSE: notebook_instance_name = self._get_param("NotebookInstanceName") notebook_instance = self.sagemaker_backend.get_notebook_instance( notebook_instance_name @@ -100,25 +102,25 @@ class SageMakerResponse(BaseResponse): return 200, {}, json.dumps(response) @amzn_request_id - def start_notebook_instance(self): + def start_notebook_instance(self) -> TYPE_RESPONSE: notebook_instance_name = self._get_param("NotebookInstanceName") self.sagemaker_backend.start_notebook_instance(notebook_instance_name) return 200, {}, json.dumps("{}") @amzn_request_id - def stop_notebook_instance(self): + def stop_notebook_instance(self) -> TYPE_RESPONSE: notebook_instance_name = self._get_param("NotebookInstanceName") self.sagemaker_backend.stop_notebook_instance(notebook_instance_name) return 200, {}, json.dumps("{}") @amzn_request_id - def delete_notebook_instance(self): + def delete_notebook_instance(self) -> TYPE_RESPONSE: notebook_instance_name = self._get_param("NotebookInstanceName") self.sagemaker_backend.delete_notebook_instance(notebook_instance_name) return 200, {}, json.dumps("{}") @amzn_request_id - def list_tags(self): + def list_tags(self) -> TYPE_RESPONSE: arn = self._get_param("ResourceArn") max_results = self._get_param("MaxResults") next_token = self._get_param("NextToken") @@ -131,22 +133,21 @@ class SageMakerResponse(BaseResponse): return 200, {}, json.dumps(response) @amzn_request_id - def add_tags(self): + def add_tags(self) -> TYPE_RESPONSE: arn = self._get_param("ResourceArn") tags = self._get_param("Tags") tags = self.sagemaker_backend.add_tags(arn, tags) - response = {"Tags": tags} - return 200, {}, json.dumps(response) + return 200, {}, json.dumps({"Tags": tags}) @amzn_request_id - def delete_tags(self): + def delete_tags(self) -> TYPE_RESPONSE: arn = self._get_param("ResourceArn") tag_keys = self._get_param("TagKeys") self.sagemaker_backend.delete_tags(arn, tag_keys) return 200, {}, json.dumps({}) @amzn_request_id - def create_endpoint_config(self): + def create_endpoint_config(self) -> TYPE_RESPONSE: endpoint_config = self.sagemaker_backend.create_endpoint_config( endpoint_config_name=self._get_param("EndpointConfigName"), production_variants=self._get_param("ProductionVariants"), @@ -154,49 +155,47 @@ class SageMakerResponse(BaseResponse): tags=self._get_param("Tags"), kms_key_id=self._get_param("KmsKeyId"), ) - response = { - "EndpointConfigArn": endpoint_config.endpoint_config_arn, - } - return 200, {}, json.dumps(response) + return ( + 200, + {}, + json.dumps({"EndpointConfigArn": endpoint_config.endpoint_config_arn}), + ) @amzn_request_id - def describe_endpoint_config(self): + def describe_endpoint_config(self) -> str: endpoint_config_name = self._get_param("EndpointConfigName") response = self.sagemaker_backend.describe_endpoint_config(endpoint_config_name) return json.dumps(response) @amzn_request_id - def delete_endpoint_config(self): + def delete_endpoint_config(self) -> TYPE_RESPONSE: endpoint_config_name = self._get_param("EndpointConfigName") self.sagemaker_backend.delete_endpoint_config(endpoint_config_name) return 200, {}, json.dumps("{}") @amzn_request_id - def create_endpoint(self): + def create_endpoint(self) -> TYPE_RESPONSE: endpoint = self.sagemaker_backend.create_endpoint( endpoint_name=self._get_param("EndpointName"), endpoint_config_name=self._get_param("EndpointConfigName"), tags=self._get_param("Tags"), ) - response = { - "EndpointArn": endpoint.endpoint_arn, - } - return 200, {}, json.dumps(response) + return 200, {}, json.dumps({"EndpointArn": endpoint.endpoint_arn}) @amzn_request_id - def describe_endpoint(self): + def describe_endpoint(self) -> str: endpoint_name = self._get_param("EndpointName") response = self.sagemaker_backend.describe_endpoint(endpoint_name) return json.dumps(response) @amzn_request_id - def delete_endpoint(self): + def delete_endpoint(self) -> TYPE_RESPONSE: endpoint_name = self._get_param("EndpointName") self.sagemaker_backend.delete_endpoint(endpoint_name) return 200, {}, json.dumps("{}") @amzn_request_id - def create_processing_job(self): + def create_processing_job(self) -> TYPE_RESPONSE: processing_job = self.sagemaker_backend.create_processing_job( app_specification=self._get_param("AppSpecification"), experiment_config=self._get_param("ExperimentConfig"), @@ -208,19 +207,17 @@ class SageMakerResponse(BaseResponse): stopping_condition=self._get_param("StoppingCondition"), tags=self._get_param("Tags"), ) - response = { - "ProcessingJobArn": processing_job.processing_job_arn, - } + response = {"ProcessingJobArn": processing_job.processing_job_arn} return 200, {}, json.dumps(response) @amzn_request_id - def describe_processing_job(self): + def describe_processing_job(self) -> str: processing_job_name = self._get_param("ProcessingJobName") response = self.sagemaker_backend.describe_processing_job(processing_job_name) return json.dumps(response) @amzn_request_id - def create_training_job(self): + def create_training_job(self) -> TYPE_RESPONSE: training_job = self.sagemaker_backend.create_training_job( training_job_name=self._get_param("TrainingJobName"), hyper_parameters=self._get_param("HyperParameters"), @@ -251,13 +248,13 @@ class SageMakerResponse(BaseResponse): return 200, {}, json.dumps(response) @amzn_request_id - def describe_training_job(self): + def describe_training_job(self) -> str: training_job_name = self._get_param("TrainingJobName") response = self.sagemaker_backend.describe_training_job(training_job_name) return json.dumps(response) @amzn_request_id - def create_notebook_instance_lifecycle_config(self): + def create_notebook_instance_lifecycle_config(self) -> TYPE_RESPONSE: lifecycle_configuration = ( self.sagemaker_backend.create_notebook_instance_lifecycle_config( notebook_instance_lifecycle_config_name=self._get_param( @@ -273,7 +270,7 @@ class SageMakerResponse(BaseResponse): return 200, {}, json.dumps(response) @amzn_request_id - def describe_notebook_instance_lifecycle_config(self): + def describe_notebook_instance_lifecycle_config(self) -> str: response = self.sagemaker_backend.describe_notebook_instance_lifecycle_config( notebook_instance_lifecycle_config_name=self._get_param( "NotebookInstanceLifecycleConfigName" @@ -282,7 +279,7 @@ class SageMakerResponse(BaseResponse): return json.dumps(response) @amzn_request_id - def delete_notebook_instance_lifecycle_config(self): + def delete_notebook_instance_lifecycle_config(self) -> TYPE_RESPONSE: self.sagemaker_backend.delete_notebook_instance_lifecycle_config( notebook_instance_lifecycle_config_name=self._get_param( "NotebookInstanceLifecycleConfigName" @@ -291,7 +288,7 @@ class SageMakerResponse(BaseResponse): return 200, {}, json.dumps("{}") @amzn_request_id - def search(self): + def search(self) -> TYPE_RESPONSE: response = self.sagemaker_backend.search( resource=self._get_param("Resource"), search_expression=self._get_param("SearchExpression"), @@ -299,7 +296,7 @@ class SageMakerResponse(BaseResponse): return 200, {}, json.dumps(response) @amzn_request_id - def list_experiments(self): + def list_experiments(self) -> TYPE_RESPONSE: MaxResults = self._get_param("MaxResults") NextToken = self._get_param("NextToken") @@ -327,28 +324,28 @@ class SageMakerResponse(BaseResponse): return 200, {}, json.dumps(response) @amzn_request_id - def delete_experiment(self): + def delete_experiment(self) -> TYPE_RESPONSE: self.sagemaker_backend.delete_experiment( experiment_name=self._get_param("ExperimentName") ) return 200, {}, json.dumps({}) @amzn_request_id - def create_experiment(self): + def create_experiment(self) -> TYPE_RESPONSE: response = self.sagemaker_backend.create_experiment( experiment_name=self._get_param("ExperimentName") ) return 200, {}, json.dumps(response) @amzn_request_id - def describe_experiment(self): + def describe_experiment(self) -> TYPE_RESPONSE: response = self.sagemaker_backend.describe_experiment( experiment_name=self._get_param("ExperimentName") ) return 200, {}, json.dumps(response) @amzn_request_id - def list_trials(self): + def list_trials(self) -> TYPE_RESPONSE: MaxResults = self._get_param("MaxResults") NextToken = self._get_param("NextToken") @@ -379,7 +376,7 @@ class SageMakerResponse(BaseResponse): return 200, {}, json.dumps(response) @amzn_request_id - def create_trial(self): + def create_trial(self) -> TYPE_RESPONSE: response = self.sagemaker_backend.create_trial( trial_name=self._get_param("TrialName"), experiment_name=self._get_param("ExperimentName"), @@ -387,7 +384,7 @@ class SageMakerResponse(BaseResponse): return 200, {}, json.dumps(response) @amzn_request_id - def list_trial_components(self): + def list_trial_components(self) -> TYPE_RESPONSE: MaxResults = self._get_param("MaxResults") NextToken = self._get_param("NextToken") @@ -417,7 +414,7 @@ class SageMakerResponse(BaseResponse): return 200, {}, json.dumps(response) @amzn_request_id - def create_trial_component(self): + def create_trial_component(self) -> TYPE_RESPONSE: response = self.sagemaker_backend.create_trial_component( trial_component_name=self._get_param("TrialComponentName"), trial_name=self._get_param("TrialName"), @@ -425,55 +422,56 @@ class SageMakerResponse(BaseResponse): return 200, {}, json.dumps(response) @amzn_request_id - def describe_trial(self): + def describe_trial(self) -> str: trial_name = self._get_param("TrialName") response = self.sagemaker_backend.describe_trial(trial_name) return json.dumps(response) @amzn_request_id - def delete_trial(self): + def delete_trial(self) -> TYPE_RESPONSE: trial_name = self._get_param("TrialName") self.sagemaker_backend.delete_trial(trial_name) return 200, {}, json.dumps({}) @amzn_request_id - def delete_trial_component(self): + def delete_trial_component(self) -> TYPE_RESPONSE: trial_component_name = self._get_param("TrialComponentName") self.sagemaker_backend.delete_trial_component(trial_component_name) return 200, {}, json.dumps({}) @amzn_request_id - def describe_trial_component(self): + def describe_trial_component(self) -> str: trial_component_name = self._get_param("TrialComponentName") response = self.sagemaker_backend.describe_trial_component(trial_component_name) return json.dumps(response) @amzn_request_id - def associate_trial_component(self): - response = self.sagemaker_backend.associate_trial_component(self.request_params) - return 200, {}, json.dumps(response) - - @amzn_request_id - def disassociate_trial_component(self): - response = self.sagemaker_backend.disassociate_trial_component( - self.request_params + def associate_trial_component(self) -> TYPE_RESPONSE: + trial_name = self._get_param("TrialName") + trial_component_name = self._get_param("TrialComponentName") + response = self.sagemaker_backend.associate_trial_component( + trial_name, trial_component_name ) return 200, {}, json.dumps(response) @amzn_request_id - def list_associations(self, *args, **kwargs): # pylint: disable=unused-argument - response = self.sagemaker_backend.list_associations(self.request_params) + def disassociate_trial_component(self) -> TYPE_RESPONSE: + trial_component_name = self._get_param("TrialComponentName") + trial_name = self._get_param("TrialName") + response = self.sagemaker_backend.disassociate_trial_component( + trial_name, trial_component_name + ) return 200, {}, json.dumps(response) @amzn_request_id - def describe_pipeline(self): + def describe_pipeline(self) -> TYPE_RESPONSE: response = self.sagemaker_backend.describe_pipeline( self._get_param("PipelineName") ) return 200, {}, json.dumps(response) @amzn_request_id - def start_pipeline_execution(self): + def start_pipeline_execution(self) -> TYPE_RESPONSE: response = self.sagemaker_backend.start_pipeline_execution( self._get_param("PipelineName"), self._get_param("PipelineExecutionDisplayName"), @@ -485,35 +483,35 @@ class SageMakerResponse(BaseResponse): return 200, {}, json.dumps(response) @amzn_request_id - def describe_pipeline_execution(self): + def describe_pipeline_execution(self) -> TYPE_RESPONSE: response = self.sagemaker_backend.describe_pipeline_execution( self._get_param("PipelineExecutionArn") ) return 200, {}, json.dumps(response) @amzn_request_id - def describe_pipeline_definition_for_execution(self): + def describe_pipeline_definition_for_execution(self) -> TYPE_RESPONSE: response = self.sagemaker_backend.describe_pipeline_definition_for_execution( self._get_param("PipelineExecutionArn") ) return 200, {}, json.dumps(response) @amzn_request_id - def list_pipeline_parameters_for_execution(self): + def list_pipeline_parameters_for_execution(self) -> TYPE_RESPONSE: response = self.sagemaker_backend.list_pipeline_parameters_for_execution( self._get_param("PipelineExecutionArn") ) return 200, {}, json.dumps(response) @amzn_request_id - def list_pipeline_executions(self): + def list_pipeline_executions(self) -> TYPE_RESPONSE: response = self.sagemaker_backend.list_pipeline_executions( self._get_param("PipelineName") ) return 200, {}, json.dumps(response) @amzn_request_id - def create_pipeline(self): + def create_pipeline(self) -> TYPE_RESPONSE: pipeline = self.sagemaker_backend.create_pipeline( pipeline_name=self._get_param("PipelineName"), pipeline_display_name=self._get_param("PipelineDisplayName"), @@ -533,7 +531,7 @@ class SageMakerResponse(BaseResponse): return 200, {}, json.dumps(response) @amzn_request_id - def delete_pipeline(self): + def delete_pipeline(self) -> TYPE_RESPONSE: pipeline_arn = self.sagemaker_backend.delete_pipeline( pipeline_name=self._get_param("PipelineName"), ) @@ -541,7 +539,7 @@ class SageMakerResponse(BaseResponse): return 200, {}, json.dumps(response) @amzn_request_id - def update_pipeline(self): + def update_pipeline(self) -> TYPE_RESPONSE: pipeline_arn = self.sagemaker_backend.update_pipeline( pipeline_name=self._get_param("PipelineName"), pipeline_display_name=self._get_param("PipelineDisplayName"), @@ -558,7 +556,7 @@ class SageMakerResponse(BaseResponse): return 200, {}, json.dumps(response) @amzn_request_id - def list_pipelines(self): + def list_pipelines(self) -> TYPE_RESPONSE: max_results_range = range(1, 101) allowed_sort_by = ("Name", "CreationTime") allowed_sort_order = ("Ascending", "Descending") @@ -601,7 +599,7 @@ class SageMakerResponse(BaseResponse): return 200, {}, json.dumps(response) @amzn_request_id - def list_processing_jobs(self): + def list_processing_jobs(self) -> TYPE_RESPONSE: max_results_range = range(1, 101) allowed_sort_by = ["Name", "CreationTime", "Status"] allowed_sort_order = ["Ascending", "Descending"] @@ -654,7 +652,7 @@ class SageMakerResponse(BaseResponse): return 200, {}, json.dumps(response) @amzn_request_id - def list_training_jobs(self): + def list_training_jobs(self) -> TYPE_RESPONSE: max_results_range = range(1, 101) allowed_sort_by = ["Name", "CreationTime", "Status"] allowed_sort_order = ["Ascending", "Descending"] @@ -706,12 +704,11 @@ class SageMakerResponse(BaseResponse): ) return 200, {}, json.dumps(response) - def update_endpoint_weights_and_capacities(self): + def update_endpoint_weights_and_capacities(self) -> TYPE_RESPONSE: endpoint_name = self._get_param("EndpointName") desired_weights_and_capacities = self._get_param("DesiredWeightsAndCapacities") endpoint_arn = self.sagemaker_backend.update_endpoint_weights_and_capacities( endpoint_name=endpoint_name, desired_weights_and_capacities=desired_weights_and_capacities, ) - response = {"EndpointArn": endpoint_arn} - return 200, {}, json.dumps(response) + return 200, {}, json.dumps({"EndpointArn": endpoint_arn}) diff --git a/moto/sagemaker/utils.py b/moto/sagemaker/utils.py index 130ce41d2..6b6b9175e 100644 --- a/moto/sagemaker/utils.py +++ b/moto/sagemaker/utils.py @@ -1,35 +1,46 @@ +import typing + from moto.s3.models import s3_backends import json +from typing import Any, Dict from .exceptions import ValidationError -def get_pipeline_from_name(pipelines, pipeline_name): +if typing.TYPE_CHECKING: + from .models import FakePipeline, FakePipelineExecution + + +def get_pipeline_from_name( + pipelines: Dict[str, "FakePipeline"], pipeline_name: str +) -> "FakePipeline": try: - pipeline = pipelines[pipeline_name] - return pipeline + return pipelines[pipeline_name] except KeyError: raise ValidationError( message=f"Could not find pipeline with PipelineName {pipeline_name}." ) -def get_pipeline_name_from_execution_arn(pipeline_execution_arn): +def get_pipeline_name_from_execution_arn(pipeline_execution_arn: str) -> str: return pipeline_execution_arn.split("/")[1].split(":")[-1] -def get_pipeline_execution_from_arn(pipelines, pipeline_execution_arn): +def get_pipeline_execution_from_arn( + pipelines: Dict[str, "FakePipeline"], pipeline_execution_arn: str +) -> "FakePipelineExecution": try: pipeline_name = get_pipeline_name_from_execution_arn(pipeline_execution_arn) pipeline = get_pipeline_from_name(pipelines, pipeline_name) - pipeline_execution = pipeline.pipeline_executions[pipeline_execution_arn] - return pipeline_execution + return pipeline.pipeline_executions[pipeline_execution_arn] except KeyError: raise ValidationError( message=f"Could not find pipeline execution with PipelineExecutionArn {pipeline_execution_arn}." ) -def load_pipeline_definition_from_s3(pipeline_definition_s3_location, account_id): +def load_pipeline_definition_from_s3( + pipeline_definition_s3_location: Dict[str, Any], account_id: str +) -> Dict[str, Any]: s3_backend = s3_backends[account_id]["global"] result = s3_backend.get_object( bucket_name=pipeline_definition_s3_location["Bucket"], @@ -38,5 +49,5 @@ def load_pipeline_definition_from_s3(pipeline_definition_s3_location, account_id return json.loads(result.value) -def arn_formatter(_type, _id, account_id, region_name): +def arn_formatter(_type: str, _id: str, account_id: str, region_name: str) -> str: return f"arn:aws:sagemaker:{region_name}:{account_id}:{_type}/{_id}" diff --git a/moto/sagemaker/validators.py b/moto/sagemaker/validators.py index 69cbee2a5..7a501a063 100644 --- a/moto/sagemaker/validators.py +++ b/moto/sagemaker/validators.py @@ -1,4 +1,9 @@ -def is_integer_between(x, mn=None, mx=None, optional=False): +from typing import Any, Optional + + +def is_integer_between( + x: int, mn: Optional[int] = None, mx: Optional[int] = None, optional: bool = False +) -> bool: if optional and x is None: return True try: @@ -14,7 +19,7 @@ def is_integer_between(x, mn=None, mx=None, optional=False): return False -def is_one_of(x, choices, optional=False): +def is_one_of(x: Any, choices: Any, optional: bool = False) -> bool: if optional and x is None: return True return x in choices diff --git a/setup.cfg b/setup.cfg index 89748a342..50374d3aa 100644 --- a/setup.cfg +++ b/setup.cfg @@ -239,7 +239,7 @@ disable = W,C,R,E enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import [mypy] -files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s3*,moto/scheduler +files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/r*,moto/s3*,moto/sagemaker,moto/scheduler show_column_numbers=True show_error_codes = True disable_error_code=abstract