TechDebt: MyPy AWSLambda (#5586)

This commit is contained in:
Bert Blommers 2022-10-22 11:40:20 +00:00 committed by GitHub
parent 6f710189ce
commit 8c88a93d7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 536 additions and 392 deletions

View File

@ -51,19 +51,20 @@ class GraphqlSchema(BaseModel):
class GraphqlAPIKey(BaseModel):
def __init__(self, description: str, expires: Optional[datetime]):
def __init__(self, description: str, expires: Optional[int]):
self.key_id = str(mock_random.uuid4())[0:6]
self.description = description
self.expires = expires
if not self.expires:
if not expires:
default_expiry = datetime.now(timezone.utc)
default_expiry = default_expiry.replace(
minute=0, second=0, microsecond=0, tzinfo=None
)
default_expiry = default_expiry + timedelta(days=7)
self.expires = unix_time(default_expiry)
else:
self.expires = expires
def update(self, description: Optional[str], expires: Optional[datetime]) -> None:
def update(self, description: Optional[str], expires: Optional[int]) -> None:
if description:
self.description = description
if expires:
@ -138,9 +139,7 @@ class GraphqlAPI(BaseModel):
if xray_enabled is not None:
self.xray_enabled = xray_enabled
def create_api_key(
self, description: str, expires: Optional[datetime]
) -> GraphqlAPIKey:
def create_api_key(self, description: str, expires: Optional[int]) -> GraphqlAPIKey:
api_key = GraphqlAPIKey(description, expires)
self.api_keys[api_key.key_id] = api_key
return api_key
@ -152,7 +151,7 @@ class GraphqlAPI(BaseModel):
self.api_keys.pop(api_key_id)
def update_api_key(
self, api_key_id: str, description: str, expires: Optional[datetime]
self, api_key_id: str, description: str, expires: Optional[int]
) -> GraphqlAPIKey:
api_key = self.api_keys[api_key_id]
api_key.update(description, expires)
@ -265,7 +264,7 @@ class AppSyncBackend(BaseBackend):
return self.graphql_apis.values()
def create_api_key(
self, api_id: str, description: str, expires: Optional[datetime]
self, api_id: str, description: str, expires: Optional[int]
) -> GraphqlAPIKey:
return self.graphql_apis[api_id].create_api_key(description, expires)
@ -286,7 +285,7 @@ class AppSyncBackend(BaseBackend):
api_id: str,
api_key_id: str,
description: str,
expires: Optional[datetime],
expires: Optional[int],
) -> GraphqlAPIKey:
return self.graphql_apis[api_id].update_api_key(
api_key_id, description, expires

View File

@ -2,26 +2,26 @@ from moto.core.exceptions import JsonRESTError
class LambdaClientError(JsonRESTError):
def __init__(self, error, message):
def __init__(self, error: str, message: str):
super().__init__(error, message)
class CrossAccountNotAllowed(LambdaClientError):
def __init__(self):
def __init__(self) -> None:
super().__init__(
"AccessDeniedException", "Cross-account pass role is not allowed."
)
class InvalidParameterValueException(LambdaClientError):
def __init__(self, message):
def __init__(self, message: str):
super().__init__("InvalidParameterValueException", message)
class InvalidRoleFormat(LambdaClientError):
pattern = r"arn:(aws[a-zA-Z-]*)?:iam::(\d{12}):role/?[a-zA-Z_0-9+=,.@\-_/]+"
def __init__(self, role):
def __init__(self, role: str):
message = "1 validation error detected: Value '{0}' at 'role' failed to satisfy constraint: Member must satisfy regular expression pattern: {1}".format(
role, InvalidRoleFormat.pattern
)
@ -31,28 +31,28 @@ class InvalidRoleFormat(LambdaClientError):
class PreconditionFailedException(JsonRESTError):
code = 412
def __init__(self, message):
def __init__(self, message: str):
super().__init__("PreconditionFailedException", message)
class UnknownAliasException(LambdaClientError):
code = 404
def __init__(self, arn):
def __init__(self, arn: str):
super().__init__("ResourceNotFoundException", f"Cannot find alias arn: {arn}")
class UnknownFunctionException(LambdaClientError):
code = 404
def __init__(self, arn):
def __init__(self, arn: str):
super().__init__("ResourceNotFoundException", f"Function not found: {arn}")
class FunctionUrlConfigNotFound(LambdaClientError):
code = 404
def __init__(self):
def __init__(self) -> None:
super().__init__(
"ResourceNotFoundException", "The resource you requested does not exist."
)
@ -61,14 +61,14 @@ class FunctionUrlConfigNotFound(LambdaClientError):
class UnknownLayerException(LambdaClientError):
code = 404
def __init__(self):
def __init__(self) -> None:
super().__init__("ResourceNotFoundException", "Cannot find layer")
class UnknownPolicyException(LambdaClientError):
code = 404
def __init__(self):
def __init__(self) -> None:
super().__init__(
"ResourceNotFoundException",
"No policy is associated with the given resource.",

File diff suppressed because it is too large Load Diff

View File

@ -5,20 +5,24 @@ from moto.awslambda.exceptions import (
UnknownPolicyException,
)
from moto.moto_api._internal import mock_random
from typing import Any, Callable, Dict, List, Optional, TypeVar
TYPE_IDENTITY = TypeVar("TYPE_IDENTITY")
class Policy:
def __init__(self, parent):
def __init__(self, parent: Any): # Parent should be a LambdaFunction
self.revision = str(mock_random.uuid4())
self.statements = []
self.statements: List[Dict[str, Any]] = []
self.parent = parent
def wire_format(self):
def wire_format(self) -> str:
p = self.get_policy()
p["Policy"] = json.dumps(p["Policy"])
return json.dumps(p)
def get_policy(self):
def get_policy(self) -> Dict[str, Any]:
return {
"Policy": {
"Version": "2012-10-17",
@ -29,7 +33,9 @@ class Policy:
}
# adds the raw JSON statement to the policy
def add_statement(self, raw, qualifier=None):
def add_statement(
self, raw: str, qualifier: Optional[str] = None
) -> Dict[str, Any]:
policy = json.loads(raw, object_hook=self.decode_policy)
if len(policy.revision) > 0 and self.revision != policy.revision:
raise PreconditionFailedException(
@ -49,7 +55,7 @@ class Policy:
return policy.statements[0]
# removes the statement that matches 'sid' from the policy
def del_statement(self, sid, revision=""):
def del_statement(self, sid: str, revision: str = "") -> None:
if len(revision) > 0 and self.revision != revision:
raise PreconditionFailedException(
"The RevisionId provided does not match the latest RevisionId"
@ -65,7 +71,7 @@ class Policy:
# converts AddPermission request to PolicyStatement
# https://docs.aws.amazon.com/lambda/latest/dg/API_AddPermission.html
def decode_policy(self, obj):
def decode_policy(self, obj: Dict[str, Any]) -> "Policy":
# import pydevd
# pydevd.settrace("localhost", port=5678)
policy = Policy(self.parent)
@ -101,14 +107,14 @@ class Policy:
return policy
def nop_formatter(self, obj):
def nop_formatter(self, obj: TYPE_IDENTITY) -> TYPE_IDENTITY:
return obj
def ensure_set(self, obj, key, value):
def ensure_set(self, obj: Dict[str, Any], key: str, value: Any) -> None:
if key not in obj:
obj[key] = value
def principal_formatter(self, obj):
def principal_formatter(self, obj: Dict[str, Any]) -> Dict[str, Any]:
if isinstance(obj, str):
if obj.endswith(".amazonaws.com"):
return {"Service": obj}
@ -116,27 +122,39 @@ class Policy:
return {"AWS": obj}
return obj
def source_account_formatter(self, obj):
def source_account_formatter(
self, obj: TYPE_IDENTITY
) -> Dict[str, Dict[str, TYPE_IDENTITY]]:
return {"StringEquals": {"AWS:SourceAccount": obj}}
def source_arn_formatter(self, obj):
def source_arn_formatter(
self, obj: TYPE_IDENTITY
) -> Dict[str, Dict[str, TYPE_IDENTITY]]:
return {"ArnLike": {"AWS:SourceArn": obj}}
def principal_org_id_formatter(self, obj):
def principal_org_id_formatter(
self, obj: TYPE_IDENTITY
) -> Dict[str, Dict[str, TYPE_IDENTITY]]:
return {"StringEquals": {"aws:PrincipalOrgID": obj}}
def transform_property(self, obj, old_name, new_name, formatter):
def transform_property(
self,
obj: Dict[str, Any],
old_name: str,
new_name: str,
formatter: Callable[..., Any],
) -> None:
if old_name in obj:
obj[new_name] = formatter(obj[old_name])
if new_name != old_name:
del obj[old_name]
def remove_if_set(self, obj, keys):
def remove_if_set(self, obj: Dict[str, Any], keys: List[str]) -> None:
for key in keys:
if key in obj:
del obj[key]
def condition_merge(self, obj):
def condition_merge(self, obj: Dict[str, Any]) -> None:
if "SourceArn" in obj:
if "Condition" not in obj:
obj["Condition"] = {}

View File

@ -1,31 +1,27 @@
import json
import sys
from typing import Any, Dict, List, Tuple, Union
from urllib.parse import unquote
from moto.core.utils import path_url
from moto.utilities.aws_headers import amz_crc32, amzn_request_id
from moto.core.responses import BaseResponse
from .models import lambda_backends
from moto.core.responses import BaseResponse, TYPE_RESPONSE
from .models import lambda_backends, LambdaBackend
class LambdaResponse(BaseResponse):
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="awslambda")
@property
def json_body(self):
"""
:return: JSON
:rtype: dict
"""
def json_body(self) -> Dict[str, Any]: # type: ignore[misc]
return json.loads(self.body)
@property
def backend(self):
def backend(self) -> LambdaBackend:
return lambda_backends[self.current_account][self.region]
def root(self, request, full_url, headers):
def root(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers)
if request.method == "GET":
return self._list_functions()
@ -34,7 +30,9 @@ class LambdaResponse(BaseResponse):
else:
raise ValueError("Cannot handle request")
def event_source_mappings(self, request, full_url, headers):
def event_source_mappings(
self, request: Any, full_url: str, headers: Any
) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers)
if request.method == "GET":
querystring = self.querystring
@ -46,12 +44,12 @@ class LambdaResponse(BaseResponse):
else:
raise ValueError("Cannot handle request")
def aliases(self, request, full_url, headers):
def aliases(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if request.method == "POST":
return self._create_alias()
def alias(self, request, full_url, headers):
def alias(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if request.method == "DELETE":
return self._delete_alias()
@ -60,7 +58,9 @@ class LambdaResponse(BaseResponse):
elif request.method == "PUT":
return self._update_alias()
def event_source_mapping(self, request, full_url, headers):
def event_source_mapping(
self, request: Any, full_url: str, headers: Any
) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers)
path = request.path if hasattr(request, "path") else path_url(request.url)
uuid = path.split("/")[-1]
@ -73,26 +73,26 @@ class LambdaResponse(BaseResponse):
else:
raise ValueError("Cannot handle request")
def list_layers(self, request, full_url, headers):
def list_layers(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if request.method == "GET":
return self._list_layers()
def layers_version(self, request, full_url, headers):
def layers_version(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if request.method == "DELETE":
return self._delete_layer_version()
elif request.method == "GET":
return self._get_layer_version()
def layers_versions(self, request, full_url, headers):
def layers_versions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if request.method == "GET":
return self._get_layer_versions()
if request.method == "POST":
return self._publish_layer_version()
def function(self, request, full_url, headers):
def function(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers)
if request.method == "GET":
return self._get_function()
@ -101,7 +101,7 @@ class LambdaResponse(BaseResponse):
else:
raise ValueError("Cannot handle request")
def versions(self, request, full_url, headers):
def versions(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers)
if request.method == "GET":
# This is ListVersionByFunction
@ -117,7 +117,7 @@ class LambdaResponse(BaseResponse):
@amz_crc32
@amzn_request_id
def invoke(self, request, full_url, headers):
def invoke(self, request: Any, full_url: str, headers: Any) -> Tuple[int, Dict[str, str], Union[str, bytes]]: # type: ignore[misc]
self.setup_class(request, full_url, headers)
if request.method == "POST":
return self._invoke(request)
@ -126,14 +126,14 @@ class LambdaResponse(BaseResponse):
@amz_crc32
@amzn_request_id
def invoke_async(self, request, full_url, headers):
def invoke_async(self, request: Any, full_url: str, headers: Any) -> Tuple[int, Dict[str, str], Union[str, bytes]]: # type: ignore[misc]
self.setup_class(request, full_url, headers)
if request.method == "POST":
return self._invoke_async()
else:
raise ValueError("Cannot handle request")
def tag(self, request, full_url, headers):
def tag(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers)
if request.method == "GET":
return self._list_tags()
@ -144,7 +144,7 @@ class LambdaResponse(BaseResponse):
else:
raise ValueError("Cannot handle {0} request".format(request.method))
def policy(self, request, full_url, headers):
def policy(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers)
if request.method == "GET":
return self._get_policy(request)
@ -155,7 +155,7 @@ class LambdaResponse(BaseResponse):
else:
raise ValueError("Cannot handle {0} request".format(request.method))
def configuration(self, request, full_url, headers):
def configuration(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers)
if request.method == "PUT":
return self._put_configuration()
@ -164,19 +164,21 @@ class LambdaResponse(BaseResponse):
else:
raise ValueError("Cannot handle request")
def code(self, request, full_url, headers):
def code(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE:
self.setup_class(request, full_url, headers)
if request.method == "PUT":
return self._put_code()
else:
raise ValueError("Cannot handle request")
def code_signing_config(self, request, full_url, headers):
def code_signing_config(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
self.setup_class(request, full_url, headers)
if request.method == "GET":
return self._get_code_signing_config()
def function_concurrency(self, request, full_url, headers):
def function_concurrency(
self, request: Any, full_url: str, headers: Any
) -> TYPE_RESPONSE:
http_method = request.method
self.setup_class(request, full_url, headers)
@ -189,7 +191,7 @@ class LambdaResponse(BaseResponse):
else:
raise ValueError("Cannot handle request")
def function_url_config(self, request, full_url, headers):
def function_url_config(self, request: Any, full_url: str, headers: Any) -> TYPE_RESPONSE: # type: ignore[return]
http_method = request.method
self.setup_class(request, full_url, headers)
@ -202,7 +204,7 @@ class LambdaResponse(BaseResponse):
elif http_method == "PUT":
return self._update_function_url_config()
def _add_policy(self, request):
def _add_policy(self, request: Any) -> TYPE_RESPONSE:
path = request.path if hasattr(request, "path") else path_url(request.url)
function_name = unquote(path.split("/")[-2])
qualifier = self.querystring.get("Qualifier", [None])[0]
@ -210,13 +212,13 @@ class LambdaResponse(BaseResponse):
statement = self.backend.add_permission(function_name, qualifier, statement)
return 200, {}, json.dumps({"Statement": json.dumps(statement)})
def _get_policy(self, request):
def _get_policy(self, request: Any) -> TYPE_RESPONSE:
path = request.path if hasattr(request, "path") else path_url(request.url)
function_name = unquote(path.split("/")[-2])
out = self.backend.get_policy(function_name)
return 200, {}, out
def _del_policy(self, request, querystring):
def _del_policy(self, request: Any, querystring: Dict[str, Any]) -> TYPE_RESPONSE:
path = request.path if hasattr(request, "path") else path_url(request.url)
function_name = unquote(path.split("/")[-3])
statement_id = path.split("/")[-1].split("?")[0]
@ -227,8 +229,8 @@ class LambdaResponse(BaseResponse):
else:
return 404, {}, "{}"
def _invoke(self, request):
response_headers = {}
def _invoke(self, request: Any) -> Tuple[int, Dict[str, str], Union[str, bytes]]:
response_headers: Dict[str, str] = {}
# URL Decode in case it's a ARN:
function_name = unquote(self.path.rsplit("/", 2)[-2])
@ -261,8 +263,8 @@ class LambdaResponse(BaseResponse):
else:
return 404, response_headers, "{}"
def _invoke_async(self):
response_headers = {}
def _invoke_async(self) -> Tuple[int, Dict[str, str], Union[str, bytes]]:
response_headers: Dict[str, Any] = {}
function_name = unquote(self.path.rsplit("/", 3)[-3])
@ -271,10 +273,10 @@ class LambdaResponse(BaseResponse):
response_headers["Content-Length"] = str(len(payload))
return 202, response_headers, payload
def _list_functions(self):
def _list_functions(self) -> TYPE_RESPONSE:
querystring = self.querystring
func_version = querystring.get("FunctionVersion", [None])[0]
result = {"Functions": []}
result: Dict[str, List[Dict[str, Any]]] = {"Functions": []}
for fn in self.backend.list_functions(func_version):
json_data = fn.get_configuration()
@ -282,67 +284,68 @@ class LambdaResponse(BaseResponse):
return 200, {}, json.dumps(result)
def _list_versions_by_function(self, function_name):
result = {"Versions": []}
def _list_versions_by_function(self, function_name: str) -> TYPE_RESPONSE:
result: Dict[str, Any] = {"Versions": []}
functions = self.backend.list_versions_by_function(function_name)
if functions:
for fn in functions:
json_data = fn.get_configuration()
result["Versions"].append(json_data)
return 200, {}, json.dumps(result)
def _create_function(self):
def _create_function(self) -> TYPE_RESPONSE:
fn = self.backend.create_function(self.json_body)
config = fn.get_configuration(on_create=True)
return 201, {}, json.dumps(config)
def _create_function_url_config(self):
def _create_function_url_config(self) -> TYPE_RESPONSE:
function_name = unquote(self.path.split("/")[-2])
config = self.backend.create_function_url_config(function_name, self.json_body)
return 201, {}, json.dumps(config.to_dict())
def _delete_function_url_config(self):
def _delete_function_url_config(self) -> TYPE_RESPONSE:
function_name = unquote(self.path.split("/")[-2])
self.backend.delete_function_url_config(function_name)
return 204, {}, "{}"
def _get_function_url_config(self):
def _get_function_url_config(self) -> TYPE_RESPONSE:
function_name = unquote(self.path.split("/")[-2])
config = self.backend.get_function_url_config(function_name)
return 201, {}, json.dumps(config.to_dict())
def _update_function_url_config(self):
def _update_function_url_config(self) -> TYPE_RESPONSE:
function_name = unquote(self.path.split("/")[-2])
config = self.backend.update_function_url_config(function_name, self.json_body)
return 200, {}, json.dumps(config.to_dict())
def _create_event_source_mapping(self):
def _create_event_source_mapping(self) -> TYPE_RESPONSE:
fn = self.backend.create_event_source_mapping(self.json_body)
config = fn.get_configuration()
return 201, {}, json.dumps(config)
def _list_event_source_mappings(self, event_source_arn, function_name):
def _list_event_source_mappings(
self, event_source_arn: str, function_name: str
) -> TYPE_RESPONSE:
esms = self.backend.list_event_source_mappings(event_source_arn, function_name)
result = {"EventSourceMappings": [esm.get_configuration() for esm in esms]}
return 200, {}, json.dumps(result)
def _get_event_source_mapping(self, uuid):
def _get_event_source_mapping(self, uuid: str) -> TYPE_RESPONSE:
result = self.backend.get_event_source_mapping(uuid)
if result:
return 200, {}, json.dumps(result.get_configuration())
else:
return 404, {}, "{}"
def _update_event_source_mapping(self, uuid):
def _update_event_source_mapping(self, uuid: str) -> TYPE_RESPONSE:
result = self.backend.update_event_source_mapping(uuid, self.json_body)
if result:
return 202, {}, json.dumps(result.get_configuration())
else:
return 404, {}, "{}"
def _delete_event_source_mapping(self, uuid):
def _delete_event_source_mapping(self, uuid: str) -> TYPE_RESPONSE:
esm = self.backend.delete_event_source_mapping(uuid)
if esm:
json_result = esm.get_configuration()
@ -351,15 +354,15 @@ class LambdaResponse(BaseResponse):
else:
return 404, {}, "{}"
def _publish_function(self):
def _publish_function(self) -> TYPE_RESPONSE:
function_name = unquote(self.path.split("/")[-2])
description = self._get_param("Description")
fn = self.backend.publish_function(function_name, description)
config = fn.get_configuration()
config = fn.get_configuration() # type: ignore[union-attr]
return 201, {}, json.dumps(config)
def _delete_function(self):
def _delete_function(self) -> TYPE_RESPONSE:
function_name = unquote(self.path.rsplit("/", 1)[-1])
qualifier = self._get_param("Qualifier", None)
@ -367,14 +370,14 @@ class LambdaResponse(BaseResponse):
return 204, {}, ""
@staticmethod
def _set_configuration_qualifier(configuration, qualifier):
def _set_configuration_qualifier(configuration: Dict[str, Any], qualifier: str) -> Dict[str, Any]: # type: ignore[misc]
if qualifier is None or qualifier == "$LATEST":
configuration["Version"] = "$LATEST"
if qualifier == "$LATEST":
configuration["FunctionArn"] += ":$LATEST"
return configuration
def _get_function(self):
def _get_function(self) -> TYPE_RESPONSE:
function_name = unquote(self.path.rsplit("/", 1)[-1])
qualifier = self._get_param("Qualifier", None)
@ -386,7 +389,7 @@ class LambdaResponse(BaseResponse):
)
return 200, {}, json.dumps(code)
def _get_function_configuration(self):
def _get_function_configuration(self) -> TYPE_RESPONSE:
function_name = unquote(self.path.rsplit("/", 2)[-2])
qualifier = self._get_param("Qualifier", None)
@ -397,33 +400,33 @@ class LambdaResponse(BaseResponse):
)
return 200, {}, json.dumps(configuration)
def _get_aws_region(self, full_url):
def _get_aws_region(self, full_url: str) -> str:
region = self.region_regex.search(full_url)
if region:
return region.group(1)
else:
return self.default_region
def _list_tags(self):
def _list_tags(self) -> TYPE_RESPONSE:
function_arn = unquote(self.path.rsplit("/", 1)[-1])
tags = self.backend.list_tags(function_arn)
return 200, {}, json.dumps({"Tags": tags})
def _tag_resource(self):
def _tag_resource(self) -> TYPE_RESPONSE:
function_arn = unquote(self.path.rsplit("/", 1)[-1])
self.backend.tag_resource(function_arn, self.json_body["Tags"])
return 200, {}, "{}"
def _untag_resource(self):
def _untag_resource(self) -> TYPE_RESPONSE:
function_arn = unquote(self.path.rsplit("/", 1)[-1])
tag_keys = self.querystring["tagKeys"]
self.backend.untag_resource(function_arn, tag_keys)
return 204, {}, "{}"
def _put_configuration(self):
def _put_configuration(self) -> TYPE_RESPONSE:
function_name = unquote(self.path.rsplit("/", 2)[-2])
qualifier = self._get_param("Qualifier", None)
resp = self.backend.update_function_configuration(
@ -435,7 +438,7 @@ class LambdaResponse(BaseResponse):
else:
return 404, {}, "{}"
def _put_code(self):
def _put_code(self) -> TYPE_RESPONSE:
function_name = unquote(self.path.rsplit("/", 2)[-2])
qualifier = self._get_param("Qualifier", None)
resp = self.backend.update_function_code(
@ -447,12 +450,12 @@ class LambdaResponse(BaseResponse):
else:
return 404, {}, "{}"
def _get_code_signing_config(self):
def _get_code_signing_config(self) -> TYPE_RESPONSE:
function_name = unquote(self.path.rsplit("/", 2)[-2])
resp = self.backend.get_code_signing_config(function_name)
return 200, {}, json.dumps(resp)
def _get_function_concurrency(self):
def _get_function_concurrency(self) -> TYPE_RESPONSE:
path_function_name = unquote(self.path.rsplit("/", 2)[-2])
function_name = self.backend.get_function(path_function_name)
@ -462,7 +465,7 @@ class LambdaResponse(BaseResponse):
resp = self.backend.get_function_concurrency(path_function_name)
return 200, {}, json.dumps({"ReservedConcurrentExecutions": resp})
def _delete_function_concurrency(self):
def _delete_function_concurrency(self) -> TYPE_RESPONSE:
path_function_name = unquote(self.path.rsplit("/", 2)[-2])
function_name = self.backend.get_function(path_function_name)
@ -473,7 +476,7 @@ class LambdaResponse(BaseResponse):
return 204, {}, "{}"
def _put_function_concurrency(self):
def _put_function_concurrency(self) -> TYPE_RESPONSE:
path_function_name = unquote(self.path.rsplit("/", 2)[-2])
function = self.backend.get_function(path_function_name)
@ -485,25 +488,25 @@ class LambdaResponse(BaseResponse):
return 200, {}, json.dumps({"ReservedConcurrentExecutions": resp})
def _list_layers(self):
def _list_layers(self) -> TYPE_RESPONSE:
layers = self.backend.list_layers()
return 200, {}, json.dumps({"Layers": layers})
def _delete_layer_version(self):
def _delete_layer_version(self) -> TYPE_RESPONSE:
layer_name = self.path.split("/")[-3]
layer_version = self.path.split("/")[-1]
self.backend.delete_layer_version(layer_name, layer_version)
return 200, {}, "{}"
def _get_layer_version(self):
def _get_layer_version(self) -> TYPE_RESPONSE:
layer_name = self.path.split("/")[-3]
layer_version = self.path.split("/")[-1]
layer = self.backend.get_layer_version(layer_name, layer_version)
return 200, {}, json.dumps(layer.get_layer_version())
def _get_layer_versions(self):
def _get_layer_versions(self) -> TYPE_RESPONSE:
layer_name = self.path.rsplit("/", 2)[-2]
layer_versions = self.backend.get_layer_versions(layer_name)
return (
@ -514,7 +517,7 @@ class LambdaResponse(BaseResponse):
),
)
def _publish_layer_version(self):
def _publish_layer_version(self) -> TYPE_RESPONSE:
spec = self.json_body
if "LayerName" not in spec:
spec["LayerName"] = self.path.rsplit("/", 2)[-2]
@ -522,7 +525,7 @@ class LambdaResponse(BaseResponse):
config = layer_version.get_layer_version()
return 201, {}, json.dumps(config)
def _create_alias(self):
def _create_alias(self) -> TYPE_RESPONSE:
function_name = unquote(self.path.rsplit("/", 2)[-2])
params = json.loads(self.body)
alias_name = params.get("Name")
@ -538,19 +541,19 @@ class LambdaResponse(BaseResponse):
)
return 201, {}, json.dumps(alias.to_json())
def _delete_alias(self):
def _delete_alias(self) -> TYPE_RESPONSE:
function_name = unquote(self.path.rsplit("/")[-3])
alias_name = unquote(self.path.rsplit("/", 2)[-1])
self.backend.delete_alias(name=alias_name, function_name=function_name)
return 201, {}, "{}"
def _get_alias(self):
def _get_alias(self) -> TYPE_RESPONSE:
function_name = unquote(self.path.rsplit("/")[-3])
alias_name = unquote(self.path.rsplit("/", 2)[-1])
alias = self.backend.get_alias(name=alias_name, function_name=function_name)
return 201, {}, json.dumps(alias.to_json())
def _update_alias(self):
def _update_alias(self) -> TYPE_RESPONSE:
function_name = unquote(self.path.rsplit("/")[-3])
alias_name = unquote(self.path.rsplit("/", 2)[-1])
params = json.loads(self.body)

View File

@ -1,11 +1,12 @@
from collections import namedtuple
from functools import partial
from typing import Any, Callable
ARN = namedtuple("ARN", ["region", "account", "function_name", "version"])
LAYER_ARN = namedtuple("LAYER_ARN", ["region", "account", "layer_name", "version"])
def make_arn(resource_type, region, account, name):
def make_arn(resource_type: str, region: str, account: str, name: str) -> str:
return "arn:aws:lambda:{0}:{1}:{2}:{3}".format(region, account, resource_type, name)
@ -13,7 +14,9 @@ make_function_arn = partial(make_arn, "function")
make_layer_arn = partial(make_arn, "layer")
def make_ver_arn(resource_type, region, account, name, version="1"):
def make_ver_arn(
resource_type: str, region: str, account: str, name: str, version: str = "1"
) -> str:
arn = make_arn(resource_type, region, account, name)
return "{0}:{1}".format(arn, version)
@ -22,7 +25,7 @@ make_function_ver_arn = partial(make_ver_arn, "function")
make_layer_ver_arn = partial(make_ver_arn, "layer")
def split_arn(arn_type, arn):
def split_arn(arn_type: Callable[[str, str, str, str], str], arn: str) -> Any:
arn = arn.replace("arn:aws:lambda:", "")
region, account, _, name, version = arn.split(":")

View File

@ -175,14 +175,14 @@ def str_to_rfc_1123_datetime(value):
return datetime.datetime.strptime(value, RFC1123)
def unix_time(dt: datetime.datetime = None) -> datetime.datetime:
def unix_time(dt: datetime.datetime = None) -> int:
dt = dt or datetime.datetime.utcnow()
epoch = datetime.datetime.utcfromtimestamp(0)
delta = dt - epoch
return (delta.days * 86400) + (delta.seconds + (delta.microseconds / 1e6))
def unix_time_millis(dt=None):
def unix_time_millis(dt: datetime = None) -> int:
return unix_time(dt) * 1000.0

View File

@ -1917,7 +1917,7 @@ class IAMBackend(BaseBackend):
return role
raise IAMNotFoundException("Role {0} not found".format(role_name))
def get_role_by_arn(self, arn):
def get_role_by_arn(self, arn: str) -> Role:
for role in self.get_roles():
if role.arn == arn:
return role

View File

@ -68,37 +68,37 @@ def allow_unknown_region():
return os.environ.get("MOTO_ALLOW_NONEXISTENT_REGION", "false").lower() == "true"
def moto_server_port():
def moto_server_port() -> str:
return os.environ.get("MOTO_PORT") or "5000"
@lru_cache()
def moto_server_host():
def moto_server_host() -> str:
if is_docker():
return get_docker_host()
else:
return "http://host.docker.internal"
def moto_lambda_image():
def moto_lambda_image() -> str:
return os.environ.get("MOTO_DOCKER_LAMBDA_IMAGE", "lambci/lambda")
def moto_network_name():
def moto_network_name() -> str:
return os.environ.get("MOTO_DOCKER_NETWORK_NAME")
def moto_network_mode():
def moto_network_mode() -> str:
return os.environ.get("MOTO_DOCKER_NETWORK_MODE")
def test_server_mode_endpoint():
def test_server_mode_endpoint() -> str:
return os.environ.get(
"TEST_SERVER_MODE_ENDPOINT", f"http://localhost:{moto_server_port()}"
)
def is_docker():
def is_docker() -> bool:
path = pathlib.Path("/proc/self/cgroup")
return (
os.path.exists("/.dockerenv")
@ -107,7 +107,7 @@ def is_docker():
)
def get_docker_host():
def get_docker_host() -> str:
try:
cmd = "curl -s --unix-socket /run/docker.sock http://docker/containers/$HOSTNAME/json"
container_info = os.popen(cmd).read()

View File

@ -1,5 +1,6 @@
import functools
import requests.adapters
from typing import Tuple
from moto import settings
@ -8,7 +9,7 @@ _orig_adapter_send = requests.adapters.HTTPAdapter.send
class DockerModel:
def __init__(self):
def __init__(self) -> None:
self.__docker_client = None
@property
@ -36,7 +37,7 @@ class DockerModel:
return self.__docker_client
def parse_image_ref(image_name):
def parse_image_ref(image_name: str) -> Tuple[str, str]:
# podman does not support short container image name out of box - try to make a full name
# See ParseDockerRef() in https://github.com/distribution/distribution/blob/main/reference/normalize.go
parts = image_name.split("/")

View File

@ -18,7 +18,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy]
files= moto/acm,moto/amp,moto/apigateway,moto/apigatewayv2,moto/applicationautoscaling/,moto/appsync,moto/athena,moto/autoscaling
files= moto/a*
show_column_numbers=True
show_error_codes = True
disable_error_code=abstract