From 8c88a93d7c90edec04dd14919585fd04f0015762 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Sat, 22 Oct 2022 11:40:20 +0000 Subject: [PATCH] TechDebt: MyPy AWSLambda (#5586) --- moto/appsync/models.py | 19 +- moto/awslambda/exceptions.py | 20 +- moto/awslambda/models.py | 640 +++++++++++++++++------------ moto/awslambda/policy.py | 50 ++- moto/awslambda/responses.py | 161 ++++---- moto/awslambda/utils.py | 9 +- moto/core/utils.py | 4 +- moto/iam/models.py | 2 +- moto/settings.py | 16 +- moto/utilities/docker_utilities.py | 5 +- setup.cfg | 2 +- 11 files changed, 536 insertions(+), 392 deletions(-) diff --git a/moto/appsync/models.py b/moto/appsync/models.py index f2da8dd33..f56b9fea3 100644 --- a/moto/appsync/models.py +++ b/moto/appsync/models.py @@ -51,19 +51,20 @@ class GraphqlSchema(BaseModel): class GraphqlAPIKey(BaseModel): - def __init__(self, description: str, expires: Optional[datetime]): + def __init__(self, description: str, expires: Optional[int]): self.key_id = str(mock_random.uuid4())[0:6] self.description = description - self.expires = expires - if not self.expires: + if not expires: default_expiry = datetime.now(timezone.utc) default_expiry = default_expiry.replace( minute=0, second=0, microsecond=0, tzinfo=None ) default_expiry = default_expiry + timedelta(days=7) self.expires = unix_time(default_expiry) + else: + self.expires = expires - def update(self, description: Optional[str], expires: Optional[datetime]) -> None: + def update(self, description: Optional[str], expires: Optional[int]) -> None: if description: self.description = description if expires: @@ -138,9 +139,7 @@ class GraphqlAPI(BaseModel): if xray_enabled is not None: self.xray_enabled = xray_enabled - def create_api_key( - self, description: str, expires: Optional[datetime] - ) -> GraphqlAPIKey: + def create_api_key(self, description: str, expires: Optional[int]) -> GraphqlAPIKey: api_key = GraphqlAPIKey(description, expires) self.api_keys[api_key.key_id] = api_key return api_key @@ -152,7 +151,7 @@ class GraphqlAPI(BaseModel): self.api_keys.pop(api_key_id) def update_api_key( - self, api_key_id: str, description: str, expires: Optional[datetime] + self, api_key_id: str, description: str, expires: Optional[int] ) -> GraphqlAPIKey: api_key = self.api_keys[api_key_id] api_key.update(description, expires) @@ -265,7 +264,7 @@ class AppSyncBackend(BaseBackend): return self.graphql_apis.values() def create_api_key( - self, api_id: str, description: str, expires: Optional[datetime] + self, api_id: str, description: str, expires: Optional[int] ) -> GraphqlAPIKey: return self.graphql_apis[api_id].create_api_key(description, expires) @@ -286,7 +285,7 @@ class AppSyncBackend(BaseBackend): api_id: str, api_key_id: str, description: str, - expires: Optional[datetime], + expires: Optional[int], ) -> GraphqlAPIKey: return self.graphql_apis[api_id].update_api_key( api_key_id, description, expires diff --git a/moto/awslambda/exceptions.py b/moto/awslambda/exceptions.py index 1f4808cf1..cb3438d0e 100644 --- a/moto/awslambda/exceptions.py +++ b/moto/awslambda/exceptions.py @@ -2,26 +2,26 @@ from moto.core.exceptions import JsonRESTError class LambdaClientError(JsonRESTError): - def __init__(self, error, message): + def __init__(self, error: str, message: str): super().__init__(error, message) class CrossAccountNotAllowed(LambdaClientError): - def __init__(self): + def __init__(self) -> None: super().__init__( "AccessDeniedException", "Cross-account pass role is not allowed." ) class InvalidParameterValueException(LambdaClientError): - def __init__(self, message): + def __init__(self, message: str): super().__init__("InvalidParameterValueException", message) class InvalidRoleFormat(LambdaClientError): pattern = r"arn:(aws[a-zA-Z-]*)?:iam::(\d{12}):role/?[a-zA-Z_0-9+=,.@\-_/]+" - def __init__(self, role): + def __init__(self, role: str): message = "1 validation error detected: Value '{0}' at 'role' failed to satisfy constraint: Member must satisfy regular expression pattern: {1}".format( role, InvalidRoleFormat.pattern ) @@ -31,28 +31,28 @@ class InvalidRoleFormat(LambdaClientError): class PreconditionFailedException(JsonRESTError): code = 412 - def __init__(self, message): + def __init__(self, message: str): super().__init__("PreconditionFailedException", message) class UnknownAliasException(LambdaClientError): code = 404 - def __init__(self, arn): + def __init__(self, arn: str): super().__init__("ResourceNotFoundException", f"Cannot find alias arn: {arn}") class UnknownFunctionException(LambdaClientError): code = 404 - def __init__(self, arn): + def __init__(self, arn: str): super().__init__("ResourceNotFoundException", f"Function not found: {arn}") class FunctionUrlConfigNotFound(LambdaClientError): code = 404 - def __init__(self): + def __init__(self) -> None: super().__init__( "ResourceNotFoundException", "The resource you requested does not exist." ) @@ -61,14 +61,14 @@ class FunctionUrlConfigNotFound(LambdaClientError): class UnknownLayerException(LambdaClientError): code = 404 - def __init__(self): + def __init__(self) -> None: super().__init__("ResourceNotFoundException", "Cannot find layer") class UnknownPolicyException(LambdaClientError): code = 404 - def __init__(self): + def __init__(self) -> None: super().__init__( "ResourceNotFoundException", "No policy is associated with the given resource.", diff --git a/moto/awslambda/models.py b/moto/awslambda/models.py index 98578fe6e..40341c870 100644 --- a/moto/awslambda/models.py +++ b/moto/awslambda/models.py @@ -4,7 +4,7 @@ from collections import defaultdict import copy import datetime from gzip import GzipFile -from typing import Mapping +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from sys import platform import docker @@ -30,7 +30,7 @@ from moto.iam.models import iam_backends from moto.iam.exceptions import IAMNotFoundException from moto.logs.models import logs_backends from moto.moto_api._internal import mock_random as random -from moto.s3.models import s3_backends +from moto.s3.models import s3_backends, FakeKey from moto.s3.exceptions import MissingBucket, MissingKey from moto import settings from .exceptions import ( @@ -61,7 +61,7 @@ logger = logging.getLogger(__name__) docker_3 = docker.__version__[0] >= "3" -def zip2tar(zip_bytes): +def zip2tar(zip_bytes: bytes) -> bytes: with TemporaryDirectory() as td: tarname = os.path.join(td, "data.tar") timeshift = int( @@ -88,26 +88,27 @@ def zip2tar(zip_bytes): class _VolumeRefCount: __slots__ = "refcount", "volume" - def __init__(self, refcount, volume): + def __init__(self, refcount: int, volume: Any): self.refcount = refcount self.volume = volume class _DockerDataVolumeContext: - _data_vol_map = defaultdict( + # {sha256: _VolumeRefCount} + _data_vol_map: Dict[str, _VolumeRefCount] = defaultdict( lambda: _VolumeRefCount(0, None) - ) # {sha256: _VolumeRefCount} + ) _lock = threading.Lock() - def __init__(self, lambda_func): + def __init__(self, lambda_func: "LambdaFunction"): self._lambda_func = lambda_func - self._vol_ref = None + self._vol_ref: Optional[_VolumeRefCount] = None @property - def name(self): - return self._vol_ref.volume.name + def name(self) -> str: + return self._vol_ref.volume.name # type: ignore[union-attr] - def __enter__(self): + def __enter__(self) -> "_DockerDataVolumeContext": # See if volume is already known with self.__class__._lock: self._vol_ref = self.__class__._data_vol_map[self._lambda_func.code_digest] @@ -125,10 +126,11 @@ class _DockerDataVolumeContext: self._vol_ref.volume = self._lambda_func.docker_client.volumes.create( self._lambda_func.code_digest ) - if docker_3: - volumes = {self.name: {"bind": "/tmp/data", "mode": "rw"}} - else: - volumes = {self.name: "/tmp/data"} + volumes = { + self.name: {"bind": "/tmp/data", "mode": "rw"} + if docker_3 + else "/tmp/data" + } self._lambda_func.docker_client.images.pull( ":".join(parse_image_ref("alpine")) @@ -144,12 +146,12 @@ class _DockerDataVolumeContext: return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: with self.__class__._lock: - self._vol_ref.refcount -= 1 - if self._vol_ref.refcount == 0: + self._vol_ref.refcount -= 1 # type: ignore[union-attr] + if self._vol_ref.refcount == 0: # type: ignore[union-attr] try: - self._vol_ref.volume.remove() + self._vol_ref.volume.remove() # type: ignore[union-attr] except docker.errors.APIError as e: if e.status_code != 409: raise @@ -157,9 +159,9 @@ class _DockerDataVolumeContext: raise # multiple processes trying to use same volume? -def _zipfile_content(zipfile_content): +def _zipfile_content(zipfile_content: Union[str, bytes]) -> Tuple[bytes, int, str, str]: try: - to_unzip_code = base64.b64decode(bytes(zipfile_content, "utf-8")) + to_unzip_code = base64.b64decode(bytes(zipfile_content, "utf-8")) # type: ignore[arg-type] except Exception: to_unzip_code = base64.b64decode(zipfile_content) @@ -169,14 +171,16 @@ def _zipfile_content(zipfile_content): return to_unzip_code, len(to_unzip_code), base64ed_sha, sha_hex_digest -def _s3_content(key): +def _s3_content(key: Any) -> Tuple[bytes, int, str, str]: sha_code = hashlib.sha256(key.value) base64ed_sha = base64.b64encode(sha_code.digest()).decode("utf-8") sha_hex_digest = sha_code.hexdigest() return key.value, key.size, base64ed_sha, sha_hex_digest -def _validate_s3_bucket_and_key(account_id, data): +def _validate_s3_bucket_and_key( + account_id: str, data: Dict[str, Any] +) -> Optional[FakeKey]: key = None try: # FIXME: does not validate bucket region @@ -198,21 +202,26 @@ def _validate_s3_bucket_and_key(account_id, data): class Permission(CloudFormationModel): - def __init__(self, region): + def __init__(self, region: str): self.region = region @staticmethod - def cloudformation_name_type(): + def cloudformation_name_type() -> str: return "Permission" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: return "AWS::Lambda::Permission" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Dict[str, Any], + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "Permission": properties = cloudformation_json["Properties"] backend = lambda_backends[account_id][region_name] fn = backend.get_function(properties["FunctionName"]) @@ -221,7 +230,7 @@ class Permission(CloudFormationModel): class LayerVersion(CloudFormationModel): - def __init__(self, spec, account_id, region): + def __init__(self, spec: Dict[str, Any], account_id: str, region: str): # required self.account_id = account_id self.region = region @@ -236,9 +245,9 @@ class LayerVersion(CloudFormationModel): # auto-generated self.created_date = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S") - self.version = None + self.version: Optional[int] = None self._attached = False - self._layer = None + self._layer: Optional["Layer"] = None if "ZipFile" in self.content: ( @@ -255,22 +264,24 @@ class LayerVersion(CloudFormationModel): self.code_size, self.code_sha_256, self.code_digest, - ) = _s3_content(key) + ) = _s3_content( + key + ) # type: ignore[assignment] @property - def arn(self): + def arn(self) -> str: if self.version: return make_layer_ver_arn( self.region, self.account_id, self.name, self.version ) raise ValueError("Layer version is not set") - def attach(self, layer, version): + def attach(self, layer: "Layer", version: int) -> None: self._attached = True self._layer = layer self.version = version - def get_layer_version(self): + def get_layer_version(self) -> Dict[str, Any]: return { "Content": { "Location": "s3://", @@ -278,7 +289,7 @@ class LayerVersion(CloudFormationModel): "CodeSize": self.code_size, }, "Version": self.version, - "LayerArn": self._layer.layer_arn, + "LayerArn": self._layer.layer_arn, # type: ignore[union-attr] "LayerVersionArn": self.arn, "CreatedDate": self.created_date, "CompatibleArchitectures": self.compatible_architectures, @@ -288,17 +299,22 @@ class LayerVersion(CloudFormationModel): } @staticmethod - def cloudformation_name_type(): + def cloudformation_name_type() -> str: return "LayerVersion" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: return "AWS::Lambda::LayerVersion" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Dict[str, Any], + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "LayerVersion": properties = cloudformation_json["Properties"] optional_properties = ("Description", "CompatibleRuntimes", "LicenseInfo") @@ -319,13 +335,13 @@ class LayerVersion(CloudFormationModel): class LambdaAlias(BaseModel): def __init__( self, - account_id, - region, - name, - function_name, - function_version, - description, - routing_config, + account_id: str, + region: str, + name: str, + function_name: str, + function_version: str, + description: str, + routing_config: str, ): self.arn = ( f"arn:aws:lambda:{region}:{account_id}:function:{function_name}:{name}" @@ -336,7 +352,12 @@ class LambdaAlias(BaseModel): self.routing_config = routing_config self.revision_id = str(random.uuid4()) - def update(self, description, function_version, routing_config): + def update( + self, + description: Optional[str], + function_version: Optional[str], + routing_config: Optional[str], + ) -> None: if description is not None: self.description = description if function_version is not None: @@ -344,7 +365,7 @@ class LambdaAlias(BaseModel): if routing_config is not None: self.routing_config = routing_config - def to_json(self): + def to_json(self) -> Dict[str, Any]: return { "AliasArn": self.arn, "Description": self.description, @@ -364,17 +385,17 @@ class Layer(object): self.region, layer_version.account_id, self.name ) self._latest_version = 0 - self.layer_versions = {} + self.layer_versions: Dict[str, LayerVersion] = {} - def attach_version(self, layer_version): + def attach_version(self, layer_version: LayerVersion) -> None: self._latest_version += 1 layer_version.attach(self, self._latest_version) self.layer_versions[str(self._latest_version)] = layer_version - def delete_version(self, layer_version): + def delete_version(self, layer_version: str) -> None: self.layer_versions.pop(str(layer_version), None) - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: if not self.layer_versions: return {} @@ -389,7 +410,13 @@ class Layer(object): class LambdaFunction(CloudFormationModel, DockerModel): - def __init__(self, account_id, spec, region, version=1): + def __init__( + self, + account_id: str, + spec: Dict[str, Any], + region: str, + version: Union[str, int] = 1, + ): DockerModel.__init__(self) # required self.account_id = account_id @@ -401,8 +428,8 @@ class LambdaFunction(CloudFormationModel, DockerModel): self.run_time = spec.get("Runtime") self.logs_backend = logs_backends[account_id][self.region] self.environment_vars = spec.get("Environment", {}).get("Variables", {}) - self.policy = None - self.url_config = None + self.policy: Optional[Policy] = None + self.url_config: Optional[FunctionUrlConfig] = None self.state = "Active" self.reserved_concurrency = spec.get("ReservedConcurrentExecutions", None) @@ -450,7 +477,7 @@ class LambdaFunction(CloudFormationModel, DockerModel): self.code_digest, ) = _s3_content(key) else: - self.code_bytes = "" + self.code_bytes = b"" self.code_size = 0 self.code_sha_256 = "" elif "ImageUri" in self.code: @@ -463,14 +490,11 @@ class LambdaFunction(CloudFormationModel, DockerModel): self.region, self.account_id, self.function_name ) - if spec.get("Tags"): - self.tags = spec.get("Tags") - else: - self.tags = dict() + self.tags = spec.get("Tags") or dict() - self._aliases = dict() + self._aliases: Dict[str, LambdaAlias] = dict() - def set_version(self, version): + def set_version(self, version: int) -> None: self.function_arn = make_function_ver_arn( self.region, self.account_id, self.function_name, version ) @@ -478,20 +502,20 @@ class LambdaFunction(CloudFormationModel, DockerModel): self.last_modified = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S") @property - def vpc_config(self): + def vpc_config(self) -> Dict[str, Any]: # type: ignore[misc] config = self._vpc_config.copy() if config["SecurityGroupIds"]: config.update({"VpcId": "vpc-123abc"}) return config @property - def physical_resource_id(self): + def physical_resource_id(self) -> str: return self.function_name - def __repr__(self): + def __repr__(self) -> str: return json.dumps(self.get_configuration()) - def _get_layers_data(self, layers_versions_arns): + def _get_layers_data(self, layers_versions_arns: List[str]) -> List[Dict[str, str]]: backend = lambda_backends[self.account_id][self.region] layer_versions = [ backend.layers_versions_by_arn(layer_version) @@ -506,13 +530,13 @@ class LambdaFunction(CloudFormationModel, DockerModel): ) return [{"Arn": lv.arn, "CodeSize": lv.code_size} for lv in layer_versions] - def get_code_signing_config(self): + def get_code_signing_config(self) -> Dict[str, Any]: return { "CodeSigningConfigArn": self.code_signing_config_arn, "FunctionName": self.function_name, } - def get_configuration(self, on_create=False): + def get_configuration(self, on_create: bool = False) -> Dict[str, Any]: config = { "CodeSha256": self.code_sha_256, "CodeSize": self.code_size, @@ -542,7 +566,7 @@ class LambdaFunction(CloudFormationModel, DockerModel): return config - def get_code(self): + def get_code(self) -> Dict[str, Any]: resp = {"Configuration": self.get_configuration()} if "S3Key" in self.code: resp["Code"] = { @@ -571,7 +595,7 @@ class LambdaFunction(CloudFormationModel, DockerModel): ) return resp - def update_configuration(self, config_updates): + def update_configuration(self, config_updates: Dict[str, Any]) -> Dict[str, Any]: for key, value in config_updates.items(): if key == "Description": self.description = value @@ -594,7 +618,7 @@ class LambdaFunction(CloudFormationModel, DockerModel): return self.get_configuration() - def update_function_code(self, updated_spec): + def update_function_code(self, updated_spec: Dict[str, Any]) -> Dict[str, Any]: if "DryRun" in updated_spec and updated_spec["DryRun"]: return self.get_configuration() @@ -636,25 +660,27 @@ class LambdaFunction(CloudFormationModel, DockerModel): self.code_size, self.code_sha_256, self.code_digest, - ) = _s3_content(key) + ) = _s3_content( + key + ) # type: ignore[assignment] self.code["S3Bucket"] = updated_spec["S3Bucket"] self.code["S3Key"] = updated_spec["S3Key"] return self.get_configuration() @staticmethod - def convert(s): + def convert(s: Any) -> str: # type: ignore[misc] try: return str(s, encoding="utf-8") except Exception: return s - def _invoke_lambda(self, event=None): + def _invoke_lambda(self, event: Optional[str] = None) -> Tuple[str, bool, str]: # Create the LogGroup if necessary, to write the result to self.logs_backend.ensure_log_group(self.logs_group_name, []) # TODO: context not yet implemented if event is None: - event = dict() + event = dict() # type: ignore[assignment] output = None try: @@ -686,7 +712,7 @@ class LambdaFunction(CloudFormationModel, DockerModel): with _DockerDataVolumeContext(self) as data_vol: try: - run_kwargs = dict() + run_kwargs: Dict[str, Any] = dict() network_name = settings.moto_network_name() network_mode = settings.moto_network_mode() if network_name: @@ -736,7 +762,7 @@ class LambdaFunction(CloudFormationModel, DockerModel): output += container.logs(stdout=True, stderr=False) container.remove() - output = output.decode("utf-8") + output = output.decode("utf-8") # type: ignore[union-attr] self.save_logs(output) @@ -754,7 +780,7 @@ class LambdaFunction(CloudFormationModel, DockerModel): self.save_logs(msg) return msg, True, "" - def save_logs(self, output): + def save_logs(self, output: str) -> None: # Send output to "logs" backend invoke_id = random.uuid4().hex log_stream_name = ( @@ -773,7 +799,9 @@ class LambdaFunction(CloudFormationModel, DockerModel): self.logs_group_name, log_stream_name, log_events ) - def invoke(self, body, request_headers, response_headers): + def invoke( + self, body: str, request_headers: Any, response_headers: Any + ) -> Union[str, bytes]: if body: body = json.loads(body) else: @@ -781,31 +809,35 @@ class LambdaFunction(CloudFormationModel, DockerModel): # Get the invocation type: res, errored, logs = self._invoke_lambda(event=body) + if errored: + response_headers["x-amz-function-error"] = "Handled" + inv_type = request_headers.get("x-amz-invocation-type", "RequestResponse") if inv_type == "RequestResponse": encoded = base64.b64encode(logs.encode("utf-8")) response_headers["x-amz-log-result"] = encoded.decode("utf-8") - result = res.encode("utf-8") + return res.encode("utf-8") else: - result = res - if errored: - response_headers["x-amz-function-error"] = "Handled" - - return result + return res @staticmethod - def cloudformation_name_type(): + def cloudformation_name_type() -> str: return "FunctionName" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-lambda-function.html return "AWS::Lambda::Function" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Dict[str, Any], + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "LambdaFunction": properties = cloudformation_json["Properties"] optional_properties = ( "Description", @@ -845,10 +877,10 @@ class LambdaFunction(CloudFormationModel, DockerModel): return fn @classmethod - def has_cfn_attr(cls, attr): + def has_cfn_attr(cls, attr: str) -> bool: return attr in ["Arn"] - def get_cfn_attribute(self, attribute_name): + def get_cfn_attribute(self, attribute_name: str) -> str: from moto.cloudformation.exceptions import UnformattedGetAttTemplateException if attribute_name == "Arn": @@ -856,21 +888,21 @@ class LambdaFunction(CloudFormationModel, DockerModel): raise UnformattedGetAttTemplateException() @classmethod - def update_from_cloudformation_json( + def update_from_cloudformation_json( # type: ignore[misc] cls, - original_resource, - new_resource_name, - cloudformation_json, - account_id, - region_name, - ): + original_resource: "LambdaFunction", + new_resource_name: str, + cloudformation_json: Dict[str, Any], + account_id: str, + region_name: str, + ) -> "LambdaFunction": updated_props = cloudformation_json["Properties"] original_resource.update_configuration(updated_props) original_resource.update_function_code(updated_props["Code"]) return original_resource @staticmethod - def _create_zipfile_from_plaintext_code(code): + def _create_zipfile_from_plaintext_code(code: str) -> bytes: zip_output = io.BytesIO() zip_file = zipfile.ZipFile(zip_output, "w", zipfile.ZIP_DEFLATED) zip_file.writestr("index.py", code) @@ -883,25 +915,27 @@ class LambdaFunction(CloudFormationModel, DockerModel): zip_output.seek(0) return zip_output.read() - def delete(self, account_id, region): + def delete(self, account_id: str, region: str) -> None: lambda_backends[account_id][region].delete_function(self.function_name) - def delete_alias(self, name): + def delete_alias(self, name: str) -> None: self._aliases.pop(name, None) - def get_alias(self, name): + def get_alias(self, name: str) -> LambdaAlias: if name in self._aliases: return self._aliases[name] arn = f"arn:aws:lambda:{self.region}:{self.account_id}:function:{self.function_name}:{name}" raise UnknownAliasException(arn) - def has_alias(self, alias_name) -> bool: + def has_alias(self, alias_name: str) -> bool: try: return self.get_alias(alias_name) is not None except UnknownAliasException: return False - def put_alias(self, name, description, function_version, routing_config): + def put_alias( + self, name: str, description: str, function_version: str, routing_config: str + ) -> LambdaAlias: alias = LambdaAlias( account_id=self.account_id, region=self.region, @@ -914,37 +948,39 @@ class LambdaFunction(CloudFormationModel, DockerModel): self._aliases[name] = alias return alias - def update_alias(self, name, description, function_version, routing_config): + def update_alias( + self, name: str, description: str, function_version: str, routing_config: str + ) -> LambdaAlias: alias = self.get_alias(name) alias.update(description, function_version, routing_config) return alias - def create_url_config(self, config): + def create_url_config(self, config: Dict[str, Any]) -> "FunctionUrlConfig": self.url_config = FunctionUrlConfig(function=self, config=config) - return self.url_config + return self.url_config # type: ignore[return-value] - def delete_url_config(self): + def delete_url_config(self) -> None: self.url_config = None - def get_url_config(self): + def get_url_config(self) -> "FunctionUrlConfig": if not self.url_config: raise FunctionUrlConfigNotFound() return self.url_config - def update_url_config(self, config): - self.url_config.update(config) - return self.url_config + def update_url_config(self, config: Dict[str, Any]) -> "FunctionUrlConfig": + self.url_config.update(config) # type: ignore[union-attr] + return self.url_config # type: ignore[return-value] class FunctionUrlConfig: - def __init__(self, function: LambdaFunction, config): + def __init__(self, function: LambdaFunction, config: Dict[str, Any]): self.function = function self.config = config self.url = f"https://{random.uuid4().hex}.lambda-url.{function.region}.on.aws" self.created = datetime.datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S") self.last_modified = self.created - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return { "FunctionUrl": self.url, "FunctionArn": self.function.function_arn, @@ -954,7 +990,7 @@ class FunctionUrlConfig: "LastModifiedTime": self.last_modified, } - def update(self, new_config): + def update(self, new_config: Dict[str, Any]) -> None: if new_config.get("Cors"): self.config["Cors"] = new_config["Cors"] if new_config.get("AuthType"): @@ -963,13 +999,13 @@ class FunctionUrlConfig: class EventSourceMapping(CloudFormationModel): - def __init__(self, spec): + def __init__(self, spec: Dict[str, Any]): # required self.function_name = spec["FunctionName"] self.event_source_arn = spec["EventSourceArn"] # optional - self.batch_size = spec.get("BatchSize") + self.batch_size = spec.get("BatchSize") # type: ignore[assignment] self.starting_position = spec.get("StartingPosition", "TRIM_HORIZON") self.enabled = spec.get("Enabled", True) self.starting_position_timestamp = spec.get("StartingPositionTimestamp", None) @@ -978,20 +1014,20 @@ class EventSourceMapping(CloudFormationModel): self.uuid = str(random.uuid4()) self.last_modified = time.mktime(datetime.datetime.utcnow().timetuple()) - def _get_service_source_from_arn(self, event_source_arn): + def _get_service_source_from_arn(self, event_source_arn: str) -> str: return event_source_arn.split(":")[2].lower() - def _validate_event_source(self, event_source_arn): + def _validate_event_source(self, event_source_arn: str) -> bool: valid_services = ("dynamodb", "kinesis", "sqs") service = self._get_service_source_from_arn(event_source_arn) - return True if service in valid_services else False + return service in valid_services @property - def event_source_arn(self): + def event_source_arn(self) -> str: return self._event_source_arn @event_source_arn.setter - def event_source_arn(self, event_source_arn): + def event_source_arn(self, event_source_arn: str) -> None: if not self._validate_event_source(event_source_arn): raise ValueError( "InvalidParameterValueException", "Unsupported event source type" @@ -999,11 +1035,11 @@ class EventSourceMapping(CloudFormationModel): self._event_source_arn = event_source_arn @property - def batch_size(self): + def batch_size(self) -> int: return self._batch_size @batch_size.setter - def batch_size(self, batch_size): + def batch_size(self, batch_size: Optional[int]) -> None: batch_size_service_map = { "kinesis": (100, 10000), "dynamodb": (100, 1000), @@ -1023,7 +1059,7 @@ class EventSourceMapping(CloudFormationModel): else: self._batch_size = int(batch_size) - def get_configuration(self): + def get_configuration(self) -> Dict[str, Any]: return { "UUID": self.uuid, "BatchSize": self.batch_size, @@ -1035,45 +1071,54 @@ class EventSourceMapping(CloudFormationModel): "StateTransitionReason": "User initiated", } - def delete(self, account_id, region_name): + def delete(self, account_id: str, region_name: str) -> None: lambda_backend = lambda_backends[account_id][region_name] lambda_backend.delete_event_source_mapping(self.uuid) @staticmethod - def cloudformation_name_type(): + def cloudformation_name_type() -> None: return None @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-lambda-eventsourcemapping.html return "AWS::Lambda::EventSourceMapping" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Dict[str, Any], + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "EventSourceMapping": properties = cloudformation_json["Properties"] lambda_backend = lambda_backends[account_id][region_name] return lambda_backend.create_event_source_mapping(properties) @classmethod - def update_from_cloudformation_json( + def update_from_cloudformation_json( # type: ignore[misc] cls, - original_resource, - new_resource_name, - cloudformation_json, - account_id, - region_name, - ): + original_resource: Any, + new_resource_name: str, + cloudformation_json: Dict[str, Any], + account_id: str, + region_name: str, + ) -> "EventSourceMapping": properties = cloudformation_json["Properties"] event_source_uuid = original_resource.uuid lambda_backend = lambda_backends[account_id][region_name] return lambda_backend.update_event_source_mapping(event_source_uuid, properties) @classmethod - def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name - ): + def delete_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Dict[str, Any], + account_id: str, + region_name: str, + ) -> None: properties = cloudformation_json["Properties"] lambda_backend = lambda_backends[account_id][region_name] esms = lambda_backend.list_event_source_mappings( @@ -1086,30 +1131,35 @@ class EventSourceMapping(CloudFormationModel): esm.delete(account_id, region_name) @property - def physical_resource_id(self): + def physical_resource_id(self) -> str: return self.uuid class LambdaVersion(CloudFormationModel): - def __init__(self, spec): + def __init__(self, spec: Dict[str, Any]): self.version = spec["Version"] - def __repr__(self): - return str(self.logical_resource_id) + def __repr__(self) -> str: + return str(self.logical_resource_id) # type: ignore[attr-defined] @staticmethod - def cloudformation_name_type(): + def cloudformation_name_type() -> None: return None @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-lambda-version.html return "AWS::Lambda::Version" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Dict[str, Any], + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "LambdaVersion": properties = cloudformation_json["Properties"] function_name = properties["FunctionName"] func = lambda_backends[account_id][region_name].publish_function(function_name) @@ -1118,43 +1168,57 @@ class LambdaVersion(CloudFormationModel): class LambdaStorage(object): - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): # Format 'func_name' {'versions': []} - self._functions = {} - self._arns = weakref.WeakValueDictionary() + self._functions: Dict[str, Any] = {} + self._arns: weakref.WeakValueDictionary[ + str, LambdaFunction + ] = weakref.WeakValueDictionary() self.region_name = region_name self.account_id = account_id - def _get_latest(self, name): + def _get_latest(self, name: str) -> LambdaFunction: return self._functions[name]["latest"] - def _get_version(self, name: str, version: str): + def _get_version(self, name: str, version: str) -> Optional[LambdaFunction]: for config in self._functions[name]["versions"]: if str(config.version) == version or config.has_alias(version): return config return None - def delete_alias(self, name, function_name): + def delete_alias(self, name: str, function_name: str) -> None: fn = self.get_function_by_name_or_arn(function_name) return fn.delete_alias(name) - def get_alias(self, name, function_name): + def get_alias(self, name: str, function_name: str) -> LambdaAlias: fn = self.get_function_by_name_or_arn(function_name) return fn.get_alias(name) def put_alias( - self, name, function_name, function_version, description, routing_config - ): + self, + name: str, + function_name: str, + function_version: str, + description: str, + routing_config: str, + ) -> LambdaAlias: fn = self.get_function_by_name_or_arn(function_name) return fn.put_alias(name, description, function_version, routing_config) def update_alias( - self, name, function_name, function_version, description, routing_config - ): + self, + name: str, + function_name: str, + function_version: str, + description: str, + routing_config: str, + ) -> LambdaAlias: fn = self.get_function_by_name_or_arn(function_name) return fn.update_alias(name, description, function_version, routing_config) - def get_function_by_name(self, name, qualifier=None): + def get_function_by_name( + self, name: str, qualifier: Optional[str] = None + ) -> Optional[LambdaFunction]: if name not in self._functions: return None @@ -1166,15 +1230,15 @@ class LambdaStorage(object): return self._get_version(name, qualifier) - def list_versions_by_function(self, name): + def list_versions_by_function(self, name: str) -> Iterable[LambdaFunction]: if name not in self._functions: - return None + return [] latest = copy.copy(self._functions[name]["latest"]) latest.function_arn += ":$LATEST" return [latest] + self._functions[name]["versions"] - def get_arn(self, arn): + def get_arn(self, arn: str) -> Optional[LambdaFunction]: # Function ARN may contain an alias # arn:aws:lambda:region:account_id:function:: if ":" in arn.split(":function:")[-1]: @@ -1183,7 +1247,7 @@ class LambdaStorage(object): return self._arns.get(arn, None) def get_function_by_name_or_arn( - self, name_or_arn, qualifier=None + self, name_or_arn: str, qualifier: Optional[str] = None ) -> LambdaFunction: fn = self.get_function_by_name(name_or_arn, qualifier) or self.get_arn( name_or_arn @@ -1198,11 +1262,7 @@ class LambdaStorage(object): raise UnknownFunctionException(arn) return fn - def put_function(self, fn): - """ - :param fn: Function - :type fn: LambdaFunction - """ + def put_function(self, fn: LambdaFunction) -> None: valid_role = re.match(InvalidRoleFormat.pattern, fn.role) if valid_role: account = valid_role.group(2) @@ -1225,7 +1285,9 @@ class LambdaStorage(object): fn.policy = Policy(fn) self._arns[fn.function_arn] = fn - def publish_function(self, name_or_arn, description=""): + def publish_function( + self, name_or_arn: str, description: str = "" + ) -> Optional[LambdaFunction]: function = self.get_function_by_name_or_arn(name_or_arn) name = function.function_name if name not in self._functions: @@ -1243,7 +1305,7 @@ class LambdaStorage(object): self._arns[fn.function_arn] = fn return fn - def del_function(self, name_or_arn, qualifier=None): + def del_function(self, name_or_arn: str, qualifier: Optional[str] = None) -> None: function = self.get_function_by_name_or_arn(name_or_arn, qualifier) name = function.function_name if not qualifier: @@ -1278,7 +1340,7 @@ class LambdaStorage(object): ): del self._functions[name] - def all(self): + def all(self) -> Iterable[LambdaFunction]: result = [] for function_group in self._functions.values(): @@ -1290,7 +1352,7 @@ class LambdaStorage(object): return result - def latest(self): + def latest(self) -> Iterable[LambdaFunction]: """ Return the list of functions with version @LATEST :return: @@ -1304,11 +1366,13 @@ class LambdaStorage(object): class LayerStorage(object): - def __init__(self): - self._layers = {} - self._arns = weakref.WeakValueDictionary() + def __init__(self) -> None: + self._layers: Dict[str, Layer] = {} + self._arns: weakref.WeakValueDictionary[ + str, LambdaFunction + ] = weakref.WeakValueDictionary() - def put_layer_version(self, layer_version): + def put_layer_version(self, layer_version: LayerVersion) -> None: """ :param layer_version: LayerVersion """ @@ -1316,15 +1380,15 @@ class LayerStorage(object): self._layers[layer_version.name] = Layer(layer_version) self._layers[layer_version.name].attach_version(layer_version) - def list_layers(self): + def list_layers(self) -> Iterable[Dict[str, Any]]: return [ layer.to_dict() for layer in self._layers.values() if layer.layer_versions ] - def delete_layer_version(self, layer_name, layer_version): + def delete_layer_version(self, layer_name: str, layer_version: str) -> None: self._layers[layer_name].delete_version(layer_version) - def get_layer_version(self, layer_name, layer_version): + def get_layer_version(self, layer_name: str, layer_version: str) -> LayerVersion: if layer_name not in self._layers: raise UnknownLayerException() for lv in self._layers[layer_name].layer_versions.values(): @@ -1332,12 +1396,14 @@ class LayerStorage(object): return lv raise UnknownLayerException() - def get_layer_versions(self, layer_name): + def get_layer_versions(self, layer_name: str) -> List[LayerVersion]: if layer_name in self._layers: return list(iter(self._layers[layer_name].layer_versions.values())) return [] - def get_layer_version_by_arn(self, layer_version_arn): + def get_layer_version_by_arn( + self, layer_version_arn: str + ) -> Optional[LayerVersion]: split_arn = split_layer_arn(layer_version_arn) if split_arn.layer_name in self._layers: return self._layers[split_arn.layer_name].layer_versions.get( @@ -1397,35 +1463,45 @@ class LambdaBackend(BaseBackend): .. note:: When using the decorators, a Docker container cannot reach Moto, as it does not run as a server. Any boto3-invocations used within your Lambda will try to connect to AWS. """ - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) self._lambdas = LambdaStorage(region_name=region_name, account_id=account_id) - self._event_source_mappings = {} + self._event_source_mappings: Dict[str, EventSourceMapping] = {} self._layers = LayerStorage() @staticmethod - def default_vpc_endpoint_service(service_region, zones): + def default_vpc_endpoint_service(service_region: str, zones: List[str]) -> List[Dict[str, str]]: # type: ignore[misc] """Default VPC endpoint service.""" return BaseBackend.default_vpc_endpoint_service_factory( service_region, zones, "lambda" ) def create_alias( - self, name, function_name, function_version, description, routing_config - ): + self, + name: str, + function_name: str, + function_version: str, + description: str, + routing_config: str, + ) -> LambdaAlias: return self._lambdas.put_alias( name, function_name, function_version, description, routing_config ) - def delete_alias(self, name, function_name): + def delete_alias(self, name: str, function_name: str) -> None: return self._lambdas.delete_alias(name, function_name) - def get_alias(self, name, function_name): + def get_alias(self, name: str, function_name: str) -> LambdaAlias: return self._lambdas.get_alias(name, function_name) def update_alias( - self, name, function_name, function_version, description, routing_config - ): + self, + name: str, + function_name: str, + function_version: str, + description: str, + routing_config: str, + ) -> LambdaAlias: """ The RevisionId parameter is not yet implemented """ @@ -1433,7 +1509,7 @@ class LambdaBackend(BaseBackend): name, function_name, function_version, description, routing_config ) - def create_function(self, spec): + def create_function(self, spec: Dict[str, Any]) -> LambdaFunction: function_name = spec.get("FunctionName", None) if function_name is None: raise RESTError("InvalidParameterValueException", "Missing FunctionName") @@ -1452,10 +1528,12 @@ class LambdaBackend(BaseBackend): fn = copy.deepcopy( fn ) # We don't want to change the actual version - just the return value - fn.version = ver.version + fn.version = ver.version # type: ignore[union-attr] return fn - def create_function_url_config(self, name_or_arn, config): + def create_function_url_config( + self, name_or_arn: str, config: Dict[str, Any] + ) -> FunctionUrlConfig: """ The Qualifier-parameter is not yet implemented. Function URLs are not yet mocked, so invoking them will fail @@ -1463,14 +1541,14 @@ class LambdaBackend(BaseBackend): function = self._lambdas.get_function_by_name_or_arn(name_or_arn) return function.create_url_config(config) - def delete_function_url_config(self, name_or_arn): + def delete_function_url_config(self, name_or_arn: str) -> None: """ The Qualifier-parameter is not yet implemented """ function = self._lambdas.get_function_by_name_or_arn(name_or_arn) function.delete_url_config() - def get_function_url_config(self, name_or_arn): + def get_function_url_config(self, name_or_arn: str) -> FunctionUrlConfig: """ The Qualifier-parameter is not yet implemented """ @@ -1479,14 +1557,16 @@ class LambdaBackend(BaseBackend): raise UnknownFunctionException(arn=name_or_arn) return function.get_url_config() - def update_function_url_config(self, name_or_arn, config): + def update_function_url_config( + self, name_or_arn: str, config: Dict[str, Any] + ) -> FunctionUrlConfig: """ The Qualifier-parameter is not yet implemented """ function = self._lambdas.get_function_by_name_or_arn(name_or_arn) return function.update_url_config(config) - def create_event_source_mapping(self, spec): + def create_event_source_mapping(self, spec: Dict[str, Any]) -> EventSourceMapping: required = ["EventSourceArn", "FunctionName"] for param in required: if not spec.get(param): @@ -1535,7 +1615,7 @@ class LambdaBackend(BaseBackend): return esm raise RESTError("ResourceNotFoundException", "Invalid EventSourceArn") - def publish_layer_version(self, spec): + def publish_layer_version(self, spec: Dict[str, Any]) -> LayerVersion: required = ["LayerName", "Content"] for param in required: if not spec.get(param): @@ -1546,42 +1626,48 @@ class LambdaBackend(BaseBackend): self._layers.put_layer_version(layer_version) return layer_version - def list_layers(self): + def list_layers(self) -> Iterable[Dict[str, Any]]: return self._layers.list_layers() - def delete_layer_version(self, layer_name, layer_version): + def delete_layer_version(self, layer_name: str, layer_version: str) -> None: return self._layers.delete_layer_version(layer_name, layer_version) - def get_layer_version(self, layer_name, layer_version): + def get_layer_version(self, layer_name: str, layer_version: str) -> LayerVersion: return self._layers.get_layer_version(layer_name, layer_version) - def get_layer_versions(self, layer_name): + def get_layer_versions(self, layer_name: str) -> Iterable[LayerVersion]: return self._layers.get_layer_versions(layer_name) - def layers_versions_by_arn(self, layer_version_arn): + def layers_versions_by_arn(self, layer_version_arn: str) -> Optional[LayerVersion]: return self._layers.get_layer_version_by_arn(layer_version_arn) - def publish_function(self, function_name, description=""): + def publish_function( + self, function_name: str, description: str = "" + ) -> Optional[LambdaFunction]: return self._lambdas.publish_function(function_name, description) - def get_function(self, function_name_or_arn, qualifier=None): + def get_function( + self, function_name_or_arn: str, qualifier: Optional[str] = None + ) -> LambdaFunction: return self._lambdas.get_function_by_name_or_arn( function_name_or_arn, qualifier ) - def list_versions_by_function(self, function_name): + def list_versions_by_function(self, function_name: str) -> Iterable[LambdaFunction]: return self._lambdas.list_versions_by_function(function_name) - def get_event_source_mapping(self, uuid): + def get_event_source_mapping(self, uuid: str) -> Optional[EventSourceMapping]: return self._event_source_mappings.get(uuid) - def delete_event_source_mapping(self, uuid): - return self._event_source_mappings.pop(uuid) + def delete_event_source_mapping(self, uuid: str) -> Optional[EventSourceMapping]: + return self._event_source_mappings.pop(uuid, None) - def update_event_source_mapping(self, uuid, spec): + def update_event_source_mapping( + self, uuid: str, spec: Dict[str, Any] + ) -> Optional[EventSourceMapping]: esm = self.get_event_source_mapping(uuid) if not esm: - return False + return None for key in spec.keys(): if key == "FunctionName": @@ -1595,7 +1681,9 @@ class LambdaBackend(BaseBackend): esm.last_modified = time.mktime(datetime.datetime.utcnow().timetuple()) return esm - def list_event_source_mappings(self, event_source_arn, function_name): + def list_event_source_mappings( + self, event_source_arn: str, function_name: str + ) -> Iterable[EventSourceMapping]: esms = list(self._event_source_mappings.values()) if event_source_arn: esms = list(filter(lambda x: x.event_source_arn == event_source_arn, esms)) @@ -1603,27 +1691,33 @@ class LambdaBackend(BaseBackend): esms = list(filter(lambda x: x.function_name == function_name, esms)) return esms - def get_function_by_arn(self, function_arn): + def get_function_by_arn(self, function_arn: str) -> Optional[LambdaFunction]: return self._lambdas.get_arn(function_arn) - def delete_function(self, function_name, qualifier=None): + def delete_function( + self, function_name: str, qualifier: Optional[str] = None + ) -> None: self._lambdas.del_function(function_name, qualifier) - def list_functions(self, func_version=None): + def list_functions( + self, func_version: Optional[str] = None + ) -> Iterable[LambdaFunction]: if func_version == "ALL": return self._lambdas.all() return self._lambdas.latest() - def send_sqs_batch(self, function_arn, messages, queue_arn): + def send_sqs_batch(self, function_arn: str, messages: Any, queue_arn: str) -> bool: success = True for message in messages: func = self.get_function_by_arn(function_arn) - result = self._send_sqs_message(func, message, queue_arn) + result = self._send_sqs_message(func, message, queue_arn) # type: ignore[arg-type] if not result: success = False return success - def _send_sqs_message(self, func, message, queue_arn): + def _send_sqs_message( + self, func: LambdaFunction, message: Any, queue_arn: str + ) -> bool: event = { "Records": [ { @@ -1645,12 +1739,18 @@ class LambdaBackend(BaseBackend): ] } - request_headers = {} - response_headers = {} + request_headers: Dict[str, Any] = {} + response_headers: Dict[str, Any] = {} func.invoke(json.dumps(event), request_headers, response_headers) return "x-amz-function-error" not in response_headers - def send_sns_message(self, function_name, message, subject=None, qualifier=None): + def send_sns_message( + self, + function_name: str, + message: str, + subject: Optional[str] = None, + qualifier: Optional[str] = None, + ) -> None: event = { "Records": [ { @@ -1677,9 +1777,11 @@ class LambdaBackend(BaseBackend): ] } func = self._lambdas.get_function_by_name_or_arn(function_name, qualifier) - func.invoke(json.dumps(event), {}, {}) + func.invoke(json.dumps(event), {}, {}) # type: ignore[union-attr] - def send_dynamodb_items(self, function_arn, items, source): + def send_dynamodb_items( + self, function_arn: str, items: List[Any], source: str + ) -> Union[str, bytes]: event = { "Records": [ { @@ -1695,11 +1797,16 @@ class LambdaBackend(BaseBackend): ] } func = self._lambdas.get_arn(function_arn) - return func.invoke(json.dumps(event), {}, {}) + return func.invoke(json.dumps(event), {}, {}) # type: ignore[union-attr] def send_log_event( - self, function_arn, filter_name, log_group_name, log_stream_name, log_events - ): + self, + function_arn: str, + filter_name: str, + log_group_name: str, + log_stream_name: str, + log_events: Any, + ) -> None: data = { "messageType": "DATA_MESSAGE", "owner": self.account_id, @@ -1717,56 +1824,67 @@ class LambdaBackend(BaseBackend): event = {"awslogs": {"data": payload_gz_encoded}} func = self._lambdas.get_arn(function_arn) - return func.invoke(json.dumps(event), {}, {}) + func.invoke(json.dumps(event), {}, {}) # type: ignore[union-attr] - def list_tags(self, resource): + def list_tags(self, resource: str) -> Dict[str, str]: return self._lambdas.get_function_by_name_or_arn(resource).tags - def tag_resource(self, resource, tags): + def tag_resource(self, resource: str, tags: Dict[str, str]) -> None: fn = self._lambdas.get_function_by_name_or_arn(resource) fn.tags.update(tags) - def untag_resource(self, resource, tagKeys): + def untag_resource(self, resource: str, tagKeys: List[str]) -> None: fn = self._lambdas.get_function_by_name_or_arn(resource) for key in tagKeys: fn.tags.pop(key, None) - def add_permission(self, function_name, qualifier, raw): + def add_permission( + self, function_name: str, qualifier: str, raw: str + ) -> Dict[str, Any]: fn = self.get_function(function_name, qualifier) - return fn.policy.add_statement(raw, qualifier) + return fn.policy.add_statement(raw, qualifier) # type: ignore[union-attr] - def remove_permission(self, function_name, sid, revision=""): + def remove_permission( + self, function_name: str, sid: str, revision: str = "" + ) -> None: fn = self.get_function(function_name) - fn.policy.del_statement(sid, revision) + fn.policy.del_statement(sid, revision) # type: ignore[union-attr] - def get_code_signing_config(self, function_name): + def get_code_signing_config(self, function_name: str) -> Dict[str, Any]: fn = self.get_function(function_name) return fn.get_code_signing_config() - def get_policy(self, function_name): + def get_policy(self, function_name: str) -> str: fn = self.get_function(function_name) if not fn: raise UnknownFunctionException(function_name) - return fn.policy.wire_format() + return fn.policy.wire_format() # type: ignore[union-attr] - def update_function_code(self, function_name, qualifier, body): - fn = self.get_function(function_name, qualifier) + def update_function_code( + self, function_name: str, qualifier: str, body: Dict[str, Any] + ) -> Optional[Dict[str, Any]]: + fn: LambdaFunction = self.get_function(function_name, qualifier) - if fn: - if body.get("Publish", False): - fn = self.publish_function(function_name) + if body.get("Publish", False): + fn = self.publish_function(function_name) # type: ignore[assignment] - config = fn.update_function_code(body) - return config - else: - return None + return fn.update_function_code(body) - def update_function_configuration(self, function_name, qualifier, body): + def update_function_configuration( + self, function_name: str, qualifier: str, body: Dict[str, Any] + ) -> Optional[Dict[str, Any]]: fn = self.get_function(function_name, qualifier) return fn.update_configuration(body) if fn else None - def invoke(self, function_name, qualifier, body, headers, response_headers): + def invoke( + self, + function_name: str, + qualifier: str, + body: Any, + headers: Any, + response_headers: Any, + ) -> Optional[Union[str, bytes]]: """ Invoking a Function with PackageType=Image is not yet supported. """ @@ -1778,23 +1896,25 @@ class LambdaBackend(BaseBackend): else: return None - def put_function_concurrency(self, function_name, reserved_concurrency): + def put_function_concurrency( + self, function_name: str, reserved_concurrency: str + ) -> str: fn = self.get_function(function_name) fn.reserved_concurrency = reserved_concurrency return fn.reserved_concurrency - def delete_function_concurrency(self, function_name): + def delete_function_concurrency(self, function_name: str) -> Optional[str]: fn = self.get_function(function_name) fn.reserved_concurrency = None return fn.reserved_concurrency - def get_function_concurrency(self, function_name): + def get_function_concurrency(self, function_name: str) -> str: fn = self.get_function(function_name) return fn.reserved_concurrency -def do_validate_s3(): +def do_validate_s3() -> bool: return os.environ.get("VALIDATE_LAMBDA_S3", "") in ["", "1", "true"] -lambda_backends: Mapping[str, LambdaBackend] = BackendDict(LambdaBackend, "lambda") +lambda_backends = BackendDict(LambdaBackend, "lambda") diff --git a/moto/awslambda/policy.py b/moto/awslambda/policy.py index 750f50b6f..2490e43f9 100644 --- a/moto/awslambda/policy.py +++ b/moto/awslambda/policy.py @@ -5,20 +5,24 @@ from moto.awslambda.exceptions import ( UnknownPolicyException, ) from moto.moto_api._internal import mock_random +from typing import Any, Callable, Dict, List, Optional, TypeVar + + +TYPE_IDENTITY = TypeVar("TYPE_IDENTITY") class Policy: - def __init__(self, parent): + def __init__(self, parent: Any): # Parent should be a LambdaFunction self.revision = str(mock_random.uuid4()) - self.statements = [] + self.statements: List[Dict[str, Any]] = [] self.parent = parent - def wire_format(self): + def wire_format(self) -> str: p = self.get_policy() p["Policy"] = json.dumps(p["Policy"]) return json.dumps(p) - def get_policy(self): + def get_policy(self) -> Dict[str, Any]: return { "Policy": { "Version": "2012-10-17", @@ -29,7 +33,9 @@ class Policy: } # adds the raw JSON statement to the policy - def add_statement(self, raw, qualifier=None): + def add_statement( + self, raw: str, qualifier: Optional[str] = None + ) -> Dict[str, Any]: policy = json.loads(raw, object_hook=self.decode_policy) if len(policy.revision) > 0 and self.revision != policy.revision: raise PreconditionFailedException( @@ -49,7 +55,7 @@ class Policy: return policy.statements[0] # removes the statement that matches 'sid' from the policy - def del_statement(self, sid, revision=""): + def del_statement(self, sid: str, revision: str = "") -> None: if len(revision) > 0 and self.revision != revision: raise PreconditionFailedException( "The RevisionId provided does not match the latest RevisionId" @@ -65,7 +71,7 @@ class Policy: # converts AddPermission request to PolicyStatement # https://docs.aws.amazon.com/lambda/latest/dg/API_AddPermission.html - def decode_policy(self, obj): + def decode_policy(self, obj: Dict[str, Any]) -> "Policy": # import pydevd # pydevd.settrace("localhost", port=5678) policy = Policy(self.parent) @@ -101,14 +107,14 @@ class Policy: return policy - def nop_formatter(self, obj): + def nop_formatter(self, obj: TYPE_IDENTITY) -> TYPE_IDENTITY: return obj - def ensure_set(self, obj, key, value): + def ensure_set(self, obj: Dict[str, Any], key: str, value: Any) -> None: if key not in obj: obj[key] = value - def principal_formatter(self, obj): + def principal_formatter(self, obj: Dict[str, Any]) -> Dict[str, Any]: if isinstance(obj, str): if obj.endswith(".amazonaws.com"): return {"Service": obj} @@ -116,27 +122,39 @@ class Policy: return {"AWS": obj} return obj - def source_account_formatter(self, obj): + def source_account_formatter( + self, obj: TYPE_IDENTITY + ) -> Dict[str, Dict[str, TYPE_IDENTITY]]: return {"StringEquals": {"AWS:SourceAccount": obj}} - def source_arn_formatter(self, obj): + def source_arn_formatter( + self, obj: TYPE_IDENTITY + ) -> Dict[str, Dict[str, TYPE_IDENTITY]]: return {"ArnLike": {"AWS:SourceArn": obj}} - def principal_org_id_formatter(self, obj): + def principal_org_id_formatter( + self, obj: TYPE_IDENTITY + ) -> Dict[str, Dict[str, TYPE_IDENTITY]]: return {"StringEquals": {"aws:PrincipalOrgID": obj}} - def transform_property(self, obj, old_name, new_name, formatter): + def transform_property( + self, + obj: Dict[str, Any], + old_name: str, + new_name: str, + formatter: Callable[..., Any], + ) -> None: if old_name in obj: obj[new_name] = formatter(obj[old_name]) if new_name != old_name: del obj[old_name] - def remove_if_set(self, obj, keys): + def remove_if_set(self, obj: Dict[str, Any], keys: List[str]) -> None: for key in keys: if key in obj: del obj[key] - def condition_merge(self, obj): + def condition_merge(self, obj: Dict[str, Any]) -> None: if "SourceArn" in obj: if "Condition" not in obj: obj["Condition"] = {} diff --git a/moto/awslambda/responses.py b/moto/awslambda/responses.py index 88dc81c66..e20a26471 100644 --- a/moto/awslambda/responses.py +++ b/moto/awslambda/responses.py @@ -1,31 +1,27 @@ import json import sys - +from typing import Any, Dict, List, Tuple, Union from urllib.parse import unquote from moto.core.utils import path_url from moto.utilities.aws_headers import amz_crc32, amzn_request_id -from moto.core.responses import BaseResponse -from .models import lambda_backends +from moto.core.responses import BaseResponse, TYPE_RESPONSE +from .models import lambda_backends, LambdaBackend class LambdaResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="awslambda") @property - def json_body(self): - """ - :return: JSON - :rtype: dict - """ + def json_body(self) -> Dict[str, Any]: # type: ignore[misc] return json.loads(self.body) @property - def backend(self): + def backend(self) -> LambdaBackend: return lambda_backends[self.current_account][self.region] - def root(self, request, full_url, headers): + def root(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: self.setup_class(request, full_url, headers) if request.method == "GET": return self._list_functions() @@ -34,7 +30,9 @@ class LambdaResponse(BaseResponse): else: raise ValueError("Cannot handle request") - def event_source_mappings(self, request, full_url, headers): + def event_source_mappings( + self, request: Any, full_url: str, headers: Any + ) -> TYPE_RESPONSE: self.setup_class(request, full_url, headers) if request.method == "GET": querystring = self.querystring @@ -46,12 +44,12 @@ class LambdaResponse(BaseResponse): else: raise ValueError("Cannot handle request") - def aliases(self, request, full_url, headers): + def aliases(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if request.method == "POST": return self._create_alias() - def alias(self, request, full_url, headers): + def alias(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if request.method == "DELETE": return self._delete_alias() @@ -60,7 +58,9 @@ class LambdaResponse(BaseResponse): elif request.method == "PUT": return self._update_alias() - def event_source_mapping(self, request, full_url, headers): + def event_source_mapping( + self, request: Any, full_url: str, headers: Any + ) -> TYPE_RESPONSE: self.setup_class(request, full_url, headers) path = request.path if hasattr(request, "path") else path_url(request.url) uuid = path.split("/")[-1] @@ -73,26 +73,26 @@ class LambdaResponse(BaseResponse): else: raise ValueError("Cannot handle request") - def list_layers(self, request, full_url, headers): + def list_layers(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if request.method == "GET": return self._list_layers() - def layers_version(self, request, full_url, headers): + def layers_version(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if request.method == "DELETE": return self._delete_layer_version() elif request.method == "GET": return self._get_layer_version() - def layers_versions(self, request, full_url, headers): + def layers_versions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if request.method == "GET": return self._get_layer_versions() if request.method == "POST": return self._publish_layer_version() - def function(self, request, full_url, headers): + def function(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: self.setup_class(request, full_url, headers) if request.method == "GET": return self._get_function() @@ -101,7 +101,7 @@ class LambdaResponse(BaseResponse): else: raise ValueError("Cannot handle request") - def versions(self, request, full_url, headers): + def versions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: self.setup_class(request, full_url, headers) if request.method == "GET": # This is ListVersionByFunction @@ -117,7 +117,7 @@ class LambdaResponse(BaseResponse): @amz_crc32 @amzn_request_id - def invoke(self, request, full_url, headers): + def invoke(self, request: Any, full_url: str, headers: Any) -> Tuple[int, Dict[str, str], Union[str, bytes]]: # type: ignore[misc] self.setup_class(request, full_url, headers) if request.method == "POST": return self._invoke(request) @@ -126,14 +126,14 @@ class LambdaResponse(BaseResponse): @amz_crc32 @amzn_request_id - def invoke_async(self, request, full_url, headers): + def invoke_async(self, request: Any, full_url: str, headers: Any) -> Tuple[int, Dict[str, str], Union[str, bytes]]: # type: ignore[misc] self.setup_class(request, full_url, headers) if request.method == "POST": return self._invoke_async() else: raise ValueError("Cannot handle request") - def tag(self, request, full_url, headers): + def tag(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: self.setup_class(request, full_url, headers) if request.method == "GET": return self._list_tags() @@ -144,7 +144,7 @@ class LambdaResponse(BaseResponse): else: raise ValueError("Cannot handle {0} request".format(request.method)) - def policy(self, request, full_url, headers): + def policy(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: self.setup_class(request, full_url, headers) if request.method == "GET": return self._get_policy(request) @@ -155,7 +155,7 @@ class LambdaResponse(BaseResponse): else: raise ValueError("Cannot handle {0} request".format(request.method)) - def configuration(self, request, full_url, headers): + def configuration(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: self.setup_class(request, full_url, headers) if request.method == "PUT": return self._put_configuration() @@ -164,19 +164,21 @@ class LambdaResponse(BaseResponse): else: raise ValueError("Cannot handle request") - def code(self, request, full_url, headers): + def code(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: self.setup_class(request, full_url, headers) if request.method == "PUT": return self._put_code() else: raise ValueError("Cannot handle request") - def code_signing_config(self, request, full_url, headers): + def code_signing_config(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if request.method == "GET": return self._get_code_signing_config() - def function_concurrency(self, request, full_url, headers): + def function_concurrency( + self, request: Any, full_url: str, headers: Any + ) -> TYPE_RESPONSE: http_method = request.method self.setup_class(request, full_url, headers) @@ -189,7 +191,7 @@ class LambdaResponse(BaseResponse): else: raise ValueError("Cannot handle request") - def function_url_config(self, request, full_url, headers): + def function_url_config(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return] http_method = request.method self.setup_class(request, full_url, headers) @@ -202,7 +204,7 @@ class LambdaResponse(BaseResponse): elif http_method == "PUT": return self._update_function_url_config() - def _add_policy(self, request): + def _add_policy(self, request: Any) -> TYPE_RESPONSE: path = request.path if hasattr(request, "path") else path_url(request.url) function_name = unquote(path.split("/")[-2]) qualifier = self.querystring.get("Qualifier", [None])[0] @@ -210,13 +212,13 @@ class LambdaResponse(BaseResponse): statement = self.backend.add_permission(function_name, qualifier, statement) return 200, {}, json.dumps({"Statement": json.dumps(statement)}) - def _get_policy(self, request): + def _get_policy(self, request: Any) -> TYPE_RESPONSE: path = request.path if hasattr(request, "path") else path_url(request.url) function_name = unquote(path.split("/")[-2]) out = self.backend.get_policy(function_name) return 200, {}, out - def _del_policy(self, request, querystring): + def _del_policy(self, request: Any, querystring: Dict[str, Any]) -> TYPE_RESPONSE: path = request.path if hasattr(request, "path") else path_url(request.url) function_name = unquote(path.split("/")[-3]) statement_id = path.split("/")[-1].split("?")[0] @@ -227,8 +229,8 @@ class LambdaResponse(BaseResponse): else: return 404, {}, "{}" - def _invoke(self, request): - response_headers = {} + def _invoke(self, request: Any) -> Tuple[int, Dict[str, str], Union[str, bytes]]: + response_headers: Dict[str, str] = {} # URL Decode in case it's a ARN: function_name = unquote(self.path.rsplit("/", 2)[-2]) @@ -261,8 +263,8 @@ class LambdaResponse(BaseResponse): else: return 404, response_headers, "{}" - def _invoke_async(self): - response_headers = {} + def _invoke_async(self) -> Tuple[int, Dict[str, str], Union[str, bytes]]: + response_headers: Dict[str, Any] = {} function_name = unquote(self.path.rsplit("/", 3)[-3]) @@ -271,10 +273,10 @@ class LambdaResponse(BaseResponse): response_headers["Content-Length"] = str(len(payload)) return 202, response_headers, payload - def _list_functions(self): + def _list_functions(self) -> TYPE_RESPONSE: querystring = self.querystring func_version = querystring.get("FunctionVersion", [None])[0] - result = {"Functions": []} + result: Dict[str, List[Dict[str, Any]]] = {"Functions": []} for fn in self.backend.list_functions(func_version): json_data = fn.get_configuration() @@ -282,67 +284,68 @@ class LambdaResponse(BaseResponse): return 200, {}, json.dumps(result) - def _list_versions_by_function(self, function_name): - result = {"Versions": []} + def _list_versions_by_function(self, function_name: str) -> TYPE_RESPONSE: + result: Dict[str, Any] = {"Versions": []} functions = self.backend.list_versions_by_function(function_name) - if functions: - for fn in functions: - json_data = fn.get_configuration() - result["Versions"].append(json_data) + for fn in functions: + json_data = fn.get_configuration() + result["Versions"].append(json_data) return 200, {}, json.dumps(result) - def _create_function(self): + def _create_function(self) -> TYPE_RESPONSE: fn = self.backend.create_function(self.json_body) config = fn.get_configuration(on_create=True) return 201, {}, json.dumps(config) - def _create_function_url_config(self): + def _create_function_url_config(self) -> TYPE_RESPONSE: function_name = unquote(self.path.split("/")[-2]) config = self.backend.create_function_url_config(function_name, self.json_body) return 201, {}, json.dumps(config.to_dict()) - def _delete_function_url_config(self): + def _delete_function_url_config(self) -> TYPE_RESPONSE: function_name = unquote(self.path.split("/")[-2]) self.backend.delete_function_url_config(function_name) return 204, {}, "{}" - def _get_function_url_config(self): + def _get_function_url_config(self) -> TYPE_RESPONSE: function_name = unquote(self.path.split("/")[-2]) config = self.backend.get_function_url_config(function_name) return 201, {}, json.dumps(config.to_dict()) - def _update_function_url_config(self): + def _update_function_url_config(self) -> TYPE_RESPONSE: function_name = unquote(self.path.split("/")[-2]) config = self.backend.update_function_url_config(function_name, self.json_body) return 200, {}, json.dumps(config.to_dict()) - def _create_event_source_mapping(self): + def _create_event_source_mapping(self) -> TYPE_RESPONSE: fn = self.backend.create_event_source_mapping(self.json_body) config = fn.get_configuration() return 201, {}, json.dumps(config) - def _list_event_source_mappings(self, event_source_arn, function_name): + def _list_event_source_mappings( + self, event_source_arn: str, function_name: str + ) -> TYPE_RESPONSE: esms = self.backend.list_event_source_mappings(event_source_arn, function_name) result = {"EventSourceMappings": [esm.get_configuration() for esm in esms]} return 200, {}, json.dumps(result) - def _get_event_source_mapping(self, uuid): + def _get_event_source_mapping(self, uuid: str) -> TYPE_RESPONSE: result = self.backend.get_event_source_mapping(uuid) if result: return 200, {}, json.dumps(result.get_configuration()) else: return 404, {}, "{}" - def _update_event_source_mapping(self, uuid): + def _update_event_source_mapping(self, uuid: str) -> TYPE_RESPONSE: result = self.backend.update_event_source_mapping(uuid, self.json_body) if result: return 202, {}, json.dumps(result.get_configuration()) else: return 404, {}, "{}" - def _delete_event_source_mapping(self, uuid): + def _delete_event_source_mapping(self, uuid: str) -> TYPE_RESPONSE: esm = self.backend.delete_event_source_mapping(uuid) if esm: json_result = esm.get_configuration() @@ -351,15 +354,15 @@ class LambdaResponse(BaseResponse): else: return 404, {}, "{}" - def _publish_function(self): + def _publish_function(self) -> TYPE_RESPONSE: function_name = unquote(self.path.split("/")[-2]) description = self._get_param("Description") fn = self.backend.publish_function(function_name, description) - config = fn.get_configuration() + config = fn.get_configuration() # type: ignore[union-attr] return 201, {}, json.dumps(config) - def _delete_function(self): + def _delete_function(self) -> TYPE_RESPONSE: function_name = unquote(self.path.rsplit("/", 1)[-1]) qualifier = self._get_param("Qualifier", None) @@ -367,14 +370,14 @@ class LambdaResponse(BaseResponse): return 204, {}, "" @staticmethod - def _set_configuration_qualifier(configuration, qualifier): + def _set_configuration_qualifier(configuration: Dict[str, Any], qualifier: str) -> Dict[str, Any]: # type: ignore[misc] if qualifier is None or qualifier == "$LATEST": configuration["Version"] = "$LATEST" if qualifier == "$LATEST": configuration["FunctionArn"] += ":$LATEST" return configuration - def _get_function(self): + def _get_function(self) -> TYPE_RESPONSE: function_name = unquote(self.path.rsplit("/", 1)[-1]) qualifier = self._get_param("Qualifier", None) @@ -386,7 +389,7 @@ class LambdaResponse(BaseResponse): ) return 200, {}, json.dumps(code) - def _get_function_configuration(self): + def _get_function_configuration(self) -> TYPE_RESPONSE: function_name = unquote(self.path.rsplit("/", 2)[-2]) qualifier = self._get_param("Qualifier", None) @@ -397,33 +400,33 @@ class LambdaResponse(BaseResponse): ) return 200, {}, json.dumps(configuration) - def _get_aws_region(self, full_url): + def _get_aws_region(self, full_url: str) -> str: region = self.region_regex.search(full_url) if region: return region.group(1) else: return self.default_region - def _list_tags(self): + def _list_tags(self) -> TYPE_RESPONSE: function_arn = unquote(self.path.rsplit("/", 1)[-1]) tags = self.backend.list_tags(function_arn) return 200, {}, json.dumps({"Tags": tags}) - def _tag_resource(self): + def _tag_resource(self) -> TYPE_RESPONSE: function_arn = unquote(self.path.rsplit("/", 1)[-1]) self.backend.tag_resource(function_arn, self.json_body["Tags"]) return 200, {}, "{}" - def _untag_resource(self): + def _untag_resource(self) -> TYPE_RESPONSE: function_arn = unquote(self.path.rsplit("/", 1)[-1]) tag_keys = self.querystring["tagKeys"] self.backend.untag_resource(function_arn, tag_keys) return 204, {}, "{}" - def _put_configuration(self): + def _put_configuration(self) -> TYPE_RESPONSE: function_name = unquote(self.path.rsplit("/", 2)[-2]) qualifier = self._get_param("Qualifier", None) resp = self.backend.update_function_configuration( @@ -435,7 +438,7 @@ class LambdaResponse(BaseResponse): else: return 404, {}, "{}" - def _put_code(self): + def _put_code(self) -> TYPE_RESPONSE: function_name = unquote(self.path.rsplit("/", 2)[-2]) qualifier = self._get_param("Qualifier", None) resp = self.backend.update_function_code( @@ -447,12 +450,12 @@ class LambdaResponse(BaseResponse): else: return 404, {}, "{}" - def _get_code_signing_config(self): + def _get_code_signing_config(self) -> TYPE_RESPONSE: function_name = unquote(self.path.rsplit("/", 2)[-2]) resp = self.backend.get_code_signing_config(function_name) return 200, {}, json.dumps(resp) - def _get_function_concurrency(self): + def _get_function_concurrency(self) -> TYPE_RESPONSE: path_function_name = unquote(self.path.rsplit("/", 2)[-2]) function_name = self.backend.get_function(path_function_name) @@ -462,7 +465,7 @@ class LambdaResponse(BaseResponse): resp = self.backend.get_function_concurrency(path_function_name) return 200, {}, json.dumps({"ReservedConcurrentExecutions": resp}) - def _delete_function_concurrency(self): + def _delete_function_concurrency(self) -> TYPE_RESPONSE: path_function_name = unquote(self.path.rsplit("/", 2)[-2]) function_name = self.backend.get_function(path_function_name) @@ -473,7 +476,7 @@ class LambdaResponse(BaseResponse): return 204, {}, "{}" - def _put_function_concurrency(self): + def _put_function_concurrency(self) -> TYPE_RESPONSE: path_function_name = unquote(self.path.rsplit("/", 2)[-2]) function = self.backend.get_function(path_function_name) @@ -485,25 +488,25 @@ class LambdaResponse(BaseResponse): return 200, {}, json.dumps({"ReservedConcurrentExecutions": resp}) - def _list_layers(self): + def _list_layers(self) -> TYPE_RESPONSE: layers = self.backend.list_layers() return 200, {}, json.dumps({"Layers": layers}) - def _delete_layer_version(self): + def _delete_layer_version(self) -> TYPE_RESPONSE: layer_name = self.path.split("/")[-3] layer_version = self.path.split("/")[-1] self.backend.delete_layer_version(layer_name, layer_version) return 200, {}, "{}" - def _get_layer_version(self): + def _get_layer_version(self) -> TYPE_RESPONSE: layer_name = self.path.split("/")[-3] layer_version = self.path.split("/")[-1] layer = self.backend.get_layer_version(layer_name, layer_version) return 200, {}, json.dumps(layer.get_layer_version()) - def _get_layer_versions(self): + def _get_layer_versions(self) -> TYPE_RESPONSE: layer_name = self.path.rsplit("/", 2)[-2] layer_versions = self.backend.get_layer_versions(layer_name) return ( @@ -514,7 +517,7 @@ class LambdaResponse(BaseResponse): ), ) - def _publish_layer_version(self): + def _publish_layer_version(self) -> TYPE_RESPONSE: spec = self.json_body if "LayerName" not in spec: spec["LayerName"] = self.path.rsplit("/", 2)[-2] @@ -522,7 +525,7 @@ class LambdaResponse(BaseResponse): config = layer_version.get_layer_version() return 201, {}, json.dumps(config) - def _create_alias(self): + def _create_alias(self) -> TYPE_RESPONSE: function_name = unquote(self.path.rsplit("/", 2)[-2]) params = json.loads(self.body) alias_name = params.get("Name") @@ -538,19 +541,19 @@ class LambdaResponse(BaseResponse): ) return 201, {}, json.dumps(alias.to_json()) - def _delete_alias(self): + def _delete_alias(self) -> TYPE_RESPONSE: function_name = unquote(self.path.rsplit("/")[-3]) alias_name = unquote(self.path.rsplit("/", 2)[-1]) self.backend.delete_alias(name=alias_name, function_name=function_name) return 201, {}, "{}" - def _get_alias(self): + def _get_alias(self) -> TYPE_RESPONSE: function_name = unquote(self.path.rsplit("/")[-3]) alias_name = unquote(self.path.rsplit("/", 2)[-1]) alias = self.backend.get_alias(name=alias_name, function_name=function_name) return 201, {}, json.dumps(alias.to_json()) - def _update_alias(self): + def _update_alias(self) -> TYPE_RESPONSE: function_name = unquote(self.path.rsplit("/")[-3]) alias_name = unquote(self.path.rsplit("/", 2)[-1]) params = json.loads(self.body) diff --git a/moto/awslambda/utils.py b/moto/awslambda/utils.py index 99fbae3f9..c0c0775ee 100644 --- a/moto/awslambda/utils.py +++ b/moto/awslambda/utils.py @@ -1,11 +1,12 @@ from collections import namedtuple from functools import partial +from typing import Any, Callable ARN = namedtuple("ARN", ["region", "account", "function_name", "version"]) LAYER_ARN = namedtuple("LAYER_ARN", ["region", "account", "layer_name", "version"]) -def make_arn(resource_type, region, account, name): +def make_arn(resource_type: str, region: str, account: str, name: str) -> str: return "arn:aws:lambda:{0}:{1}:{2}:{3}".format(region, account, resource_type, name) @@ -13,7 +14,9 @@ make_function_arn = partial(make_arn, "function") make_layer_arn = partial(make_arn, "layer") -def make_ver_arn(resource_type, region, account, name, version="1"): +def make_ver_arn( + resource_type: str, region: str, account: str, name: str, version: str = "1" +) -> str: arn = make_arn(resource_type, region, account, name) return "{0}:{1}".format(arn, version) @@ -22,7 +25,7 @@ make_function_ver_arn = partial(make_ver_arn, "function") make_layer_ver_arn = partial(make_ver_arn, "layer") -def split_arn(arn_type, arn): +def split_arn(arn_type: Callable[[str, str, str, str], str], arn: str) -> Any: arn = arn.replace("arn:aws:lambda:", "") region, account, _, name, version = arn.split(":") diff --git a/moto/core/utils.py b/moto/core/utils.py index 4fa4416df..f1959669f 100644 --- a/moto/core/utils.py +++ b/moto/core/utils.py @@ -175,14 +175,14 @@ def str_to_rfc_1123_datetime(value): return datetime.datetime.strptime(value, RFC1123) -def unix_time(dt: datetime.datetime = None) -> datetime.datetime: +def unix_time(dt: datetime.datetime = None) -> int: dt = dt or datetime.datetime.utcnow() epoch = datetime.datetime.utcfromtimestamp(0) delta = dt - epoch return (delta.days * 86400) + (delta.seconds + (delta.microseconds / 1e6)) -def unix_time_millis(dt=None): +def unix_time_millis(dt: datetime = None) -> int: return unix_time(dt) * 1000.0 diff --git a/moto/iam/models.py b/moto/iam/models.py index 2d84a9681..27fd89ec0 100644 --- a/moto/iam/models.py +++ b/moto/iam/models.py @@ -1917,7 +1917,7 @@ class IAMBackend(BaseBackend): return role raise IAMNotFoundException("Role {0} not found".format(role_name)) - def get_role_by_arn(self, arn): + def get_role_by_arn(self, arn: str) -> Role: for role in self.get_roles(): if role.arn == arn: return role diff --git a/moto/settings.py b/moto/settings.py index b6ece4171..9d9a926f5 100644 --- a/moto/settings.py +++ b/moto/settings.py @@ -68,37 +68,37 @@ def allow_unknown_region(): return os.environ.get("MOTO_ALLOW_NONEXISTENT_REGION", "false").lower() == "true" -def moto_server_port(): +def moto_server_port() -> str: return os.environ.get("MOTO_PORT") or "5000" @lru_cache() -def moto_server_host(): +def moto_server_host() -> str: if is_docker(): return get_docker_host() else: return "http://host.docker.internal" -def moto_lambda_image(): +def moto_lambda_image() -> str: return os.environ.get("MOTO_DOCKER_LAMBDA_IMAGE", "lambci/lambda") -def moto_network_name(): +def moto_network_name() -> str: return os.environ.get("MOTO_DOCKER_NETWORK_NAME") -def moto_network_mode(): +def moto_network_mode() -> str: return os.environ.get("MOTO_DOCKER_NETWORK_MODE") -def test_server_mode_endpoint(): +def test_server_mode_endpoint() -> str: return os.environ.get( "TEST_SERVER_MODE_ENDPOINT", f"http://localhost:{moto_server_port()}" ) -def is_docker(): +def is_docker() -> bool: path = pathlib.Path("/proc/self/cgroup") return ( os.path.exists("/.dockerenv") @@ -107,7 +107,7 @@ def is_docker(): ) -def get_docker_host(): +def get_docker_host() -> str: try: cmd = "curl -s --unix-socket /run/docker.sock http://docker/containers/$HOSTNAME/json" container_info = os.popen(cmd).read() diff --git a/moto/utilities/docker_utilities.py b/moto/utilities/docker_utilities.py index e049d8902..4e152194a 100644 --- a/moto/utilities/docker_utilities.py +++ b/moto/utilities/docker_utilities.py @@ -1,5 +1,6 @@ import functools import requests.adapters +from typing import Tuple from moto import settings @@ -8,7 +9,7 @@ _orig_adapter_send = requests.adapters.HTTPAdapter.send class DockerModel: - def __init__(self): + def __init__(self) -> None: self.__docker_client = None @property @@ -36,7 +37,7 @@ class DockerModel: return self.__docker_client -def parse_image_ref(image_name): +def parse_image_ref(image_name: str) -> Tuple[str, str]: # podman does not support short container image name out of box - try to make a full name # See ParseDockerRef() in https://github.com/distribution/distribution/blob/main/reference/normalize.go parts = image_name.split("/") diff --git a/setup.cfg b/setup.cfg index d7206a327..a70a36960 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,7 +18,7 @@ disable = W,C,R,E enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import [mypy] -files= moto/acm,moto/amp,moto/apigateway,moto/apigatewayv2,moto/applicationautoscaling/,moto/appsync,moto/athena,moto/autoscaling +files= moto/a* show_column_numbers=True show_error_codes = True disable_error_code=abstract