From eb53df2a1f33b4007eb52f46e8ea9e3831fef722 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Wed, 9 Nov 2022 16:39:07 -0100 Subject: [PATCH] TechDebt - Fix MyPy issues that popped up in 0.990 release (#5646) --- .../integration_parsers/aws_parser.py | 2 +- .../integration_parsers/http_parser.py | 4 +- .../integration_parsers/unknown_parser.py | 2 +- moto/apigateway/models.py | 769 +++++++++++------- moto/apigateway/responses.py | 140 ++-- moto/wafv2/models.py | 8 +- 6 files changed, 546 insertions(+), 379 deletions(-) diff --git a/moto/apigateway/integration_parsers/aws_parser.py b/moto/apigateway/integration_parsers/aws_parser.py index 50404cd3a..ca1d46258 100644 --- a/moto/apigateway/integration_parsers/aws_parser.py +++ b/moto/apigateway/integration_parsers/aws_parser.py @@ -14,7 +14,7 @@ class TypeAwsParser(IntegrationParser): try: # We need a better way to support services automatically # This is how AWS does it though - sending a new HTTP request to the target service - arn, action = integration["uri"].split("/") + arn, action = integration.uri.split("/") _, _, _, region, service, path_or_action = arn.split(":") if service == "dynamodb" and path_or_action == "action": target_url = f"https://dynamodb.{region}.amazonaws.com/" diff --git a/moto/apigateway/integration_parsers/http_parser.py b/moto/apigateway/integration_parsers/http_parser.py index d43bc70fb..ef2dd689d 100644 --- a/moto/apigateway/integration_parsers/http_parser.py +++ b/moto/apigateway/integration_parsers/http_parser.py @@ -13,7 +13,7 @@ class TypeHttpParser(IntegrationParser): def invoke( self, request: requests.PreparedRequest, integration: Integration ) -> Tuple[int, Union[str, bytes]]: - uri = integration["uri"] - requests_func = getattr(requests, integration["httpMethod"].lower()) + uri = integration.uri + requests_func = getattr(requests, integration.http_method.lower()) # type: ignore[union-attr] response = requests_func(uri) return response.status_code, response.text diff --git a/moto/apigateway/integration_parsers/unknown_parser.py b/moto/apigateway/integration_parsers/unknown_parser.py index 8de33ff26..876c31e43 100644 --- a/moto/apigateway/integration_parsers/unknown_parser.py +++ b/moto/apigateway/integration_parsers/unknown_parser.py @@ -12,5 +12,5 @@ class TypeUnknownParser(IntegrationParser): def invoke( self, request: requests.PreparedRequest, integration: Integration ) -> Tuple[int, Union[str, bytes]]: - _type = integration["type"] + _type = integration.integration_type raise NotImplementedError("The {0} type has not been implemented".format(_type)) diff --git a/moto/apigateway/models.py b/moto/apigateway/models.py index 259bcdc83..bed0c7b46 100644 --- a/moto/apigateway/models.py +++ b/moto/apigateway/models.py @@ -6,7 +6,6 @@ import responses import requests import time from collections import defaultdict -from copy import copy from openapi_spec_validator import validate_spec from typing import Any, Dict, List, Optional, Tuple, Union from urllib.parse import urlparse @@ -15,7 +14,7 @@ try: from openapi_spec_validator.validation.exceptions import OpenAPIValidationError except ImportError: # OpenAPI Spec Validator < 0.5.0 - from openapi_spec_validator.exceptions import OpenAPIValidationError + from openapi_spec_validator.exceptions import OpenAPIValidationError # type: ignore from moto.core import BaseBackend, BaseModel, CloudFormationModel from .utils import create_id, to_path from moto.core.utils import path_url, BackendDict @@ -67,13 +66,20 @@ from moto.moto_api._internal import mock_random as random STAGE_URL = "https://{api_id}.execute-api.{region_name}.amazonaws.com/{stage_name}" -class Deployment(CloudFormationModel, dict): # type: ignore[type-arg] +class Deployment(CloudFormationModel): def __init__(self, deployment_id: str, name: str, description: str = ""): - super().__init__() - self["id"] = deployment_id - self["stageName"] = name - self["description"] = description - self["createdDate"] = int(time.time()) + self.id = deployment_id + self.stage_name = name + self.description = description + self.created_date = int(time.time()) + + def to_json(self) -> Dict[str, Any]: + return { + "id": self.id, + "stageName": self.stage_name, + "description": self.description, + "createdDate": self.created_date, + } @staticmethod def cloudformation_name_type() -> str: @@ -102,7 +108,7 @@ class Deployment(CloudFormationModel, dict): # type: ignore[type-arg] ) -class IntegrationResponse(BaseModel, dict): # type: ignore[type-arg] +class IntegrationResponse(BaseModel): def __init__( self, status_code: Union[str, int], @@ -117,15 +123,24 @@ class IntegrationResponse(BaseModel, dict): # type: ignore[type-arg] response_templates[key] = ( response_templates[key] or None ) # required for compatibility with TF - self["responseTemplates"] = response_templates - self["statusCode"] = status_code - if selection_pattern: - self["selectionPattern"] = selection_pattern - if content_handling: - self["contentHandling"] = content_handling + self.response_templates = response_templates + self.status_code = status_code + self.selection_pattern = selection_pattern + self.content_handling = content_handling + + def to_json(self) -> Dict[str, Any]: + resp = { + "responseTemplates": self.response_templates, + "statusCode": self.status_code, + } + if self.selection_pattern: + resp["selectionPattern"] = self.selection_pattern + if self.content_handling: + resp["contentHandling"] = self.content_handling + return resp -class Integration(BaseModel, dict): # type: ignore[type-arg] +class Integration(BaseModel): def __init__( self, integration_type: str, @@ -133,27 +148,43 @@ class Integration(BaseModel, dict): # type: ignore[type-arg] http_method: str, request_templates: Optional[Dict[str, Any]] = None, passthrough_behavior: Optional[str] = "WHEN_NO_MATCH", - cache_key_parameters: Optional[str] = None, + cache_key_parameters: Optional[List[str]] = None, tls_config: Optional[str] = None, cache_namespace: Optional[str] = None, timeout_in_millis: Optional[str] = None, request_parameters: Optional[Dict[str, Any]] = None, ): - super().__init__() - self["type"] = integration_type - self["uri"] = uri - self["httpMethod"] = http_method if integration_type != "MOCK" else None - self["passthroughBehavior"] = passthrough_behavior - self["cacheKeyParameters"] = cache_key_parameters or [] - self["requestTemplates"] = request_templates - # self["integrationResponses"] = {"200": IntegrationResponse(200)} # commented out (tf-compat) - self[ - "integrationResponses" - ] = None # prevent json serialization from including them if none provided - self["tlsConfig"] = tls_config - self["cacheNamespace"] = cache_namespace - self["timeoutInMillis"] = timeout_in_millis - self["requestParameters"] = request_parameters + self.integration_type = integration_type + self.uri = uri + self.http_method = http_method if integration_type != "MOCK" else None + self.passthrough_behaviour = passthrough_behavior + self.cache_key_parameters: List[str] = cache_key_parameters or [] + self.request_templates = request_templates + self.tls_config = tls_config + self.cache_namespace = cache_namespace + self.timeout_in_millis = timeout_in_millis + self.request_parameters = request_parameters + self.integration_responses: Optional[Dict[str, IntegrationResponse]] = None + + def to_json(self) -> Dict[str, Any]: + int_responses: Optional[Dict[str, Any]] = None + if self.integration_responses is not None: + int_responses = { + k: v.to_json() for k, v in self.integration_responses.items() + } + return { + "type": self.integration_type, + "uri": self.uri, + "httpMethod": self.http_method, + "passthroughBehavior": self.passthrough_behaviour, + "cacheKeyParameters": self.cache_key_parameters, + "requestTemplates": self.request_templates, + "integrationResponses": int_responses, + "tlsConfig": self.tls_config, + "cacheNamespace": self.cache_namespace, + "timeoutInMillis": self.timeout_in_millis, + "requestParameters": self.request_parameters, + } def create_integration_response( self, @@ -165,54 +196,74 @@ class Integration(BaseModel, dict): # type: ignore[type-arg] integration_response = IntegrationResponse( status_code, selection_pattern, response_templates or None, content_handling ) - if self.get("integrationResponses") is None: - self["integrationResponses"] = {} - self["integrationResponses"][status_code] = integration_response + if self.integration_responses is None: + self.integration_responses = {} + self.integration_responses[status_code] = integration_response return integration_response def get_integration_response(self, status_code: str) -> IntegrationResponse: - result = self.get("integrationResponses", {}).get(status_code) + result = (self.integration_responses or {}).get(status_code) if not result: raise NoIntegrationResponseDefined() return result def delete_integration_response(self, status_code: str) -> IntegrationResponse: - return self.get("integrationResponses", {}).pop(status_code, None) + return (self.integration_responses or {}).pop(status_code, None) # type: ignore[arg-type] -class MethodResponse(BaseModel, dict): # type: ignore[type-arg] +class MethodResponse(BaseModel): def __init__( self, status_code: str, response_models: Dict[str, str], response_parameters: Dict[str, Dict[str, str]], ): - super().__init__() - self["statusCode"] = status_code - self["responseModels"] = response_models - self["responseParameters"] = response_parameters + self.status_code = status_code + self.response_models = response_models + self.response_parameters = response_parameters + + def to_json(self) -> Dict[str, Any]: + return { + "statusCode": self.status_code, + "responseModels": self.response_models, + "responseParameters": self.response_parameters, + } -class Method(CloudFormationModel, dict): # type: ignore[type-arg] +class Method(CloudFormationModel): def __init__( self, method_type: str, authorization_type: Optional[str], **kwargs: Any ): - super().__init__() - self.update( - dict( - httpMethod=method_type, - authorizationType=authorization_type, - authorizerId=kwargs.get("authorizer_id"), - authorizationScopes=kwargs.get("authorization_scopes"), - apiKeyRequired=kwargs.get("api_key_required") or False, - requestParameters=kwargs.get("request_parameters"), - requestModels=kwargs.get("request_models"), - methodIntegration=None, - operationName=kwargs.get("operation_name"), - requestValidatorId=kwargs.get("request_validator_id"), - ) - ) - self["methodResponses"] = {} + self.http_method = method_type + self.authorization_type = authorization_type + self.authorizer_id = kwargs.get("authorizer_id") + self.authorization_scopes = kwargs.get("authorization_scopes") + self.api_key_required = kwargs.get("api_key_required") or False + self.request_parameters = kwargs.get("request_parameters") + self.request_models = kwargs.get("request_models") + self.method_integration: Optional[Integration] = None + self.operation_name = kwargs.get("operation_name") + self.request_validator_id = kwargs.get("request_validator_id") + self.method_responses: Dict[str, MethodResponse] = {} + + def to_json(self) -> Dict[str, Any]: + return { + "httpMethod": self.http_method, + "authorizationType": self.authorization_type, + "authorizerId": self.authorizer_id, + "authorizationScopes": self.authorization_scopes, + "apiKeyRequired": self.api_key_required, + "requestParameters": self.request_parameters, + "requestModels": self.request_models, + "methodIntegration": self.method_integration.to_json() + if self.method_integration + else None, + "operationName": self.operation_name, + "requestValidatorId": self.request_validator_id, + "methodResponses": { + k: v.to_json() for k, v in self.method_responses.items() + }, + } @staticmethod def cloudformation_name_type() -> str: @@ -267,14 +318,14 @@ class Method(CloudFormationModel, dict): # type: ignore[type-arg] method_response = MethodResponse( response_code, response_models, response_parameters ) - self["methodResponses"][response_code] = method_response + self.method_responses[response_code] = method_response return method_response def get_response(self, response_code: str) -> Optional[MethodResponse]: - return self["methodResponses"].get(response_code) + return self.method_responses.get(response_code) def delete_response(self, response_code: str) -> Optional[MethodResponse]: - return self["methodResponses"].pop(response_code, None) + return self.method_responses.pop(response_code, None) class Resource(CloudFormationModel): @@ -287,7 +338,6 @@ class Resource(CloudFormationModel): path_part: str, parent_id: Optional[str], ): - super().__init__() self.id = resource_id self.account_id = account_id self.region_name = region_name @@ -312,7 +362,9 @@ class Resource(CloudFormationModel): "id": self.id, } if self.resource_methods: - response["resourceMethods"] = self.resource_methods + response["resourceMethods"] = { + k: v.to_json() for k, v in self.resource_methods.items() + } if self.parent_id: response["parentId"] = self.parent_id response["pathPart"] = self.path_part @@ -374,10 +426,10 @@ class Resource(CloudFormationModel): self, request: requests.PreparedRequest ) -> Tuple[int, Union[str, bytes]]: integration = self.get_integration(str(request.method)) - integration_type = integration["type"] + integration_type = integration.integration_type # type: ignore[union-attr] status, result = self.integration_parsers[integration_type].invoke( - request, integration + request, integration # type: ignore[arg-type] ) return status, result @@ -444,20 +496,20 @@ class Resource(CloudFormationModel): timeout_in_millis=timeout_in_millis, request_parameters=request_parameters, ) - self.resource_methods[method_type]["methodIntegration"] = integration + self.resource_methods[method_type].method_integration = integration return integration - def get_integration(self, method_type: str) -> Integration: - method: Dict[str, Integration] = dict( - self.resource_methods.get(method_type, {}) - ) - return method.get("methodIntegration") or {} # type: ignore[return-value] + def get_integration(self, method_type: str) -> Optional[Integration]: + method = self.resource_methods.get(method_type) + return method.method_integration if method else None def delete_integration(self, method_type: str) -> Integration: - return self.resource_methods[method_type].pop("methodIntegration") + integration = self.resource_methods[method_type].method_integration + self.resource_methods[method_type].method_integration = None + return integration # type: ignore[return-value] -class Authorizer(BaseModel, dict): # type: ignore[type-arg] +class Authorizer(BaseModel): def __init__( self, authorizer_id: Optional[str], @@ -465,53 +517,67 @@ class Authorizer(BaseModel, dict): # type: ignore[type-arg] authorizer_type: Optional[str], **kwargs: Any ): - super().__init__() - self["id"] = authorizer_id - self["name"] = name - self["type"] = authorizer_type - if kwargs.get("provider_arns"): - self["providerARNs"] = kwargs.get("provider_arns") - if kwargs.get("auth_type"): - self["authType"] = kwargs.get("auth_type") - if kwargs.get("authorizer_uri"): - self["authorizerUri"] = kwargs.get("authorizer_uri") - if kwargs.get("authorizer_credentials"): - self["authorizerCredentials"] = kwargs.get("authorizer_credentials") - if kwargs.get("identity_source"): - self["identitySource"] = kwargs.get("identity_source") - if kwargs.get("identity_validation_expression"): - self["identityValidationExpression"] = kwargs.get( - "identity_validation_expression" - ) - self["authorizerResultTtlInSeconds"] = kwargs.get("authorizer_result_ttl") + self.id = authorizer_id + self.name = name + self.type = authorizer_type + self.provider_arns = kwargs.get("provider_arns") + self.auth_type = kwargs.get("auth_type") + self.authorizer_uri = kwargs.get("authorizer_uri") + self.authorizer_credentials = kwargs.get("authorizer_credentials") + self.identity_source = kwargs.get("identity_source") + self.identity_validation_expression = kwargs.get( + "identity_validation_expression" + ) + self.authorizer_result_ttl = kwargs.get("authorizer_result_ttl") + + def to_json(self) -> Dict[str, Any]: + dct = { + "id": self.id, + "name": self.name, + "type": self.type, + "authorizerResultTtlInSeconds": self.authorizer_result_ttl, + } + if self.provider_arns: + dct["providerARNs"] = self.provider_arns + if self.auth_type: + dct["authType"] = self.auth_type + if self.authorizer_uri: + dct["authorizerUri"] = self.authorizer_uri + if self.authorizer_credentials: + dct["authorizerCredentials"] = self.authorizer_credentials + if self.identity_source: + dct["identitySource"] = self.identity_source + if self.identity_validation_expression: + dct["identityValidationExpression"] = self.identity_validation_expression + return dct def apply_operations(self, patch_operations: List[Dict[str, Any]]) -> "Authorizer": for op in patch_operations: if "/authorizerUri" in op["path"]: - self["authorizerUri"] = op["value"] + self.authorizer_uri = op["value"] elif "/authorizerCredentials" in op["path"]: - self["authorizerCredentials"] = op["value"] + self.authorizer_credentials = op["value"] elif "/authorizerResultTtlInSeconds" in op["path"]: - self["authorizerResultTtlInSeconds"] = int(op["value"]) + self.authorizer_result_ttl = int(op["value"]) elif "/authType" in op["path"]: - self["authType"] = op["value"] + self.auth_type = op["value"] elif "/identitySource" in op["path"]: - self["identitySource"] = op["value"] + self.identity_source = op["value"] elif "/identityValidationExpression" in op["path"]: - self["identityValidationExpression"] = op["value"] + self.identity_validation_expression = op["value"] elif "/name" in op["path"]: - self["name"] = op["value"] + self.name = op["value"] elif "/providerARNs" in op["path"]: # TODO: add and remove raise Exception('Patch operation for "%s" not implemented' % op["path"]) elif "/type" in op["path"]: - self["type"] = op["value"] + self.type = op["value"] else: raise Exception('Patch operation "%s" not implemented' % op["op"]) return self -class Stage(BaseModel, dict): # type: ignore[type-arg] +class Stage(BaseModel): def __init__( self, name: Optional[str] = None, @@ -520,50 +586,71 @@ class Stage(BaseModel, dict): # type: ignore[type-arg] description: str = "", cacheClusterEnabled: Optional[bool] = False, cacheClusterSize: Optional[str] = None, - tags: Optional[List[Dict[str, str]]] = None, + tags: Optional[Dict[str, str]] = None, tracing_enabled: Optional[bool] = None, ): - super().__init__() - self["stageName"] = name - self["deploymentId"] = deployment_id - self["methodSettings"] = {} - self["variables"] = variables or {} - self["description"] = description - self["cacheClusterEnabled"] = cacheClusterEnabled - if self["cacheClusterEnabled"]: - self["cacheClusterStatus"] = "AVAILABLE" - self["cacheClusterSize"] = str(0.5) - if cacheClusterSize is not None: - self["cacheClusterSize"] = str(cacheClusterSize) - if tags is not None: - self["tags"] = tags - if tracing_enabled is not None: - self["tracingEnabled"] = tracing_enabled + self.name = name + self.deployment_id = deployment_id + self.method_settings: Dict[str, Any] = {} + self.variables = variables or {} + self.description = description + self.cache_cluster_enabled = cacheClusterEnabled + self.cache_cluster_status = "AVAILABLE" if cacheClusterEnabled else None + self.cache_cluster_size = ( + str(cacheClusterSize) if cacheClusterSize is not None else None + ) + self.tags = tags + self.tracing_enabled = tracing_enabled + self.access_log_settings: Optional[Dict[str, Any]] = None + self.web_acl_arn = None + + def to_json(self) -> Dict[str, Any]: + dct: Dict[str, Any] = { + "stageName": self.name, + "deploymentId": self.deployment_id, + "methodSettings": self.method_settings, + "variables": self.variables, + "description": self.description, + "cacheClusterEnabled": self.cache_cluster_enabled, + "accessLogSettings": self.access_log_settings, + } + if self.cache_cluster_status is not None: + dct["cacheClusterStatus"] = self.cache_cluster_status + if self.cache_cluster_enabled: + if self.cache_cluster_size is not None: + dct["cacheClusterSize"] = self.cache_cluster_size + else: + dct["cacheClusterSize"] = "0.5" + if self.tags: + dct["tags"] = self.tags + if self.tracing_enabled is not None: + dct["tracingEnabled"] = self.tracing_enabled + if self.web_acl_arn is not None: + dct["webAclArn"] = self.web_acl_arn + return dct def apply_operations(self, patch_operations: List[Dict[str, Any]]) -> "Stage": for op in patch_operations: if "variables/" in op["path"]: self._apply_operation_to_variables(op) elif "/cacheClusterEnabled" in op["path"]: - self["cacheClusterEnabled"] = self._str2bool(op["value"]) - if self["cacheClusterEnabled"]: - self["cacheClusterStatus"] = "AVAILABLE" - if "cacheClusterSize" not in self: - self["cacheClusterSize"] = str(0.5) + self.cache_cluster_enabled = self._str2bool(op["value"]) + if self.cache_cluster_enabled: + self.cache_cluster_status = "AVAILABLE" else: - self["cacheClusterStatus"] = "NOT_AVAILABLE" + self.cache_cluster_status = "NOT_AVAILABLE" elif "/cacheClusterSize" in op["path"]: - self["cacheClusterSize"] = str(op["value"]) + self.cache_cluster_size = str(op["value"]) elif "/description" in op["path"]: - self["description"] = op["value"] + self.description = op["value"] elif "/deploymentId" in op["path"]: - self["deploymentId"] = op["value"] + self.deployment_id = op["value"] elif op["op"] == "replace": if op["path"] == "/tracingEnabled": - self["tracingEnabled"] = self._str2bool(op["value"]) + self.tracing_enabled = self._str2bool(op["value"]) elif op["path"].startswith("/accessLogSettings/"): - self["accessLogSettings"] = self.get("accessLogSettings", {}) - self["accessLogSettings"][op["path"].split("/")[-1]] = op["value"] + self.access_log_settings = self.access_log_settings or {} + self.access_log_settings[op["path"].split("/")[-1]] = op["value"] # type: ignore[index] else: # (e.g., path could be '/*/*/logging/loglevel') split_path = op["path"].split("/", 3) @@ -574,7 +661,7 @@ class Stage(BaseModel, dict): # type: ignore[type-arg] ) elif op["op"] == "remove": if op["path"] == "/accessLogSettings": - self["accessLogSettings"] = None + self.access_log_settings = None else: raise ValidationException( "Member must satisfy enum value set: [add, remove, move, test, replace, copy]" @@ -586,11 +673,11 @@ class Stage(BaseModel, dict): # type: ignore[type-arg] ) -> None: updated_key = self._method_settings_translations(key) if updated_key is not None: - if resource_path_and_method not in self["methodSettings"]: - self["methodSettings"][ + if resource_path_and_method not in self.method_settings: + self.method_settings[ resource_path_and_method ] = self._get_default_method_settings() - self["methodSettings"][resource_path_and_method][ + self.method_settings[resource_path_and_method][ updated_key ] = self._convert_to_type(updated_key, value) @@ -657,14 +744,14 @@ class Stage(BaseModel, dict): # type: ignore[type-arg] def _apply_operation_to_variables(self, op: Dict[str, Any]) -> None: key = op["path"][op["path"].rindex("variables/") + 10 :] if op["op"] == "remove": - self["variables"].pop(key, None) + self.variables.pop(key, None) elif op["op"] == "replace": - self["variables"][key] = op["value"] + self.variables[key] = op["value"] else: raise Exception('Patch operation "%s" not implemented' % op["op"]) -class ApiKey(BaseModel, dict): # type: ignore[type-arg] +class ApiKey(BaseModel): def __init__( self, name: Optional[str] = None, @@ -676,30 +763,43 @@ class ApiKey(BaseModel, dict): # type: ignore[type-arg] tags: Optional[List[Dict[str, str]]] = None, customerId: Optional[str] = None, ): - super().__init__() - self["id"] = create_id() - self["value"] = value or "".join( + self.id = create_id() + self.value = value or "".join( random.sample(string.ascii_letters + string.digits, 40) ) - self["name"] = name - self["customerId"] = customerId - self["description"] = description - self["enabled"] = enabled - self["createdDate"] = self["lastUpdatedDate"] = int(time.time()) - self["stageKeys"] = stageKeys or [] - self["tags"] = tags + self.name = name + self.customer_id = customerId + self.description = description + self.enabled = enabled + self.created_date = self.last_updated_date = int(time.time()) + self.stage_keys = stageKeys or [] + self.tags = tags + + def to_json(self) -> Dict[str, Any]: + return { + "id": self.id, + "value": self.value, + "name": self.name, + "customerId": self.customer_id, + "description": self.description, + "enabled": self.enabled, + "createdDate": self.created_date, + "lastUpdatedDate": self.last_updated_date, + "stageKeys": self.stage_keys, + "tags": self.tags, + } def update_operations(self, patch_operations: List[Dict[str, Any]]) -> "ApiKey": for op in patch_operations: if op["op"] == "replace": if "/name" in op["path"]: - self["name"] = op["value"] + self.name = op["value"] elif "/customerId" in op["path"]: - self["customerId"] = op["value"] + self.customer_id = op["value"] elif "/description" in op["path"]: - self["description"] = op["value"] + self.description = op["value"] elif "/enabled" in op["path"]: - self["enabled"] = self._str2bool(op["value"]) + self.enabled = self._str2bool(op["value"]) else: raise Exception('Patch operation "%s" not implemented' % op["op"]) return self @@ -708,26 +808,37 @@ class ApiKey(BaseModel, dict): # type: ignore[type-arg] return v.lower() == "true" -class UsagePlan(BaseModel, dict): # type: ignore[type-arg] +class UsagePlan(BaseModel): def __init__( self, name: Optional[str] = None, description: Optional[str] = None, apiStages: Any = None, - throttle: Optional[str] = None, - quota: Optional[str] = None, + throttle: Optional[Dict[str, Any]] = None, + quota: Optional[Dict[str, Any]] = None, productCode: Optional[str] = None, tags: Optional[List[Dict[str, str]]] = None, ): - super().__init__() - self["id"] = create_id() - self["name"] = name - self["description"] = description - self["apiStages"] = apiStages if apiStages else [] - self["throttle"] = throttle - self["quota"] = quota - self["productCode"] = productCode - self["tags"] = tags + self.id = create_id() + self.name = name + self.description = description + self.api_stages = apiStages or [] + self.throttle = throttle or {} + self.quota = quota or {} + self.product_code = productCode + self.tags = tags + + def to_json(self) -> Dict[str, Any]: + return { + "id": self.id, + "name": self.name, + "description": self.description, + "apiStages": self.api_stages, + "throttle": self.throttle, + "quota": self.quota, + "productCode": self.product_code, + "tags": self.tags, + } def apply_patch_operations(self, patch_operations: List[Dict[str, Any]]) -> None: for op in patch_operations: @@ -735,22 +846,22 @@ class UsagePlan(BaseModel, dict): # type: ignore[type-arg] value = op["value"] if op["op"] == "replace": if "/name" in path: - self["name"] = value + self.name = value if "/productCode" in path: - self["productCode"] = value + self.product_code = value if "/description" in path: - self["description"] = value + self.description = value if "/quota/limit" in path: - self["quota"]["limit"] = value + self.quota["limit"] = value if "/quota/period" in path: - self["quota"]["period"] = value + self.quota["period"] = value if "/throttle/rateLimit" in path: - self["throttle"]["rateLimit"] = value + self.throttle["rateLimit"] = value if "/throttle/burstLimit" in path: - self["throttle"]["burstLimit"] = value + self.throttle["burstLimit"] = value -class RequestValidator(BaseModel, dict): # type: ignore[type-arg] +class RequestValidator(BaseModel): PROP_ID = "id" PROP_NAME = "name" PROP_VALIDATE_REQUEST_BODY = "validateRequestBody" @@ -769,13 +880,10 @@ class RequestValidator(BaseModel, dict): # type: ignore[type-arg] validateRequestBody: Optional[bool], validateRequestParameters: Any, ): - super().__init__() - self[RequestValidator.PROP_ID] = _id - self[RequestValidator.PROP_NAME] = name - self[RequestValidator.PROP_VALIDATE_REQUEST_BODY] = validateRequestBody - self[ - RequestValidator.PROP_VALIDATE_REQUEST_PARAMETERS - ] = validateRequestParameters + self.id = _id + self.name = name + self.validate_request_body = validateRequestBody + self.validate_request_parameters = validateRequestParameters def apply_patch_operations(self, operations: List[Dict[str, Any]]) -> None: for operation in operations: @@ -783,35 +891,38 @@ class RequestValidator(BaseModel, dict): # type: ignore[type-arg] value = operation[RequestValidator.OP_VALUE] if operation[RequestValidator.OP_OP] == RequestValidator.OP_REPLACE: if to_path(RequestValidator.PROP_NAME) in path: - self[RequestValidator.PROP_NAME] = value + self.name = value if to_path(RequestValidator.PROP_VALIDATE_REQUEST_BODY) in path: - self[ - RequestValidator.PROP_VALIDATE_REQUEST_BODY - ] = value.lower() in ("true") + self.validate_request_body = value.lower() in ("true") if to_path(RequestValidator.PROP_VALIDATE_REQUEST_PARAMETERS) in path: - self[ - RequestValidator.PROP_VALIDATE_REQUEST_PARAMETERS - ] = value.lower() in ("true") + self.validate_request_parameters = value.lower() in ("true") def to_dict(self) -> Dict[str, Any]: return { - "id": self["id"], - "name": self["name"], - "validateRequestBody": self["validateRequestBody"], - "validateRequestParameters": self["validateRequestParameters"], + RequestValidator.PROP_ID: self.id, + RequestValidator.PROP_NAME: self.name, + RequestValidator.PROP_VALIDATE_REQUEST_BODY: self.validate_request_body, + RequestValidator.PROP_VALIDATE_REQUEST_PARAMETERS: self.validate_request_parameters, } -class UsagePlanKey(BaseModel, dict): # type: ignore[type-arg] - def __init__(self, plan_id: Dict[str, Any], plan_type: str, name: str, value: str): - super().__init__() - self["id"] = plan_id - self["name"] = name - self["type"] = plan_type - self["value"] = value +class UsagePlanKey(BaseModel): + def __init__(self, plan_id: str, plan_type: str, name: Optional[str], value: str): + self.id = plan_id + self.name = name + self.type = plan_type + self.value = value + + def to_json(self) -> Dict[str, Any]: + return { + "id": self.id, + "name": self.name, + "type": self.type, + "value": self.value, + } -class VpcLink(BaseModel, dict): # type: ignore[type-arg] +class VpcLink(BaseModel): def __init__( self, name: str, @@ -819,13 +930,22 @@ class VpcLink(BaseModel, dict): # type: ignore[type-arg] target_arns: List[str], tags: List[Dict[str, str]], ): - super().__init__() - self["id"] = create_id() - self["name"] = name - self["description"] = description - self["targetArns"] = target_arns - self["tags"] = tags - self["status"] = "AVAILABLE" + self.id = create_id() + self.name = name + self.description = description + self.target_arns = target_arns + self.tags = tags + self.status = "AVAILABLE" + + def to_json(self) -> Dict[str, Any]: + return { + "id": self.id, + "name": self.name, + "description": self.description, + "targetArns": self.target_arns, + "tags": self.tags, + "status": self.status, + } class RestAPI(CloudFormationModel): @@ -860,7 +980,6 @@ class RestAPI(CloudFormationModel): description: str, **kwargs: Any ): - super().__init__() self.id = api_id self.account_id = account_id self.region_name = region_name @@ -1084,7 +1203,7 @@ class RestAPI(CloudFormationModel): description: str, cacheClusterEnabled: Optional[bool], cacheClusterSize: Optional[str], - tags: Optional[List[Dict[str, str]]], + tags: Optional[Dict[str, str]], tracing_enabled: Optional[bool], ) -> Stage: if name in self.stages: @@ -1137,7 +1256,7 @@ class RestAPI(CloudFormationModel): if deployment_id not in self.deployments: raise DeploymentNotFoundException() deployment = self.deployments[deployment_id] - if deployment["stageName"] and deployment["stageName"] in self.stages: + if deployment.stage_name and deployment.stage_name in self.stages: # Stage is still active raise StageStillActive() @@ -1205,56 +1324,87 @@ class RestAPI(CloudFormationModel): self.gateway_responses.pop(response_type, None) -class DomainName(BaseModel, dict): # type: ignore[type-arg] +class DomainName(BaseModel): def __init__(self, domain_name: str, **kwargs: Any): - super().__init__() - self["domainName"] = domain_name - self["regionalDomainName"] = "d-%s.execute-api.%s.amazonaws.com" % ( + self.domain_name = domain_name + self.regional_domain_name = "d-%s.execute-api.%s.amazonaws.com" % ( create_id(), kwargs.get("region_name") or "us-east-1", ) - self["distributionDomainName"] = "d%s.cloudfront.net" % create_id() - self["domainNameStatus"] = "AVAILABLE" - self["domainNameStatusMessage"] = "Domain Name Available" - self["regionalHostedZoneId"] = "Z2FDTNDATAQYW2" - self["distributionHostedZoneId"] = "Z2FDTNDATAQYW2" - self["certificateUploadDate"] = int(time.time()) - if kwargs.get("certificate_name"): - self["certificateName"] = kwargs.get("certificate_name") - if kwargs.get("certificate_arn"): - self["certificateArn"] = kwargs.get("certificate_arn") - if kwargs.get("certificate_body"): - self["certificateBody"] = kwargs.get("certificate_body") - if kwargs.get("tags"): - self["tags"] = kwargs.get("tags") - if kwargs.get("security_policy"): - self["securityPolicy"] = kwargs.get("security_policy") - if kwargs.get("certificate_chain"): - self["certificateChain"] = kwargs.get("certificate_chain") - if kwargs.get("regional_certificate_name"): - self["regionalCertificateName"] = kwargs.get("regional_certificate_name") - if kwargs.get("certificate_private_key"): - self["certificatePrivateKey"] = kwargs.get("certificate_private_key") - if kwargs.get("regional_certificate_arn"): - self["regionalCertificateArn"] = kwargs.get("regional_certificate_arn") - if kwargs.get("endpoint_configuration"): - self["endpointConfiguration"] = kwargs.get("endpoint_configuration") + self.distribution_domain_name = "d%s.cloudfront.net" % create_id() + self.domain_name_status = "AVAILABLE" + self.status_message = "Domain Name Available" + self.regional_hosted_zone_id = "Z2FDTNDATAQYW2" + self.distribution_hosted_zone_id = "Z2FDTNDATAQYW2" + self.certificate_upload_date = int(time.time()) + self.certificate_name = kwargs.get("certificate_name") + self.certificate_arn = kwargs.get("certificate_arn") + self.certificate_body = kwargs.get("certificate_body") + self.tags = kwargs.get("tags") + self.security_policy = kwargs.get("security_policy") + self.certificate_chain = kwargs.get("certificate_chain") + self.regional_certificate_name = kwargs.get("regional_certificate_name") + self.certificate_private_key = kwargs.get("certificate_private_key") + self.regional_certificate_arn = kwargs.get("regional_certificate_arn") + self.endpoint_configuration = kwargs.get("endpoint_configuration") + + def to_json(self) -> Dict[str, Any]: + dct = { + "domainName": self.domain_name, + "regionalDomainName": self.regional_domain_name, + "distributionDomainName": self.distribution_domain_name, + "domainNameStatus": self.domain_name_status, + "domainNameStatusMessage": self.status_message, + "regionalHostedZoneId": self.regional_hosted_zone_id, + "distributionHostedZoneId": self.distribution_hosted_zone_id, + "certificateUploadDate": self.certificate_upload_date, + } + if self.certificate_name: + dct["certificateName"] = self.certificate_name + if self.certificate_arn: + dct["certificateArn"] = self.certificate_arn + if self.certificate_body: + dct["certificateBody"] = self.certificate_body + if self.tags: + dct["tags"] = self.tags + if self.security_policy: + dct["securityPolicy"] = self.security_policy + if self.certificate_chain: + dct["certificateChain"] = self.certificate_chain + if self.regional_certificate_name: + dct["regionalCertificateName"] = self.regional_certificate_name + if self.certificate_private_key: + dct["certificatePrivateKey"] = self.certificate_private_key + if self.regional_certificate_arn: + dct["regionalCertificateArn"] = self.regional_certificate_arn + if self.endpoint_configuration: + dct["endpointConfiguration"] = self.endpoint_configuration + return dct -class Model(BaseModel, dict): # type: ignore[type-arg] +class Model(BaseModel): def __init__(self, model_id: str, name: str, **kwargs: Any): - super().__init__() - self["id"] = model_id - self["name"] = name - if kwargs.get("description"): - self["description"] = kwargs.get("description") - if kwargs.get("schema"): - self["schema"] = kwargs.get("schema") - if kwargs.get("content_type"): - self["contentType"] = kwargs.get("content_type") + self.id = model_id + self.name = name + self.description = kwargs.get("description") + self.schema = kwargs.get("schema") + self.content_type = kwargs.get("content_type") + + def to_json(self) -> Dict[str, Any]: + dct = { + "id": self.id, + "name": self.name, + } + if self.description: + dct["description"] = self.description + if self.schema: + dct["schema"] = self.schema + if self.content_type: + dct["contentType"] = self.content_type + return dct -class BasePathMapping(BaseModel, dict): # type: ignore[type-arg] +class BasePathMapping(BaseModel): # operations OPERATION_REPLACE = "replace" @@ -1263,15 +1413,20 @@ class BasePathMapping(BaseModel, dict): # type: ignore[type-arg] OPERATION_OP = "op" def __init__(self, domain_name: str, rest_api_id: str, **kwargs: Any): - super().__init__() - self["domain_name"] = domain_name - self["restApiId"] = rest_api_id - if kwargs.get("basePath"): - self["basePath"] = kwargs.get("basePath") - else: - self["basePath"] = "(none)" - if kwargs.get("stage"): - self["stage"] = kwargs.get("stage") + self.domain_name = domain_name + self.rest_api_id = rest_api_id + self.base_path = kwargs.get("basePath") or "(none)" + self.stage = kwargs.get("stage") + + def to_json(self) -> Dict[str, Any]: + dct = { + "domain_name": self.domain_name, + "restApiId": self.rest_api_id, + "basePath": self.base_path, + } + if self.stage is not None: + dct["stage"] = self.stage + return dct def apply_patch_operations(self, patch_operations: List[Dict[str, Any]]) -> None: for op in patch_operations: @@ -1280,14 +1435,14 @@ class BasePathMapping(BaseModel, dict): # type: ignore[type-arg] operation = op["op"] if operation == self.OPERATION_REPLACE: if "/basePath" in path: - self["basePath"] = value + self.base_path = value if "/restapiId" in path: - self["restApiId"] = value + self.rest_api_id = value if "/stage" in path: - self["stage"] = value + self.stage = value -class GatewayResponse(BaseModel, dict): # type: ignore[type-arg] +class GatewayResponse(BaseModel): def __init__( self, response_type: str, @@ -1295,15 +1450,24 @@ class GatewayResponse(BaseModel, dict): # type: ignore[type-arg] response_parameters: Dict[str, Any], response_templates: Dict[str, str], ): - super().__init__() - self["responseType"] = response_type - if status_code is not None: - self["statusCode"] = status_code - if response_parameters is not None: - self["responseParameters"] = response_parameters - if response_templates is not None: - self["responseTemplates"] = response_templates - self["defaultResponse"] = False + self.response_type = response_type + self.default_response = False + self.status_code = status_code + self.response_parameters = response_parameters + self.response_templates = response_templates + + def to_json(self) -> Dict[str, Any]: + dct = { + "responseType": self.response_type, + "defaultResponse": self.default_response, + } + if self.status_code is not None: + dct["statusCode"] = self.status_code + if self.response_parameters is not None: + dct["responseParameters"] = self.response_parameters + if self.response_templates is not None: + dct["responseTemplates"] = self.response_templates + return dct class APIGatewayBackend(BaseBackend): @@ -1378,7 +1542,7 @@ class APIGatewayBackend(BaseBackend): """ if fail_on_warnings: try: - validate_spec(api_doc) + validate_spec(api_doc) # type: ignore[arg-type] except OpenAPIValidationError as e: raise InvalidOpenAPIDocumentException(e) name = api_doc["info"]["title"] @@ -1413,7 +1577,7 @@ class APIGatewayBackend(BaseBackend): if fail_on_warnings: try: - validate_spec(api_doc) + validate_spec(api_doc) # type: ignore[arg-type] except OpenAPIValidationError as e: raise InvalidOpenAPIDocumentException(e) @@ -1593,7 +1757,7 @@ class APIGatewayBackend(BaseBackend): description: str = "", cacheClusterEnabled: Optional[bool] = None, cacheClusterSize: Optional[str] = None, - tags: Optional[List[Dict[str, str]]] = None, + tags: Optional[Dict[str, str]] = None, tracing_enabled: Optional[bool] = None, ) -> Stage: if variables is None: @@ -1715,7 +1879,7 @@ class APIGatewayBackend(BaseBackend): self, function_id: str, resource_id: str, method_type: str ) -> Integration: resource = self.get_resource(function_id, resource_id) - return resource.get_integration(method_type) + return resource.get_integration(method_type) # type: ignore[return-value] def delete_integration( self, function_id: str, resource_id: str, method_type: str @@ -1772,7 +1936,7 @@ class APIGatewayBackend(BaseBackend): if not any(methods): raise NoMethodDefined() method_integrations = [ - method.get("methodIntegration", None) for method in methods + method.method_integration for method in methods if method.method_integration ] if not any(method_integrations): raise NoIntegrationDefined() @@ -1795,37 +1959,20 @@ class APIGatewayBackend(BaseBackend): if payload.get("value"): if len(payload.get("value", [])) < 20: raise ApiKeyValueMinLength() - for api_key in self.get_api_keys(include_values=True): - if api_key.get("value") == payload["value"]: + for api_key in self.get_api_keys(): + if api_key.value == payload["value"]: raise ApiKeyAlreadyExists() key = ApiKey(**payload) - self.keys[key["id"]] = key + self.keys[key.id] = key return key - def get_api_keys(self, include_values: bool) -> List[ApiKey]: - api_keys = list(self.keys.values()) + def get_api_keys(self) -> List[ApiKey]: + return list(self.keys.values()) - if not include_values: - keys = [] - for api_key in list(self.keys.values()): - new_key = copy(api_key) - del new_key["value"] - keys.append(new_key) - api_keys = keys - - return api_keys - - def get_api_key(self, api_key_id: str, include_value: bool = False) -> ApiKey: - api_key = self.keys.get(api_key_id) - if not api_key: + def get_api_key(self, api_key_id: str) -> ApiKey: + if api_key_id not in self.keys: raise ApiKeyNotFoundException() - - if not include_value: - new_key = copy(api_key) - del new_key["value"] - api_key = new_key - - return api_key + return self.keys[api_key_id] def update_api_key(self, api_key_id: str, patch_operations: Any) -> ApiKey: key = self.keys[api_key_id] @@ -1836,7 +1983,7 @@ class APIGatewayBackend(BaseBackend): def create_usage_plan(self, payload: Any) -> UsagePlan: plan = UsagePlan(**payload) - self.usage_plans[plan["id"]] = plan + self.usage_plans[plan.id] = plan return plan def get_usage_plans(self, api_key_id: Optional[str] = None) -> List[UsagePlan]: @@ -1845,7 +1992,7 @@ class APIGatewayBackend(BaseBackend): plans = [ plan for plan in plans - if dict(self.usage_plan_keys.get(plan["id"], {})).get(api_key_id) + if dict(self.usage_plan_keys.get(plan.id, {})).get(api_key_id) ] return plans @@ -1878,10 +2025,10 @@ class APIGatewayBackend(BaseBackend): usage_plan_key = UsagePlanKey( plan_id=key_id, plan_type=payload["keyType"], - name=api_key["name"], - value=api_key["value"], + name=api_key.name, + value=api_key.value, ) - self.usage_plan_keys[usage_plan_id][usage_plan_key["id"]] = usage_plan_key + self.usage_plan_keys[usage_plan_id][usage_plan_key.id] = usage_plan_key return usage_plan_key def get_usage_plan_keys(self, usage_plan_id: str) -> List[UsagePlanKey]: @@ -2054,7 +2201,7 @@ class APIGatewayBackend(BaseBackend): stage=stage, ) - new_base_path = new_base_path_mapping.get("basePath") + new_base_path = new_base_path_mapping.base_path if self.base_path_mappings.get(domain_name) is None: self.base_path_mappings[domain_name] = {} else: @@ -2108,13 +2255,13 @@ class APIGatewayBackend(BaseBackend): op["value"] for op in patch_operations if op["path"] == "/restapiId" ] if len(rest_api_ids) == 0: - modified_rest_api_id = base_path_mapping["restApiId"] + modified_rest_api_id = base_path_mapping.rest_api_id else: modified_rest_api_id = rest_api_ids[-1] stages = [op["value"] for op in patch_operations if op["path"] == "/stage"] if len(stages) == 0: - modified_stage = base_path_mapping.get("stage") + modified_stage = base_path_mapping.stage else: modified_stage = stages[-1] @@ -2122,7 +2269,7 @@ class APIGatewayBackend(BaseBackend): op["value"] for op in patch_operations if op["path"] == "/basePath" ] if len(base_paths) == 0: - modified_base_path = base_path_mapping["basePath"] + modified_base_path = base_path_mapping.base_path else: modified_base_path = base_paths[-1] @@ -2150,7 +2297,7 @@ class APIGatewayBackend(BaseBackend): vpc_link = VpcLink( name, description=description, target_arns=target_arns, tags=tags ) - self.vpc_links[vpc_link["id"]] = vpc_link + self.vpc_links[vpc_link.id] = vpc_link return vpc_link def delete_vpc_link(self, vpc_link_id: str) -> None: diff --git a/moto/apigateway/responses.py b/moto/apigateway/responses.py index 48e019cd9..2691abc0b 100644 --- a/moto/apigateway/responses.py +++ b/moto/apigateway/responses.py @@ -186,7 +186,7 @@ class APIGatewayResponse(BaseResponse): if self.method == "GET": method = self.backend.get_method(function_id, resource_id, method_type) - return 200, {}, json.dumps(method) + return 200, {}, json.dumps(method.to_json()) elif self.method == "PUT": authorization_type = self._get_param("authorizationType") api_key_required = self._get_param("apiKeyRequired") @@ -209,7 +209,7 @@ class APIGatewayResponse(BaseResponse): authorization_scopes=authorization_scopes, request_validator_id=request_validator_id, ) - return 201, {}, json.dumps(method) + return 201, {}, json.dumps(method.to_json()) elif self.method == "DELETE": self.backend.delete_method(function_id, resource_id, method_type) @@ -231,7 +231,7 @@ class APIGatewayResponse(BaseResponse): method_response = self.backend.get_method_response( function_id, resource_id, method_type, response_code ) - return 200, {}, json.dumps(method_response) + return 200, {}, json.dumps(method_response.to_json()) # type: ignore[union-attr] elif self.method == "PUT": response_models = self._get_param("responseModels") response_parameters = self._get_param("responseParameters") @@ -243,12 +243,12 @@ class APIGatewayResponse(BaseResponse): response_models, response_parameters, ) - return 201, {}, json.dumps(method_response) + return 201, {}, json.dumps(method_response.to_json()) elif self.method == "DELETE": method_response = self.backend.delete_method_response( function_id, resource_id, method_type, response_code ) - return 204, {}, json.dumps(method_response) + return 204, {}, json.dumps(method_response.to_json()) # type: ignore[union-attr] raise Exception('Unexpected HTTP method "%s"' % self.method) def restapis_authorizers(self, request: Any, full_url: str, headers: Dict[str, str]) -> TYPE_RESPONSE: # type: ignore[return] @@ -296,10 +296,10 @@ class APIGatewayResponse(BaseResponse): identiy_validation_expression=identiy_validation_expression, authorizer_result_ttl=authorizer_result_ttl, ) - return 201, {}, json.dumps(authorizer_response) + return 201, {}, json.dumps(authorizer_response.to_json()) elif self.method == "GET": authorizers = self.backend.get_authorizers(restapi_id) - return 200, {}, json.dumps({"item": authorizers}) + return 200, {}, json.dumps({"item": [a.to_json() for a in authorizers]}) def request_validators(self, request: Any, full_url: str, headers: Dict[str, str]) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) @@ -319,7 +319,7 @@ class APIGatewayResponse(BaseResponse): validator = self.backend.create_request_validator( restapi_id, name, body, params ) - return 201, {}, json.dumps(validator) + return 201, {}, json.dumps(validator.to_dict()) def request_validator_individual(self, request: Any, full_url: str, headers: Dict[str, str]) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) @@ -329,7 +329,7 @@ class APIGatewayResponse(BaseResponse): if self.method == "GET": validator = self.backend.get_request_validator(restapi_id, validator_id) - return 200, {}, json.dumps(validator) + return 200, {}, json.dumps(validator.to_dict()) if self.method == "DELETE": self.backend.delete_request_validator(restapi_id, validator_id) return 202, {}, "" @@ -338,7 +338,7 @@ class APIGatewayResponse(BaseResponse): validator = self.backend.update_request_validator( restapi_id, validator_id, patch_operations ) - return 200, {}, json.dumps(validator) + return 200, {}, json.dumps(validator.to_dict()) def authorizers(self, request: Any, full_url: str, headers: Dict[str, str]) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) @@ -348,13 +348,13 @@ class APIGatewayResponse(BaseResponse): if self.method == "GET": authorizer_response = self.backend.get_authorizer(restapi_id, authorizer_id) - return 200, {}, json.dumps(authorizer_response) + return 200, {}, json.dumps(authorizer_response.to_json()) elif self.method == "PATCH": patch_operations = self._get_param("patchOperations") authorizer_response = self.backend.update_authorizer( restapi_id, authorizer_id, patch_operations ) - return 200, {}, json.dumps(authorizer_response) + return 200, {}, json.dumps(authorizer_response.to_json()) elif self.method == "DELETE": self.backend.delete_authorizer(restapi_id, authorizer_id) return 202, {}, "{}" @@ -385,10 +385,10 @@ class APIGatewayResponse(BaseResponse): tags=tags, tracing_enabled=tracing_enabled, ) - return 201, {}, json.dumps(stage_response) + return 201, {}, json.dumps(stage_response.to_json()) elif self.method == "GET": stages = self.backend.get_stages(function_id) - return 200, {}, json.dumps({"item": stages}) + return 200, {}, json.dumps({"item": [s.to_json() for s in stages]}) def restapis_stages_tags(self, request: Any, full_url: str, headers: Dict[str, str]) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) @@ -399,13 +399,13 @@ class APIGatewayResponse(BaseResponse): tags = self._get_param("tags") if tags: stage = self.backend.get_stage(function_id, stage_name) - stage["tags"] = merge_multiple_dicts(stage.get("tags"), tags) + stage.tags = merge_multiple_dicts(stage.tags or {}, tags) return 200, {}, json.dumps({"item": tags}) if self.method == "DELETE": stage = self.backend.get_stage(function_id, stage_name) - for tag in stage.get("tags", {}).copy(): + for tag in (stage.tags or {}).copy(): if tag in (self.querystring.get("tagKeys") or {}): - stage["tags"].pop(tag, None) + stage.tags.pop(tag, None) # type: ignore[union-attr] return 200, {}, json.dumps({"item": ""}) def stages(self, request: Any, full_url: str, headers: Dict[str, str]) -> TYPE_RESPONSE: # type: ignore[return] @@ -416,13 +416,13 @@ class APIGatewayResponse(BaseResponse): if self.method == "GET": stage_response = self.backend.get_stage(function_id, stage_name) - return 200, {}, json.dumps(stage_response) + return 200, {}, json.dumps(stage_response.to_json()) elif self.method == "PATCH": patch_operations = self._get_param("patchOperations") stage_response = self.backend.update_stage( function_id, stage_name, patch_operations ) - return 200, {}, json.dumps(stage_response) + return 200, {}, json.dumps(stage_response.to_json()) elif self.method == "DELETE": self.backend.delete_stage(function_id, stage_name) return 202, {}, "{}" @@ -438,7 +438,9 @@ class APIGatewayResponse(BaseResponse): integration_response = self.backend.get_integration( function_id, resource_id, method_type ) - return 200, {}, json.dumps(integration_response) + if integration_response: + return 200, {}, json.dumps(integration_response.to_json()) + return 200, {}, "{}" elif self.method == "PUT": integration_type = self._get_param("type") uri = self._get_param("uri") @@ -470,12 +472,12 @@ class APIGatewayResponse(BaseResponse): timeout_in_millis=timeout_in_millis, request_parameters=request_parameters, ) - return 201, {}, json.dumps(integration_response) + return 201, {}, json.dumps(integration_response.to_json()) elif self.method == "DELETE": integration_response = self.backend.delete_integration( function_id, resource_id, method_type ) - return 204, {}, json.dumps(integration_response) + return 204, {}, json.dumps(integration_response.to_json()) def integration_responses(self, request: Any, full_url: str, headers: Dict[str, str]) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) @@ -489,7 +491,7 @@ class APIGatewayResponse(BaseResponse): integration_response = self.backend.get_integration_response( function_id, resource_id, method_type, status_code ) - return 200, {}, json.dumps(integration_response) + return 200, {}, json.dumps(integration_response.to_json()) elif self.method == "PUT": if not self.body: raise InvalidRequestInput() @@ -506,12 +508,12 @@ class APIGatewayResponse(BaseResponse): response_templates, content_handling, ) - return 201, {}, json.dumps(integration_response) + return 201, {}, json.dumps(integration_response.to_json()) elif self.method == "DELETE": integration_response = self.backend.delete_integration_response( function_id, resource_id, method_type, status_code ) - return 204, {}, json.dumps(integration_response) + return 204, {}, json.dumps(integration_response.to_json()) def deployments(self, request: Any, full_url: str, headers: Dict[str, str]) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) @@ -519,7 +521,7 @@ class APIGatewayResponse(BaseResponse): if self.method == "GET": deployments = self.backend.get_deployments(function_id) - return 200, {}, json.dumps({"item": deployments}) + return 200, {}, json.dumps({"item": [d.to_json() for d in deployments]}) elif self.method == "POST": name = self._get_param("stageName") description = self._get_param("description") @@ -527,7 +529,7 @@ class APIGatewayResponse(BaseResponse): deployment = self.backend.create_deployment( function_id, name, description, stage_variables ) - return 201, {}, json.dumps(deployment) + return 201, {}, json.dumps(deployment.to_json()) def individual_deployment(self, request: Any, full_url: str, headers: Dict[str, str]) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) @@ -537,22 +539,26 @@ class APIGatewayResponse(BaseResponse): if self.method == "GET": deployment = self.backend.get_deployment(function_id, deployment_id) - return 200, {}, json.dumps(deployment) + return 200, {}, json.dumps(deployment.to_json()) elif self.method == "DELETE": deployment = self.backend.delete_deployment(function_id, deployment_id) - return 202, {}, json.dumps(deployment) + return 202, {}, json.dumps(deployment.to_json()) def apikeys(self, request: Any, full_url: str, headers: Dict[str, str]) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "POST": apikey_response = self.backend.create_api_key(json.loads(self.body)) - return 201, {}, json.dumps(apikey_response) + return 201, {}, json.dumps(apikey_response.to_json()) elif self.method == "GET": include_values = self._get_bool_param("includeValues") or False - apikeys_response = self.backend.get_api_keys(include_values=include_values) - return 200, {}, json.dumps({"item": apikeys_response}) + apikeys_response = self.backend.get_api_keys() + resp = [a.to_json() for a in apikeys_response] + if not include_values: + for key in resp: + key.pop("value") + return 200, {}, json.dumps({"item": resp}) def apikey_individual( self, request: Any, full_url: str, headers: Dict[str, str] @@ -564,27 +570,33 @@ class APIGatewayResponse(BaseResponse): if self.method == "GET": include_value = self._get_bool_param("includeValue") or False - apikey_response = self.backend.get_api_key( - apikey, include_value=include_value - ) + apikey_resp = self.backend.get_api_key(apikey).to_json() + if not include_value: + apikey_resp.pop("value") elif self.method == "PATCH": patch_operations = self._get_param("patchOperations") - apikey_response = self.backend.update_api_key(apikey, patch_operations) + apikey_resp = self.backend.update_api_key( + apikey, patch_operations + ).to_json() elif self.method == "DELETE": self.backend.delete_api_key(apikey) return 202, {}, "{}" - return 200, {}, json.dumps(apikey_response) + return 200, {}, json.dumps(apikey_resp) def usage_plans(self, request: Any, full_url: str, headers: Dict[str, str]) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "POST": usage_plan_response = self.backend.create_usage_plan(json.loads(self.body)) - return 201, {}, json.dumps(usage_plan_response) + return 201, {}, json.dumps(usage_plan_response.to_json()) elif self.method == "GET": api_key_id = self.querystring.get("keyId", [None])[0] usage_plans_response = self.backend.get_usage_plans(api_key_id=api_key_id) - return 200, {}, json.dumps({"item": usage_plans_response}) + return ( + 200, + {}, + json.dumps({"item": [u.to_json() for u in usage_plans_response]}), + ) def usage_plan_individual(self, request: Any, full_url: str, headers: Dict[str, str]) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) @@ -594,7 +606,7 @@ class APIGatewayResponse(BaseResponse): if self.method == "GET": usage_plan_response = self.backend.get_usage_plan(usage_plan) - return 200, {}, json.dumps(usage_plan_response) + return 200, {}, json.dumps(usage_plan_response.to_json()) elif self.method == "DELETE": self.backend.delete_usage_plan(usage_plan) return 202, {}, "{}" @@ -603,7 +615,7 @@ class APIGatewayResponse(BaseResponse): usage_plan_response = self.backend.update_usage_plan( usage_plan, patch_operations ) - return 200, {}, json.dumps(usage_plan_response) + return 200, {}, json.dumps(usage_plan_response.to_json()) def usage_plan_keys(self, request: Any, full_url: str, headers: Dict[str, str]) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) @@ -615,10 +627,14 @@ class APIGatewayResponse(BaseResponse): usage_plan_response = self.backend.create_usage_plan_key( usage_plan_id, json.loads(self.body) ) - return 201, {}, json.dumps(usage_plan_response) + return 201, {}, json.dumps(usage_plan_response.to_json()) elif self.method == "GET": usage_plans_response = self.backend.get_usage_plan_keys(usage_plan_id) - return 200, {}, json.dumps({"item": usage_plans_response}) + return ( + 200, + {}, + json.dumps({"item": [u.to_json() for u in usage_plans_response]}), + ) def usage_plan_key_individual(self, request: Any, full_url: str, headers: Dict[str, str]) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) @@ -629,7 +645,7 @@ class APIGatewayResponse(BaseResponse): if self.method == "GET": usage_plan_response = self.backend.get_usage_plan_key(usage_plan_id, key_id) - return 200, {}, json.dumps(usage_plan_response) + return 200, {}, json.dumps(usage_plan_response.to_json()) elif self.method == "DELETE": self.backend.delete_usage_plan_key(usage_plan_id, key_id) return 202, {}, "{}" @@ -639,7 +655,7 @@ class APIGatewayResponse(BaseResponse): if self.method == "GET": domain_names = self.backend.get_domain_names() - return 200, {}, json.dumps({"item": domain_names}) + return 200, {}, json.dumps({"item": [d.to_json() for d in domain_names]}) elif self.method == "POST": domain_name = self._get_param("domainName") @@ -666,7 +682,7 @@ class APIGatewayResponse(BaseResponse): endpoint_configuration, security_policy, ) - return 201, {}, json.dumps(domain_name_resp) + return 201, {}, json.dumps(domain_name_resp.to_json()) def domain_name_induvidual( self, request: Any, full_url: str, headers: Dict[str, str] @@ -679,7 +695,7 @@ class APIGatewayResponse(BaseResponse): if self.method == "GET": if domain_name is not None: domain_names = self.backend.get_domain_name(domain_name) - return 200, {}, json.dumps(domain_names) + return 200, {}, json.dumps(domain_names.to_json()) return 200, {}, "{}" elif self.method == "DELETE": if domain_name is not None: @@ -695,7 +711,7 @@ class APIGatewayResponse(BaseResponse): if self.method == "GET": models = self.backend.get_models(rest_api_id) - return 200, {}, json.dumps({"item": models}) + return 200, {}, json.dumps({"item": [m.to_json() for m in models]}) elif self.method == "POST": name = self._get_param("name") @@ -709,7 +725,7 @@ class APIGatewayResponse(BaseResponse): description, schema, ) - return 201, {}, json.dumps(model) + return 201, {}, json.dumps(model.to_json()) def model_induvidual( self, request: Any, full_url: str, headers: Dict[str, str] @@ -721,7 +737,7 @@ class APIGatewayResponse(BaseResponse): if self.method == "GET": model_info = self.backend.get_model(rest_api_id, model_name) - return 200, {}, json.dumps(model_info) + return 200, {}, json.dumps(model_info.to_json()) return 200, {}, "{}" def base_path_mappings(self, request: Any, full_url: str, headers: Dict[str, str]) -> TYPE_RESPONSE: # type: ignore[return] @@ -732,7 +748,11 @@ class APIGatewayResponse(BaseResponse): if self.method == "GET": base_path_mappings = self.backend.get_base_path_mappings(domain_name) - return 200, {}, json.dumps({"item": base_path_mappings}) + return ( + 200, + {}, + json.dumps({"item": [m.to_json() for m in base_path_mappings]}), + ) elif self.method == "POST": base_path = self._get_param("basePath") rest_api_id = self._get_param("restApiId") @@ -741,7 +761,7 @@ class APIGatewayResponse(BaseResponse): base_path_mapping_resp = self.backend.create_base_path_mapping( domain_name, rest_api_id, base_path, stage ) - return 201, {}, json.dumps(base_path_mapping_resp) + return 201, {}, json.dumps(base_path_mapping_resp.to_json()) def base_path_mapping_individual(self, request: Any, full_url: str, headers: Dict[str, str]) -> TYPE_RESPONSE: # type: ignore[return] @@ -755,7 +775,7 @@ class APIGatewayResponse(BaseResponse): base_path_mapping = self.backend.get_base_path_mapping( domain_name, base_path ) - return 200, {}, json.dumps(base_path_mapping) + return 200, {}, json.dumps(base_path_mapping.to_json()) elif self.method == "DELETE": self.backend.delete_base_path_mapping(domain_name, base_path) return 202, {}, "" @@ -764,7 +784,7 @@ class APIGatewayResponse(BaseResponse): base_path_mapping = self.backend.update_base_path_mapping( domain_name, base_path, patch_operations ) - return 200, {}, json.dumps(base_path_mapping) + return 200, {}, json.dumps(base_path_mapping.to_json()) def vpc_link(self, request: Any, full_url: str, headers: Dict[str, str]) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) @@ -776,14 +796,14 @@ class APIGatewayResponse(BaseResponse): return 202, {}, "{}" if self.method == "GET": vpc_link = self.backend.get_vpc_link(vpc_link_id=vpc_link_id) - return 200, {}, json.dumps(vpc_link) + return 200, {}, json.dumps(vpc_link.to_json()) def vpc_links(self, request: Any, full_url: str, headers: Dict[str, str]) -> TYPE_RESPONSE: # type: ignore[return] self.setup_class(request, full_url, headers) if self.method == "GET": vpc_links = self.backend.get_vpc_links() - return 200, {}, json.dumps({"item": vpc_links}) + return 200, {}, json.dumps({"item": [v.to_json() for v in vpc_links]}) if self.method == "POST": name = self._get_param("name") description = self._get_param("description") @@ -792,7 +812,7 @@ class APIGatewayResponse(BaseResponse): vpc_link = self.backend.create_vpc_link( name=name, description=description, target_arns=target_arns, tags=tags ) - return 202, {}, json.dumps(vpc_link) + return 202, {}, json.dumps(vpc_link.to_json()) def put_gateway_response(self) -> TYPE_RESPONSE: rest_api_id = self.path.split("/")[-3] @@ -808,7 +828,7 @@ class APIGatewayResponse(BaseResponse): response_parameters=response_parameters, response_templates=response_templates, ) - return 201, {}, json.dumps(response) + return 201, {}, json.dumps(response.to_json()) def get_gateway_response(self) -> TYPE_RESPONSE: rest_api_id = self.path.split("/")[-3] @@ -816,12 +836,12 @@ class APIGatewayResponse(BaseResponse): response = self.backend.get_gateway_response( rest_api_id=rest_api_id, response_type=response_type ) - return 200, {}, json.dumps(response) + return 200, {}, json.dumps(response.to_json()) def get_gateway_responses(self) -> TYPE_RESPONSE: rest_api_id = self.path.split("/")[-2] responses = self.backend.get_gateway_responses(rest_api_id=rest_api_id) - return 200, {}, json.dumps(dict(item=responses)) + return 200, {}, json.dumps(dict(item=[gw.to_json() for gw in responses])) def delete_gateway_response(self) -> TYPE_RESPONSE: rest_api_id = self.path.split("/")[-3] diff --git a/moto/wafv2/models.py b/moto/wafv2/models.py index 2432ac731..652621057 100644 --- a/moto/wafv2/models.py +++ b/moto/wafv2/models.py @@ -81,17 +81,17 @@ class WAFV2Backend(BaseBackend): raise WAFNonexistentItemException stage = self._find_apigw_stage(resource_arn) if stage: - stage["webAclArn"] = web_acl_arn + stage.web_acl_arn = web_acl_arn def disassociate_web_acl(self, resource_arn): stage = self._find_apigw_stage(resource_arn) if stage: - stage.pop("webAclArn", None) + stage.web_acl_arn = None def get_web_acl_for_resource(self, resource_arn): stage = self._find_apigw_stage(resource_arn) - if stage and stage.get("webAclArn"): - wacl_arn = stage.get("webAclArn") + if stage and stage.web_acl_arn is not None: + wacl_arn = stage.web_acl_arn return self.wacls.get(wacl_arn) return None