From 56ca48cfdd983711decdb8e175328a33fd84a364 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Tue, 11 Oct 2022 13:16:27 +0000 Subject: [PATCH] TechDebt: Enable MyPy on APIGateway-module (#5549) --- IMPLEMENTATION_COVERAGE.md | 8 +- Makefile | 2 +- docs/docs/services/apigateway.rst | 6 +- moto/apigateway/exceptions.py | 87 +- .../integration_parsers/__init__.py | 7 +- .../integration_parsers/aws_parser.py | 6 +- .../integration_parsers/http_parser.py | 6 +- .../integration_parsers/unknown_parser.py | 7 +- moto/apigateway/models.py | 993 ++++++++++-------- moto/apigateway/responses.py | 168 ++- moto/apigateway/utils.py | 9 +- moto/core/common_models.py | 2 +- moto/core/utils.py | 2 +- moto/utilities/utils.py | 4 +- setup.cfg | 3 +- tests/test_apigateway/test_apigateway.py | 30 +- 16 files changed, 718 insertions(+), 622 deletions(-) diff --git a/IMPLEMENTATION_COVERAGE.md b/IMPLEMENTATION_COVERAGE.md index 1f38f27c0..9bb99f522 100644 --- a/IMPLEMENTATION_COVERAGE.md +++ b/IMPLEMENTATION_COVERAGE.md @@ -49,7 +49,7 @@ ## apigateway
-65% implemented +62% implemented - [X] create_api_key - [X] create_authorizer @@ -157,12 +157,12 @@ - [ ] update_deployment - [ ] update_documentation_part - [ ] update_documentation_version -- [X] update_domain_name +- [ ] update_domain_name - [ ] update_gateway_response - [ ] update_integration - [ ] update_integration_response -- [X] update_method -- [X] update_method_response +- [ ] update_method +- [ ] update_method_response - [ ] update_model - [X] update_request_validator - [ ] update_resource diff --git a/Makefile b/Makefile index 7adc3c1a5..11a3cd6f6 100644 --- a/Makefile +++ b/Makefile @@ -28,7 +28,7 @@ lint: @echo "Running pylint..." pylint -j 0 moto tests @echo "Running MyPy..." - mypy --install-types --non-interactive moto/acm moto/amp moto/applicationautoscaling/ + mypy --install-types --non-interactive format: black moto/ tests/ diff --git a/docs/docs/services/apigateway.rst b/docs/docs/services/apigateway.rst index a78ecf0db..e539a9ed7 100644 --- a/docs/docs/services/apigateway.rst +++ b/docs/docs/services/apigateway.rst @@ -149,12 +149,12 @@ apigateway - [ ] update_deployment - [ ] update_documentation_part - [ ] update_documentation_version -- [X] update_domain_name +- [ ] update_domain_name - [ ] update_gateway_response - [ ] update_integration - [ ] update_integration_response -- [X] update_method -- [X] update_method_response +- [ ] update_method +- [ ] update_method_response - [ ] update_model - [X] update_request_validator - [ ] update_resource diff --git a/moto/apigateway/exceptions.py b/moto/apigateway/exceptions.py index d8373bb8d..759366cfc 100644 --- a/moto/apigateway/exceptions.py +++ b/moto/apigateway/exceptions.py @@ -1,4 +1,5 @@ from moto.core.exceptions import JsonRESTError +from typing import Any class ApiGatewayException(JsonRESTError): @@ -6,12 +7,12 @@ class ApiGatewayException(JsonRESTError): class BadRequestException(ApiGatewayException): - def __init__(self, message): + def __init__(self, message: str): super().__init__("BadRequestException", message) class NotFoundException(ApiGatewayException): - def __init__(self, message): + def __init__(self, message: str): super().__init__("NotFoundException", message) @@ -22,199 +23,199 @@ class AccessDeniedException(ApiGatewayException): class ConflictException(ApiGatewayException): code = 409 - def __init__(self, message): + def __init__(self, message: str): super().__init__("ConflictException", message) class AwsProxyNotAllowed(BadRequestException): - def __init__(self): + def __init__(self) -> None: super().__init__( "Integrations of type 'AWS_PROXY' currently only supports Lambda function and Firehose stream invocations." ) class CrossAccountNotAllowed(AccessDeniedException): - def __init__(self): + def __init__(self) -> None: super().__init__( "AccessDeniedException", "Cross-account pass role is not allowed." ) class RoleNotSpecified(BadRequestException): - def __init__(self): + def __init__(self) -> None: super().__init__("Role ARN must be specified for AWS integrations") class IntegrationMethodNotDefined(BadRequestException): - def __init__(self): + def __init__(self) -> None: super().__init__("Enumeration value for HttpMethod must be non-empty") class InvalidOpenAPIDocumentException(BadRequestException): - def __init__(self, cause): + def __init__(self, cause: Any): super().__init__( f"Failed to parse the uploaded OpenAPI document due to: {cause.message}" ) class InvalidOpenApiDocVersionException(BadRequestException): - def __init__(self): + def __init__(self) -> None: super().__init__("Only OpenAPI 3.x.x are currently supported") class InvalidOpenApiModeException(BadRequestException): - def __init__(self): + def __init__(self) -> None: super().__init__( 'Enumeration value of OpenAPI import mode must be "overwrite" or "merge"', ) class InvalidResourcePathException(BadRequestException): - def __init__(self): + def __init__(self) -> None: super().__init__( "Resource's path part only allow a-zA-Z0-9._- and curly braces at the beginning and the end and an optional plus sign before the closing brace." ) class InvalidHttpEndpoint(BadRequestException): - def __init__(self): + def __init__(self) -> None: super().__init__("Invalid HTTP endpoint specified for URI") class InvalidArn(BadRequestException): - def __init__(self): + def __init__(self) -> None: super().__init__("Invalid ARN specified in the request") class InvalidIntegrationArn(BadRequestException): - def __init__(self): + def __init__(self) -> None: super().__init__("AWS ARN for integration must contain path or action") class InvalidRequestInput(BadRequestException): - def __init__(self): + def __init__(self) -> None: super().__init__("Invalid request input") class NoIntegrationDefined(NotFoundException): - def __init__(self): + def __init__(self) -> None: super().__init__("No integration defined for method") class NoIntegrationResponseDefined(NotFoundException): code = 404 - def __init__(self): + def __init__(self) -> None: super().__init__("Invalid Response status code specified") class NoMethodDefined(BadRequestException): - def __init__(self): + def __init__(self) -> None: super().__init__("The REST API doesn't contain any methods") class AuthorizerNotFoundException(NotFoundException): code = 404 - def __init__(self): + def __init__(self) -> None: super().__init__("Invalid Authorizer identifier specified") class StageNotFoundException(NotFoundException): code = 404 - def __init__(self): + def __init__(self) -> None: super().__init__("Invalid stage identifier specified") class ApiKeyNotFoundException(NotFoundException): code = 404 - def __init__(self): + def __init__(self) -> None: super().__init__("Invalid API Key identifier specified") class UsagePlanNotFoundException(NotFoundException): code = 404 - def __init__(self): + def __init__(self) -> None: super().__init__("Invalid Usage Plan ID specified") class ApiKeyAlreadyExists(ApiGatewayException): code = 409 - def __init__(self): + def __init__(self) -> None: super().__init__("ConflictException", "API Key already exists") class InvalidDomainName(BadRequestException): code = 404 - def __init__(self): + def __init__(self) -> None: super().__init__("No Domain Name specified") class DomainNameNotFound(NotFoundException): code = 404 - def __init__(self): + def __init__(self) -> None: super().__init__("Invalid domain name identifier specified") class InvalidRestApiId(BadRequestException): code = 404 - def __init__(self): + def __init__(self) -> None: super().__init__("No Rest API Id specified") class InvalidModelName(BadRequestException): code = 404 - def __init__(self): + def __init__(self) -> None: super().__init__("No Model Name specified") class RestAPINotFound(NotFoundException): code = 404 - def __init__(self): + def __init__(self) -> None: super().__init__("Invalid Rest API Id specified") class RequestValidatorNotFound(BadRequestException): code = 400 - def __init__(self): + def __init__(self) -> None: super().__init__("Invalid Request Validator Id specified") class ModelNotFound(NotFoundException): code = 404 - def __init__(self): + def __init__(self) -> None: super().__init__("Invalid Model Name specified") class ApiKeyValueMinLength(BadRequestException): code = 400 - def __init__(self): + def __init__(self) -> None: super().__init__("API Key value should be at least 20 characters") class MethodNotFoundException(NotFoundException): code = 404 - def __init__(self): + def __init__(self) -> None: super().__init__("Invalid Method identifier specified") class InvalidBasePathException(BadRequestException): code = 400 - def __init__(self): + def __init__(self) -> None: super().__init__( "API Gateway V1 doesn't support the slash character (/) in base path mappings. " "To create a multi-level base path mapping, use API Gateway V2." @@ -222,64 +223,64 @@ class InvalidBasePathException(BadRequestException): class DeploymentNotFoundException(NotFoundException): - def __init__(self): + def __init__(self) -> None: super().__init__("Invalid Deployment identifier specified") class InvalidRestApiIdForBasePathMappingException(BadRequestException): code = 400 - def __init__(self): + def __init__(self) -> None: super().__init__("Invalid REST API identifier specified") class InvalidStageException(BadRequestException): code = 400 - def __init__(self): + def __init__(self) -> None: super().__init__("Invalid stage identifier specified") class BasePathConflictException(ConflictException): - def __init__(self): + def __init__(self) -> None: super().__init__("Base path already exists for this domain name") class BasePathNotFoundException(NotFoundException): code = 404 - def __init__(self): + def __init__(self) -> None: super().__init__("Invalid base path mapping identifier specified") class ResourceIdNotFoundException(NotFoundException): code = 404 - def __init__(self): + def __init__(self) -> None: super().__init__("Invalid resource identifier specified") class VpcLinkNotFound(NotFoundException): code = 404 - def __init__(self): + def __init__(self) -> None: super().__init__("VPCLink not found") class ValidationException(ApiGatewayException): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__("ValidationException", message) class StageStillActive(BadRequestException): - def __init__(self): + def __init__(self) -> None: super().__init__( "Active stages pointing to this deployment must be moved or deleted" ) class GatewayResponseNotFound(NotFoundException): - def __init__(self): + def __init__(self) -> None: super().__init__("GatewayResponse not found") diff --git a/moto/apigateway/integration_parsers/__init__.py b/moto/apigateway/integration_parsers/__init__.py index f9fa1db79..a185d4eca 100644 --- a/moto/apigateway/integration_parsers/__init__.py +++ b/moto/apigateway/integration_parsers/__init__.py @@ -1,7 +1,12 @@ import abc +from typing import Tuple, Union +from requests.models import PreparedRequest +from ..models import Integration class IntegrationParser: @abc.abstractmethod - def invoke(self, request, integration): + def invoke( + self, request: PreparedRequest, integration: Integration + ) -> Tuple[int, Union[str, bytes]]: pass diff --git a/moto/apigateway/integration_parsers/aws_parser.py b/moto/apigateway/integration_parsers/aws_parser.py index 2aa644197..50404cd3a 100644 --- a/moto/apigateway/integration_parsers/aws_parser.py +++ b/moto/apigateway/integration_parsers/aws_parser.py @@ -1,10 +1,14 @@ import requests from . import IntegrationParser +from ..models import Integration +from typing import Tuple, Union class TypeAwsParser(IntegrationParser): - def invoke(self, request, integration): + def invoke( + self, request: requests.PreparedRequest, integration: Integration + ) -> Tuple[int, Union[str, bytes]]: # integration.uri = arn:aws:apigateway:{region}:{subdomain.service|service}:path|action/{service_api} # example value = 'arn:aws:apigateway:us-west-2:dynamodb:action/PutItem' try: diff --git a/moto/apigateway/integration_parsers/http_parser.py b/moto/apigateway/integration_parsers/http_parser.py index 574558465..d43bc70fb 100644 --- a/moto/apigateway/integration_parsers/http_parser.py +++ b/moto/apigateway/integration_parsers/http_parser.py @@ -1,6 +1,8 @@ import requests +from typing import Tuple, Union from . import IntegrationParser +from ..models import Integration class TypeHttpParser(IntegrationParser): @@ -8,7 +10,9 @@ class TypeHttpParser(IntegrationParser): Parse invocations to a APIGateway resource with integration type HTTP """ - def invoke(self, request, integration): + def invoke( + self, request: requests.PreparedRequest, integration: Integration + ) -> Tuple[int, Union[str, bytes]]: uri = integration["uri"] requests_func = getattr(requests, integration["httpMethod"].lower()) response = requests_func(uri) diff --git a/moto/apigateway/integration_parsers/unknown_parser.py b/moto/apigateway/integration_parsers/unknown_parser.py index c008daff2..8de33ff26 100644 --- a/moto/apigateway/integration_parsers/unknown_parser.py +++ b/moto/apigateway/integration_parsers/unknown_parser.py @@ -1,4 +1,7 @@ +import requests +from typing import Tuple, Union from . import IntegrationParser +from ..models import Integration class TypeUnknownParser(IntegrationParser): @@ -6,6 +9,8 @@ class TypeUnknownParser(IntegrationParser): Parse invocations to a APIGateway resource with an unknown integration type """ - def invoke(self, request, integration): + def invoke( + self, request: requests.PreparedRequest, integration: Integration + ) -> Tuple[int, Union[str, bytes]]: _type = integration["type"] raise NotImplementedError("The {0} type has not been implemented".format(_type)) diff --git a/moto/apigateway/models.py b/moto/apigateway/models.py index 544174e7e..259bcdc83 100644 --- a/moto/apigateway/models.py +++ b/moto/apigateway/models.py @@ -2,14 +2,14 @@ from __future__ import absolute_import import string import re +import responses +import requests +import time from collections import defaultdict from copy import copy - from openapi_spec_validator import validate_spec -import time - +from typing import Any, Dict, List, Optional, Tuple, Union from urllib.parse import urlparse -import responses try: from openapi_spec_validator.validation.exceptions import OpenAPIValidationError @@ -19,9 +19,6 @@ except ImportError: from moto.core import BaseBackend, BaseModel, CloudFormationModel from .utils import create_id, to_path from moto.core.utils import path_url, BackendDict -from .integration_parsers.aws_parser import TypeAwsParser -from .integration_parsers.http_parser import TypeHttpParser -from .integration_parsers.unknown_parser import TypeUnknownParser from .exceptions import ( ConflictException, DeploymentNotFoundException, @@ -70,8 +67,8 @@ from moto.moto_api._internal import mock_random as random STAGE_URL = "https://{api_id}.execute-api.{region_name}.amazonaws.com/{stage_name}" -class Deployment(CloudFormationModel, dict): - def __init__(self, deployment_id, name, description=""): +class Deployment(CloudFormationModel, dict): # type: ignore[type-arg] + def __init__(self, deployment_id: str, name: str, description: str = ""): super().__init__() self["id"] = deployment_id self["stageName"] = name @@ -79,34 +76,39 @@ class Deployment(CloudFormationModel, dict): self["createdDate"] = int(time.time()) @staticmethod - def cloudformation_name_type(): + def cloudformation_name_type() -> str: return "Deployment" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: return "AWS::ApiGateway::Deployment" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Dict[str, Any], + account_id: str, + region_name: str, + **kwargs: Any + ) -> "Deployment": properties = cloudformation_json["Properties"] rest_api_id = properties["RestApiId"] name = properties.get("StageName") desc = properties.get("Description", "") - backend = apigateway_backends[account_id][region_name] + backend: "APIGatewayBackend" = apigateway_backends[account_id][region_name] return backend.create_deployment( function_id=rest_api_id, name=name, description=desc ) -class IntegrationResponse(BaseModel, dict): +class IntegrationResponse(BaseModel, dict): # type: ignore[type-arg] def __init__( self, - status_code, - selection_pattern=None, - response_templates=None, - content_handling=None, + status_code: Union[str, int], + selection_pattern: Optional[str] = None, + response_templates: Optional[Dict[str, Any]] = None, + content_handling: Optional[Any] = None, ): if response_templates is None: # response_templates = {"application/json": None} # Note: removed for compatibility with TF @@ -123,19 +125,19 @@ class IntegrationResponse(BaseModel, dict): self["contentHandling"] = content_handling -class Integration(BaseModel, dict): +class Integration(BaseModel, dict): # type: ignore[type-arg] def __init__( self, - integration_type, - uri, - http_method, - request_templates=None, - passthrough_behavior="WHEN_NO_MATCH", - cache_key_parameters=None, - tls_config=None, - cache_namespace=None, - timeout_in_millis=None, - request_parameters=None, + integration_type: str, + uri: str, + http_method: str, + request_templates: Optional[Dict[str, Any]] = None, + passthrough_behavior: Optional[str] = "WHEN_NO_MATCH", + cache_key_parameters: Optional[str] = None, + tls_config: Optional[str] = None, + cache_namespace: Optional[str] = None, + timeout_in_millis: Optional[str] = None, + request_parameters: Optional[Dict[str, Any]] = None, ): super().__init__() self["type"] = integration_type @@ -154,38 +156,47 @@ class Integration(BaseModel, dict): self["requestParameters"] = request_parameters def create_integration_response( - self, status_code, selection_pattern, response_templates, content_handling - ): - if response_templates == {}: - response_templates = None + self, + status_code: str, + selection_pattern: str, + response_templates: Dict[str, str], + content_handling: str, + ) -> IntegrationResponse: integration_response = IntegrationResponse( - status_code, selection_pattern, response_templates, content_handling + status_code, selection_pattern, response_templates or None, content_handling ) if self.get("integrationResponses") is None: self["integrationResponses"] = {} self["integrationResponses"][status_code] = integration_response return integration_response - def get_integration_response(self, status_code): + def get_integration_response(self, status_code: str) -> IntegrationResponse: result = self.get("integrationResponses", {}).get(status_code) if not result: raise NoIntegrationResponseDefined() return result - def delete_integration_response(self, status_code): + def delete_integration_response(self, status_code: str) -> IntegrationResponse: return self.get("integrationResponses", {}).pop(status_code, None) -class MethodResponse(BaseModel, dict): - def __init__(self, status_code, response_models=None, response_parameters=None): +class MethodResponse(BaseModel, dict): # type: ignore[type-arg] + def __init__( + self, + status_code: str, + response_models: Dict[str, str], + response_parameters: Dict[str, Dict[str, str]], + ): super().__init__() self["statusCode"] = status_code self["responseModels"] = response_models self["responseParameters"] = response_parameters -class Method(CloudFormationModel, dict): - def __init__(self, method_type, authorization_type, **kwargs): +class Method(CloudFormationModel, dict): # type: ignore[type-arg] + def __init__( + self, method_type: str, authorization_type: Optional[str], **kwargs: Any + ): super().__init__() self.update( dict( @@ -204,17 +215,22 @@ class Method(CloudFormationModel, dict): self["methodResponses"] = {} @staticmethod - def cloudformation_name_type(): + def cloudformation_name_type() -> str: return "Method" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: return "AWS::ApiGateway::Method" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Dict[str, Any], + account_id: str, + region_name: str, + **kwargs: Any + ) -> "Method": properties = cloudformation_json["Properties"] rest_api_id = properties["RestApiId"] resource_id = properties["ResourceId"] @@ -242,23 +258,34 @@ class Method(CloudFormationModel, dict): ) return m - def create_response(self, response_code, response_models, response_parameters): + def create_response( + self, + response_code: str, + response_models: Dict[str, str], + response_parameters: Dict[str, Dict[str, str]], + ) -> MethodResponse: method_response = MethodResponse( response_code, response_models, response_parameters ) self["methodResponses"][response_code] = method_response return method_response - def get_response(self, response_code): + def get_response(self, response_code: str) -> Optional[MethodResponse]: return self["methodResponses"].get(response_code) - def delete_response(self, response_code): + def delete_response(self, response_code: str) -> Optional[MethodResponse]: return self["methodResponses"].pop(response_code, None) class Resource(CloudFormationModel): def __init__( - self, resource_id, account_id, region_name, api_id, path_part, parent_id + self, + resource_id: str, + account_id: str, + region_name: str, + api_id: str, + path_part: str, + parent_id: Optional[str], ): super().__init__() self.id = resource_id @@ -267,13 +294,20 @@ class Resource(CloudFormationModel): self.api_id = api_id self.path_part = path_part self.parent_id = parent_id - self.resource_methods = {} - self.integration_parsers = defaultdict(TypeUnknownParser) + self.resource_methods: Dict[str, Method] = {} + from .integration_parsers import IntegrationParser + from .integration_parsers.aws_parser import TypeAwsParser + from .integration_parsers.http_parser import TypeHttpParser + from .integration_parsers.unknown_parser import TypeUnknownParser + + self.integration_parsers: Dict[str, IntegrationParser] = defaultdict( + TypeUnknownParser + ) self.integration_parsers["HTTP"] = TypeHttpParser() self.integration_parsers["AWS"] = TypeAwsParser() - def to_dict(self): - response = { + def to_dict(self) -> Dict[str, Any]: + response: Dict[str, Any] = { "path": self.get_path(), "id": self.id, } @@ -285,21 +319,26 @@ class Resource(CloudFormationModel): return response @property - def physical_resource_id(self): + def physical_resource_id(self) -> str: return self.id @staticmethod - def cloudformation_name_type(): + def cloudformation_name_type() -> str: return "Resource" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: return "AWS::ApiGateway::Resource" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Dict[str, Any], + account_id: str, + region_name: str, + **kwargs: Any + ) -> "Resource": properties = cloudformation_json["Properties"] api_id = properties["RestApiId"] parent = properties["ParentId"] @@ -317,10 +356,10 @@ class Resource(CloudFormationModel): function_id=api_id, parent_resource_id=parent, path_part=path ) - def get_path(self): + def get_path(self) -> str: return self.get_parent_path() + self.path_part - def get_parent_path(self): + def get_parent_path(self) -> str: if self.parent_id: backend = apigateway_backends[self.account_id][self.region_name] parent = backend.get_resource(self.api_id, self.parent_id) @@ -331,8 +370,10 @@ class Resource(CloudFormationModel): else: return "" - def get_response(self, request): - integration = self.get_integration(request.method) + def get_response( + self, request: requests.PreparedRequest + ) -> Tuple[int, Union[str, bytes]]: + integration = self.get_integration(str(request.method)) integration_type = integration["type"] status, result = self.integration_parsers[integration_type].invoke( @@ -343,16 +384,16 @@ class Resource(CloudFormationModel): def add_method( self, - method_type, - authorization_type, - api_key_required, - request_parameters=None, - request_models=None, - operation_name=None, - authorizer_id=None, - authorization_scopes=None, - request_validator_id=None, - ): + method_type: str, + authorization_type: Optional[str], + api_key_required: Optional[bool], + request_parameters: Any = None, + request_models: Any = None, + operation_name: Optional[str] = None, + authorizer_id: Optional[str] = None, + authorization_scopes: Any = None, + request_validator_id: Any = None, + ) -> Method: if authorization_scopes and not isinstance(authorization_scopes, list): authorization_scopes = [authorization_scopes] method = Method( @@ -369,28 +410,28 @@ class Resource(CloudFormationModel): self.resource_methods[method_type] = method return method - def get_method(self, method_type): + def get_method(self, method_type: str) -> Method: method = self.resource_methods.get(method_type) if not method: raise MethodNotFoundException() return method - def delete_method(self, method_type): + def delete_method(self, method_type: str) -> None: self.resource_methods.pop(method_type, None) def add_integration( self, - method_type, - integration_type, - uri, - request_templates=None, - passthrough_behavior=None, - integration_method=None, - tls_config=None, - cache_namespace=None, - timeout_in_millis=None, - request_parameters=None, - ): + method_type: str, + integration_type: str, + uri: str, + request_templates: Optional[Dict[str, Any]] = None, + passthrough_behavior: Optional[str] = None, + integration_method: Optional[str] = None, + tls_config: Optional[str] = None, + cache_namespace: Optional[str] = None, + timeout_in_millis: Optional[str] = None, + request_parameters: Optional[Dict[str, Any]] = None, + ) -> Integration: integration_method = integration_method or method_type integration = Integration( integration_type, @@ -406,15 +447,24 @@ class Resource(CloudFormationModel): self.resource_methods[method_type]["methodIntegration"] = integration return integration - def get_integration(self, method_type): - return self.resource_methods.get(method_type, {}).get("methodIntegration", {}) + def get_integration(self, method_type: str) -> Integration: + method: Dict[str, Integration] = dict( + self.resource_methods.get(method_type, {}) + ) + return method.get("methodIntegration") or {} # type: ignore[return-value] - def delete_integration(self, method_type): + def delete_integration(self, method_type: str) -> Integration: return self.resource_methods[method_type].pop("methodIntegration") -class Authorizer(BaseModel, dict): - def __init__(self, authorizer_id, name, authorizer_type, **kwargs): +class Authorizer(BaseModel, dict): # type: ignore[type-arg] + def __init__( + self, + authorizer_id: Optional[str], + name: Optional[str], + authorizer_type: Optional[str], + **kwargs: Any + ): super().__init__() self["id"] = authorizer_id self["name"] = name @@ -435,7 +485,7 @@ class Authorizer(BaseModel, dict): ) self["authorizerResultTtlInSeconds"] = kwargs.get("authorizer_result_ttl") - def apply_operations(self, patch_operations): + def apply_operations(self, patch_operations: List[Dict[str, Any]]) -> "Authorizer": for op in patch_operations: if "/authorizerUri" in op["path"]: self["authorizerUri"] = op["value"] @@ -461,25 +511,23 @@ class Authorizer(BaseModel, dict): return self -class Stage(BaseModel, dict): +class Stage(BaseModel, dict): # type: ignore[type-arg] def __init__( self, - name=None, - deployment_id=None, - variables=None, - description="", - cacheClusterEnabled=False, - cacheClusterSize=None, - tags=None, - tracing_enabled=None, + name: Optional[str] = None, + deployment_id: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + description: str = "", + cacheClusterEnabled: Optional[bool] = False, + cacheClusterSize: Optional[str] = None, + tags: Optional[List[Dict[str, str]]] = None, + tracing_enabled: Optional[bool] = None, ): super().__init__() - if variables is None: - variables = {} self["stageName"] = name self["deploymentId"] = deployment_id self["methodSettings"] = {} - self["variables"] = variables + self["variables"] = variables or {} self["description"] = description self["cacheClusterEnabled"] = cacheClusterEnabled if self["cacheClusterEnabled"]: @@ -492,7 +540,7 @@ class Stage(BaseModel, dict): if tracing_enabled is not None: self["tracingEnabled"] = tracing_enabled - def apply_operations(self, patch_operations): + def apply_operations(self, patch_operations: List[Dict[str, Any]]) -> "Stage": for op in patch_operations: if "variables/" in op["path"]: self._apply_operation_to_variables(op) @@ -533,7 +581,9 @@ class Stage(BaseModel, dict): ) return self - def _patch_method_setting(self, resource_path_and_method, key, value): + def _patch_method_setting( + self, resource_path_and_method: str, key: str, value: str + ) -> None: updated_key = self._method_settings_translations(key) if updated_key is not None: if resource_path_and_method not in self["methodSettings"]: @@ -544,7 +594,7 @@ class Stage(BaseModel, dict): updated_key ] = self._convert_to_type(updated_key, value) - def _get_default_method_settings(self): + def _get_default_method_settings(self) -> Dict[str, Any]: return { "throttlingRateLimit": 1000.0, "dataTraceEnabled": False, @@ -557,7 +607,7 @@ class Stage(BaseModel, dict): "requireAuthorizationForCacheControl": True, } - def _method_settings_translations(self, key): + def _method_settings_translations(self, key: str) -> Optional[str]: mappings = { "metrics/enabled": "metricsEnabled", "logging/loglevel": "loggingLevel", @@ -573,10 +623,10 @@ class Stage(BaseModel, dict): return mappings.get(key) - def _str2bool(self, v): + def _str2bool(self, v: str) -> bool: return v.lower() == "true" - def _convert_to_type(self, key, val): + def _convert_to_type(self, key: str, val: str) -> Union[str, int, float]: type_mappings = { "metricsEnabled": "bool", "loggingLevel": "str", @@ -604,7 +654,7 @@ class Stage(BaseModel, dict): else: return str(val) - def _apply_operation_to_variables(self, op): + def _apply_operation_to_variables(self, op: Dict[str, Any]) -> None: key = op["path"][op["path"].rindex("variables/") + 10 :] if op["op"] == "remove": self["variables"].pop(key, None) @@ -614,17 +664,17 @@ class Stage(BaseModel, dict): raise Exception('Patch operation "%s" not implemented' % op["op"]) -class ApiKey(BaseModel, dict): +class ApiKey(BaseModel, dict): # type: ignore[type-arg] def __init__( self, - name=None, - description=None, - enabled=False, - generateDistinctId=False, # pylint: disable=unused-argument - value=None, - stageKeys=None, - tags=None, - customerId=None, + name: Optional[str] = None, + description: Optional[str] = None, + enabled: bool = False, + generateDistinctId: bool = False, # pylint: disable=unused-argument + value: Optional[str] = None, + stageKeys: Optional[Any] = None, + tags: Optional[List[Dict[str, str]]] = None, + customerId: Optional[str] = None, ): super().__init__() self["id"] = create_id() @@ -639,7 +689,7 @@ class ApiKey(BaseModel, dict): self["stageKeys"] = stageKeys or [] self["tags"] = tags - def update_operations(self, patch_operations): + def update_operations(self, patch_operations: List[Dict[str, Any]]) -> "ApiKey": for op in patch_operations: if op["op"] == "replace": if "/name" in op["path"]: @@ -654,20 +704,20 @@ class ApiKey(BaseModel, dict): raise Exception('Patch operation "%s" not implemented' % op["op"]) return self - def _str2bool(self, v): + def _str2bool(self, v: str) -> bool: return v.lower() == "true" -class UsagePlan(BaseModel, dict): +class UsagePlan(BaseModel, dict): # type: ignore[type-arg] def __init__( self, - name=None, - description=None, - apiStages=None, - throttle=None, - quota=None, - productCode=None, - tags=None, + name: Optional[str] = None, + description: Optional[str] = None, + apiStages: Any = None, + throttle: Optional[str] = None, + quota: Optional[str] = None, + productCode: Optional[str] = None, + tags: Optional[List[Dict[str, str]]] = None, ): super().__init__() self["id"] = create_id() @@ -679,7 +729,7 @@ class UsagePlan(BaseModel, dict): self["productCode"] = productCode self["tags"] = tags - def apply_patch_operations(self, patch_operations): + def apply_patch_operations(self, patch_operations: List[Dict[str, Any]]) -> None: for op in patch_operations: path = op["path"] value = op["value"] @@ -700,7 +750,7 @@ class UsagePlan(BaseModel, dict): self["throttle"]["burstLimit"] = value -class RequestValidator(BaseModel, dict): +class RequestValidator(BaseModel, dict): # type: ignore[type-arg] PROP_ID = "id" PROP_NAME = "name" PROP_VALIDATE_REQUEST_BODY = "validateRequestBody" @@ -712,7 +762,13 @@ class RequestValidator(BaseModel, dict): OP_REPLACE = "replace" OP_OP = "op" - def __init__(self, _id, name, validateRequestBody, validateRequestParameters): + def __init__( + self, + _id: str, + name: str, + validateRequestBody: Optional[bool], + validateRequestParameters: Any, + ): super().__init__() self[RequestValidator.PROP_ID] = _id self[RequestValidator.PROP_NAME] = name @@ -721,7 +777,7 @@ class RequestValidator(BaseModel, dict): RequestValidator.PROP_VALIDATE_REQUEST_PARAMETERS ] = validateRequestParameters - def apply_patch_operations(self, operations): + def apply_patch_operations(self, operations: List[Dict[str, Any]]) -> None: for operation in operations: path = operation[RequestValidator.OP_PATH] value = operation[RequestValidator.OP_VALUE] @@ -737,7 +793,7 @@ class RequestValidator(BaseModel, dict): RequestValidator.PROP_VALIDATE_REQUEST_PARAMETERS ] = value.lower() in ("true") - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return { "id": self["id"], "name": self["name"], @@ -746,8 +802,8 @@ class RequestValidator(BaseModel, dict): } -class UsagePlanKey(BaseModel, dict): - def __init__(self, plan_id, plan_type, name, value): +class UsagePlanKey(BaseModel, dict): # type: ignore[type-arg] + def __init__(self, plan_id: Dict[str, Any], plan_type: str, name: str, value: str): super().__init__() self["id"] = plan_id self["name"] = name @@ -755,8 +811,14 @@ class UsagePlanKey(BaseModel, dict): self["value"] = value -class VpcLink(BaseModel, dict): - def __init__(self, name, description, target_arns, tags): +class VpcLink(BaseModel, dict): # type: ignore[type-arg] + def __init__( + self, + name: str, + description: str, + target_arns: List[str], + tags: List[Dict[str, str]], + ): super().__init__() self["id"] = create_id() self["name"] = name @@ -789,7 +851,15 @@ class RestAPI(CloudFormationModel): OPERATION_VALUE = "value" OPERATION_OP = "op" - def __init__(self, api_id, account_id, region_name, name, description, **kwargs): + def __init__( + self, + api_id: str, + account_id: str, + region_name: str, + name: str, + description: str, + **kwargs: Any + ): super().__init__() self.id = api_id self.account_id = account_id @@ -809,19 +879,19 @@ class RestAPI(CloudFormationModel): kwargs.get(RestAPI.PROP_DISABLE_EXECUTE_API_ENDPOINT) or False ) self.minimum_compression_size = kwargs.get("minimum_compression_size") - self.deployments = {} - self.authorizers = {} - self.gateway_responses = {} - self.stages = {} - self.resources = {} - self.models = {} - self.request_validators = {} + self.deployments: Dict[str, Deployment] = {} + self.authorizers: Dict[str, Authorizer] = {} + self.gateway_responses: Dict[str, GatewayResponse] = {} + self.stages: Dict[str, Stage] = {} + self.resources: Dict[str, Resource] = {} + self.models: Dict[str, Model] = {} + self.request_validators: Dict[str, RequestValidator] = {} self.default = self.add_child("/") # Add default child - def __repr__(self): + def __repr__(self) -> str: return str(self.id) - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return { self.PROP_ID: self.id, self.PROP_NAME: self.name, @@ -837,8 +907,7 @@ class RestAPI(CloudFormationModel): self.PROP_MINIMUM_COMPRESSION_SIZE: self.minimum_compression_size, } - def apply_patch_operations(self, patch_operations): - + def apply_patch_operations(self, patch_operations: List[Dict[str, Any]]) -> None: for op in patch_operations: path = op[self.OPERATION_PATH] value = "" @@ -866,10 +935,10 @@ class RestAPI(CloudFormationModel): self.description = "" @classmethod - def has_cfn_attr(cls, attr): + def has_cfn_attr(cls, attr: str) -> bool: return attr in ["RootResourceId"] - def get_cfn_attribute(self, attribute_name): + def get_cfn_attribute(self, attribute_name: str) -> Any: from moto.cloudformation.exceptions import UnformattedGetAttTemplateException if attribute_name == "RootResourceId": @@ -880,21 +949,26 @@ class RestAPI(CloudFormationModel): raise UnformattedGetAttTemplateException() @property - def physical_resource_id(self): + def physical_resource_id(self) -> str: return self.id @staticmethod - def cloudformation_name_type(): + def cloudformation_name_type() -> str: return "RestApi" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: return "AWS::ApiGateway::RestApi" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Dict[str, Any], + account_id: str, + region_name: str, + **kwargs: Any + ) -> "RestAPI": properties = cloudformation_json["Properties"] name = properties["Name"] desc = properties.get("Description", "") @@ -904,7 +978,7 @@ class RestAPI(CloudFormationModel): name=name, description=desc, endpoint_configuration=config ) - def add_child(self, path, parent_id=None): + def add_child(self, path: str, parent_id: Optional[str] = None) -> Resource: child_id = create_id() child = Resource( resource_id=child_id, @@ -919,13 +993,11 @@ class RestAPI(CloudFormationModel): def add_model( self, - name, - description=None, - schema=None, - content_type=None, - cli_input_json=None, - generate_cli_skeleton=None, - ): + name: str, + description: str, + schema: str, + content_type: str, + ) -> "Model": model_id = create_id() new_model = Model( model_id=model_id, @@ -933,20 +1005,20 @@ class RestAPI(CloudFormationModel): description=description, schema=schema, content_type=content_type, - cli_input_json=cli_input_json, - generate_cli_skeleton=generate_cli_skeleton, ) self.models[name] = new_model return new_model - def get_resource_for_path(self, path_after_stage_name): + def get_resource_for_path(self, path_after_stage_name: str) -> Resource: # type: ignore[return] for resource in self.resources.values(): if resource.get_path() == path_after_stage_name: return resource # TODO deal with no matching resource - def resource_callback(self, request): + def resource_callback( + self, request: Any + ) -> Tuple[int, Dict[str, str], Union[str, bytes]]: path = path_url(request.url) path_after_stage_name = "/" + "/".join(path.split("/")[2:]) @@ -954,7 +1026,7 @@ class RestAPI(CloudFormationModel): status_code, response = resource.get_response(request) return status_code, {}, response - def update_integration_mocks(self, stage_name): + def update_integration_mocks(self, stage_name: str) -> None: stage_url_lower = STAGE_URL.format( api_id=self.id.lower(), region_name=self.region_name, stage_name=stage_name ) @@ -978,17 +1050,17 @@ class RestAPI(CloudFormationModel): def create_authorizer( self, - authorizer_id, - name, - authorizer_type, - provider_arns=None, - auth_type=None, - authorizer_uri=None, - authorizer_credentials=None, - identity_source=None, - identiy_validation_expression=None, - authorizer_result_ttl=None, - ): + authorizer_id: str, + name: str, + authorizer_type: str, + provider_arns: Optional[List[str]], + auth_type: Optional[str], + authorizer_uri: Optional[str], + authorizer_credentials: Optional[str], + identity_source: Optional[str], + identiy_validation_expression: Optional[str], + authorizer_result_ttl: Optional[int], + ) -> Authorizer: authorizer = Authorizer( authorizer_id=authorizer_id, name=name, @@ -1006,15 +1078,15 @@ class RestAPI(CloudFormationModel): def create_stage( self, - name, - deployment_id, - variables=None, - description="", - cacheClusterEnabled=None, - cacheClusterSize=None, - tags=None, - tracing_enabled=None, - ): + name: str, + deployment_id: str, + variables: Any, + description: str, + cacheClusterEnabled: Optional[bool], + cacheClusterSize: Optional[str], + tags: Optional[List[Dict[str, str]]], + tracing_enabled: Optional[bool], + ) -> Stage: if name in self.stages: raise ConflictException("Stage already exists") if variables is None: @@ -1033,7 +1105,9 @@ class RestAPI(CloudFormationModel): self.update_integration_mocks(name) return stage - def create_deployment(self, name, description="", stage_variables=None): + def create_deployment( + self, name: str, description: str, stage_variables: Any = None + ) -> Deployment: if stage_variables is None: stage_variables = {} deployment_id = create_id() @@ -1047,19 +1121,19 @@ class RestAPI(CloudFormationModel): return deployment - def get_deployment(self, deployment_id): + def get_deployment(self, deployment_id: str) -> Deployment: return self.deployments[deployment_id] - def get_authorizers(self): + def get_authorizers(self) -> List[Authorizer]: return list(self.authorizers.values()) - def get_stages(self): + def get_stages(self) -> List[Stage]: return list(self.stages.values()) - def get_deployments(self): + def get_deployments(self) -> List[Deployment]: return list(self.deployments.values()) - def delete_deployment(self, deployment_id): + def delete_deployment(self, deployment_id: str) -> Deployment: if deployment_id not in self.deployments: raise DeploymentNotFoundException() deployment = self.deployments[deployment_id] @@ -1070,8 +1144,11 @@ class RestAPI(CloudFormationModel): return self.deployments.pop(deployment_id) def create_request_validator( - self, name, validateRequestBody, validateRequestParameters - ): + self, + name: str, + validateRequestBody: Optional[bool], + validateRequestParameters: Any, + ) -> RequestValidator: validator_id = create_id() request_validator = RequestValidator( _id=validator_id, @@ -1082,26 +1159,31 @@ class RestAPI(CloudFormationModel): self.request_validators[validator_id] = request_validator return request_validator - def get_request_validators(self): + def get_request_validators(self) -> List[RequestValidator]: return list(self.request_validators.values()) - def get_request_validator(self, validator_id): + def get_request_validator(self, validator_id: str) -> RequestValidator: reqeust_validator = self.request_validators.get(validator_id) if reqeust_validator is None: raise RequestValidatorNotFound() return reqeust_validator - def delete_request_validator(self, validator_id): - reqeust_validator = self.request_validators.pop(validator_id) - return reqeust_validator + def delete_request_validator(self, validator_id: str) -> RequestValidator: + return self.request_validators.pop(validator_id) - def update_request_validator(self, validator_id, patch_operations): + def update_request_validator( + self, validator_id: str, patch_operations: List[Dict[str, Any]] + ) -> RequestValidator: self.request_validators[validator_id].apply_patch_operations(patch_operations) return self.request_validators[validator_id] def put_gateway_response( - self, response_type, status_code, response_parameters, response_templates - ): + self, + response_type: str, + status_code: int, + response_parameters: Dict[str, Any], + response_templates: Dict[str, str], + ) -> "GatewayResponse": response = GatewayResponse( response_type=response_type, status_code=status_code, @@ -1111,20 +1193,20 @@ class RestAPI(CloudFormationModel): self.gateway_responses[response_type] = response return response - def get_gateway_response(self, response_type): + def get_gateway_response(self, response_type: str) -> "GatewayResponse": if response_type not in self.gateway_responses: raise GatewayResponseNotFound() return self.gateway_responses[response_type] - def get_gateway_responses(self): + def get_gateway_responses(self) -> List["GatewayResponse"]: return list(self.gateway_responses.values()) - def delete_gateway_response(self, response_type): + def delete_gateway_response(self, response_type: str) -> None: self.gateway_responses.pop(response_type, None) -class DomainName(BaseModel, dict): - def __init__(self, domain_name, **kwargs): +class DomainName(BaseModel, dict): # type: ignore[type-arg] + def __init__(self, domain_name: str, **kwargs: Any): super().__init__() self["domainName"] = domain_name self["regionalDomainName"] = "d-%s.execute-api.%s.amazonaws.com" % ( @@ -1157,12 +1239,10 @@ class DomainName(BaseModel, dict): self["regionalCertificateArn"] = kwargs.get("regional_certificate_arn") if kwargs.get("endpoint_configuration"): self["endpointConfiguration"] = kwargs.get("endpoint_configuration") - if kwargs.get("generate_cli_skeleton"): - self["generateCliSkeleton"] = kwargs.get("generate_cli_skeleton") -class Model(BaseModel, dict): - def __init__(self, model_id, name, **kwargs): +class Model(BaseModel, dict): # type: ignore[type-arg] + def __init__(self, model_id: str, name: str, **kwargs: Any): super().__init__() self["id"] = model_id self["name"] = name @@ -1172,13 +1252,9 @@ class Model(BaseModel, dict): self["schema"] = kwargs.get("schema") if kwargs.get("content_type"): self["contentType"] = kwargs.get("content_type") - if kwargs.get("cli_input_json"): - self["cliInputJson"] = kwargs.get("cli_input_json") - if kwargs.get("generate_cli_skeleton"): - self["generateCliSkeleton"] = kwargs.get("generate_cli_skeleton") -class BasePathMapping(BaseModel, dict): +class BasePathMapping(BaseModel, dict): # type: ignore[type-arg] # operations OPERATION_REPLACE = "replace" @@ -1186,7 +1262,7 @@ class BasePathMapping(BaseModel, dict): OPERATION_VALUE = "value" OPERATION_OP = "op" - def __init__(self, domain_name, rest_api_id, **kwargs): + def __init__(self, domain_name: str, rest_api_id: str, **kwargs: Any): super().__init__() self["domain_name"] = domain_name self["restApiId"] = rest_api_id @@ -1197,8 +1273,7 @@ class BasePathMapping(BaseModel, dict): if kwargs.get("stage"): self["stage"] = kwargs.get("stage") - def apply_patch_operations(self, patch_operations): - + def apply_patch_operations(self, patch_operations: List[Dict[str, Any]]) -> None: for op in patch_operations: path = op["path"] value = op["value"] @@ -1212,9 +1287,13 @@ class BasePathMapping(BaseModel, dict): self["stage"] = value -class GatewayResponse(BaseModel, dict): +class GatewayResponse(BaseModel, dict): # type: ignore[type-arg] def __init__( - self, response_type, status_code, response_parameters, response_templates + self, + response_type: str, + status_code: int, + response_parameters: Dict[str, Any], + response_templates: Dict[str, str], ): super().__init__() self["responseType"] = response_type @@ -1254,27 +1333,27 @@ class APIGatewayBackend(BaseBackend): - This only works when using the decorators, not in ServerMode """ - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.apis = {} - self.keys = {} - self.usage_plans = {} - self.usage_plan_keys = {} - self.domain_names = {} - self.models = {} - self.base_path_mappings = {} - self.vpc_links = {} + self.apis: Dict[str, RestAPI] = {} + self.keys: Dict[str, ApiKey] = {} + self.usage_plans: Dict[str, UsagePlan] = {} + self.usage_plan_keys: Dict[str, Dict[str, UsagePlanKey]] = {} + self.domain_names: Dict[str, DomainName] = {} + self.models: Dict[str, Model] = {} + self.base_path_mappings: Dict[str, Dict[str, BasePathMapping]] = {} + self.vpc_links: Dict[str, VpcLink] = {} def create_rest_api( self, - name, - description, - api_key_source=None, - endpoint_configuration=None, - tags=None, - policy=None, - minimum_compression_size=None, - ): + name: str, + description: str, + api_key_source: Optional[str] = None, + endpoint_configuration: Optional[str] = None, + tags: Optional[List[Dict[str, str]]] = None, + policy: Optional[str] = None, + minimum_compression_size: Optional[int] = None, + ) -> RestAPI: api_id = create_id() rest_api = RestAPI( api_id, @@ -1291,7 +1370,9 @@ class APIGatewayBackend(BaseBackend): self.apis[api_id] = rest_api return rest_api - def import_rest_api(self, api_doc, fail_on_warnings): + def import_rest_api( + self, api_doc: Dict[str, Any], fail_on_warnings: bool + ) -> RestAPI: """ Only a subset of the OpenAPI spec 3.x is currently implemented. """ @@ -1306,13 +1387,19 @@ class APIGatewayBackend(BaseBackend): self.put_rest_api(api.id, api_doc, fail_on_warnings=fail_on_warnings) return api - def get_rest_api(self, function_id): + def get_rest_api(self, function_id: str) -> RestAPI: rest_api = self.apis.get(function_id) if rest_api is None: raise RestAPINotFound() return rest_api - def put_rest_api(self, function_id, api_doc, mode="merge", fail_on_warnings=False): + def put_rest_api( + self, + function_id: str, + api_doc: Dict[str, Any], + fail_on_warnings: bool, + mode: str = "merge", + ) -> RestAPI: """ Only a subset of the OpenAPI spec 3.x is currently implemented. """ @@ -1365,63 +1452,67 @@ class APIGatewayBackend(BaseBackend): return self.get_rest_api(function_id) - def update_rest_api(self, function_id, patch_operations): + def update_rest_api( + self, function_id: str, patch_operations: List[Dict[str, Any]] + ) -> RestAPI: rest_api = self.apis.get(function_id) if rest_api is None: raise RestAPINotFound() self.apis[function_id].apply_patch_operations(patch_operations) return self.apis[function_id] - def list_apis(self): - return self.apis.values() + def list_apis(self) -> List[RestAPI]: + return list(self.apis.values()) - def delete_rest_api(self, function_id): + def delete_rest_api(self, function_id: str) -> RestAPI: rest_api = self.apis.pop(function_id) return rest_api - def get_resources(self, function_id): + def get_resources(self, function_id: str) -> List[Resource]: api = self.get_rest_api(function_id) - return api.resources.values() + return list(api.resources.values()) - def get_resource(self, function_id, resource_id): + def get_resource(self, function_id: str, resource_id: str) -> Resource: api = self.get_rest_api(function_id) if resource_id not in api.resources: raise ResourceIdNotFoundException return api.resources[resource_id] - def create_resource(self, function_id, parent_resource_id, path_part): + def create_resource( + self, function_id: str, parent_resource_id: str, path_part: str + ) -> Resource: api = self.get_rest_api(function_id) if not path_part: # We're attempting to create the default resource, which already exists. return api.default if not re.match("^\\{?[a-zA-Z0-9._-]+\\+?\\}?$", path_part): raise InvalidResourcePathException() - child = api.add_child(path=path_part, parent_id=parent_resource_id) - return child + return api.add_child(path=path_part, parent_id=parent_resource_id) - def delete_resource(self, function_id, resource_id): + def delete_resource(self, function_id: str, resource_id: str) -> Resource: api = self.get_rest_api(function_id) - resource = api.resources.pop(resource_id) - return resource + return api.resources.pop(resource_id) - def get_method(self, function_id, resource_id, method_type): + def get_method( + self, function_id: str, resource_id: str, method_type: str + ) -> Method: resource = self.get_resource(function_id, resource_id) return resource.get_method(method_type) def put_method( self, - function_id, - resource_id, - method_type, - authorization_type, - api_key_required=None, - request_parameters=None, - request_models=None, - operation_name=None, - authorizer_id=None, - authorization_scopes=None, - request_validator_id=None, - ): + function_id: str, + resource_id: str, + method_type: str, + authorization_type: Optional[str], + api_key_required: Optional[bool] = None, + request_parameters: Optional[Dict[str, Any]] = None, + request_models: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + authorizer_id: Optional[str] = None, + authorization_scopes: Optional[str] = None, + request_validator_id: Optional[str] = None, + ) -> Method: resource = self.get_resource(function_id, resource_id) method = resource.add_method( method_type, @@ -1436,16 +1527,13 @@ class APIGatewayBackend(BaseBackend): ) return method - def update_method(self, function_id, resource_id, method_type, patch_operations): - resource = self.get_resource(function_id, resource_id) - method = resource.get_method(method_type) - return method.apply_operations(patch_operations) - - def delete_method(self, function_id, resource_id, method_type): + def delete_method( + self, function_id: str, resource_id: str, method_type: str + ) -> None: resource = self.get_resource(function_id, resource_id) resource.delete_method(method_type) - def get_authorizer(self, restapi_id, authorizer_id): + def get_authorizer(self, restapi_id: str, authorizer_id: str) -> Authorizer: api = self.get_rest_api(restapi_id) authorizer = api.authorizers.get(authorizer_id) if authorizer is None: @@ -1453,14 +1541,16 @@ class APIGatewayBackend(BaseBackend): else: return authorizer - def get_authorizers(self, restapi_id): + def get_authorizers(self, restapi_id: str) -> List[Authorizer]: api = self.get_rest_api(restapi_id) return api.get_authorizers() - def create_authorizer(self, restapi_id, name, authorizer_type, **kwargs): + def create_authorizer( + self, restapi_id: str, name: str, authorizer_type: str, **kwargs: Any + ) -> Authorizer: api = self.get_rest_api(restapi_id) authorizer_id = create_id() - authorizer = api.create_authorizer( + return api.create_authorizer( authorizer_id, name, authorizer_type, @@ -1472,46 +1562,44 @@ class APIGatewayBackend(BaseBackend): identiy_validation_expression=kwargs.get("identiy_validation_expression"), authorizer_result_ttl=kwargs.get("authorizer_result_ttl"), ) - return api.authorizers.get(authorizer["id"]) - def update_authorizer(self, restapi_id, authorizer_id, patch_operations): + def update_authorizer( + self, restapi_id: str, authorizer_id: str, patch_operations: Any + ) -> Authorizer: authorizer = self.get_authorizer(restapi_id, authorizer_id) - if not authorizer: - api = self.get_rest_api(restapi_id) - authorizer = api.authorizers[authorizer_id] = Authorizer() return authorizer.apply_operations(patch_operations) - def delete_authorizer(self, restapi_id, authorizer_id): + def delete_authorizer(self, restapi_id: str, authorizer_id: str) -> None: api = self.get_rest_api(restapi_id) del api.authorizers[authorizer_id] - def get_stage(self, function_id, stage_name) -> Stage: + def get_stage(self, function_id: str, stage_name: str) -> Stage: api = self.get_rest_api(function_id) stage = api.stages.get(stage_name) if stage is None: raise StageNotFoundException() return stage - def get_stages(self, function_id): + def get_stages(self, function_id: str) -> List[Stage]: api = self.get_rest_api(function_id) return api.get_stages() def create_stage( self, - function_id, - stage_name, - deploymentId, - variables=None, - description="", - cacheClusterEnabled=None, - cacheClusterSize=None, - tags=None, - tracing_enabled=None, - ): + function_id: str, + stage_name: str, + deploymentId: str, + variables: Optional[Any] = None, + description: str = "", + cacheClusterEnabled: Optional[bool] = None, + cacheClusterSize: Optional[str] = None, + tags: Optional[List[Dict[str, str]]] = None, + tracing_enabled: Optional[bool] = None, + ) -> Stage: if variables is None: variables = {} api = self.get_rest_api(function_id) - api.create_stage( + return api.create_stage( stage_name, deploymentId, variables=variables, @@ -1521,72 +1609,64 @@ class APIGatewayBackend(BaseBackend): tags=tags, tracing_enabled=tracing_enabled, ) - return api.stages.get(stage_name) - def update_stage(self, function_id, stage_name, patch_operations): + def update_stage( + self, function_id: str, stage_name: str, patch_operations: Any + ) -> Stage: stage = self.get_stage(function_id, stage_name) if not stage: api = self.get_rest_api(function_id) stage = api.stages[stage_name] = Stage() return stage.apply_operations(patch_operations) - def delete_stage(self, function_id, stage_name): + def delete_stage(self, function_id: str, stage_name: str) -> None: api = self.get_rest_api(function_id) deleted = api.stages.pop(stage_name, None) if not deleted: raise StageNotFoundException() - def get_method_response(self, function_id, resource_id, method_type, response_code): + def get_method_response( + self, function_id: str, resource_id: str, method_type: str, response_code: str + ) -> Optional[MethodResponse]: method = self.get_method(function_id, resource_id, method_type) - method_response = method.get_response(response_code) - return method_response + return method.get_response(response_code) def put_method_response( self, - function_id, - resource_id, - method_type, - response_code, - response_models, - response_parameters, - ): + function_id: str, + resource_id: str, + method_type: str, + response_code: str, + response_models: Any, + response_parameters: Any, + ) -> MethodResponse: method = self.get_method(function_id, resource_id, method_type) - method_response = method.create_response( + return method.create_response( response_code, response_models, response_parameters ) - return method_response - - def update_method_response( - self, function_id, resource_id, method_type, response_code, patch_operations - ): - method = self.get_method(function_id, resource_id, method_type) - method_response = method.get_response(response_code) - method_response.apply_operations(patch_operations) - return method_response def delete_method_response( - self, function_id, resource_id, method_type, response_code - ): + self, function_id: str, resource_id: str, method_type: str, response_code: str + ) -> Optional[MethodResponse]: method = self.get_method(function_id, resource_id, method_type) - method_response = method.delete_response(response_code) - return method_response + return method.delete_response(response_code) def put_integration( self, - function_id, - resource_id, - method_type, - integration_type, - uri, - integration_method=None, - credentials=None, - request_templates=None, - passthrough_behavior=None, - tls_config=None, - cache_namespace=None, - timeout_in_millis=None, - request_parameters=None, - ): + function_id: str, + resource_id: str, + method_type: str, + integration_type: str, + uri: str, + integration_method: str, + credentials: Optional[str] = None, + request_templates: Optional[Dict[str, Any]] = None, + passthrough_behavior: Optional[str] = None, + tls_config: Optional[str] = None, + cache_namespace: Optional[str] = None, + timeout_in_millis: Optional[str] = None, + request_parameters: Optional[Dict[str, Any]] = None, + ) -> Integration: resource = self.get_resource(function_id, resource_id) if credentials and not re.match( "^arn:aws:iam::" + str(self.account_id), credentials @@ -1631,24 +1711,28 @@ class APIGatewayBackend(BaseBackend): ) return integration - def get_integration(self, function_id, resource_id, method_type): + def get_integration( + self, function_id: str, resource_id: str, method_type: str + ) -> Integration: resource = self.get_resource(function_id, resource_id) return resource.get_integration(method_type) - def delete_integration(self, function_id, resource_id, method_type): + def delete_integration( + self, function_id: str, resource_id: str, method_type: str + ) -> Integration: resource = self.get_resource(function_id, resource_id) return resource.delete_integration(method_type) def put_integration_response( self, - function_id, - resource_id, - method_type, - status_code, - selection_pattern, - response_templates, - content_handling, - ): + function_id: str, + resource_id: str, + method_type: str, + status_code: str, + selection_pattern: str, + response_templates: Dict[str, str], + content_handling: str, + ) -> IntegrationResponse: integration = self.get_integration(function_id, resource_id, method_type) if integration: return integration.create_integration_response( @@ -1657,30 +1741,34 @@ class APIGatewayBackend(BaseBackend): raise NoIntegrationResponseDefined() def get_integration_response( - self, function_id, resource_id, method_type, status_code - ): + self, function_id: str, resource_id: str, method_type: str, status_code: str + ) -> IntegrationResponse: integration = self.get_integration(function_id, resource_id, method_type) integration_response = integration.get_integration_response(status_code) return integration_response def delete_integration_response( - self, function_id, resource_id, method_type, status_code - ): + self, function_id: str, resource_id: str, method_type: str, status_code: str + ) -> IntegrationResponse: integration = self.get_integration(function_id, resource_id, method_type) integration_response = integration.delete_integration_response(status_code) return integration_response def create_deployment( - self, function_id, name, description="", stage_variables=None - ): + self, + function_id: str, + name: str, + description: str = "", + stage_variables: Any = None, + ) -> Deployment: if stage_variables is None: stage_variables = {} api = self.get_rest_api(function_id) - methods = [ + nested_methods = [ list(res.resource_methods.values()) for res in self.get_resources(function_id) ] - methods = [m for sublist in methods for m in sublist] + methods = [m for sublist in nested_methods for m in sublist] if not any(methods): raise NoMethodDefined() method_integrations = [ @@ -1691,19 +1779,19 @@ class APIGatewayBackend(BaseBackend): deployment = api.create_deployment(name, description, stage_variables) return deployment - def get_deployment(self, function_id, deployment_id): + def get_deployment(self, function_id: str, deployment_id: str) -> Deployment: api = self.get_rest_api(function_id) return api.get_deployment(deployment_id) - def get_deployments(self, function_id): + def get_deployments(self, function_id: str) -> List[Deployment]: api = self.get_rest_api(function_id) return api.get_deployments() - def delete_deployment(self, function_id, deployment_id): + def delete_deployment(self, function_id: str, deployment_id: str) -> Deployment: api = self.get_rest_api(function_id) return api.delete_deployment(deployment_id) - def create_api_key(self, payload): + def create_api_key(self, payload: Dict[str, Any]) -> ApiKey: if payload.get("value"): if len(payload.get("value", [])) < 20: raise ApiKeyValueMinLength() @@ -1714,7 +1802,7 @@ class APIGatewayBackend(BaseBackend): self.keys[key["id"]] = key return key - def get_api_keys(self, include_values=False): + def get_api_keys(self, include_values: bool) -> List[ApiKey]: api_keys = list(self.keys.values()) if not include_values: @@ -1727,7 +1815,7 @@ class APIGatewayBackend(BaseBackend): return api_keys - def get_api_key(self, api_key_id, include_value=False): + def get_api_key(self, api_key_id: str, include_value: bool = False) -> ApiKey: api_key = self.keys.get(api_key_id) if not api_key: raise ApiKeyNotFoundException() @@ -1739,45 +1827,45 @@ class APIGatewayBackend(BaseBackend): return api_key - def update_api_key(self, api_key_id, patch_operations): + def update_api_key(self, api_key_id: str, patch_operations: Any) -> ApiKey: key = self.keys[api_key_id] return key.update_operations(patch_operations) - def delete_api_key(self, api_key_id): + def delete_api_key(self, api_key_id: str) -> None: self.keys.pop(api_key_id) - return {} - def create_usage_plan(self, payload): + def create_usage_plan(self, payload: Any) -> UsagePlan: plan = UsagePlan(**payload) self.usage_plans[plan["id"]] = plan return plan - def get_usage_plans(self, api_key_id=None): + def get_usage_plans(self, api_key_id: Optional[str] = None) -> List[UsagePlan]: plans = list(self.usage_plans.values()) if api_key_id is not None: plans = [ plan for plan in plans - if self.usage_plan_keys.get(plan["id"], {}).get(api_key_id, False) + if dict(self.usage_plan_keys.get(plan["id"], {})).get(api_key_id) ] return plans - def get_usage_plan(self, usage_plan_id): + def get_usage_plan(self, usage_plan_id: str) -> UsagePlan: if usage_plan_id not in self.usage_plans: raise UsagePlanNotFoundException() return self.usage_plans[usage_plan_id] - def update_usage_plan(self, usage_plan_id, patch_operations): + def update_usage_plan(self, usage_plan_id: str, patch_operations: Any) -> UsagePlan: if usage_plan_id not in self.usage_plans: raise UsagePlanNotFoundException() self.usage_plans[usage_plan_id].apply_patch_operations(patch_operations) return self.usage_plans[usage_plan_id] - def delete_usage_plan(self, usage_plan_id): + def delete_usage_plan(self, usage_plan_id: str) -> None: self.usage_plans.pop(usage_plan_id) - return {} - def create_usage_plan_key(self, usage_plan_id, payload): + def create_usage_plan_key( + self, usage_plan_id: str, payload: Dict[str, Any] + ) -> UsagePlanKey: if usage_plan_id not in self.usage_plan_keys: self.usage_plan_keys[usage_plan_id] = {} @@ -1796,13 +1884,13 @@ class APIGatewayBackend(BaseBackend): self.usage_plan_keys[usage_plan_id][usage_plan_key["id"]] = usage_plan_key return usage_plan_key - def get_usage_plan_keys(self, usage_plan_id): + def get_usage_plan_keys(self, usage_plan_id: str) -> List[UsagePlanKey]: if usage_plan_id not in self.usage_plan_keys: return [] return list(self.usage_plan_keys[usage_plan_id].values()) - def get_usage_plan_key(self, usage_plan_id, key_id): + def get_usage_plan_key(self, usage_plan_id: str, key_id: str) -> UsagePlanKey: # first check if is a valid api key if key_id not in self.keys: raise ApiKeyNotFoundException() @@ -1816,11 +1904,10 @@ class APIGatewayBackend(BaseBackend): return self.usage_plan_keys[usage_plan_id][key_id] - def delete_usage_plan_key(self, usage_plan_id, key_id): + def delete_usage_plan_key(self, usage_plan_id: str, key_id: str) -> None: self.usage_plan_keys[usage_plan_id].pop(key_id) - return {} - def _uri_validator(self, uri): + def _uri_validator(self, uri: str) -> bool: try: result = urlparse(uri) return all([result.scheme, result.netloc, result.path or "/"]) @@ -1829,20 +1916,18 @@ class APIGatewayBackend(BaseBackend): def create_domain_name( self, - domain_name, - certificate_name=None, - tags=None, - certificate_arn=None, - certificate_body=None, - certificate_private_key=None, - certificate_chain=None, - regional_certificate_name=None, - regional_certificate_arn=None, - endpoint_configuration=None, - security_policy=None, - generate_cli_skeleton=None, - ): - + domain_name: str, + certificate_name: str, + tags: List[Dict[str, str]], + certificate_arn: str, + certificate_body: str, + certificate_private_key: str, + certificate_chain: str, + regional_certificate_name: str, + regional_certificate_arn: str, + endpoint_configuration: Any, + security_policy: str, + ) -> DomainName: if not domain_name: raise InvalidDomainName() @@ -1858,46 +1943,35 @@ class APIGatewayBackend(BaseBackend): endpoint_configuration=endpoint_configuration, tags=tags, security_policy=security_policy, - generate_cli_skeleton=generate_cli_skeleton, region_name=self.region_name, ) self.domain_names[domain_name] = new_domain_name return new_domain_name - def get_domain_names(self): + def get_domain_names(self) -> List[DomainName]: return list(self.domain_names.values()) - def get_domain_name(self, domain_name): + def get_domain_name(self, domain_name: str) -> DomainName: domain_info = self.domain_names.get(domain_name) if domain_info is None: raise DomainNameNotFound() else: - return self.domain_names[domain_name] + return domain_info - def delete_domain_name(self, domain_name): + def delete_domain_name(self, domain_name: str) -> None: domain_info = self.domain_names.pop(domain_name, None) if domain_info is None: raise DomainNameNotFound() - def update_domain_name(self, domain_name, patch_operations): - domain_info = self.domain_names.get(domain_name) - if not domain_info: - raise DomainNameNotFound() - domain_info.apply_patch_operations(patch_operations) - return domain_info - def create_model( self, - rest_api_id, - name, - content_type, - description=None, - schema=None, - cli_input_json=None, - generate_cli_skeleton=None, - ): - + rest_api_id: str, + name: str, + content_type: str, + description: str, + schema: str, + ) -> Model: if not rest_api_id: raise InvalidRestApiId if not name: @@ -1909,20 +1983,18 @@ class APIGatewayBackend(BaseBackend): description=description, schema=schema, content_type=content_type, - cli_input_json=cli_input_json, - generate_cli_skeleton=generate_cli_skeleton, ) return new_model - def get_models(self, rest_api_id): + def get_models(self, rest_api_id: str) -> List[Model]: if not rest_api_id: raise InvalidRestApiId api = self.get_rest_api(rest_api_id) models = api.models.values() return list(models) - def get_model(self, rest_api_id, model_name): + def get_model(self, rest_api_id: str, model_name: str) -> Model: if not rest_api_id: raise InvalidRestApiId api = self.get_rest_api(rest_api_id) @@ -1932,31 +2004,37 @@ class APIGatewayBackend(BaseBackend): else: return model - def get_request_validators(self, restapi_id): + def get_request_validators(self, restapi_id: str) -> List[RequestValidator]: restApi = self.get_rest_api(restapi_id) return restApi.get_request_validators() - def create_request_validator(self, restapi_id, name, body, params): + def create_request_validator( + self, restapi_id: str, name: str, body: Optional[bool], params: Any + ) -> RequestValidator: restApi = self.get_rest_api(restapi_id) return restApi.create_request_validator( name=name, validateRequestBody=body, validateRequestParameters=params ) - def get_request_validator(self, restapi_id, validator_id): + def get_request_validator( + self, restapi_id: str, validator_id: str + ) -> RequestValidator: restApi = self.get_rest_api(restapi_id) return restApi.get_request_validator(validator_id) - def delete_request_validator(self, restapi_id, validator_id): + def delete_request_validator(self, restapi_id: str, validator_id: str) -> None: restApi = self.get_rest_api(restapi_id) restApi.delete_request_validator(validator_id) - def update_request_validator(self, restapi_id, validator_id, patch_operations): + def update_request_validator( + self, restapi_id: str, validator_id: str, patch_operations: Any + ) -> RequestValidator: restApi = self.get_rest_api(restapi_id) return restApi.update_request_validator(validator_id, patch_operations) def create_base_path_mapping( - self, domain_name, rest_api_id, base_path=None, stage=None - ): + self, domain_name: str, rest_api_id: str, base_path: str, stage: str + ) -> BasePathMapping: if domain_name not in self.domain_names: raise DomainNameNotFound() @@ -1981,22 +2059,22 @@ class APIGatewayBackend(BaseBackend): self.base_path_mappings[domain_name] = {} else: if ( - self.base_path_mappings[domain_name].get(new_base_path) + self.base_path_mappings[domain_name].get(new_base_path) # type: ignore[arg-type] and new_base_path != "(none)" ): raise BasePathConflictException() - self.base_path_mappings[domain_name][new_base_path] = new_base_path_mapping + self.base_path_mappings[domain_name][new_base_path] = new_base_path_mapping # type: ignore[index] return new_base_path_mapping - def get_base_path_mappings(self, domain_name): - + def get_base_path_mappings(self, domain_name: str) -> List[BasePathMapping]: if domain_name not in self.domain_names: raise DomainNameNotFound() return list(self.base_path_mappings[domain_name].values()) - def get_base_path_mapping(self, domain_name, base_path): - + def get_base_path_mapping( + self, domain_name: str, base_path: str + ) -> BasePathMapping: if domain_name not in self.domain_names: raise DomainNameNotFound() @@ -2005,8 +2083,7 @@ class APIGatewayBackend(BaseBackend): return self.base_path_mappings[domain_name][base_path] - def delete_base_path_mapping(self, domain_name, base_path): - + def delete_base_path_mapping(self, domain_name: str, base_path: str) -> None: if domain_name not in self.domain_names: raise DomainNameNotFound() @@ -2015,7 +2092,9 @@ class APIGatewayBackend(BaseBackend): self.base_path_mappings[domain_name].pop(base_path) - def update_base_path_mapping(self, domain_name, base_path, patch_operations): + def update_base_path_mapping( + self, domain_name: str, base_path: str, patch_operations: Any + ) -> BasePathMapping: if domain_name not in self.domain_names: raise DomainNameNotFound() @@ -2061,22 +2140,28 @@ class APIGatewayBackend(BaseBackend): return base_path_mapping - def create_vpc_link(self, name, description, target_arns, tags): + def create_vpc_link( + self, + name: str, + description: str, + target_arns: List[str], + tags: List[Dict[str, str]], + ) -> VpcLink: vpc_link = VpcLink( name, description=description, target_arns=target_arns, tags=tags ) self.vpc_links[vpc_link["id"]] = vpc_link return vpc_link - def delete_vpc_link(self, vpc_link_id): + def delete_vpc_link(self, vpc_link_id: str) -> None: self.vpc_links.pop(vpc_link_id, None) - def get_vpc_link(self, vpc_link_id): + def get_vpc_link(self, vpc_link_id: str) -> VpcLink: if vpc_link_id not in self.vpc_links: raise VpcLinkNotFound return self.vpc_links[vpc_link_id] - def get_vpc_links(self): + def get_vpc_links(self) -> List[VpcLink]: """ Pagination has not yet been implemented """ @@ -2084,12 +2169,12 @@ class APIGatewayBackend(BaseBackend): def put_gateway_response( self, - rest_api_id, - response_type, - status_code, - response_parameters, - response_templates, - ): + rest_api_id: str, + response_type: str, + status_code: int, + response_parameters: Any, + response_templates: Any, + ) -> GatewayResponse: api = self.get_rest_api(rest_api_id) response = api.put_gateway_response( response_type, @@ -2099,18 +2184,20 @@ class APIGatewayBackend(BaseBackend): ) return response - def get_gateway_response(self, rest_api_id, response_type): + def get_gateway_response( + self, rest_api_id: str, response_type: str + ) -> GatewayResponse: api = self.get_rest_api(rest_api_id) return api.get_gateway_response(response_type) - def get_gateway_responses(self, rest_api_id): + def get_gateway_responses(self, rest_api_id: str) -> List[GatewayResponse]: """ Pagination is not yet implemented """ api = self.get_rest_api(rest_api_id) return api.get_gateway_responses() - def delete_gateway_response(self, rest_api_id, response_type): + def delete_gateway_response(self, rest_api_id: str, response_type: str) -> None: api = self.get_rest_api(rest_api_id) api.delete_gateway_response(response_type) diff --git a/moto/apigateway/responses.py b/moto/apigateway/responses.py index 5ddd7b840..4595ef1f7 100644 --- a/moto/apigateway/responses.py +++ b/moto/apigateway/responses.py @@ -1,9 +1,10 @@ import json +from typing import Any, Dict, List, Tuple from urllib.parse import unquote from moto.utilities.utils import merge_multiple_dicts from moto.core.responses import BaseResponse -from .models import apigateway_backends +from .models import apigateway_backends, APIGatewayBackend from .utils import deserialize_body from .exceptions import InvalidRequestInput @@ -11,21 +12,23 @@ API_KEY_SOURCES = ["AUTHORIZER", "HEADER"] AUTHORIZER_TYPES = ["TOKEN", "REQUEST", "COGNITO_USER_POOLS"] ENDPOINT_CONFIGURATION_TYPES = ["PRIVATE", "EDGE", "REGIONAL"] +RESPONSE_TYPE = Tuple[int, Dict[str, str], str] + class APIGatewayResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="apigateway") - def error(self, type_, message, status=400): + def error(self, type_: str, message: str, status: int = 400) -> RESPONSE_TYPE: headers = self.response_headers or {} headers["X-Amzn-Errortype"] = type_ return (status, headers, json.dumps({"__type": type_, "message": message})) @property - def backend(self): + def backend(self) -> APIGatewayBackend: return apigateway_backends[self.current_account][self.region] - def __validate_api_key_source(self, api_key_source): + def __validate_api_key_source(self, api_key_source: str) -> RESPONSE_TYPE: # type: ignore[return] if api_key_source and api_key_source not in API_KEY_SOURCES: return self.error( "ValidationException", @@ -37,7 +40,7 @@ class APIGatewayResponse(BaseResponse): ).format(api_key_source=api_key_source), ) - def __validate_endpoint_configuration(self, endpoint_configuration): + def __validate_endpoint_configuration(self, endpoint_configuration: Dict[str, str]) -> RESPONSE_TYPE: # type: ignore[return] if endpoint_configuration and "types" in endpoint_configuration: invalid_types = list( set(endpoint_configuration["types"]) - set(ENDPOINT_CONFIGURATION_TYPES) @@ -53,7 +56,7 @@ class APIGatewayResponse(BaseResponse): ).format(endpoint_type=invalid_types[0]), ) - def restapis(self, request, full_url, headers): + def restapis(self, request: Any, full_url: str, headers: Dict[str, str]) -> RESPONSE_TYPE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "GET": @@ -62,7 +65,7 @@ class APIGatewayResponse(BaseResponse): elif self.method == "POST": api_doc = deserialize_body(self.body) if api_doc: - fail_on_warnings = self._get_bool_param("failonwarnings") + fail_on_warnings = self._get_bool_param("failonwarnings") or False rest_api = self.backend.import_rest_api(api_doc, fail_on_warnings) return 200, {}, json.dumps(rest_api.to_dict()) @@ -97,14 +100,16 @@ class APIGatewayResponse(BaseResponse): return 200, {}, json.dumps(rest_api.to_dict()) - def __validte_rest_patch_operations(self, patch_operations): + def __validte_rest_patch_operations(self, patch_operations: List[Dict[str, str]]) -> RESPONSE_TYPE: # type: ignore[return] for op in patch_operations: path = op["path"] if "apiKeySource" in path: value = op["value"] return self.__validate_api_key_source(value) - def restapis_individual(self, request, full_url, headers): + def restapis_individual( + self, request: Any, full_url: str, headers: Dict[str, str] + ) -> RESPONSE_TYPE: self.setup_class(request, full_url, headers) function_id = self.path.replace("/restapis/", "", 1).split("/")[0] @@ -114,12 +119,12 @@ class APIGatewayResponse(BaseResponse): rest_api = self.backend.delete_rest_api(function_id) elif self.method == "PUT": mode = self._get_param("mode", "merge") - fail_on_warnings = self._get_bool_param("failonwarnings", False) + fail_on_warnings = self._get_bool_param("failonwarnings") or False api_doc = deserialize_body(self.body) rest_api = self.backend.put_rest_api( - function_id, api_doc, mode, fail_on_warnings + function_id, api_doc, mode=mode, fail_on_warnings=fail_on_warnings ) elif self.method == "PATCH": patch_operations = self._get_param("patchOperations") @@ -130,7 +135,7 @@ class APIGatewayResponse(BaseResponse): return 200, {}, json.dumps(rest_api.to_dict()) - def resources(self, request, full_url, headers): + def resources(self, request: Any, full_url: str, headers: Dict[str, str]) -> RESPONSE_TYPE: # type: ignore[return] self.setup_class(request, full_url, headers) function_id = self.path.replace("/restapis/", "", 1).split("/")[0] @@ -142,7 +147,7 @@ class APIGatewayResponse(BaseResponse): json.dumps({"item": [resource.to_dict() for resource in resources]}), ) - def gateway_response(self, request, full_url, headers): + def gateway_response(self, request: Any, full_url: str, headers: Dict[str, str]) -> RESPONSE_TYPE: # type: ignore[return] self.setup_class(request, full_url, headers) if request.method == "PUT": return self.put_gateway_response() @@ -151,12 +156,12 @@ class APIGatewayResponse(BaseResponse): elif request.method == "DELETE": return self.delete_gateway_response() - def gateway_responses(self, request, full_url, headers): + def gateway_responses(self, request: Any, full_url: str, headers: Dict[str, str]) -> RESPONSE_TYPE: # type: ignore[return] self.setup_class(request, full_url, headers) if request.method == "GET": return self.get_gateway_responses() - def resource_individual(self, request, full_url, headers): + def resource_individual(self, request: Any, full_url: str, headers: Dict[str, str]) -> RESPONSE_TYPE: # type: ignore[return] self.setup_class(request, full_url, headers) function_id = self.path.replace("/restapis/", "", 1).split("/")[0] resource_id = self.path.split("/")[-1] @@ -172,7 +177,9 @@ class APIGatewayResponse(BaseResponse): resource = self.backend.delete_resource(function_id, resource_id) return 202, {}, json.dumps(resource.to_dict()) - def resource_methods(self, request, full_url, headers): + def resource_methods( + self, request: Any, full_url: str, headers: Dict[str, str] + ) -> RESPONSE_TYPE: self.setup_class(request, full_url, headers) url_path_parts = self.path.split("/") function_id = url_path_parts[2] @@ -210,15 +217,11 @@ class APIGatewayResponse(BaseResponse): self.backend.delete_method(function_id, resource_id, method_type) return 204, {}, "" - elif self.method == "PATCH": - patch_operations = self._get_param("patchOperations") - self.backend.update_method( - function_id, resource_id, method_type, patch_operations - ) - return 200, {}, "" - def resource_method_responses(self, request, full_url, headers): + def resource_method_responses( + self, request: Any, full_url: str, headers: Dict[str, str] + ) -> RESPONSE_TYPE: self.setup_class(request, full_url, headers) url_path_parts = self.path.split("/") function_id = url_path_parts[2] @@ -248,15 +251,9 @@ class APIGatewayResponse(BaseResponse): function_id, resource_id, method_type, response_code ) return 204, {}, json.dumps(method_response) - elif self.method == "PATCH": - patch_operations = self._get_param("patchOperations") - method_response = self.backend.update_method_response( - function_id, resource_id, method_type, response_code, patch_operations - ) - return 201, {}, json.dumps(method_response) raise Exception('Unexpected HTTP method "%s"' % self.method) - def restapis_authorizers(self, request, full_url, headers): + def restapis_authorizers(self, request: Any, full_url: str, headers: Dict[str, str]) -> RESPONSE_TYPE: # type: ignore[return] self.setup_class(request, full_url, headers) url_path_parts = self.path.split("/") restapi_id = url_path_parts[2] @@ -306,7 +303,7 @@ class APIGatewayResponse(BaseResponse): authorizers = self.backend.get_authorizers(restapi_id) return 200, {}, json.dumps({"item": authorizers}) - def request_validators(self, request, full_url, headers): + def request_validators(self, request: Any, full_url: str, headers: Dict[str, str]) -> RESPONSE_TYPE: # type: ignore[return] self.setup_class(request, full_url, headers) url_path_parts = self.path.split("/") restapi_id = url_path_parts[2] @@ -326,7 +323,7 @@ class APIGatewayResponse(BaseResponse): ) return 201, {}, json.dumps(validator) - def request_validator_individual(self, request, full_url, headers): + def request_validator_individual(self, request: Any, full_url: str, headers: Dict[str, str]) -> RESPONSE_TYPE: # type: ignore[return] self.setup_class(request, full_url, headers) url_path_parts = self.path.split("/") restapi_id = url_path_parts[2] @@ -345,7 +342,7 @@ class APIGatewayResponse(BaseResponse): ) return 200, {}, json.dumps(validator) - def authorizers(self, request, full_url, headers): + def authorizers(self, request: Any, full_url: str, headers: Dict[str, str]) -> RESPONSE_TYPE: # type: ignore[return] self.setup_class(request, full_url, headers) url_path_parts = self.path.split("/") restapi_id = url_path_parts[2] @@ -364,7 +361,7 @@ class APIGatewayResponse(BaseResponse): self.backend.delete_authorizer(restapi_id, authorizer_id) return 202, {}, "{}" - def restapis_stages(self, request, full_url, headers): + def restapis_stages(self, request: Any, full_url: str, headers: Dict[str, str]) -> RESPONSE_TYPE: # type: ignore[return] self.setup_class(request, full_url, headers) url_path_parts = self.path.split("/") function_id = url_path_parts[2] @@ -395,7 +392,7 @@ class APIGatewayResponse(BaseResponse): stages = self.backend.get_stages(function_id) return 200, {}, json.dumps({"item": stages}) - def restapis_stages_tags(self, request, full_url, headers): + def restapis_stages_tags(self, request: Any, full_url: str, headers: Dict[str, str]) -> RESPONSE_TYPE: # type: ignore[return] self.setup_class(request, full_url, headers) url_path_parts = self.path.split("/") function_id = url_path_parts[4] @@ -408,12 +405,12 @@ class APIGatewayResponse(BaseResponse): return 200, {}, json.dumps({"item": tags}) if self.method == "DELETE": stage = self.backend.get_stage(function_id, stage_name) - for tag in stage.get("tags").copy(): - if tag in self.querystring.get("tagKeys"): + for tag in stage.get("tags", {}).copy(): + if tag in (self.querystring.get("tagKeys") or {}): stage["tags"].pop(tag, None) return 200, {}, json.dumps({"item": ""}) - def stages(self, request, full_url, headers): + def stages(self, request: Any, full_url: str, headers: Dict[str, str]) -> RESPONSE_TYPE: # type: ignore[return] self.setup_class(request, full_url, headers) url_path_parts = self.path.split("/") function_id = url_path_parts[2] @@ -432,15 +429,13 @@ class APIGatewayResponse(BaseResponse): self.backend.delete_stage(function_id, stage_name) return 202, {}, "{}" - def integrations(self, request, full_url, headers): + def integrations(self, request: Any, full_url: str, headers: Dict[str, str]) -> RESPONSE_TYPE: # type: ignore[return] self.setup_class(request, full_url, headers) url_path_parts = self.path.split("/") function_id = url_path_parts[2] resource_id = url_path_parts[4] method_type = url_path_parts[6] - integration_response = {} - if self.method == "GET": integration_response = self.backend.get_integration( function_id, resource_id, method_type @@ -484,7 +479,7 @@ class APIGatewayResponse(BaseResponse): ) return 204, {}, json.dumps(integration_response) - def integration_responses(self, request, full_url, headers): + def integration_responses(self, request: Any, full_url: str, headers: Dict[str, str]) -> RESPONSE_TYPE: # type: ignore[return] self.setup_class(request, full_url, headers) url_path_parts = self.path.split("/") function_id = url_path_parts[2] @@ -520,7 +515,7 @@ class APIGatewayResponse(BaseResponse): ) return 204, {}, json.dumps(integration_response) - def deployments(self, request, full_url, headers): + def deployments(self, request: Any, full_url: str, headers: Dict[str, str]) -> RESPONSE_TYPE: # type: ignore[return] self.setup_class(request, full_url, headers) function_id = self.path.replace("/restapis/", "", 1).split("/")[0] @@ -536,7 +531,7 @@ class APIGatewayResponse(BaseResponse): ) return 201, {}, json.dumps(deployment) - def individual_deployment(self, request, full_url, headers): + def individual_deployment(self, request: Any, full_url: str, headers: Dict[str, str]) -> RESPONSE_TYPE: # type: ignore[return] self.setup_class(request, full_url, headers) url_path_parts = self.path.split("/") function_id = url_path_parts[2] @@ -549,7 +544,7 @@ class APIGatewayResponse(BaseResponse): deployment = self.backend.delete_deployment(function_id, deployment_id) return 202, {}, json.dumps(deployment) - def apikeys(self, request, full_url, headers): + def apikeys(self, request: Any, full_url: str, headers: Dict[str, str]) -> RESPONSE_TYPE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "POST": @@ -557,19 +552,20 @@ class APIGatewayResponse(BaseResponse): return 201, {}, json.dumps(apikey_response) elif self.method == "GET": - include_values = self._get_bool_param("includeValues") + include_values = self._get_bool_param("includeValues") or False apikeys_response = self.backend.get_api_keys(include_values=include_values) return 200, {}, json.dumps({"item": apikeys_response}) - def apikey_individual(self, request, full_url, headers): + def apikey_individual( + self, request: Any, full_url: str, headers: Dict[str, str] + ) -> RESPONSE_TYPE: self.setup_class(request, full_url, headers) url_path_parts = self.path.split("/") apikey = url_path_parts[2] - status_code = 200 if self.method == "GET": - include_value = self._get_bool_param("includeValue") + include_value = self._get_bool_param("includeValue") or False apikey_response = self.backend.get_api_key( apikey, include_value=include_value ) @@ -577,12 +573,12 @@ class APIGatewayResponse(BaseResponse): patch_operations = self._get_param("patchOperations") apikey_response = self.backend.update_api_key(apikey, patch_operations) elif self.method == "DELETE": - apikey_response = self.backend.delete_api_key(apikey) - status_code = 202 + self.backend.delete_api_key(apikey) + return 202, {}, "{}" - return status_code, {}, json.dumps(apikey_response) + return 200, {}, json.dumps(apikey_response) - def usage_plans(self, request, full_url, headers): + def usage_plans(self, request: Any, full_url: str, headers: Dict[str, str]) -> RESPONSE_TYPE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "POST": usage_plan_response = self.backend.create_usage_plan(json.loads(self.body)) @@ -592,7 +588,7 @@ class APIGatewayResponse(BaseResponse): usage_plans_response = self.backend.get_usage_plans(api_key_id=api_key_id) return 200, {}, json.dumps({"item": usage_plans_response}) - def usage_plan_individual(self, request, full_url, headers): + def usage_plan_individual(self, request: Any, full_url: str, headers: Dict[str, str]) -> RESPONSE_TYPE: # type: ignore[return] self.setup_class(request, full_url, headers) url_path_parts = self.path.split("/") @@ -602,8 +598,8 @@ class APIGatewayResponse(BaseResponse): usage_plan_response = self.backend.get_usage_plan(usage_plan) return 200, {}, json.dumps(usage_plan_response) elif self.method == "DELETE": - usage_plan_response = self.backend.delete_usage_plan(usage_plan) - return 202, {}, json.dumps(usage_plan_response) + self.backend.delete_usage_plan(usage_plan) + return 202, {}, "{}" elif self.method == "PATCH": patch_operations = self._get_param("patchOperations") usage_plan_response = self.backend.update_usage_plan( @@ -611,7 +607,7 @@ class APIGatewayResponse(BaseResponse): ) return 200, {}, json.dumps(usage_plan_response) - def usage_plan_keys(self, request, full_url, headers): + def usage_plan_keys(self, request: Any, full_url: str, headers: Dict[str, str]) -> RESPONSE_TYPE: # type: ignore[return] self.setup_class(request, full_url, headers) url_path_parts = self.path.split("/") @@ -626,7 +622,7 @@ class APIGatewayResponse(BaseResponse): usage_plans_response = self.backend.get_usage_plan_keys(usage_plan_id) return 200, {}, json.dumps({"item": usage_plans_response}) - def usage_plan_key_individual(self, request, full_url, headers): + def usage_plan_key_individual(self, request: Any, full_url: str, headers: Dict[str, str]) -> RESPONSE_TYPE: # type: ignore[return] self.setup_class(request, full_url, headers) url_path_parts = self.path.split("/") @@ -637,12 +633,10 @@ class APIGatewayResponse(BaseResponse): usage_plan_response = self.backend.get_usage_plan_key(usage_plan_id, key_id) return 200, {}, json.dumps(usage_plan_response) elif self.method == "DELETE": - usage_plan_response = self.backend.delete_usage_plan_key( - usage_plan_id, key_id - ) - return 202, {}, json.dumps(usage_plan_response) + self.backend.delete_usage_plan_key(usage_plan_id, key_id) + return 202, {}, "{}" - def domain_names(self, request, full_url, headers): + def domain_names(self, request: Any, full_url: str, headers: Dict[str, str]) -> RESPONSE_TYPE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "GET": @@ -661,7 +655,6 @@ class APIGatewayResponse(BaseResponse): regional_certificate_arn = self._get_param("regionalCertificateArn") endpoint_configuration = self._get_param("endpointConfiguration") security_policy = self._get_param("securityPolicy") - generate_cli_skeleton = self._get_param("generateCliSkeleton") domain_name_resp = self.backend.create_domain_name( domain_name, certificate_name, @@ -674,35 +667,31 @@ class APIGatewayResponse(BaseResponse): regional_certificate_arn, endpoint_configuration, security_policy, - generate_cli_skeleton, ) return 201, {}, json.dumps(domain_name_resp) - def domain_name_induvidual(self, request, full_url, headers): + def domain_name_induvidual( + self, request: Any, full_url: str, headers: Dict[str, str] + ) -> RESPONSE_TYPE: self.setup_class(request, full_url, headers) url_path_parts = self.path.split("/") domain_name = url_path_parts[2] - domain_names = {} if self.method == "GET": if domain_name is not None: domain_names = self.backend.get_domain_name(domain_name) - return 200, {}, json.dumps(domain_names) + return 200, {}, json.dumps(domain_names) + return 200, {}, "{}" elif self.method == "DELETE": if domain_name is not None: self.backend.delete_domain_name(domain_name) return 202, {}, json.dumps({}) - elif self.method == "PATCH": - if domain_name is not None: - patch_operations = self._get_param("patchOperations") - self.backend.update_domain_name(domain_name, patch_operations) - return 200, {}, json.dumps(domain_name) else: msg = 'Method "%s" for API GW domain names not implemented' % self.method return 404, {}, json.dumps({"error": msg}) - def models(self, request, full_url, headers): + def models(self, request: Any, full_url: str, headers: Dict[str, str]) -> RESPONSE_TYPE: # type: ignore[return] self.setup_class(request, full_url, headers) rest_api_id = self.path.replace("/restapis/", "", 1).split("/")[0] @@ -715,30 +704,29 @@ class APIGatewayResponse(BaseResponse): description = self._get_param("description") schema = self._get_param("schema") content_type = self._get_param("contentType") - cli_input_json = self._get_param("cliInputJson") - generate_cli_skeleton = self._get_param("generateCliSkeleton") model = self.backend.create_model( rest_api_id, name, content_type, description, schema, - cli_input_json, - generate_cli_skeleton, ) return 201, {}, json.dumps(model) - def model_induvidual(self, request, full_url, headers): + def model_induvidual( + self, request: Any, full_url: str, headers: Dict[str, str] + ) -> RESPONSE_TYPE: self.setup_class(request, full_url, headers) url_path_parts = self.path.split("/") rest_api_id = url_path_parts[2] model_name = url_path_parts[4] - model_info = {} + if self.method == "GET": model_info = self.backend.get_model(rest_api_id, model_name) - return 200, {}, json.dumps(model_info) + return 200, {}, json.dumps(model_info) + return 200, {}, "{}" - def base_path_mappings(self, request, full_url, headers): + def base_path_mappings(self, request: Any, full_url: str, headers: Dict[str, str]) -> RESPONSE_TYPE: # type: ignore[return] self.setup_class(request, full_url, headers) url_path_parts = self.path.split("/") @@ -757,7 +745,7 @@ class APIGatewayResponse(BaseResponse): ) return 201, {}, json.dumps(base_path_mapping_resp) - def base_path_mapping_individual(self, request, full_url, headers): + def base_path_mapping_individual(self, request: Any, full_url: str, headers: Dict[str, str]) -> RESPONSE_TYPE: # type: ignore[return] self.setup_class(request, full_url, headers) @@ -780,7 +768,7 @@ class APIGatewayResponse(BaseResponse): ) return 200, {}, json.dumps(base_path_mapping) - def vpc_link(self, request, full_url, headers): + def vpc_link(self, request: Any, full_url: str, headers: Dict[str, str]) -> RESPONSE_TYPE: # type: ignore[return] self.setup_class(request, full_url, headers) url_path_parts = self.path.split("/") vpc_link_id = url_path_parts[-1] @@ -792,7 +780,7 @@ class APIGatewayResponse(BaseResponse): vpc_link = self.backend.get_vpc_link(vpc_link_id=vpc_link_id) return 200, {}, json.dumps(vpc_link) - def vpc_links(self, request, full_url, headers): + def vpc_links(self, request: Any, full_url: str, headers: Dict[str, str]) -> RESPONSE_TYPE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "GET": @@ -808,7 +796,7 @@ class APIGatewayResponse(BaseResponse): ) return 202, {}, json.dumps(vpc_link) - def put_gateway_response(self): + def put_gateway_response(self) -> RESPONSE_TYPE: rest_api_id = self.path.split("/")[-3] response_type = self.path.split("/")[-1] params = json.loads(self.body) @@ -824,7 +812,7 @@ class APIGatewayResponse(BaseResponse): ) return 201, {}, json.dumps(response) - def get_gateway_response(self): + def get_gateway_response(self) -> RESPONSE_TYPE: rest_api_id = self.path.split("/")[-3] response_type = self.path.split("/")[-1] response = self.backend.get_gateway_response( @@ -832,12 +820,12 @@ class APIGatewayResponse(BaseResponse): ) return 200, {}, json.dumps(response) - def get_gateway_responses(self): + def get_gateway_responses(self) -> RESPONSE_TYPE: rest_api_id = self.path.split("/")[-2] responses = self.backend.get_gateway_responses(rest_api_id=rest_api_id) return 200, {}, json.dumps(dict(item=responses)) - def delete_gateway_response(self): + def delete_gateway_response(self) -> RESPONSE_TYPE: rest_api_id = self.path.split("/")[-3] response_type = self.path.split("/")[-1] self.backend.delete_gateway_response( diff --git a/moto/apigateway/utils.py b/moto/apigateway/utils.py index facecca82..86e1e424d 100644 --- a/moto/apigateway/utils.py +++ b/moto/apigateway/utils.py @@ -2,15 +2,16 @@ import string import json import yaml from moto.moto_api._internal import mock_random as random +from typing import Any, Dict -def create_id(): +def create_id() -> str: size = 10 chars = list(range(10)) + list(string.ascii_lowercase) return "".join(str(random.choice(chars)) for x in range(size)) -def deserialize_body(body): +def deserialize_body(body: str) -> Dict[str, Any]: try: api_doc = json.loads(body) except json.JSONDecodeError: @@ -19,8 +20,8 @@ def deserialize_body(body): if "openapi" in api_doc or "swagger" in api_doc: return api_doc - return None + return {} -def to_path(prop): +def to_path(prop: str) -> str: return "/" + prop diff --git a/moto/core/common_models.py b/moto/core/common_models.py index 2785b6703..0a0bb6189 100644 --- a/moto/core/common_models.py +++ b/moto/core/common_models.py @@ -29,7 +29,7 @@ class CloudFormationModel(BaseModel): @classmethod @abstractmethod - def has_cfn_attr(cls, attr): + def has_cfn_attr(cls, attr): # pylint: disable=unused-argument # Used for validation # If a template creates an Output for an attribute that does not exist, an error should be thrown return True diff --git a/moto/core/utils.py b/moto/core/utils.py index 1411b6aa1..115310ff5 100644 --- a/moto/core/utils.py +++ b/moto/core/utils.py @@ -186,7 +186,7 @@ def unix_time_millis(dt=None): return unix_time(dt) * 1000.0 -def path_url(url): +def path_url(url: str) -> str: parsed_url = urlparse(url) path = parsed_url.path if not path: diff --git a/moto/utilities/utils.py b/moto/utilities/utils.py index 878642711..be6bca8ca 100644 --- a/moto/utilities/utils.py +++ b/moto/utilities/utils.py @@ -2,8 +2,8 @@ import json import hashlib import pkgutil - from collections.abc import MutableMapping +from typing import Any, Dict def str2bool(v): @@ -23,7 +23,7 @@ def load_resource(package, resource, as_json=True): return json.loads(resource) if as_json else resource.decode("utf-8") -def merge_multiple_dicts(*args): +def merge_multiple_dicts(*args: Any) -> Dict[str, any]: result = {} for d in args: result.update(d) diff --git a/setup.cfg b/setup.cfg index f7fc6b329..2ae771c60 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,9 +18,10 @@ disable = W,C,R,E enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import [mypy] -exclude = tests +files= moto/acm,moto/amp,moto/apigateway,moto/applicationautoscaling/ show_column_numbers=True show_error_codes = True +disable_error_code=abstract disallow_any_unimported=False disallow_any_expr=False diff --git a/tests/test_apigateway/test_apigateway.py b/tests/test_apigateway/test_apigateway.py index f0ca4e6a3..ae8ac4052 100644 --- a/tests/test_apigateway/test_apigateway.py +++ b/tests/test_apigateway/test_apigateway.py @@ -842,9 +842,21 @@ def test_non_existent_authorizer(): response = client.create_rest_api(name="my_api", description="this is my api") api_id = response["id"] - client.get_authorizer.when.called_with( - restApiId=api_id, authorizerId="xxx" - ).should.throw(ClientError) + with pytest.raises(ClientError) as exc: + client.get_authorizer(restApiId=api_id, authorizerId="xxx") + err = exc.value.response["Error"] + err["Code"].should.equal("NotFoundException") + err["Message"].should.equal("Invalid Authorizer identifier specified") + + with pytest.raises(ClientError) as exc: + client.update_authorizer( + restApiId=api_id, + authorizerId="xxx", + patchOperations=[{"op": "add", "path": "/type", "value": "sth"}], + ) + err = exc.value.response["Error"] + err["Code"].should.equal("NotFoundException") + err["Message"].should.equal("Invalid Authorizer identifier specified") @mock_apigateway @@ -1878,18 +1890,6 @@ def test_get_domain_name_unknown_domainname(): ex.value.response["Error"]["Code"].should.equal("NotFoundException") -@mock_apigateway -def test_update_domain_name_unknown_domainname(): - client = boto3.client("apigateway", region_name="us-east-1") - with pytest.raises(ClientError) as ex: - client.update_domain_name(domainName="www.google.fr", patchOperations=[]) - - ex.value.response["Error"]["Message"].should.equal( - "Invalid domain name identifier specified" - ) - ex.value.response["Error"]["Code"].should.equal("NotFoundException") - - @mock_apigateway def test_delete_domain_name_unknown_domainname(): client = boto3.client("apigateway", region_name="us-east-1")