2461 lines
90 KiB
Python
2461 lines
90 KiB
Python
import re
|
|
import string
|
|
import time
|
|
from collections import defaultdict
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
from urllib.parse import urlparse
|
|
|
|
import requests
|
|
import responses
|
|
|
|
try:
|
|
# Recommended as of 0.7.x
|
|
from openapi_spec_validator import validate # type: ignore
|
|
except ImportError:
|
|
# Only used in < 0.7.x
|
|
# (Also exists in 0.7.0, but throws a warning)
|
|
from openapi_spec_validator import validate_spec as validate # type: ignore
|
|
from openapi_spec_validator.validation.exceptions import OpenAPIValidationError
|
|
|
|
from moto.apigateway.exceptions import MethodNotFoundException
|
|
from moto.core import BackendDict, BaseBackend, BaseModel, CloudFormationModel
|
|
from moto.core.utils import path_url
|
|
from moto.moto_api._internal import mock_random as random
|
|
|
|
from ..core.models import responses_mock
|
|
from .exceptions import (
|
|
ApiKeyAlreadyExists,
|
|
ApiKeyNotFoundException,
|
|
ApiKeyValueMinLength,
|
|
AuthorizerNotFoundException,
|
|
AwsProxyNotAllowed,
|
|
BadRequestException,
|
|
BasePathConflictException,
|
|
BasePathNotFoundException,
|
|
ConflictException,
|
|
CrossAccountNotAllowed,
|
|
DeploymentNotFoundException,
|
|
DomainNameNotFound,
|
|
GatewayResponseNotFound,
|
|
IntegrationMethodNotDefined,
|
|
InvalidArn,
|
|
InvalidBasePathException,
|
|
InvalidDomainName,
|
|
InvalidHttpEndpoint,
|
|
InvalidIntegrationArn,
|
|
InvalidModelName,
|
|
InvalidOpenAPIDocumentException,
|
|
InvalidOpenApiDocVersionException,
|
|
InvalidOpenApiModeException,
|
|
InvalidResourcePathException,
|
|
InvalidRestApiId,
|
|
InvalidRestApiIdForBasePathMappingException,
|
|
InvalidStageException,
|
|
ModelNotFound,
|
|
NoIntegrationDefined,
|
|
NoIntegrationResponseDefined,
|
|
NoMethodDefined,
|
|
RequestValidatorNotFound,
|
|
ResourceIdNotFoundException,
|
|
RestAPINotFound,
|
|
RoleNotSpecified,
|
|
StageNotFoundException,
|
|
StageStillActive,
|
|
UsagePlanNotFoundException,
|
|
ValidationException,
|
|
VpcLinkNotFound,
|
|
)
|
|
from .utils import create_id, to_path
|
|
|
|
STAGE_URL = "https://{api_id}.execute-api.{region_name}.amazonaws.com/{stage_name}"
|
|
|
|
|
|
class Deployment(CloudFormationModel):
|
|
def __init__(self, deployment_id: str, name: str, description: str = ""):
|
|
self.id = deployment_id
|
|
self.stage_name = name
|
|
self.description = description
|
|
self.created_date = int(time.time())
|
|
|
|
def to_json(self) -> Dict[str, Any]:
|
|
return {
|
|
"id": self.id,
|
|
"stageName": self.stage_name,
|
|
"description": self.description,
|
|
"createdDate": self.created_date,
|
|
}
|
|
|
|
@staticmethod
|
|
def cloudformation_name_type() -> str:
|
|
return "Deployment"
|
|
|
|
@staticmethod
|
|
def cloudformation_type() -> str:
|
|
return "AWS::ApiGateway::Deployment"
|
|
|
|
@classmethod
|
|
def create_from_cloudformation_json( # type: ignore[misc]
|
|
cls,
|
|
resource_name: str,
|
|
cloudformation_json: Dict[str, Any],
|
|
account_id: str,
|
|
region_name: str,
|
|
**kwargs: Any,
|
|
) -> "Deployment":
|
|
properties = cloudformation_json["Properties"]
|
|
rest_api_id = properties["RestApiId"]
|
|
name = properties.get("StageName")
|
|
desc = properties.get("Description", "")
|
|
backend: "APIGatewayBackend" = apigateway_backends[account_id][region_name]
|
|
return backend.create_deployment(
|
|
function_id=rest_api_id, name=name, description=desc
|
|
)
|
|
|
|
|
|
class IntegrationResponse(BaseModel):
|
|
def __init__(
|
|
self,
|
|
status_code: Union[str, int],
|
|
selection_pattern: Optional[str] = None,
|
|
response_templates: Optional[Dict[str, Any]] = None,
|
|
response_parameters: Optional[Dict[str, str]] = None,
|
|
content_handling: Optional[Any] = None,
|
|
):
|
|
if response_templates is None:
|
|
# response_templates = {"application/json": None} # Note: removed for compatibility with TF
|
|
response_templates = {}
|
|
for key in response_templates.keys():
|
|
response_templates[key] = (
|
|
response_templates[key] or None
|
|
) # required for compatibility with TF
|
|
self.response_templates = response_templates
|
|
self.status_code = status_code
|
|
self.selection_pattern = selection_pattern
|
|
self.response_parameters = response_parameters
|
|
self.content_handling = content_handling
|
|
|
|
def to_json(self) -> Dict[str, Any]:
|
|
resp = {
|
|
"responseTemplates": self.response_templates,
|
|
"statusCode": self.status_code,
|
|
}
|
|
if self.selection_pattern:
|
|
resp["selectionPattern"] = self.selection_pattern
|
|
if self.content_handling:
|
|
resp["contentHandling"] = self.content_handling
|
|
if self.response_parameters:
|
|
resp["responseParameters"] = self.response_parameters
|
|
return resp
|
|
|
|
|
|
class Integration(BaseModel):
|
|
def __init__(
|
|
self,
|
|
integration_type: str,
|
|
uri: str,
|
|
http_method: str,
|
|
request_templates: Optional[Dict[str, Any]] = None,
|
|
passthrough_behavior: Optional[str] = "WHEN_NO_MATCH",
|
|
cache_key_parameters: Optional[List[str]] = None,
|
|
tls_config: Optional[Dict[str, Any]] = None,
|
|
cache_namespace: Optional[str] = None,
|
|
timeout_in_millis: Optional[str] = None,
|
|
request_parameters: Optional[Dict[str, Any]] = None,
|
|
content_handling: Optional[str] = None,
|
|
credentials: Optional[str] = None,
|
|
connection_type: Optional[str] = None,
|
|
):
|
|
self.integration_type = integration_type
|
|
self.uri = uri
|
|
self.http_method = http_method if integration_type != "MOCK" else None
|
|
self.passthrough_behaviour = passthrough_behavior
|
|
self.cache_key_parameters: List[str] = cache_key_parameters or []
|
|
self.request_templates = request_templates
|
|
self.tls_config = tls_config
|
|
self.cache_namespace = cache_namespace
|
|
self.timeout_in_millis = timeout_in_millis
|
|
self.request_parameters = request_parameters
|
|
self.content_handling = content_handling
|
|
self.credentials = credentials
|
|
self.connection_type = connection_type
|
|
self.integration_responses: Optional[Dict[str, IntegrationResponse]] = None
|
|
|
|
def to_json(self) -> Dict[str, Any]:
|
|
int_responses: Optional[Dict[str, Any]] = None
|
|
if self.integration_responses is not None:
|
|
int_responses = {
|
|
k: v.to_json() for k, v in self.integration_responses.items()
|
|
}
|
|
return {
|
|
"type": self.integration_type,
|
|
"uri": self.uri,
|
|
"httpMethod": self.http_method,
|
|
"passthroughBehavior": self.passthrough_behaviour,
|
|
"cacheKeyParameters": self.cache_key_parameters,
|
|
"requestTemplates": self.request_templates,
|
|
"integrationResponses": int_responses,
|
|
"tlsConfig": self.tls_config,
|
|
"cacheNamespace": self.cache_namespace,
|
|
"timeoutInMillis": self.timeout_in_millis,
|
|
"requestParameters": self.request_parameters,
|
|
"contentHandling": self.content_handling,
|
|
"credentials": self.credentials,
|
|
"connectionType": self.connection_type,
|
|
}
|
|
|
|
def create_integration_response(
|
|
self,
|
|
status_code: str,
|
|
selection_pattern: str,
|
|
response_templates: Dict[str, str],
|
|
response_parameters: Dict[str, str],
|
|
content_handling: str,
|
|
) -> IntegrationResponse:
|
|
integration_response = IntegrationResponse(
|
|
status_code,
|
|
selection_pattern,
|
|
response_templates or None,
|
|
response_parameters,
|
|
content_handling,
|
|
)
|
|
if self.integration_responses is None:
|
|
self.integration_responses = {}
|
|
self.integration_responses[status_code] = integration_response
|
|
return integration_response
|
|
|
|
def get_integration_response(self, status_code: str) -> IntegrationResponse:
|
|
result = (self.integration_responses or {}).get(status_code)
|
|
if not result:
|
|
raise NoIntegrationResponseDefined()
|
|
return result
|
|
|
|
def delete_integration_response(self, status_code: str) -> IntegrationResponse:
|
|
return (self.integration_responses or {}).pop(status_code, None) # type: ignore[arg-type]
|
|
|
|
|
|
class MethodResponse(BaseModel):
|
|
def __init__(
|
|
self,
|
|
status_code: str,
|
|
response_models: Dict[str, str],
|
|
response_parameters: Dict[str, Dict[str, str]],
|
|
):
|
|
self.status_code = status_code
|
|
self.response_models = response_models
|
|
self.response_parameters = response_parameters
|
|
|
|
def to_json(self) -> Dict[str, Any]:
|
|
return {
|
|
"statusCode": self.status_code,
|
|
"responseModels": self.response_models,
|
|
"responseParameters": self.response_parameters,
|
|
}
|
|
|
|
|
|
class Method(CloudFormationModel):
|
|
def __init__(
|
|
self, method_type: str, authorization_type: Optional[str], **kwargs: Any
|
|
):
|
|
self.http_method = method_type
|
|
self.authorization_type = authorization_type
|
|
self.authorizer_id = kwargs.get("authorizer_id")
|
|
self.authorization_scopes = kwargs.get("authorization_scopes")
|
|
self.api_key_required = kwargs.get("api_key_required") or False
|
|
self.request_parameters = kwargs.get("request_parameters")
|
|
self.request_models = kwargs.get("request_models")
|
|
self.method_integration: Optional[Integration] = None
|
|
self.operation_name = kwargs.get("operation_name")
|
|
self.request_validator_id = kwargs.get("request_validator_id")
|
|
self.method_responses: Dict[str, MethodResponse] = {}
|
|
|
|
def to_json(self) -> Dict[str, Any]:
|
|
return {
|
|
"httpMethod": self.http_method,
|
|
"authorizationType": self.authorization_type,
|
|
"authorizerId": self.authorizer_id,
|
|
"authorizationScopes": self.authorization_scopes,
|
|
"apiKeyRequired": self.api_key_required,
|
|
"requestParameters": self.request_parameters,
|
|
"requestModels": self.request_models,
|
|
"methodIntegration": self.method_integration.to_json()
|
|
if self.method_integration
|
|
else None,
|
|
"operationName": self.operation_name,
|
|
"requestValidatorId": self.request_validator_id,
|
|
"methodResponses": {
|
|
k: v.to_json() for k, v in self.method_responses.items()
|
|
},
|
|
}
|
|
|
|
@staticmethod
|
|
def cloudformation_name_type() -> str:
|
|
return "Method"
|
|
|
|
@staticmethod
|
|
def cloudformation_type() -> str:
|
|
return "AWS::ApiGateway::Method"
|
|
|
|
@classmethod
|
|
def create_from_cloudformation_json( # type: ignore[misc]
|
|
cls,
|
|
resource_name: str,
|
|
cloudformation_json: Dict[str, Any],
|
|
account_id: str,
|
|
region_name: str,
|
|
**kwargs: Any,
|
|
) -> "Method":
|
|
properties = cloudformation_json["Properties"]
|
|
rest_api_id = properties["RestApiId"]
|
|
resource_id = properties["ResourceId"]
|
|
method_type = properties["HttpMethod"]
|
|
auth_type = properties["AuthorizationType"]
|
|
key_req = properties.get("ApiKeyRequired")
|
|
backend = apigateway_backends[account_id][region_name]
|
|
m = backend.put_method(
|
|
function_id=rest_api_id,
|
|
resource_id=resource_id,
|
|
method_type=method_type,
|
|
authorization_type=auth_type,
|
|
api_key_required=key_req,
|
|
)
|
|
int_method = properties["Integration"]["IntegrationHttpMethod"]
|
|
int_type = properties["Integration"]["Type"]
|
|
int_uri = properties["Integration"]["Uri"]
|
|
backend.put_integration(
|
|
function_id=rest_api_id,
|
|
resource_id=resource_id,
|
|
method_type=method_type,
|
|
integration_type=int_type,
|
|
uri=int_uri,
|
|
integration_method=int_method,
|
|
)
|
|
return m
|
|
|
|
def create_response(
|
|
self,
|
|
response_code: str,
|
|
response_models: Dict[str, str],
|
|
response_parameters: Dict[str, Dict[str, str]],
|
|
) -> MethodResponse:
|
|
method_response = MethodResponse(
|
|
response_code, response_models, response_parameters
|
|
)
|
|
self.method_responses[response_code] = method_response
|
|
return method_response
|
|
|
|
def get_response(self, response_code: str) -> Optional[MethodResponse]:
|
|
return self.method_responses.get(response_code)
|
|
|
|
def delete_response(self, response_code: str) -> Optional[MethodResponse]:
|
|
return self.method_responses.pop(response_code, None)
|
|
|
|
|
|
class Resource(CloudFormationModel):
|
|
def __init__(
|
|
self,
|
|
resource_id: str,
|
|
account_id: str,
|
|
region_name: str,
|
|
api_id: str,
|
|
path_part: str,
|
|
parent_id: Optional[str],
|
|
):
|
|
self.id = resource_id
|
|
self.account_id = account_id
|
|
self.region_name = region_name
|
|
self.api_id = api_id
|
|
self.path_part = path_part
|
|
self.parent_id = parent_id
|
|
self.resource_methods: Dict[str, Method] = {}
|
|
from .integration_parsers import IntegrationParser
|
|
from .integration_parsers.aws_parser import TypeAwsParser
|
|
from .integration_parsers.http_parser import TypeHttpParser
|
|
from .integration_parsers.unknown_parser import TypeUnknownParser
|
|
|
|
self.integration_parsers: Dict[str, IntegrationParser] = defaultdict(
|
|
TypeUnknownParser
|
|
)
|
|
self.integration_parsers["HTTP"] = TypeHttpParser()
|
|
self.integration_parsers["AWS"] = TypeAwsParser()
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
response: Dict[str, Any] = {
|
|
"path": self.get_path(),
|
|
"id": self.id,
|
|
}
|
|
if self.resource_methods:
|
|
response["resourceMethods"] = {
|
|
k: v.to_json() for k, v in self.resource_methods.items()
|
|
}
|
|
if self.parent_id:
|
|
response["parentId"] = self.parent_id
|
|
response["pathPart"] = self.path_part
|
|
return response
|
|
|
|
@property
|
|
def physical_resource_id(self) -> str:
|
|
return self.id
|
|
|
|
@staticmethod
|
|
def cloudformation_name_type() -> str:
|
|
return "Resource"
|
|
|
|
@staticmethod
|
|
def cloudformation_type() -> str:
|
|
return "AWS::ApiGateway::Resource"
|
|
|
|
@classmethod
|
|
def create_from_cloudformation_json( # type: ignore[misc]
|
|
cls,
|
|
resource_name: str,
|
|
cloudformation_json: Dict[str, Any],
|
|
account_id: str,
|
|
region_name: str,
|
|
**kwargs: Any,
|
|
) -> "Resource":
|
|
properties = cloudformation_json["Properties"]
|
|
api_id = properties["RestApiId"]
|
|
parent = properties["ParentId"]
|
|
path = properties["PathPart"]
|
|
|
|
backend = apigateway_backends[account_id][region_name]
|
|
if parent == api_id:
|
|
# A Root path (/) is automatically created. Any new paths should use this as their parent
|
|
resources = backend.get_resources(function_id=api_id)
|
|
root_id = [resource for resource in resources if resource.path_part == "/"][
|
|
0
|
|
].id
|
|
parent = root_id
|
|
return backend.create_resource(
|
|
function_id=api_id, parent_resource_id=parent, path_part=path
|
|
)
|
|
|
|
def get_path(self) -> str:
|
|
return self.get_parent_path() + self.path_part
|
|
|
|
def get_parent_path(self) -> str:
|
|
if self.parent_id:
|
|
backend = apigateway_backends[self.account_id][self.region_name]
|
|
parent = backend.get_resource(self.api_id, self.parent_id)
|
|
parent_path = parent.get_path()
|
|
if parent_path != "/": # Root parent
|
|
parent_path += "/"
|
|
return parent_path
|
|
else:
|
|
return ""
|
|
|
|
def get_response(
|
|
self, request: requests.PreparedRequest
|
|
) -> Tuple[int, Union[str, bytes]]:
|
|
integration = self.get_integration(str(request.method))
|
|
integration_type = integration.integration_type # type: ignore[union-attr]
|
|
|
|
status, result = self.integration_parsers[integration_type].invoke(
|
|
request, integration # type: ignore[arg-type]
|
|
)
|
|
|
|
return status, result
|
|
|
|
def add_method(
|
|
self,
|
|
method_type: str,
|
|
authorization_type: Optional[str],
|
|
api_key_required: Optional[bool],
|
|
request_parameters: Any = None,
|
|
request_models: Any = None,
|
|
operation_name: Optional[str] = None,
|
|
authorizer_id: Optional[str] = None,
|
|
authorization_scopes: Any = None,
|
|
request_validator_id: Any = None,
|
|
) -> Method:
|
|
if authorization_scopes and not isinstance(authorization_scopes, list):
|
|
authorization_scopes = [authorization_scopes]
|
|
method = Method(
|
|
method_type=method_type,
|
|
authorization_type=authorization_type,
|
|
api_key_required=api_key_required,
|
|
request_parameters=request_parameters,
|
|
request_models=request_models,
|
|
operation_name=operation_name,
|
|
authorizer_id=authorizer_id,
|
|
authorization_scopes=authorization_scopes,
|
|
request_validator_id=request_validator_id,
|
|
)
|
|
self.resource_methods[method_type] = method
|
|
return method
|
|
|
|
def get_method(self, method_type: str) -> Method:
|
|
method = self.resource_methods.get(method_type)
|
|
if not method:
|
|
raise MethodNotFoundException()
|
|
return method
|
|
|
|
def delete_method(self, method_type: str) -> None:
|
|
self.resource_methods.pop(method_type, None)
|
|
|
|
def add_integration(
|
|
self,
|
|
method_type: str,
|
|
integration_type: str,
|
|
uri: str,
|
|
request_templates: Optional[Dict[str, Any]] = None,
|
|
passthrough_behavior: Optional[str] = None,
|
|
integration_method: Optional[str] = None,
|
|
tls_config: Optional[Dict[str, Any]] = None,
|
|
cache_namespace: Optional[str] = None,
|
|
timeout_in_millis: Optional[str] = None,
|
|
request_parameters: Optional[Dict[str, Any]] = None,
|
|
content_handling: Optional[str] = None,
|
|
credentials: Optional[str] = None,
|
|
connection_type: Optional[str] = None,
|
|
) -> Integration:
|
|
integration_method = integration_method or method_type
|
|
integration = Integration(
|
|
integration_type,
|
|
uri,
|
|
integration_method,
|
|
request_templates=request_templates,
|
|
passthrough_behavior=passthrough_behavior,
|
|
tls_config=tls_config,
|
|
cache_namespace=cache_namespace,
|
|
timeout_in_millis=timeout_in_millis,
|
|
request_parameters=request_parameters,
|
|
content_handling=content_handling,
|
|
credentials=credentials,
|
|
connection_type=connection_type,
|
|
)
|
|
self.resource_methods[method_type].method_integration = integration
|
|
return integration
|
|
|
|
def get_integration(self, method_type: str) -> Optional[Integration]:
|
|
method = self.resource_methods.get(method_type)
|
|
return method.method_integration if method else None
|
|
|
|
def delete_integration(self, method_type: str) -> Integration:
|
|
integration = self.resource_methods[method_type].method_integration
|
|
self.resource_methods[method_type].method_integration = None
|
|
return integration # type: ignore[return-value]
|
|
|
|
|
|
class Authorizer(BaseModel):
|
|
def __init__(
|
|
self,
|
|
authorizer_id: Optional[str],
|
|
name: Optional[str],
|
|
authorizer_type: Optional[str],
|
|
**kwargs: Any,
|
|
):
|
|
self.id = authorizer_id
|
|
self.name = name
|
|
self.type = authorizer_type
|
|
self.provider_arns = kwargs.get("provider_arns")
|
|
self.auth_type = kwargs.get("auth_type")
|
|
self.authorizer_uri = kwargs.get("authorizer_uri")
|
|
self.authorizer_credentials = kwargs.get("authorizer_credentials")
|
|
self.identity_source = kwargs.get("identity_source")
|
|
self.identity_validation_expression = kwargs.get(
|
|
"identity_validation_expression"
|
|
)
|
|
self.authorizer_result_ttl = kwargs.get("authorizer_result_ttl")
|
|
|
|
def to_json(self) -> Dict[str, Any]:
|
|
dct = {
|
|
"id": self.id,
|
|
"name": self.name,
|
|
"type": self.type,
|
|
"authorizerResultTtlInSeconds": self.authorizer_result_ttl,
|
|
}
|
|
if self.provider_arns:
|
|
dct["providerARNs"] = self.provider_arns
|
|
if self.auth_type:
|
|
dct["authType"] = self.auth_type
|
|
if self.authorizer_uri:
|
|
dct["authorizerUri"] = self.authorizer_uri
|
|
if self.authorizer_credentials:
|
|
dct["authorizerCredentials"] = self.authorizer_credentials
|
|
if self.identity_source:
|
|
dct["identitySource"] = self.identity_source
|
|
if self.identity_validation_expression:
|
|
dct["identityValidationExpression"] = self.identity_validation_expression
|
|
return dct
|
|
|
|
def apply_operations(self, patch_operations: List[Dict[str, Any]]) -> "Authorizer":
|
|
for op in patch_operations:
|
|
if "/authorizerUri" in op["path"]:
|
|
self.authorizer_uri = op["value"]
|
|
elif "/authorizerCredentials" in op["path"]:
|
|
self.authorizer_credentials = op["value"]
|
|
elif "/authorizerResultTtlInSeconds" in op["path"]:
|
|
self.authorizer_result_ttl = int(op["value"])
|
|
elif "/authType" in op["path"]:
|
|
self.auth_type = op["value"]
|
|
elif "/identitySource" in op["path"]:
|
|
self.identity_source = op["value"]
|
|
elif "/identityValidationExpression" in op["path"]:
|
|
self.identity_validation_expression = op["value"]
|
|
elif "/name" in op["path"]:
|
|
self.name = op["value"]
|
|
elif "/providerARNs" in op["path"]:
|
|
# TODO: add and remove
|
|
raise Exception(f'Patch operation for "{op["path"]}" not implemented')
|
|
elif "/type" in op["path"]:
|
|
self.type = op["value"]
|
|
else:
|
|
raise Exception(f'Patch operation "{op["op"]}" not implemented')
|
|
return self
|
|
|
|
|
|
class Stage(BaseModel):
|
|
def __init__(
|
|
self,
|
|
name: Optional[str] = None,
|
|
deployment_id: Optional[str] = None,
|
|
variables: Optional[Dict[str, Any]] = None,
|
|
description: str = "",
|
|
cacheClusterEnabled: Optional[bool] = False,
|
|
cacheClusterSize: Optional[str] = None,
|
|
tags: Optional[Dict[str, str]] = None,
|
|
tracing_enabled: Optional[bool] = None,
|
|
):
|
|
self.name = name
|
|
self.deployment_id = deployment_id
|
|
self.method_settings: Dict[str, Any] = {}
|
|
self.variables = variables or {}
|
|
self.description = description
|
|
self.cache_cluster_enabled = cacheClusterEnabled
|
|
self.cache_cluster_status = "AVAILABLE" if cacheClusterEnabled else None
|
|
self.cache_cluster_size = (
|
|
str(cacheClusterSize) if cacheClusterSize is not None else None
|
|
)
|
|
self.tags = tags
|
|
self.tracing_enabled = tracing_enabled
|
|
self.access_log_settings: Optional[Dict[str, Any]] = None
|
|
self.web_acl_arn: Optional[str] = None
|
|
|
|
def to_json(self) -> Dict[str, Any]:
|
|
dct: Dict[str, Any] = {
|
|
"stageName": self.name,
|
|
"deploymentId": self.deployment_id,
|
|
"methodSettings": self.method_settings,
|
|
"variables": self.variables,
|
|
"description": self.description,
|
|
"cacheClusterEnabled": self.cache_cluster_enabled,
|
|
"accessLogSettings": self.access_log_settings,
|
|
}
|
|
if self.cache_cluster_status is not None:
|
|
dct["cacheClusterStatus"] = self.cache_cluster_status
|
|
if self.cache_cluster_enabled:
|
|
if self.cache_cluster_size is not None:
|
|
dct["cacheClusterSize"] = self.cache_cluster_size
|
|
else:
|
|
dct["cacheClusterSize"] = "0.5"
|
|
if self.tags:
|
|
dct["tags"] = self.tags
|
|
if self.tracing_enabled is not None:
|
|
dct["tracingEnabled"] = self.tracing_enabled
|
|
if self.web_acl_arn is not None:
|
|
dct["webAclArn"] = self.web_acl_arn
|
|
return dct
|
|
|
|
def apply_operations(self, patch_operations: List[Dict[str, Any]]) -> "Stage":
|
|
for op in patch_operations:
|
|
if "variables/" in op["path"]:
|
|
self._apply_operation_to_variables(op)
|
|
elif "/cacheClusterEnabled" in op["path"]:
|
|
self.cache_cluster_enabled = self._str2bool(op["value"])
|
|
if self.cache_cluster_enabled:
|
|
self.cache_cluster_status = "AVAILABLE"
|
|
else:
|
|
self.cache_cluster_status = "NOT_AVAILABLE"
|
|
elif "/cacheClusterSize" in op["path"]:
|
|
self.cache_cluster_size = str(op["value"])
|
|
elif "/description" in op["path"]:
|
|
self.description = op["value"]
|
|
elif "/deploymentId" in op["path"]:
|
|
self.deployment_id = op["value"]
|
|
elif op["op"] == "replace":
|
|
if op["path"] == "/tracingEnabled":
|
|
self.tracing_enabled = self._str2bool(op["value"])
|
|
elif op["path"].startswith("/accessLogSettings/"):
|
|
self.access_log_settings = self.access_log_settings or {}
|
|
self.access_log_settings[op["path"].split("/")[-1]] = op["value"]
|
|
else:
|
|
# (e.g., path could be '/*/*/logging/loglevel')
|
|
split_path = op["path"].split("/", 3)
|
|
if len(split_path) != 4:
|
|
continue
|
|
self._patch_method_setting(
|
|
"/".join(split_path[1:3]), split_path[3], op["value"]
|
|
)
|
|
elif op["op"] == "remove":
|
|
if op["path"] == "/accessLogSettings":
|
|
self.access_log_settings = None
|
|
else:
|
|
raise ValidationException(
|
|
"Member must satisfy enum value set: [add, remove, move, test, replace, copy]"
|
|
)
|
|
return self
|
|
|
|
def _patch_method_setting(
|
|
self, resource_path_and_method: str, key: str, value: str
|
|
) -> None:
|
|
updated_key = self._method_settings_translations(key)
|
|
if updated_key is not None:
|
|
if resource_path_and_method not in self.method_settings:
|
|
self.method_settings[
|
|
resource_path_and_method
|
|
] = self._get_default_method_settings()
|
|
self.method_settings[resource_path_and_method][
|
|
updated_key
|
|
] = self._convert_to_type(updated_key, value)
|
|
|
|
def _get_default_method_settings(self) -> Dict[str, Any]:
|
|
return {
|
|
"throttlingRateLimit": 1000.0,
|
|
"dataTraceEnabled": False,
|
|
"metricsEnabled": False,
|
|
"unauthorizedCacheControlHeaderStrategy": "SUCCEED_WITH_RESPONSE_HEADER",
|
|
"cacheTtlInSeconds": 300,
|
|
"cacheDataEncrypted": True,
|
|
"cachingEnabled": False,
|
|
"throttlingBurstLimit": 2000,
|
|
"requireAuthorizationForCacheControl": True,
|
|
}
|
|
|
|
def _method_settings_translations(self, key: str) -> Optional[str]:
|
|
mappings = {
|
|
"metrics/enabled": "metricsEnabled",
|
|
"logging/loglevel": "loggingLevel",
|
|
"logging/dataTrace": "dataTraceEnabled",
|
|
"throttling/burstLimit": "throttlingBurstLimit",
|
|
"throttling/rateLimit": "throttlingRateLimit",
|
|
"caching/enabled": "cachingEnabled",
|
|
"caching/ttlInSeconds": "cacheTtlInSeconds",
|
|
"caching/dataEncrypted": "cacheDataEncrypted",
|
|
"caching/requireAuthorizationForCacheControl": "requireAuthorizationForCacheControl",
|
|
"caching/unauthorizedCacheControlHeaderStrategy": "unauthorizedCacheControlHeaderStrategy",
|
|
}
|
|
|
|
return mappings.get(key)
|
|
|
|
def _str2bool(self, v: str) -> bool:
|
|
return v.lower() == "true"
|
|
|
|
def _convert_to_type(self, key: str, val: str) -> Union[str, int, float]:
|
|
type_mappings = {
|
|
"metricsEnabled": "bool",
|
|
"loggingLevel": "str",
|
|
"dataTraceEnabled": "bool",
|
|
"throttlingBurstLimit": "int",
|
|
"throttlingRateLimit": "float",
|
|
"cachingEnabled": "bool",
|
|
"cacheTtlInSeconds": "int",
|
|
"cacheDataEncrypted": "bool",
|
|
"requireAuthorizationForCacheControl": "bool",
|
|
"unauthorizedCacheControlHeaderStrategy": "str",
|
|
}
|
|
|
|
if key in type_mappings:
|
|
type_value = type_mappings[key]
|
|
|
|
if type_value == "bool":
|
|
return self._str2bool(val)
|
|
elif type_value == "int":
|
|
return int(val)
|
|
elif type_value == "float":
|
|
return float(val)
|
|
else:
|
|
return str(val)
|
|
else:
|
|
return str(val)
|
|
|
|
def _apply_operation_to_variables(self, op: Dict[str, Any]) -> None:
|
|
key = op["path"][op["path"].rindex("variables/") + 10 :]
|
|
if op["op"] == "remove":
|
|
self.variables.pop(key, None)
|
|
elif op["op"] == "replace":
|
|
self.variables[key] = op["value"]
|
|
else:
|
|
raise Exception(f'Patch operation "{op["op"]}" not implemented')
|
|
|
|
|
|
class ApiKey(BaseModel):
|
|
def __init__(
|
|
self,
|
|
name: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
enabled: bool = False,
|
|
generateDistinctId: bool = False, # pylint: disable=unused-argument
|
|
value: Optional[str] = None,
|
|
stageKeys: Optional[Any] = None,
|
|
tags: Optional[List[Dict[str, str]]] = None,
|
|
customerId: Optional[str] = None,
|
|
):
|
|
self.id = create_id()
|
|
self.value = value or "".join(
|
|
random.sample(string.ascii_letters + string.digits, 40)
|
|
)
|
|
self.name = name
|
|
self.customer_id = customerId
|
|
self.description = description
|
|
self.enabled = enabled
|
|
self.created_date = self.last_updated_date = int(time.time())
|
|
self.stage_keys = stageKeys or []
|
|
self.tags = tags
|
|
|
|
def to_json(self) -> Dict[str, Any]:
|
|
return {
|
|
"id": self.id,
|
|
"value": self.value,
|
|
"name": self.name,
|
|
"customerId": self.customer_id,
|
|
"description": self.description,
|
|
"enabled": self.enabled,
|
|
"createdDate": self.created_date,
|
|
"lastUpdatedDate": self.last_updated_date,
|
|
"stageKeys": self.stage_keys,
|
|
"tags": self.tags,
|
|
}
|
|
|
|
def update_operations(self, patch_operations: List[Dict[str, Any]]) -> "ApiKey":
|
|
for op in patch_operations:
|
|
if op["op"] == "replace":
|
|
if "/name" in op["path"]:
|
|
self.name = op["value"]
|
|
elif "/customerId" in op["path"]:
|
|
self.customer_id = op["value"]
|
|
elif "/description" in op["path"]:
|
|
self.description = op["value"]
|
|
elif "/enabled" in op["path"]:
|
|
self.enabled = self._str2bool(op["value"])
|
|
else:
|
|
raise Exception(f'Patch operation "{op["op"]}" not implemented')
|
|
return self
|
|
|
|
def _str2bool(self, v: str) -> bool:
|
|
return v.lower() == "true"
|
|
|
|
|
|
class UsagePlan(BaseModel):
|
|
def __init__(
|
|
self,
|
|
name: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
apiStages: Any = None,
|
|
throttle: Optional[Dict[str, Any]] = None,
|
|
quota: Optional[Dict[str, Any]] = None,
|
|
productCode: Optional[str] = None,
|
|
tags: Optional[List[Dict[str, str]]] = None,
|
|
):
|
|
self.id = create_id()
|
|
self.name = name
|
|
self.description = description
|
|
self.api_stages = apiStages or []
|
|
self.throttle = throttle or {}
|
|
self.quota = quota or {}
|
|
self.product_code = productCode
|
|
self.tags = tags
|
|
|
|
def to_json(self) -> Dict[str, Any]:
|
|
return {
|
|
"id": self.id,
|
|
"name": self.name,
|
|
"description": self.description,
|
|
"apiStages": self.api_stages,
|
|
"throttle": self.throttle,
|
|
"quota": self.quota,
|
|
"productCode": self.product_code,
|
|
"tags": self.tags,
|
|
}
|
|
|
|
def apply_patch_operations(self, patch_operations: List[Dict[str, Any]]) -> None:
|
|
for op in patch_operations:
|
|
path = op["path"]
|
|
value = op["value"]
|
|
if op["op"] == "replace":
|
|
if "/name" in path:
|
|
self.name = value
|
|
if "/productCode" in path:
|
|
self.product_code = value
|
|
if "/description" in path:
|
|
self.description = value
|
|
if "/quota/limit" in path:
|
|
self.quota["limit"] = value
|
|
if "/quota/period" in path:
|
|
self.quota["period"] = value
|
|
if "/throttle/rateLimit" in path:
|
|
self.throttle["rateLimit"] = value
|
|
if "/throttle/burstLimit" in path:
|
|
self.throttle["burstLimit"] = value
|
|
|
|
|
|
class RequestValidator(BaseModel):
|
|
PROP_ID = "id"
|
|
PROP_NAME = "name"
|
|
PROP_VALIDATE_REQUEST_BODY = "validateRequestBody"
|
|
PROP_VALIDATE_REQUEST_PARAMETERS = "validateRequestParameters"
|
|
|
|
# operations
|
|
OP_PATH = "path"
|
|
OP_VALUE = "value"
|
|
OP_REPLACE = "replace"
|
|
OP_OP = "op"
|
|
|
|
def __init__(
|
|
self,
|
|
_id: str,
|
|
name: str,
|
|
validateRequestBody: Optional[bool],
|
|
validateRequestParameters: Any,
|
|
):
|
|
self.id = _id
|
|
self.name = name
|
|
self.validate_request_body = validateRequestBody
|
|
self.validate_request_parameters = validateRequestParameters
|
|
|
|
def apply_patch_operations(self, operations: List[Dict[str, Any]]) -> None:
|
|
for operation in operations:
|
|
path = operation[RequestValidator.OP_PATH]
|
|
value = operation[RequestValidator.OP_VALUE]
|
|
if operation[RequestValidator.OP_OP] == RequestValidator.OP_REPLACE:
|
|
if to_path(RequestValidator.PROP_NAME) in path:
|
|
self.name = value
|
|
if to_path(RequestValidator.PROP_VALIDATE_REQUEST_BODY) in path:
|
|
self.validate_request_body = value.lower() in ("true")
|
|
if to_path(RequestValidator.PROP_VALIDATE_REQUEST_PARAMETERS) in path:
|
|
self.validate_request_parameters = value.lower() in ("true")
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
RequestValidator.PROP_ID: self.id,
|
|
RequestValidator.PROP_NAME: self.name,
|
|
RequestValidator.PROP_VALIDATE_REQUEST_BODY: self.validate_request_body,
|
|
RequestValidator.PROP_VALIDATE_REQUEST_PARAMETERS: self.validate_request_parameters,
|
|
}
|
|
|
|
|
|
class UsagePlanKey(BaseModel):
|
|
def __init__(self, plan_id: str, plan_type: str, name: Optional[str], value: str):
|
|
self.id = plan_id
|
|
self.name = name
|
|
self.type = plan_type
|
|
self.value = value
|
|
|
|
def to_json(self) -> Dict[str, Any]:
|
|
return {
|
|
"id": self.id,
|
|
"name": self.name,
|
|
"type": self.type,
|
|
"value": self.value,
|
|
}
|
|
|
|
|
|
class VpcLink(BaseModel):
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
description: str,
|
|
target_arns: List[str],
|
|
tags: List[Dict[str, str]],
|
|
):
|
|
self.id = create_id()
|
|
self.name = name
|
|
self.description = description
|
|
self.target_arns = target_arns
|
|
self.tags = tags
|
|
self.status = "AVAILABLE"
|
|
|
|
def to_json(self) -> Dict[str, Any]:
|
|
return {
|
|
"id": self.id,
|
|
"name": self.name,
|
|
"description": self.description,
|
|
"targetArns": self.target_arns,
|
|
"tags": self.tags,
|
|
"status": self.status,
|
|
}
|
|
|
|
|
|
class RestAPI(CloudFormationModel):
|
|
|
|
PROP_ID = "id"
|
|
PROP_NAME = "name"
|
|
PROP_DESCRIPTION = "description"
|
|
PROP_VERSION = "version"
|
|
PROP_BINARY_MEDIA_TYPES = "binaryMediaTypes"
|
|
PROP_CREATED_DATE = "createdDate"
|
|
PROP_API_KEY_SOURCE = "apiKeySource"
|
|
PROP_ENDPOINT_CONFIGURATION = "endpointConfiguration"
|
|
PROP_TAGS = "tags"
|
|
PROP_POLICY = "policy"
|
|
PROP_DISABLE_EXECUTE_API_ENDPOINT = "disableExecuteApiEndpoint"
|
|
PROP_MINIMUM_COMPRESSION_SIZE = "minimumCompressionSize"
|
|
|
|
# operations
|
|
OPERATION_ADD = "add"
|
|
OPERATION_REPLACE = "replace"
|
|
OPERATION_REMOVE = "remove"
|
|
OPERATION_PATH = "path"
|
|
OPERATION_VALUE = "value"
|
|
OPERATION_OP = "op"
|
|
|
|
def __init__(
|
|
self,
|
|
api_id: str,
|
|
account_id: str,
|
|
region_name: str,
|
|
name: str,
|
|
description: str,
|
|
**kwargs: Any,
|
|
):
|
|
self.id = api_id
|
|
self.account_id = account_id
|
|
self.region_name = region_name
|
|
self.name = name
|
|
self.description = description
|
|
self.version = kwargs.get(RestAPI.PROP_VERSION) or "V1"
|
|
self.binaryMediaTypes = kwargs.get(RestAPI.PROP_BINARY_MEDIA_TYPES) or []
|
|
self.create_date = int(time.time())
|
|
self.api_key_source = kwargs.get("api_key_source") or "HEADER"
|
|
self.policy = kwargs.get(RestAPI.PROP_POLICY) or None
|
|
self.endpoint_configuration = kwargs.get("endpoint_configuration") or {
|
|
"types": ["EDGE"]
|
|
}
|
|
self.tags = kwargs.get(RestAPI.PROP_TAGS) or {}
|
|
self.disableExecuteApiEndpoint = (
|
|
kwargs.get("disable_execute_api_endpoint") or False
|
|
)
|
|
self.minimum_compression_size = kwargs.get("minimum_compression_size")
|
|
self.deployments: Dict[str, Deployment] = {}
|
|
self.authorizers: Dict[str, Authorizer] = {}
|
|
self.gateway_responses: Dict[str, GatewayResponse] = {}
|
|
self.stages: Dict[str, Stage] = {}
|
|
self.resources: Dict[str, Resource] = {}
|
|
self.models: Dict[str, Model] = {}
|
|
self.request_validators: Dict[str, RequestValidator] = {}
|
|
self.default = self.add_child("/") # Add default child
|
|
|
|
def __repr__(self) -> str:
|
|
return str(self.id)
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
self.PROP_ID: self.id,
|
|
self.PROP_NAME: self.name,
|
|
self.PROP_DESCRIPTION: self.description,
|
|
self.PROP_VERSION: self.version,
|
|
self.PROP_BINARY_MEDIA_TYPES: self.binaryMediaTypes,
|
|
self.PROP_CREATED_DATE: self.create_date,
|
|
self.PROP_API_KEY_SOURCE: self.api_key_source,
|
|
self.PROP_ENDPOINT_CONFIGURATION: self.endpoint_configuration,
|
|
self.PROP_TAGS: self.tags,
|
|
self.PROP_POLICY: self.policy,
|
|
self.PROP_DISABLE_EXECUTE_API_ENDPOINT: self.disableExecuteApiEndpoint,
|
|
self.PROP_MINIMUM_COMPRESSION_SIZE: self.minimum_compression_size,
|
|
}
|
|
|
|
def apply_patch_operations(self, patch_operations: List[Dict[str, Any]]) -> None:
|
|
for op in patch_operations:
|
|
path = op[self.OPERATION_PATH]
|
|
value = ""
|
|
if self.OPERATION_VALUE in op:
|
|
value = op[self.OPERATION_VALUE]
|
|
operaton = op[self.OPERATION_OP]
|
|
if operaton == self.OPERATION_REPLACE:
|
|
if to_path(self.PROP_NAME) in path:
|
|
self.name = value
|
|
if to_path(self.PROP_DESCRIPTION) in path:
|
|
self.description = value
|
|
if to_path(self.PROP_API_KEY_SOURCE) in path:
|
|
self.api_key_source = value
|
|
if to_path(self.PROP_BINARY_MEDIA_TYPES) in path:
|
|
self.binaryMediaTypes = [value]
|
|
if to_path(self.PROP_DISABLE_EXECUTE_API_ENDPOINT) in path:
|
|
self.disableExecuteApiEndpoint = bool(value)
|
|
if to_path(self.PROP_POLICY) in path:
|
|
self.policy = value
|
|
elif operaton == self.OPERATION_ADD:
|
|
if to_path(self.PROP_BINARY_MEDIA_TYPES) in path:
|
|
self.binaryMediaTypes.append(value)
|
|
elif operaton == self.OPERATION_REMOVE:
|
|
if to_path(self.PROP_BINARY_MEDIA_TYPES) in path:
|
|
self.binaryMediaTypes.remove(value)
|
|
if to_path(self.PROP_DESCRIPTION) in path:
|
|
self.description = ""
|
|
|
|
@classmethod
|
|
def has_cfn_attr(cls, attr: str) -> bool:
|
|
return attr in ["RootResourceId"]
|
|
|
|
def get_cfn_attribute(self, attribute_name: str) -> Any:
|
|
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
|
|
|
|
if attribute_name == "RootResourceId":
|
|
for res_id, res_obj in self.resources.items():
|
|
if res_obj.path_part == "/" and not res_obj.parent_id:
|
|
return res_id
|
|
raise Exception(f"Unable to find root resource for API {self}")
|
|
raise UnformattedGetAttTemplateException()
|
|
|
|
@property
|
|
def physical_resource_id(self) -> str:
|
|
return self.id
|
|
|
|
@staticmethod
|
|
def cloudformation_name_type() -> str:
|
|
return "RestApi"
|
|
|
|
@staticmethod
|
|
def cloudformation_type() -> str:
|
|
return "AWS::ApiGateway::RestApi"
|
|
|
|
@classmethod
|
|
def create_from_cloudformation_json( # type: ignore[misc]
|
|
cls,
|
|
resource_name: str,
|
|
cloudformation_json: Dict[str, Any],
|
|
account_id: str,
|
|
region_name: str,
|
|
**kwargs: Any,
|
|
) -> "RestAPI":
|
|
properties = cloudformation_json["Properties"]
|
|
name = properties["Name"]
|
|
desc = properties.get("Description", "")
|
|
config = properties.get("EndpointConfiguration", None)
|
|
backend = apigateway_backends[account_id][region_name]
|
|
return backend.create_rest_api(
|
|
name=name, description=desc, endpoint_configuration=config
|
|
)
|
|
|
|
def add_child(self, path: str, parent_id: Optional[str] = None) -> Resource:
|
|
child_id = create_id()
|
|
child = Resource(
|
|
resource_id=child_id,
|
|
account_id=self.account_id,
|
|
region_name=self.region_name,
|
|
api_id=self.id,
|
|
path_part=path,
|
|
parent_id=parent_id,
|
|
)
|
|
self.resources[child_id] = child
|
|
return child
|
|
|
|
def add_model(
|
|
self,
|
|
name: str,
|
|
description: str,
|
|
schema: str,
|
|
content_type: str,
|
|
) -> "Model":
|
|
model_id = create_id()
|
|
new_model = Model(
|
|
model_id=model_id,
|
|
name=name,
|
|
description=description,
|
|
schema=schema,
|
|
content_type=content_type,
|
|
)
|
|
|
|
self.models[name] = new_model
|
|
return new_model
|
|
|
|
def get_resource_for_path(self, path_after_stage_name: str) -> Resource: # type: ignore[return]
|
|
for resource in self.resources.values():
|
|
if resource.get_path() == path_after_stage_name:
|
|
return resource
|
|
# TODO deal with no matching resource
|
|
|
|
def resource_callback(
|
|
self, request: Any
|
|
) -> Tuple[int, Dict[str, str], Union[str, bytes]]:
|
|
path = path_url(request.url)
|
|
path_after_stage_name = "/" + "/".join(path.split("/")[2:])
|
|
|
|
resource = self.get_resource_for_path(path_after_stage_name)
|
|
status_code, response = resource.get_response(request)
|
|
return status_code, {}, response
|
|
|
|
def update_integration_mocks(self, stage_name: str) -> None:
|
|
stage_url_lower = STAGE_URL.format(
|
|
api_id=self.id.lower(), region_name=self.region_name, stage_name=stage_name
|
|
)
|
|
stage_url_upper = STAGE_URL.format(
|
|
api_id=self.id.upper(), region_name=self.region_name, stage_name=stage_name
|
|
)
|
|
|
|
for resource in self.resources.values():
|
|
path = resource.get_path()
|
|
path = "" if path == "/" else path
|
|
|
|
for http_method in resource.resource_methods.keys():
|
|
for url in [stage_url_lower, stage_url_upper]:
|
|
callback_response = responses.CallbackResponse(
|
|
url=url + path,
|
|
method=http_method,
|
|
callback=self.resource_callback,
|
|
content_type="text/plain",
|
|
)
|
|
responses_mock.add(callback_response)
|
|
|
|
def create_authorizer(
|
|
self,
|
|
authorizer_id: str,
|
|
name: str,
|
|
authorizer_type: str,
|
|
provider_arns: Optional[List[str]],
|
|
auth_type: Optional[str],
|
|
authorizer_uri: Optional[str],
|
|
authorizer_credentials: Optional[str],
|
|
identity_source: Optional[str],
|
|
identiy_validation_expression: Optional[str],
|
|
authorizer_result_ttl: Optional[int],
|
|
) -> Authorizer:
|
|
authorizer = Authorizer(
|
|
authorizer_id=authorizer_id,
|
|
name=name,
|
|
authorizer_type=authorizer_type,
|
|
provider_arns=provider_arns,
|
|
auth_type=auth_type,
|
|
authorizer_uri=authorizer_uri,
|
|
authorizer_credentials=authorizer_credentials,
|
|
identity_source=identity_source,
|
|
identiy_validation_expression=identiy_validation_expression,
|
|
authorizer_result_ttl=authorizer_result_ttl,
|
|
)
|
|
self.authorizers[authorizer_id] = authorizer
|
|
return authorizer
|
|
|
|
def create_stage(
|
|
self,
|
|
name: str,
|
|
deployment_id: str,
|
|
variables: Any,
|
|
description: str,
|
|
cacheClusterEnabled: Optional[bool],
|
|
cacheClusterSize: Optional[str],
|
|
tags: Optional[Dict[str, str]],
|
|
tracing_enabled: Optional[bool],
|
|
) -> Stage:
|
|
if name in self.stages:
|
|
raise ConflictException("Stage already exists")
|
|
if variables is None:
|
|
variables = {}
|
|
stage = Stage(
|
|
name=name,
|
|
deployment_id=deployment_id,
|
|
variables=variables,
|
|
description=description,
|
|
cacheClusterSize=cacheClusterSize,
|
|
cacheClusterEnabled=cacheClusterEnabled,
|
|
tags=tags,
|
|
tracing_enabled=tracing_enabled,
|
|
)
|
|
self.stages[name] = stage
|
|
self.update_integration_mocks(name)
|
|
return stage
|
|
|
|
def create_deployment(
|
|
self, name: str, description: str, stage_variables: Any = None
|
|
) -> Deployment:
|
|
if stage_variables is None:
|
|
stage_variables = {}
|
|
deployment_id = create_id()
|
|
deployment = Deployment(deployment_id, name, description)
|
|
self.deployments[deployment_id] = deployment
|
|
if name:
|
|
self.stages[name] = Stage(
|
|
name=name, deployment_id=deployment_id, variables=stage_variables
|
|
)
|
|
self.update_integration_mocks(name)
|
|
|
|
return deployment
|
|
|
|
def get_deployment(self, deployment_id: str) -> Deployment:
|
|
return self.deployments[deployment_id]
|
|
|
|
def get_authorizers(self) -> List[Authorizer]:
|
|
return list(self.authorizers.values())
|
|
|
|
def get_stages(self) -> List[Stage]:
|
|
return list(self.stages.values())
|
|
|
|
def get_deployments(self) -> List[Deployment]:
|
|
return list(self.deployments.values())
|
|
|
|
def delete_deployment(self, deployment_id: str) -> Deployment:
|
|
if deployment_id not in self.deployments:
|
|
raise DeploymentNotFoundException()
|
|
deployment = self.deployments[deployment_id]
|
|
if deployment.stage_name and deployment.stage_name in self.stages:
|
|
# Stage is still active
|
|
raise StageStillActive()
|
|
|
|
return self.deployments.pop(deployment_id)
|
|
|
|
def create_request_validator(
|
|
self,
|
|
name: str,
|
|
validateRequestBody: Optional[bool],
|
|
validateRequestParameters: Any,
|
|
) -> RequestValidator:
|
|
validator_id = create_id()
|
|
request_validator = RequestValidator(
|
|
_id=validator_id,
|
|
name=name,
|
|
validateRequestBody=validateRequestBody,
|
|
validateRequestParameters=validateRequestParameters,
|
|
)
|
|
self.request_validators[validator_id] = request_validator
|
|
return request_validator
|
|
|
|
def get_request_validators(self) -> List[RequestValidator]:
|
|
return list(self.request_validators.values())
|
|
|
|
def get_request_validator(self, validator_id: str) -> RequestValidator:
|
|
reqeust_validator = self.request_validators.get(validator_id)
|
|
if reqeust_validator is None:
|
|
raise RequestValidatorNotFound()
|
|
return reqeust_validator
|
|
|
|
def delete_request_validator(self, validator_id: str) -> RequestValidator:
|
|
return self.request_validators.pop(validator_id)
|
|
|
|
def update_request_validator(
|
|
self, validator_id: str, patch_operations: List[Dict[str, Any]]
|
|
) -> RequestValidator:
|
|
self.request_validators[validator_id].apply_patch_operations(patch_operations)
|
|
return self.request_validators[validator_id]
|
|
|
|
def put_gateway_response(
|
|
self,
|
|
response_type: str,
|
|
status_code: int,
|
|
response_parameters: Dict[str, Any],
|
|
response_templates: Dict[str, str],
|
|
) -> "GatewayResponse":
|
|
response = GatewayResponse(
|
|
response_type=response_type,
|
|
status_code=status_code,
|
|
response_parameters=response_parameters,
|
|
response_templates=response_templates,
|
|
)
|
|
self.gateway_responses[response_type] = response
|
|
return response
|
|
|
|
def get_gateway_response(self, response_type: str) -> "GatewayResponse":
|
|
if response_type not in self.gateway_responses:
|
|
raise GatewayResponseNotFound()
|
|
return self.gateway_responses[response_type]
|
|
|
|
def get_gateway_responses(self) -> List["GatewayResponse"]:
|
|
return list(self.gateway_responses.values())
|
|
|
|
def delete_gateway_response(self, response_type: str) -> None:
|
|
self.gateway_responses.pop(response_type, None)
|
|
|
|
|
|
class DomainName(BaseModel):
|
|
def __init__(self, domain_name: str, **kwargs: Any):
|
|
self.domain_name = domain_name
|
|
region = kwargs.get("region_name") or "us-east-1"
|
|
self.regional_domain_name = (
|
|
f"d-{create_id()}.execute-api.{region}.amazonaws.com"
|
|
)
|
|
self.distribution_domain_name = f"d{create_id()}.cloudfront.net"
|
|
self.domain_name_status = "AVAILABLE"
|
|
self.status_message = "Domain Name Available"
|
|
self.regional_hosted_zone_id = "Z2FDTNDATAQYW2"
|
|
self.distribution_hosted_zone_id = "Z2FDTNDATAQYW2"
|
|
self.certificate_upload_date = int(time.time())
|
|
self.certificate_name = kwargs.get("certificate_name")
|
|
self.certificate_arn = kwargs.get("certificate_arn")
|
|
self.certificate_body = kwargs.get("certificate_body")
|
|
self.tags = kwargs.get("tags")
|
|
self.security_policy = kwargs.get("security_policy")
|
|
self.certificate_chain = kwargs.get("certificate_chain")
|
|
self.regional_certificate_name = kwargs.get("regional_certificate_name")
|
|
self.certificate_private_key = kwargs.get("certificate_private_key")
|
|
self.regional_certificate_arn = kwargs.get("regional_certificate_arn")
|
|
self.endpoint_configuration = kwargs.get("endpoint_configuration")
|
|
|
|
def to_json(self) -> Dict[str, Any]:
|
|
dct = {
|
|
"domainName": self.domain_name,
|
|
"regionalDomainName": self.regional_domain_name,
|
|
"distributionDomainName": self.distribution_domain_name,
|
|
"domainNameStatus": self.domain_name_status,
|
|
"domainNameStatusMessage": self.status_message,
|
|
"regionalHostedZoneId": self.regional_hosted_zone_id,
|
|
"distributionHostedZoneId": self.distribution_hosted_zone_id,
|
|
"certificateUploadDate": self.certificate_upload_date,
|
|
}
|
|
if self.certificate_name:
|
|
dct["certificateName"] = self.certificate_name
|
|
if self.certificate_arn:
|
|
dct["certificateArn"] = self.certificate_arn
|
|
if self.certificate_body:
|
|
dct["certificateBody"] = self.certificate_body
|
|
if self.tags:
|
|
dct["tags"] = self.tags
|
|
if self.security_policy:
|
|
dct["securityPolicy"] = self.security_policy
|
|
if self.certificate_chain:
|
|
dct["certificateChain"] = self.certificate_chain
|
|
if self.regional_certificate_name:
|
|
dct["regionalCertificateName"] = self.regional_certificate_name
|
|
if self.certificate_private_key:
|
|
dct["certificatePrivateKey"] = self.certificate_private_key
|
|
if self.regional_certificate_arn:
|
|
dct["regionalCertificateArn"] = self.regional_certificate_arn
|
|
if self.endpoint_configuration:
|
|
dct["endpointConfiguration"] = self.endpoint_configuration
|
|
return dct
|
|
|
|
|
|
class Model(BaseModel):
|
|
def __init__(self, model_id: str, name: str, **kwargs: Any):
|
|
self.id = model_id
|
|
self.name = name
|
|
self.description = kwargs.get("description")
|
|
self.schema = kwargs.get("schema")
|
|
self.content_type = kwargs.get("content_type")
|
|
|
|
def to_json(self) -> Dict[str, Any]:
|
|
dct = {
|
|
"id": self.id,
|
|
"name": self.name,
|
|
}
|
|
if self.description:
|
|
dct["description"] = self.description
|
|
if self.schema:
|
|
dct["schema"] = self.schema
|
|
if self.content_type:
|
|
dct["contentType"] = self.content_type
|
|
return dct
|
|
|
|
|
|
class BasePathMapping(BaseModel):
|
|
|
|
# operations
|
|
OPERATION_REPLACE = "replace"
|
|
OPERATION_PATH = "path"
|
|
OPERATION_VALUE = "value"
|
|
OPERATION_OP = "op"
|
|
|
|
def __init__(self, domain_name: str, rest_api_id: str, **kwargs: Any):
|
|
self.domain_name = domain_name
|
|
self.rest_api_id = rest_api_id
|
|
self.base_path = kwargs.get("basePath") or "(none)"
|
|
self.stage = kwargs.get("stage")
|
|
|
|
def to_json(self) -> Dict[str, Any]:
|
|
dct = {
|
|
"domain_name": self.domain_name,
|
|
"restApiId": self.rest_api_id,
|
|
"basePath": self.base_path,
|
|
}
|
|
if self.stage is not None:
|
|
dct["stage"] = self.stage
|
|
return dct
|
|
|
|
def apply_patch_operations(self, patch_operations: List[Dict[str, Any]]) -> None:
|
|
for op in patch_operations:
|
|
path = op["path"]
|
|
value = op["value"]
|
|
operation = op["op"]
|
|
if operation == self.OPERATION_REPLACE:
|
|
if "/basePath" in path:
|
|
self.base_path = value
|
|
if "/restapiId" in path:
|
|
self.rest_api_id = value
|
|
if "/stage" in path:
|
|
self.stage = value
|
|
|
|
|
|
class GatewayResponse(BaseModel):
|
|
def __init__(
|
|
self,
|
|
response_type: str,
|
|
status_code: int,
|
|
response_parameters: Dict[str, Any],
|
|
response_templates: Dict[str, str],
|
|
):
|
|
self.response_type = response_type
|
|
self.default_response = False
|
|
self.status_code = status_code
|
|
self.response_parameters = response_parameters
|
|
self.response_templates = response_templates
|
|
|
|
def to_json(self) -> Dict[str, Any]:
|
|
dct = {
|
|
"responseType": self.response_type,
|
|
"defaultResponse": self.default_response,
|
|
}
|
|
if self.status_code is not None:
|
|
dct["statusCode"] = self.status_code
|
|
if self.response_parameters is not None:
|
|
dct["responseParameters"] = self.response_parameters
|
|
if self.response_templates is not None:
|
|
dct["responseTemplates"] = self.response_templates
|
|
return dct
|
|
|
|
|
|
class APIGatewayBackend(BaseBackend):
|
|
"""
|
|
API Gateway mock.
|
|
|
|
The public URLs of an API integration are mocked as well, i.e. the following would be supported in Moto:
|
|
|
|
.. sourcecode:: python
|
|
|
|
client.put_integration(
|
|
restApiId=api_id,
|
|
...,
|
|
uri="http://httpbin.org/robots.txt",
|
|
integrationHttpMethod="GET"
|
|
)
|
|
deploy_url = f"https://{api_id}.execute-api.us-east-1.amazonaws.com/dev"
|
|
assert requests.get(deploy_url).content == b"a fake response"
|
|
|
|
Limitations:
|
|
- Integrations of type HTTP are supported
|
|
- Integrations of type AWS with service DynamoDB are supported
|
|
- Other types (AWS_PROXY, MOCK, etc) are ignored
|
|
- Other services are not yet supported
|
|
- The BasePath of an API is ignored
|
|
- TemplateMapping is not yet supported for requests/responses
|
|
- This only works when using the decorators, not in ServerMode
|
|
"""
|
|
|
|
def __init__(self, region_name: str, account_id: str):
|
|
super().__init__(region_name, account_id)
|
|
self.apis: Dict[str, RestAPI] = {}
|
|
self.keys: Dict[str, ApiKey] = {}
|
|
self.usage_plans: Dict[str, UsagePlan] = {}
|
|
self.usage_plan_keys: Dict[str, Dict[str, UsagePlanKey]] = {}
|
|
self.domain_names: Dict[str, DomainName] = {}
|
|
self.models: Dict[str, Model] = {}
|
|
self.base_path_mappings: Dict[str, Dict[str, BasePathMapping]] = {}
|
|
self.vpc_links: Dict[str, VpcLink] = {}
|
|
|
|
def create_rest_api(
|
|
self,
|
|
name: str,
|
|
description: str,
|
|
api_key_source: Optional[str] = None,
|
|
endpoint_configuration: Optional[str] = None,
|
|
tags: Optional[List[Dict[str, str]]] = None,
|
|
policy: Optional[str] = None,
|
|
minimum_compression_size: Optional[int] = None,
|
|
disable_execute_api_endpoint: Optional[bool] = None,
|
|
) -> RestAPI:
|
|
api_id = create_id()
|
|
rest_api = RestAPI(
|
|
api_id,
|
|
self.account_id,
|
|
self.region_name,
|
|
name,
|
|
description,
|
|
api_key_source=api_key_source,
|
|
endpoint_configuration=endpoint_configuration,
|
|
tags=tags,
|
|
policy=policy,
|
|
minimum_compression_size=minimum_compression_size,
|
|
disable_execute_api_endpoint=disable_execute_api_endpoint,
|
|
)
|
|
self.apis[api_id] = rest_api
|
|
return rest_api
|
|
|
|
def import_rest_api(
|
|
self, api_doc: Dict[str, Any], fail_on_warnings: bool
|
|
) -> RestAPI:
|
|
"""
|
|
Only a subset of the OpenAPI spec 3.x is currently implemented.
|
|
"""
|
|
if fail_on_warnings:
|
|
try:
|
|
validate(api_doc) # type: ignore[arg-type]
|
|
except OpenAPIValidationError as e:
|
|
raise InvalidOpenAPIDocumentException(e)
|
|
except AttributeError:
|
|
# Call can fail in Python3.7 due to `typing_extensions 4.6.0` throwing an error
|
|
# Easiest to just ignore this for now - Py3.7 is EOL soon anyway
|
|
pass
|
|
name = api_doc["info"]["title"]
|
|
description = api_doc["info"]["description"]
|
|
api = self.create_rest_api(name=name, description=description)
|
|
self.put_rest_api(api.id, api_doc, fail_on_warnings=fail_on_warnings)
|
|
return api
|
|
|
|
def export_api(self, rest_api_id: str, export_type: str) -> Dict[str, Any]:
|
|
"""
|
|
Not all fields are implemented yet.
|
|
The export-type is currently ignored - we will only return the 'swagger'-format
|
|
"""
|
|
try:
|
|
api = self.get_rest_api(rest_api_id)
|
|
except RestAPINotFound:
|
|
raise StageNotFoundException
|
|
if export_type not in ["swagger", "oas30"]:
|
|
raise BadRequestException(f"No API exporter for type '{export_type}'")
|
|
now = datetime.now().strftime("%Y-%m-%dT%H:%m:%S")
|
|
resp: Dict[str, Any] = {
|
|
"swagger": "2.0",
|
|
"info": {"version": now, "title": api.name},
|
|
"host": f"{api.id}.execute-api.{self.region_name}.amazonaws.com",
|
|
"basePath": "/",
|
|
"schemes": ["https"],
|
|
"paths": {},
|
|
"definitions": {"Empty": {"type": "object", "title": "Empty Schema"}},
|
|
}
|
|
for res in api.resources.values():
|
|
path = res.get_path()
|
|
resp["paths"][path] = {}
|
|
for method_type, method in res.resource_methods.items():
|
|
resp["paths"][path][method_type] = {
|
|
"produces": ["application/json"],
|
|
"responses": {},
|
|
}
|
|
for code, _ in method.method_responses.items():
|
|
resp["paths"][path][method_type]["responses"][code] = {
|
|
"description": f"{code} response",
|
|
"schema": {"$ref": "#/definitions/Empty"},
|
|
}
|
|
return resp
|
|
|
|
def get_rest_api(self, function_id: str) -> RestAPI:
|
|
rest_api = self.apis.get(function_id)
|
|
if rest_api is None:
|
|
raise RestAPINotFound()
|
|
return rest_api
|
|
|
|
def put_rest_api(
|
|
self,
|
|
function_id: str,
|
|
api_doc: Dict[str, Any],
|
|
fail_on_warnings: bool,
|
|
mode: str = "merge",
|
|
) -> RestAPI:
|
|
"""
|
|
Only a subset of the OpenAPI spec 3.x is currently implemented.
|
|
"""
|
|
if mode not in ["merge", "overwrite"]:
|
|
raise InvalidOpenApiModeException()
|
|
|
|
if api_doc.get("swagger") is not None or (
|
|
api_doc.get("openapi") is not None and api_doc["openapi"][0] != "3"
|
|
):
|
|
raise InvalidOpenApiDocVersionException()
|
|
|
|
if fail_on_warnings:
|
|
try:
|
|
validate(api_doc) # type: ignore[arg-type]
|
|
except OpenAPIValidationError as e:
|
|
raise InvalidOpenAPIDocumentException(e)
|
|
except AttributeError:
|
|
# Call can fail in Python3.7 due to `typing_extensions 4.6.0` throwing an error
|
|
# Easiest to just ignore this for now - Py3.7 is EOL soon anyway
|
|
pass
|
|
|
|
if mode == "overwrite":
|
|
api = self.get_rest_api(function_id)
|
|
api.resources = {}
|
|
api.default = api.add_child("/") # Add default child
|
|
|
|
for (path, resource_doc) in sorted(
|
|
api_doc["paths"].items(), key=lambda x: x[0]
|
|
):
|
|
# We may want to create a path like /store/inventory
|
|
# Ensure that /store exists first, so we can use it as a parent
|
|
ancestors = path.split("/")[
|
|
1:-1
|
|
] # skip first (empty), skip last (child) - only process ancestors
|
|
direct_parent = ""
|
|
parent_id = self.apis[function_id].get_resource_for_path("/").id
|
|
for a in ancestors:
|
|
res = self.apis[function_id].get_resource_for_path(
|
|
direct_parent + "/" + a
|
|
)
|
|
if res is None:
|
|
res = self.create_resource(
|
|
function_id=function_id,
|
|
parent_resource_id=parent_id,
|
|
path_part=a,
|
|
)
|
|
parent_id = res.id
|
|
direct_parent = direct_parent + "/" + a
|
|
|
|
# Now that we know all ancestors are created, create the resource itself
|
|
parent_path_part = path[0 : path.rfind("/")] or "/"
|
|
parent_resource_id = (
|
|
self.apis[function_id].get_resource_for_path(parent_path_part).id
|
|
)
|
|
resource = self.create_resource(
|
|
function_id=function_id,
|
|
parent_resource_id=parent_resource_id,
|
|
path_part=path[path.rfind("/") + 1 :],
|
|
)
|
|
|
|
for (method_type, method_doc) in resource_doc.items():
|
|
method_type = method_type.upper()
|
|
if method_doc.get("x-amazon-apigateway-integration") is None:
|
|
self.put_method(function_id, resource.id, method_type, None)
|
|
method_responses = method_doc.get("responses", {}).items()
|
|
for (response_code, _) in method_responses:
|
|
self.put_method_response(
|
|
function_id,
|
|
resource.id,
|
|
method_type,
|
|
response_code,
|
|
response_models=None,
|
|
response_parameters=None,
|
|
)
|
|
|
|
return self.get_rest_api(function_id)
|
|
|
|
def update_rest_api(
|
|
self, function_id: str, patch_operations: List[Dict[str, Any]]
|
|
) -> RestAPI:
|
|
rest_api = self.apis.get(function_id)
|
|
if rest_api is None:
|
|
raise RestAPINotFound()
|
|
self.apis[function_id].apply_patch_operations(patch_operations)
|
|
return self.apis[function_id]
|
|
|
|
def list_apis(self) -> List[RestAPI]:
|
|
return list(self.apis.values())
|
|
|
|
def delete_rest_api(self, function_id: str) -> RestAPI:
|
|
rest_api = self.apis.pop(function_id)
|
|
return rest_api
|
|
|
|
def get_resources(self, function_id: str) -> List[Resource]:
|
|
api = self.get_rest_api(function_id)
|
|
return list(api.resources.values())
|
|
|
|
def get_resource(self, function_id: str, resource_id: str) -> Resource:
|
|
api = self.get_rest_api(function_id)
|
|
if resource_id not in api.resources:
|
|
raise ResourceIdNotFoundException
|
|
return api.resources[resource_id]
|
|
|
|
def create_resource(
|
|
self, function_id: str, parent_resource_id: str, path_part: str
|
|
) -> Resource:
|
|
api = self.get_rest_api(function_id)
|
|
if not path_part:
|
|
# We're attempting to create the default resource, which already exists.
|
|
return api.default
|
|
if not re.match("^\\{?[a-zA-Z0-9._\\-\\:]+\\+?\\}?$", path_part):
|
|
raise InvalidResourcePathException()
|
|
return api.add_child(path=path_part, parent_id=parent_resource_id)
|
|
|
|
def delete_resource(self, function_id: str, resource_id: str) -> Resource:
|
|
api = self.get_rest_api(function_id)
|
|
return api.resources.pop(resource_id)
|
|
|
|
def get_method(
|
|
self, function_id: str, resource_id: str, method_type: str
|
|
) -> Method:
|
|
resource = self.get_resource(function_id, resource_id)
|
|
return resource.get_method(method_type)
|
|
|
|
def put_method(
|
|
self,
|
|
function_id: str,
|
|
resource_id: str,
|
|
method_type: str,
|
|
authorization_type: Optional[str],
|
|
api_key_required: Optional[bool] = None,
|
|
request_parameters: Optional[Dict[str, Any]] = None,
|
|
request_models: Optional[Dict[str, Any]] = None,
|
|
operation_name: Optional[str] = None,
|
|
authorizer_id: Optional[str] = None,
|
|
authorization_scopes: Optional[str] = None,
|
|
request_validator_id: Optional[str] = None,
|
|
) -> Method:
|
|
resource = self.get_resource(function_id, resource_id)
|
|
method = resource.add_method(
|
|
method_type,
|
|
authorization_type,
|
|
api_key_required=api_key_required,
|
|
request_parameters=request_parameters,
|
|
request_models=request_models,
|
|
operation_name=operation_name,
|
|
authorizer_id=authorizer_id,
|
|
authorization_scopes=authorization_scopes,
|
|
request_validator_id=request_validator_id,
|
|
)
|
|
return method
|
|
|
|
def delete_method(
|
|
self, function_id: str, resource_id: str, method_type: str
|
|
) -> None:
|
|
resource = self.get_resource(function_id, resource_id)
|
|
resource.delete_method(method_type)
|
|
|
|
def get_authorizer(self, restapi_id: str, authorizer_id: str) -> Authorizer:
|
|
api = self.get_rest_api(restapi_id)
|
|
authorizer = api.authorizers.get(authorizer_id)
|
|
if authorizer is None:
|
|
raise AuthorizerNotFoundException()
|
|
else:
|
|
return authorizer
|
|
|
|
def get_authorizers(self, restapi_id: str) -> List[Authorizer]:
|
|
api = self.get_rest_api(restapi_id)
|
|
return api.get_authorizers()
|
|
|
|
def create_authorizer(
|
|
self, restapi_id: str, name: str, authorizer_type: str, **kwargs: Any
|
|
) -> Authorizer:
|
|
api = self.get_rest_api(restapi_id)
|
|
authorizer_id = create_id()
|
|
return api.create_authorizer(
|
|
authorizer_id,
|
|
name,
|
|
authorizer_type,
|
|
provider_arns=kwargs.get("provider_arns"),
|
|
auth_type=kwargs.get("auth_type"),
|
|
authorizer_uri=kwargs.get("authorizer_uri"),
|
|
authorizer_credentials=kwargs.get("authorizer_credentials"),
|
|
identity_source=kwargs.get("identity_source"),
|
|
identiy_validation_expression=kwargs.get("identiy_validation_expression"),
|
|
authorizer_result_ttl=kwargs.get("authorizer_result_ttl"),
|
|
)
|
|
|
|
def update_authorizer(
|
|
self, restapi_id: str, authorizer_id: str, patch_operations: Any
|
|
) -> Authorizer:
|
|
authorizer = self.get_authorizer(restapi_id, authorizer_id)
|
|
return authorizer.apply_operations(patch_operations)
|
|
|
|
def delete_authorizer(self, restapi_id: str, authorizer_id: str) -> None:
|
|
api = self.get_rest_api(restapi_id)
|
|
del api.authorizers[authorizer_id]
|
|
|
|
def get_stage(self, function_id: str, stage_name: str) -> Stage:
|
|
api = self.get_rest_api(function_id)
|
|
stage = api.stages.get(stage_name)
|
|
if stage is None:
|
|
raise StageNotFoundException()
|
|
return stage
|
|
|
|
def get_stages(self, function_id: str) -> List[Stage]:
|
|
api = self.get_rest_api(function_id)
|
|
return api.get_stages()
|
|
|
|
def create_stage(
|
|
self,
|
|
function_id: str,
|
|
stage_name: str,
|
|
deploymentId: str,
|
|
variables: Optional[Any] = None,
|
|
description: str = "",
|
|
cacheClusterEnabled: Optional[bool] = None,
|
|
cacheClusterSize: Optional[str] = None,
|
|
tags: Optional[Dict[str, str]] = None,
|
|
tracing_enabled: Optional[bool] = None,
|
|
) -> Stage:
|
|
if variables is None:
|
|
variables = {}
|
|
api = self.get_rest_api(function_id)
|
|
return api.create_stage(
|
|
stage_name,
|
|
deploymentId,
|
|
variables=variables,
|
|
description=description,
|
|
cacheClusterEnabled=cacheClusterEnabled,
|
|
cacheClusterSize=cacheClusterSize,
|
|
tags=tags,
|
|
tracing_enabled=tracing_enabled,
|
|
)
|
|
|
|
def update_stage(
|
|
self, function_id: str, stage_name: str, patch_operations: Any
|
|
) -> Stage:
|
|
stage = self.get_stage(function_id, stage_name)
|
|
if not stage:
|
|
api = self.get_rest_api(function_id)
|
|
stage = api.stages[stage_name] = Stage()
|
|
return stage.apply_operations(patch_operations)
|
|
|
|
def delete_stage(self, function_id: str, stage_name: str) -> None:
|
|
api = self.get_rest_api(function_id)
|
|
deleted = api.stages.pop(stage_name, None)
|
|
if not deleted:
|
|
raise StageNotFoundException()
|
|
|
|
def get_method_response(
|
|
self, function_id: str, resource_id: str, method_type: str, response_code: str
|
|
) -> Optional[MethodResponse]:
|
|
method = self.get_method(function_id, resource_id, method_type)
|
|
return method.get_response(response_code)
|
|
|
|
def put_method_response(
|
|
self,
|
|
function_id: str,
|
|
resource_id: str,
|
|
method_type: str,
|
|
response_code: str,
|
|
response_models: Any,
|
|
response_parameters: Any,
|
|
) -> MethodResponse:
|
|
method = self.get_method(function_id, resource_id, method_type)
|
|
return method.create_response(
|
|
response_code, response_models, response_parameters
|
|
)
|
|
|
|
def delete_method_response(
|
|
self, function_id: str, resource_id: str, method_type: str, response_code: str
|
|
) -> Optional[MethodResponse]:
|
|
method = self.get_method(function_id, resource_id, method_type)
|
|
return method.delete_response(response_code)
|
|
|
|
def put_integration(
|
|
self,
|
|
function_id: str,
|
|
resource_id: str,
|
|
method_type: str,
|
|
integration_type: str,
|
|
uri: str,
|
|
integration_method: str,
|
|
credentials: Optional[str] = None,
|
|
request_templates: Optional[Dict[str, Any]] = None,
|
|
passthrough_behavior: Optional[str] = None,
|
|
tls_config: Optional[Dict[str, Any]] = None,
|
|
cache_namespace: Optional[str] = None,
|
|
timeout_in_millis: Optional[str] = None,
|
|
request_parameters: Optional[Dict[str, Any]] = None,
|
|
content_handling: Optional[str] = None,
|
|
connection_type: Optional[str] = None,
|
|
) -> Integration:
|
|
resource = self.get_resource(function_id, resource_id)
|
|
if credentials and not re.match(
|
|
"^arn:aws:iam::" + str(self.account_id), credentials
|
|
):
|
|
raise CrossAccountNotAllowed()
|
|
if not integration_method and integration_type in [
|
|
"HTTP",
|
|
"HTTP_PROXY",
|
|
"AWS",
|
|
"AWS_PROXY",
|
|
]:
|
|
raise IntegrationMethodNotDefined()
|
|
if integration_type in ["AWS_PROXY"] and re.match(
|
|
"^arn:aws:apigateway:[a-zA-Z0-9-]+:s3", uri
|
|
):
|
|
raise AwsProxyNotAllowed()
|
|
if (
|
|
integration_type in ["AWS"]
|
|
and re.match("^arn:aws:apigateway:[a-zA-Z0-9-]+:s3", uri)
|
|
and not credentials
|
|
):
|
|
raise RoleNotSpecified()
|
|
if integration_type in ["HTTP", "HTTP_PROXY"] and not self._uri_validator(uri):
|
|
raise InvalidHttpEndpoint()
|
|
if integration_type in ["AWS", "AWS_PROXY"] and not re.match("^arn:aws:", uri):
|
|
raise InvalidArn()
|
|
if integration_type in ["AWS", "AWS_PROXY"] and not re.match(
|
|
"^arn:aws:apigateway:[a-zA-Z0-9-]+:[a-zA-Z0-9-.]+:(path|action)/", uri
|
|
):
|
|
raise InvalidIntegrationArn()
|
|
integration = resource.add_integration(
|
|
method_type,
|
|
integration_type,
|
|
uri,
|
|
integration_method=integration_method,
|
|
request_templates=request_templates,
|
|
passthrough_behavior=passthrough_behavior,
|
|
tls_config=tls_config,
|
|
cache_namespace=cache_namespace,
|
|
timeout_in_millis=timeout_in_millis,
|
|
request_parameters=request_parameters,
|
|
content_handling=content_handling,
|
|
credentials=credentials,
|
|
connection_type=connection_type,
|
|
)
|
|
return integration
|
|
|
|
def get_integration(
|
|
self, function_id: str, resource_id: str, method_type: str
|
|
) -> Integration:
|
|
resource = self.get_resource(function_id, resource_id)
|
|
return resource.get_integration(method_type) # type: ignore[return-value]
|
|
|
|
def delete_integration(
|
|
self, function_id: str, resource_id: str, method_type: str
|
|
) -> Integration:
|
|
resource = self.get_resource(function_id, resource_id)
|
|
return resource.delete_integration(method_type)
|
|
|
|
def put_integration_response(
|
|
self,
|
|
function_id: str,
|
|
resource_id: str,
|
|
method_type: str,
|
|
status_code: str,
|
|
selection_pattern: str,
|
|
response_templates: Dict[str, str],
|
|
response_parameters: Dict[str, str],
|
|
content_handling: str,
|
|
) -> IntegrationResponse:
|
|
integration = self.get_integration(function_id, resource_id, method_type)
|
|
if integration:
|
|
return integration.create_integration_response(
|
|
status_code,
|
|
selection_pattern,
|
|
response_templates,
|
|
response_parameters,
|
|
content_handling,
|
|
)
|
|
raise NoIntegrationResponseDefined()
|
|
|
|
def get_integration_response(
|
|
self, function_id: str, resource_id: str, method_type: str, status_code: str
|
|
) -> IntegrationResponse:
|
|
integration = self.get_integration(function_id, resource_id, method_type)
|
|
integration_response = integration.get_integration_response(status_code)
|
|
return integration_response
|
|
|
|
def delete_integration_response(
|
|
self, function_id: str, resource_id: str, method_type: str, status_code: str
|
|
) -> IntegrationResponse:
|
|
integration = self.get_integration(function_id, resource_id, method_type)
|
|
integration_response = integration.delete_integration_response(status_code)
|
|
return integration_response
|
|
|
|
def create_deployment(
|
|
self,
|
|
function_id: str,
|
|
name: str,
|
|
description: str = "",
|
|
stage_variables: Any = None,
|
|
) -> Deployment:
|
|
if stage_variables is None:
|
|
stage_variables = {}
|
|
api = self.get_rest_api(function_id)
|
|
nested_methods = [
|
|
list(res.resource_methods.values())
|
|
for res in self.get_resources(function_id)
|
|
]
|
|
methods = [m for sublist in nested_methods for m in sublist]
|
|
if not any(methods):
|
|
raise NoMethodDefined()
|
|
method_integrations = [
|
|
method.method_integration for method in methods if method.method_integration
|
|
]
|
|
if not any(method_integrations):
|
|
raise NoIntegrationDefined()
|
|
deployment = api.create_deployment(name, description, stage_variables)
|
|
return deployment
|
|
|
|
def get_deployment(self, function_id: str, deployment_id: str) -> Deployment:
|
|
api = self.get_rest_api(function_id)
|
|
return api.get_deployment(deployment_id)
|
|
|
|
def get_deployments(self, function_id: str) -> List[Deployment]:
|
|
api = self.get_rest_api(function_id)
|
|
return api.get_deployments()
|
|
|
|
def delete_deployment(self, function_id: str, deployment_id: str) -> Deployment:
|
|
api = self.get_rest_api(function_id)
|
|
return api.delete_deployment(deployment_id)
|
|
|
|
def create_api_key(self, payload: Dict[str, Any]) -> ApiKey:
|
|
if payload.get("value"):
|
|
if len(payload.get("value", [])) < 20:
|
|
raise ApiKeyValueMinLength()
|
|
for api_key in self.get_api_keys():
|
|
if api_key.value == payload["value"]:
|
|
raise ApiKeyAlreadyExists()
|
|
key = ApiKey(**payload)
|
|
self.keys[key.id] = key
|
|
return key
|
|
|
|
def get_api_keys(self) -> List[ApiKey]:
|
|
return list(self.keys.values())
|
|
|
|
def get_api_key(self, api_key_id: str) -> ApiKey:
|
|
if api_key_id not in self.keys:
|
|
raise ApiKeyNotFoundException()
|
|
return self.keys[api_key_id]
|
|
|
|
def update_api_key(self, api_key_id: str, patch_operations: Any) -> ApiKey:
|
|
key = self.keys[api_key_id]
|
|
return key.update_operations(patch_operations)
|
|
|
|
def delete_api_key(self, api_key_id: str) -> None:
|
|
self.keys.pop(api_key_id)
|
|
|
|
def create_usage_plan(self, payload: Any) -> UsagePlan:
|
|
plan = UsagePlan(**payload)
|
|
self.usage_plans[plan.id] = plan
|
|
return plan
|
|
|
|
def get_usage_plans(self, api_key_id: Optional[str] = None) -> List[UsagePlan]:
|
|
plans = list(self.usage_plans.values())
|
|
if api_key_id is not None:
|
|
plans = [
|
|
plan
|
|
for plan in plans
|
|
if dict(self.usage_plan_keys.get(plan.id, {})).get(api_key_id)
|
|
]
|
|
return plans
|
|
|
|
def get_usage_plan(self, usage_plan_id: str) -> UsagePlan:
|
|
if usage_plan_id not in self.usage_plans:
|
|
raise UsagePlanNotFoundException()
|
|
return self.usage_plans[usage_plan_id]
|
|
|
|
def update_usage_plan(self, usage_plan_id: str, patch_operations: Any) -> UsagePlan:
|
|
if usage_plan_id not in self.usage_plans:
|
|
raise UsagePlanNotFoundException()
|
|
self.usage_plans[usage_plan_id].apply_patch_operations(patch_operations)
|
|
return self.usage_plans[usage_plan_id]
|
|
|
|
def delete_usage_plan(self, usage_plan_id: str) -> None:
|
|
self.usage_plans.pop(usage_plan_id)
|
|
|
|
def create_usage_plan_key(
|
|
self, usage_plan_id: str, payload: Dict[str, Any]
|
|
) -> UsagePlanKey:
|
|
if usage_plan_id not in self.usage_plan_keys:
|
|
self.usage_plan_keys[usage_plan_id] = {}
|
|
|
|
key_id = payload["keyId"]
|
|
if key_id not in self.keys:
|
|
raise ApiKeyNotFoundException()
|
|
|
|
api_key = self.keys[key_id]
|
|
|
|
usage_plan_key = UsagePlanKey(
|
|
plan_id=key_id,
|
|
plan_type=payload["keyType"],
|
|
name=api_key.name,
|
|
value=api_key.value,
|
|
)
|
|
self.usage_plan_keys[usage_plan_id][usage_plan_key.id] = usage_plan_key
|
|
return usage_plan_key
|
|
|
|
def get_usage_plan_keys(self, usage_plan_id: str) -> List[UsagePlanKey]:
|
|
if usage_plan_id not in self.usage_plan_keys:
|
|
return []
|
|
|
|
return list(self.usage_plan_keys[usage_plan_id].values())
|
|
|
|
def get_usage_plan_key(self, usage_plan_id: str, key_id: str) -> UsagePlanKey:
|
|
# first check if is a valid api key
|
|
if key_id not in self.keys:
|
|
raise ApiKeyNotFoundException()
|
|
|
|
# then check if is a valid api key and that the key is in the plan
|
|
if (
|
|
usage_plan_id not in self.usage_plan_keys
|
|
or key_id not in self.usage_plan_keys[usage_plan_id]
|
|
):
|
|
raise UsagePlanNotFoundException()
|
|
|
|
return self.usage_plan_keys[usage_plan_id][key_id]
|
|
|
|
def delete_usage_plan_key(self, usage_plan_id: str, key_id: str) -> None:
|
|
self.usage_plan_keys[usage_plan_id].pop(key_id)
|
|
|
|
def _uri_validator(self, uri: str) -> bool:
|
|
try:
|
|
result = urlparse(uri)
|
|
return all([result.scheme, result.netloc, result.path or "/"])
|
|
except Exception:
|
|
return False
|
|
|
|
def create_domain_name(
|
|
self,
|
|
domain_name: str,
|
|
certificate_name: str,
|
|
tags: List[Dict[str, str]],
|
|
certificate_arn: str,
|
|
certificate_body: str,
|
|
certificate_private_key: str,
|
|
certificate_chain: str,
|
|
regional_certificate_name: str,
|
|
regional_certificate_arn: str,
|
|
endpoint_configuration: Any,
|
|
security_policy: str,
|
|
) -> DomainName:
|
|
if not domain_name:
|
|
raise InvalidDomainName()
|
|
|
|
new_domain_name = DomainName(
|
|
domain_name=domain_name,
|
|
certificate_name=certificate_name,
|
|
certificate_private_key=certificate_private_key,
|
|
certificate_arn=certificate_arn,
|
|
certificate_body=certificate_body,
|
|
certificate_chain=certificate_chain,
|
|
regional_certificate_name=regional_certificate_name,
|
|
regional_certificate_arn=regional_certificate_arn,
|
|
endpoint_configuration=endpoint_configuration,
|
|
tags=tags,
|
|
security_policy=security_policy,
|
|
region_name=self.region_name,
|
|
)
|
|
|
|
self.domain_names[domain_name] = new_domain_name
|
|
return new_domain_name
|
|
|
|
def get_domain_names(self) -> List[DomainName]:
|
|
return list(self.domain_names.values())
|
|
|
|
def get_domain_name(self, domain_name: str) -> DomainName:
|
|
domain_info = self.domain_names.get(domain_name)
|
|
if domain_info is None:
|
|
raise DomainNameNotFound()
|
|
else:
|
|
return domain_info
|
|
|
|
def delete_domain_name(self, domain_name: str) -> None:
|
|
domain_info = self.domain_names.pop(domain_name, None)
|
|
if domain_info is None:
|
|
raise DomainNameNotFound()
|
|
|
|
def create_model(
|
|
self,
|
|
rest_api_id: str,
|
|
name: str,
|
|
content_type: str,
|
|
description: str,
|
|
schema: str,
|
|
) -> Model:
|
|
if not rest_api_id:
|
|
raise InvalidRestApiId
|
|
if not name:
|
|
raise InvalidModelName
|
|
|
|
api = self.get_rest_api(rest_api_id)
|
|
new_model = api.add_model(
|
|
name=name,
|
|
description=description,
|
|
schema=schema,
|
|
content_type=content_type,
|
|
)
|
|
|
|
return new_model
|
|
|
|
def get_models(self, rest_api_id: str) -> List[Model]:
|
|
if not rest_api_id:
|
|
raise InvalidRestApiId
|
|
api = self.get_rest_api(rest_api_id)
|
|
models = api.models.values()
|
|
return list(models)
|
|
|
|
def get_model(self, rest_api_id: str, model_name: str) -> Model:
|
|
if not rest_api_id:
|
|
raise InvalidRestApiId
|
|
api = self.get_rest_api(rest_api_id)
|
|
model = api.models.get(model_name)
|
|
if model is None:
|
|
raise ModelNotFound
|
|
else:
|
|
return model
|
|
|
|
def get_request_validators(self, restapi_id: str) -> List[RequestValidator]:
|
|
restApi = self.get_rest_api(restapi_id)
|
|
return restApi.get_request_validators()
|
|
|
|
def create_request_validator(
|
|
self, restapi_id: str, name: str, body: Optional[bool], params: Any
|
|
) -> RequestValidator:
|
|
restApi = self.get_rest_api(restapi_id)
|
|
return restApi.create_request_validator(
|
|
name=name, validateRequestBody=body, validateRequestParameters=params
|
|
)
|
|
|
|
def get_request_validator(
|
|
self, restapi_id: str, validator_id: str
|
|
) -> RequestValidator:
|
|
restApi = self.get_rest_api(restapi_id)
|
|
return restApi.get_request_validator(validator_id)
|
|
|
|
def delete_request_validator(self, restapi_id: str, validator_id: str) -> None:
|
|
restApi = self.get_rest_api(restapi_id)
|
|
restApi.delete_request_validator(validator_id)
|
|
|
|
def update_request_validator(
|
|
self, restapi_id: str, validator_id: str, patch_operations: Any
|
|
) -> RequestValidator:
|
|
restApi = self.get_rest_api(restapi_id)
|
|
return restApi.update_request_validator(validator_id, patch_operations)
|
|
|
|
def create_base_path_mapping(
|
|
self, domain_name: str, rest_api_id: str, base_path: str, stage: str
|
|
) -> BasePathMapping:
|
|
if domain_name not in self.domain_names:
|
|
raise DomainNameNotFound()
|
|
|
|
if base_path and "/" in base_path:
|
|
raise InvalidBasePathException()
|
|
|
|
if rest_api_id not in self.apis:
|
|
raise InvalidRestApiIdForBasePathMappingException()
|
|
|
|
if stage and self.apis[rest_api_id].stages.get(stage) is None:
|
|
raise InvalidStageException()
|
|
|
|
new_base_path_mapping = BasePathMapping(
|
|
domain_name=domain_name,
|
|
rest_api_id=rest_api_id,
|
|
basePath=base_path,
|
|
stage=stage,
|
|
)
|
|
|
|
new_base_path = new_base_path_mapping.base_path
|
|
if self.base_path_mappings.get(domain_name) is None:
|
|
self.base_path_mappings[domain_name] = {}
|
|
else:
|
|
if (
|
|
self.base_path_mappings[domain_name].get(new_base_path)
|
|
and new_base_path != "(none)"
|
|
):
|
|
raise BasePathConflictException()
|
|
self.base_path_mappings[domain_name][new_base_path] = new_base_path_mapping
|
|
return new_base_path_mapping
|
|
|
|
def get_base_path_mappings(self, domain_name: str) -> List[BasePathMapping]:
|
|
if domain_name not in self.domain_names:
|
|
raise DomainNameNotFound()
|
|
|
|
return list(self.base_path_mappings[domain_name].values())
|
|
|
|
def get_base_path_mapping(
|
|
self, domain_name: str, base_path: str
|
|
) -> BasePathMapping:
|
|
if domain_name not in self.domain_names:
|
|
raise DomainNameNotFound()
|
|
|
|
if base_path not in self.base_path_mappings[domain_name]:
|
|
raise BasePathNotFoundException()
|
|
|
|
return self.base_path_mappings[domain_name][base_path]
|
|
|
|
def delete_base_path_mapping(self, domain_name: str, base_path: str) -> None:
|
|
if domain_name not in self.domain_names:
|
|
raise DomainNameNotFound()
|
|
|
|
if base_path not in self.base_path_mappings[domain_name]:
|
|
raise BasePathNotFoundException()
|
|
|
|
self.base_path_mappings[domain_name].pop(base_path)
|
|
|
|
def update_base_path_mapping(
|
|
self, domain_name: str, base_path: str, patch_operations: Any
|
|
) -> BasePathMapping:
|
|
|
|
if domain_name not in self.domain_names:
|
|
raise DomainNameNotFound()
|
|
|
|
if base_path not in self.base_path_mappings[domain_name]:
|
|
raise BasePathNotFoundException()
|
|
|
|
base_path_mapping = self.get_base_path_mapping(domain_name, base_path)
|
|
|
|
rest_api_ids = [
|
|
op["value"] for op in patch_operations if op["path"] == "/restapiId"
|
|
]
|
|
if len(rest_api_ids) == 0:
|
|
modified_rest_api_id = base_path_mapping.rest_api_id
|
|
else:
|
|
modified_rest_api_id = rest_api_ids[-1]
|
|
|
|
stages = [op["value"] for op in patch_operations if op["path"] == "/stage"]
|
|
if len(stages) == 0:
|
|
modified_stage = base_path_mapping.stage
|
|
else:
|
|
modified_stage = stages[-1]
|
|
|
|
base_paths = [
|
|
op["value"] for op in patch_operations if op["path"] == "/basePath"
|
|
]
|
|
if len(base_paths) == 0:
|
|
modified_base_path = base_path_mapping.base_path
|
|
else:
|
|
modified_base_path = base_paths[-1]
|
|
|
|
rest_api = self.apis.get(modified_rest_api_id)
|
|
if rest_api is None:
|
|
raise InvalidRestApiIdForBasePathMappingException()
|
|
if modified_stage and rest_api.stages.get(modified_stage) is None:
|
|
raise InvalidStageException()
|
|
|
|
base_path_mapping.apply_patch_operations(patch_operations)
|
|
|
|
if base_path != modified_base_path:
|
|
self.base_path_mappings[domain_name].pop(base_path)
|
|
self.base_path_mappings[domain_name][modified_base_path] = base_path_mapping
|
|
|
|
return base_path_mapping
|
|
|
|
def create_vpc_link(
|
|
self,
|
|
name: str,
|
|
description: str,
|
|
target_arns: List[str],
|
|
tags: List[Dict[str, str]],
|
|
) -> VpcLink:
|
|
vpc_link = VpcLink(
|
|
name, description=description, target_arns=target_arns, tags=tags
|
|
)
|
|
self.vpc_links[vpc_link.id] = vpc_link
|
|
return vpc_link
|
|
|
|
def delete_vpc_link(self, vpc_link_id: str) -> None:
|
|
self.vpc_links.pop(vpc_link_id, None)
|
|
|
|
def get_vpc_link(self, vpc_link_id: str) -> VpcLink:
|
|
if vpc_link_id not in self.vpc_links:
|
|
raise VpcLinkNotFound
|
|
return self.vpc_links[vpc_link_id]
|
|
|
|
def get_vpc_links(self) -> List[VpcLink]:
|
|
"""
|
|
Pagination has not yet been implemented
|
|
"""
|
|
return list(self.vpc_links.values())
|
|
|
|
def put_gateway_response(
|
|
self,
|
|
rest_api_id: str,
|
|
response_type: str,
|
|
status_code: int,
|
|
response_parameters: Any,
|
|
response_templates: Any,
|
|
) -> GatewayResponse:
|
|
api = self.get_rest_api(rest_api_id)
|
|
response = api.put_gateway_response(
|
|
response_type,
|
|
status_code=status_code,
|
|
response_parameters=response_parameters,
|
|
response_templates=response_templates,
|
|
)
|
|
return response
|
|
|
|
def get_gateway_response(
|
|
self, rest_api_id: str, response_type: str
|
|
) -> GatewayResponse:
|
|
api = self.get_rest_api(rest_api_id)
|
|
return api.get_gateway_response(response_type)
|
|
|
|
def get_gateway_responses(self, rest_api_id: str) -> List[GatewayResponse]:
|
|
"""
|
|
Pagination is not yet implemented
|
|
"""
|
|
api = self.get_rest_api(rest_api_id)
|
|
return api.get_gateway_responses()
|
|
|
|
def delete_gateway_response(self, rest_api_id: str, response_type: str) -> None:
|
|
api = self.get_rest_api(rest_api_id)
|
|
api.delete_gateway_response(response_type)
|
|
|
|
|
|
apigateway_backends = BackendDict(APIGatewayBackend, "apigateway")
|