APIGateway improvements (#4874)

This commit is contained in:
Bert Blommers 2022-02-18 22:31:33 -01:00 committed by GitHub
parent ecd7cf9d92
commit 876c783a24
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 1400 additions and 962 deletions

View File

@ -22,7 +22,7 @@
## apigateway ## apigateway
<details> <details>
<summary>59% implemented</summary> <summary>62% implemented</summary>
- [X] create_api_key - [X] create_api_key
- [X] create_authorizer - [X] create_authorizer
@ -47,7 +47,7 @@
- [ ] delete_documentation_part - [ ] delete_documentation_part
- [ ] delete_documentation_version - [ ] delete_documentation_version
- [X] delete_domain_name - [X] delete_domain_name
- [ ] delete_gateway_response - [X] delete_gateway_response
- [X] delete_integration - [X] delete_integration
- [X] delete_integration_response - [X] delete_integration_response
- [X] delete_method - [X] delete_method
@ -81,8 +81,8 @@
- [X] get_domain_name - [X] get_domain_name
- [X] get_domain_names - [X] get_domain_names
- [ ] get_export - [ ] get_export
- [ ] get_gateway_response - [X] get_gateway_response
- [ ] get_gateway_responses - [X] get_gateway_responses
- [X] get_integration - [X] get_integration
- [X] get_integration_response - [X] get_integration_response
- [X] get_method - [X] get_method
@ -112,7 +112,7 @@
- [ ] import_api_keys - [ ] import_api_keys
- [ ] import_documentation_parts - [ ] import_documentation_parts
- [ ] import_rest_api - [ ] import_rest_api
- [ ] put_gateway_response - [X] put_gateway_response
- [X] put_integration - [X] put_integration
- [X] put_integration_response - [X] put_integration_response
- [X] put_method - [X] put_method

View File

@ -50,7 +50,7 @@ apigateway
- [ ] delete_documentation_part - [ ] delete_documentation_part
- [ ] delete_documentation_version - [ ] delete_documentation_version
- [X] delete_domain_name - [X] delete_domain_name
- [ ] delete_gateway_response - [X] delete_gateway_response
- [X] delete_integration - [X] delete_integration
- [X] delete_integration_response - [X] delete_integration_response
- [X] delete_method - [X] delete_method
@ -84,8 +84,12 @@ apigateway
- [X] get_domain_name - [X] get_domain_name
- [X] get_domain_names - [X] get_domain_names
- [ ] get_export - [ ] get_export
- [ ] get_gateway_response - [X] get_gateway_response
- [ ] get_gateway_responses - [X] get_gateway_responses
Pagination is not yet implemented
- [X] get_integration - [X] get_integration
- [X] get_integration_response - [X] get_integration_response
- [X] get_method - [X] get_method
@ -119,7 +123,7 @@ apigateway
- [ ] import_api_keys - [ ] import_api_keys
- [ ] import_documentation_parts - [ ] import_documentation_parts
- [ ] import_rest_api - [ ] import_rest_api
- [ ] put_gateway_response - [X] put_gateway_response
- [X] put_integration - [X] put_integration
- [X] put_integration_response - [X] put_integration_response
- [X] put_method - [X] put_method

View File

@ -1,26 +1,34 @@
from moto.core.exceptions import JsonRESTError from moto.core.exceptions import JsonRESTError
class BadRequestException(JsonRESTError): class ApiGatewayException(JsonRESTError):
pass pass
class NotFoundException(JsonRESTError): class BadRequestException(ApiGatewayException):
def __init__(self, message):
super().__init__("BadRequestException", message)
class NotFoundException(ApiGatewayException):
def __init__(self, message):
super().__init__("NotFoundException", message)
class AccessDeniedException(ApiGatewayException):
pass pass
class AccessDeniedException(JsonRESTError): class ConflictException(ApiGatewayException):
pass
class ConflictException(JsonRESTError):
code = 409 code = 409
def __init__(self, message):
super().__init__("ConflictException", message)
class AwsProxyNotAllowed(BadRequestException): class AwsProxyNotAllowed(BadRequestException):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
"BadRequestException",
"Integrations of type 'AWS_PROXY' currently only supports Lambda function and Firehose stream invocations.", "Integrations of type 'AWS_PROXY' currently only supports Lambda function and Firehose stream invocations.",
) )
@ -34,98 +42,87 @@ class CrossAccountNotAllowed(AccessDeniedException):
class RoleNotSpecified(BadRequestException): class RoleNotSpecified(BadRequestException):
def __init__(self): def __init__(self):
super().__init__( super().__init__("Role ARN must be specified for AWS integrations")
"BadRequestException", "Role ARN must be specified for AWS integrations"
)
class IntegrationMethodNotDefined(BadRequestException): class IntegrationMethodNotDefined(BadRequestException):
def __init__(self): def __init__(self):
super().__init__( super().__init__("Enumeration value for HttpMethod must be non-empty")
"BadRequestException", "Enumeration value for HttpMethod must be non-empty"
)
class InvalidResourcePathException(BadRequestException): class InvalidResourcePathException(BadRequestException):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
"BadRequestException",
"Resource's path part only allow a-zA-Z0-9._- and curly braces at the beginning and the end and an optional plus sign before the closing brace.", "Resource's path part only allow a-zA-Z0-9._- and curly braces at the beginning and the end and an optional plus sign before the closing brace.",
) )
class InvalidHttpEndpoint(BadRequestException): class InvalidHttpEndpoint(BadRequestException):
def __init__(self): def __init__(self):
super().__init__( super().__init__("Invalid HTTP endpoint specified for URI")
"BadRequestException", "Invalid HTTP endpoint specified for URI"
)
class InvalidArn(BadRequestException): class InvalidArn(BadRequestException):
def __init__(self): def __init__(self):
super().__init__("BadRequestException", "Invalid ARN specified in the request") super().__init__("Invalid ARN specified in the request")
class InvalidIntegrationArn(BadRequestException): class InvalidIntegrationArn(BadRequestException):
def __init__(self): def __init__(self):
super().__init__( super().__init__("AWS ARN for integration must contain path or action")
"BadRequestException", "AWS ARN for integration must contain path or action"
)
class InvalidRequestInput(BadRequestException): class InvalidRequestInput(BadRequestException):
def __init__(self): def __init__(self):
super().__init__("BadRequestException", "Invalid request input") super().__init__("Invalid request input")
class NoIntegrationDefined(NotFoundException): class NoIntegrationDefined(NotFoundException):
def __init__(self): def __init__(self):
super().__init__("NotFoundException", "No integration defined for method") super().__init__("No integration defined for method")
class NoIntegrationResponseDefined(NotFoundException): class NoIntegrationResponseDefined(NotFoundException):
code = 404 code = 404
def __init__(self, code=None): def __init__(self, code=None):
super().__init__("NotFoundException", "Invalid Response status code specified") super().__init__("Invalid Response status code specified")
class NoMethodDefined(BadRequestException): class NoMethodDefined(BadRequestException):
def __init__(self): def __init__(self):
super().__init__( super().__init__("The REST API doesn't contain any methods")
"BadRequestException", "The REST API doesn't contain any methods"
)
class AuthorizerNotFoundException(NotFoundException): class AuthorizerNotFoundException(NotFoundException):
code = 404 code = 404
def __init__(self): def __init__(self):
super().__init__("NotFoundException", "Invalid Authorizer identifier specified") super().__init__("Invalid Authorizer identifier specified")
class StageNotFoundException(NotFoundException): class StageNotFoundException(NotFoundException):
code = 404 code = 404
def __init__(self): def __init__(self):
super().__init__("NotFoundException", "Invalid stage identifier specified") super().__init__("Invalid stage identifier specified")
class ApiKeyNotFoundException(NotFoundException): class ApiKeyNotFoundException(NotFoundException):
code = 404 code = 404
def __init__(self): def __init__(self):
super().__init__("NotFoundException", "Invalid API Key identifier specified") super().__init__("Invalid API Key identifier specified")
class UsagePlanNotFoundException(NotFoundException): class UsagePlanNotFoundException(NotFoundException):
code = 404 code = 404
def __init__(self): def __init__(self):
super().__init__("NotFoundException", "Invalid Usage Plan ID specified") super().__init__("Invalid Usage Plan ID specified")
class ApiKeyAlreadyExists(JsonRESTError): class ApiKeyAlreadyExists(ApiGatewayException):
code = 409 code = 409
def __init__(self): def __init__(self):
@ -136,67 +133,63 @@ class InvalidDomainName(BadRequestException):
code = 404 code = 404
def __init__(self): def __init__(self):
super().__init__("BadRequestException", "No Domain Name specified") super().__init__("No Domain Name specified")
class DomainNameNotFound(NotFoundException): class DomainNameNotFound(NotFoundException):
code = 404 code = 404
def __init__(self): def __init__(self):
super().__init__( super().__init__("Invalid domain name identifier specified")
"NotFoundException", "Invalid domain name identifier specified"
)
class InvalidRestApiId(BadRequestException): class InvalidRestApiId(BadRequestException):
code = 404 code = 404
def __init__(self): def __init__(self):
super().__init__("BadRequestException", "No Rest API Id specified") super().__init__("No Rest API Id specified")
class InvalidModelName(BadRequestException): class InvalidModelName(BadRequestException):
code = 404 code = 404
def __init__(self): def __init__(self):
super().__init__("BadRequestException", "No Model Name specified") super().__init__("No Model Name specified")
class RestAPINotFound(NotFoundException): class RestAPINotFound(NotFoundException):
code = 404 code = 404
def __init__(self): def __init__(self):
super().__init__("NotFoundException", "Invalid Rest API Id specified") super().__init__("Invalid Rest API Id specified")
class RequestValidatorNotFound(BadRequestException): class RequestValidatorNotFound(BadRequestException):
code = 400 code = 400
def __init__(self): def __init__(self):
super().__init__("NotFoundException", "Invalid Request Validator Id specified") super().__init__("Invalid Request Validator Id specified")
class ModelNotFound(NotFoundException): class ModelNotFound(NotFoundException):
code = 404 code = 404
def __init__(self): def __init__(self):
super().__init__("NotFoundException", "Invalid Model Name specified") super().__init__("Invalid Model Name specified")
class ApiKeyValueMinLength(BadRequestException): class ApiKeyValueMinLength(BadRequestException):
code = 400 code = 400
def __init__(self): def __init__(self):
super().__init__( super().__init__("API Key value should be at least 20 characters")
"BadRequestException", "API Key value should be at least 20 characters"
)
class MethodNotFoundException(NotFoundException): class MethodNotFoundException(NotFoundException):
code = 404 code = 404
def __init__(self): def __init__(self):
super().__init__("NotFoundException", "Invalid Method identifier specified") super().__init__("Invalid Method identifier specified")
class InvalidBasePathException(BadRequestException): class InvalidBasePathException(BadRequestException):
@ -204,44 +197,63 @@ class InvalidBasePathException(BadRequestException):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
"BadRequestException",
"API Gateway V1 doesn't support the slash character (/) in base path mappings. " "API Gateway V1 doesn't support the slash character (/) in base path mappings. "
"To create a multi-level base path mapping, use API Gateway V2.", "To create a multi-level base path mapping, use API Gateway V2.",
) )
class DeploymentNotFoundException(NotFoundException):
def __init__(self):
super().__init__("Invalid Deployment identifier specified")
class InvalidRestApiIdForBasePathMappingException(BadRequestException): class InvalidRestApiIdForBasePathMappingException(BadRequestException):
code = 400 code = 400
def __init__(self): def __init__(self):
super().__init__("BadRequestException", "Invalid REST API identifier specified") super().__init__("Invalid REST API identifier specified")
class InvalidStageException(BadRequestException): class InvalidStageException(BadRequestException):
code = 400 code = 400
def __init__(self): def __init__(self):
super().__init__("BadRequestException", "Invalid stage identifier specified") super().__init__("Invalid stage identifier specified")
class BasePathConflictException(ConflictException): class BasePathConflictException(ConflictException):
def __init__(self): def __init__(self):
super().__init__( super().__init__("Base path already exists for this domain name")
"ConflictException", "Base path already exists for this domain name"
)
class BasePathNotFoundException(NotFoundException): class BasePathNotFoundException(NotFoundException):
code = 404 code = 404
def __init__(self): def __init__(self):
super().__init__( super().__init__("Invalid base path mapping identifier specified")
"NotFoundException", "Invalid base path mapping identifier specified"
)
class VpcLinkNotFound(NotFoundException): class VpcLinkNotFound(NotFoundException):
code = 404 code = 404
def __init__(self): def __init__(self):
super().__init__("NotFoundException", "VPCLink not found") super().__init__("VPCLink not found")
class ValidationException(ApiGatewayException):
code = 400
def __init__(self, message):
super().__init__("ValidationException", message)
class StageStillActive(BadRequestException):
def __init__(self):
super().__init__(
"Active stages pointing to this deployment must be moved or deleted"
)
class GatewayResponseNotFound(NotFoundException):
def __init__(self):
super().__init__("GatewayResponse not found")

View File

@ -20,6 +20,8 @@ from .integration_parsers.aws_parser import TypeAwsParser
from .integration_parsers.http_parser import TypeHttpParser from .integration_parsers.http_parser import TypeHttpParser
from .integration_parsers.unknown_parser import TypeUnknownParser from .integration_parsers.unknown_parser import TypeUnknownParser
from .exceptions import ( from .exceptions import (
ConflictException,
DeploymentNotFoundException,
ApiKeyNotFoundException, ApiKeyNotFoundException,
UsagePlanNotFoundException, UsagePlanNotFoundException,
AwsProxyNotAllowed, AwsProxyNotAllowed,
@ -49,7 +51,10 @@ from .exceptions import (
InvalidStageException, InvalidStageException,
BasePathConflictException, BasePathConflictException,
BasePathNotFoundException, BasePathNotFoundException,
StageStillActive,
VpcLinkNotFound, VpcLinkNotFound,
ValidationException,
GatewayResponseNotFound,
) )
from ..core.models import responses_mock from ..core.models import responses_mock
from moto.apigateway.exceptions import MethodNotFoundException from moto.apigateway.exceptions import MethodNotFoundException
@ -119,6 +124,7 @@ class Integration(BaseModel, dict):
request_templates=None, request_templates=None,
tls_config=None, tls_config=None,
cache_namespace=None, cache_namespace=None,
timeout_in_millis=None,
): ):
super().__init__() super().__init__()
self["type"] = integration_type self["type"] = integration_type
@ -131,6 +137,7 @@ class Integration(BaseModel, dict):
] = None # prevent json serialization from including them if none provided ] = None # prevent json serialization from including them if none provided
self["tlsConfig"] = tls_config self["tlsConfig"] = tls_config
self["cacheNamespace"] = cache_namespace self["cacheNamespace"] = cache_namespace
self["timeoutInMillis"] = timeout_in_millis
def create_integration_response( def create_integration_response(
self, status_code, selection_pattern, response_templates, content_handling self, status_code, selection_pattern, response_templates, content_handling
@ -361,6 +368,7 @@ class Resource(CloudFormationModel):
integration_method=None, integration_method=None,
tls_config=None, tls_config=None,
cache_namespace=None, cache_namespace=None,
timeout_in_millis=None,
): ):
integration_method = integration_method or method_type integration_method = integration_method or method_type
integration = Integration( integration = Integration(
@ -370,6 +378,7 @@ class Resource(CloudFormationModel):
request_templates=request_templates, request_templates=request_templates,
tls_config=tls_config, tls_config=tls_config,
cache_namespace=cache_namespace, cache_namespace=cache_namespace,
timeout_in_millis=timeout_in_millis,
) )
self.resource_methods[method_type]["methodIntegration"] = integration self.resource_methods[method_type]["methodIntegration"] = integration
return integration return integration
@ -451,6 +460,7 @@ class Stage(BaseModel, dict):
self["description"] = description self["description"] = description
self["cacheClusterEnabled"] = cacheClusterEnabled self["cacheClusterEnabled"] = cacheClusterEnabled
if self["cacheClusterEnabled"]: if self["cacheClusterEnabled"]:
self["cacheClusterStatus"] = "AVAILABLE"
self["cacheClusterSize"] = str(0.5) self["cacheClusterSize"] = str(0.5)
if cacheClusterSize is not None: if cacheClusterSize is not None:
self["cacheClusterSize"] = str(cacheClusterSize) self["cacheClusterSize"] = str(cacheClusterSize)
@ -465,25 +475,39 @@ class Stage(BaseModel, dict):
self._apply_operation_to_variables(op) self._apply_operation_to_variables(op)
elif "/cacheClusterEnabled" in op["path"]: elif "/cacheClusterEnabled" in op["path"]:
self["cacheClusterEnabled"] = self._str2bool(op["value"]) self["cacheClusterEnabled"] = self._str2bool(op["value"])
if "cacheClusterSize" not in self and self["cacheClusterEnabled"]: if self["cacheClusterEnabled"]:
self["cacheClusterSize"] = str(0.5) self["cacheClusterStatus"] = "AVAILABLE"
if "cacheClusterSize" not in self:
self["cacheClusterSize"] = str(0.5)
else:
self["cacheClusterStatus"] = "NOT_AVAILABLE"
elif "/cacheClusterSize" in op["path"]: elif "/cacheClusterSize" in op["path"]:
self["cacheClusterSize"] = str(float(op["value"])) self["cacheClusterSize"] = str(op["value"])
elif "/description" in op["path"]: elif "/description" in op["path"]:
self["description"] = op["value"] self["description"] = op["value"]
elif "/deploymentId" in op["path"]: elif "/deploymentId" in op["path"]:
self["deploymentId"] = op["value"] self["deploymentId"] = op["value"]
elif op["op"] == "replace": elif op["op"] == "replace":
# Method Settings drop into here if op["path"] == "/tracingEnabled":
# (e.g., path could be '/*/*/logging/loglevel') self["tracingEnabled"] = self._str2bool(op["value"])
split_path = op["path"].split("/", 3) elif op["path"].startswith("/accessLogSettings/"):
if len(split_path) != 4: self["accessLogSettings"] = self.get("accessLogSettings", {})
continue self["accessLogSettings"][op["path"].split("/")[-1]] = op["value"]
self._patch_method_setting( else:
"/".join(split_path[1:3]), split_path[3], op["value"] # (e.g., path could be '/*/*/logging/loglevel')
) split_path = op["path"].split("/", 3)
if len(split_path) != 4:
continue
self._patch_method_setting(
"/".join(split_path[1:3]), split_path[3], op["value"]
)
elif op["op"] == "remove":
if op["path"] == "/accessLogSettings":
self["accessLogSettings"] = None
else: else:
raise Exception('Patch operation "%s" not implemented' % op["op"]) raise ValidationException(
"Member must satisfy enum value set: [add, remove, move, test, replace, copy]"
)
return self return self
def _patch_method_setting(self, resource_path_and_method, key, value): def _patch_method_setting(self, resource_path_and_method, key, value):
@ -768,6 +792,7 @@ class RestAPI(CloudFormationModel):
self.minimum_compression_size = kwargs.get("minimum_compression_size") self.minimum_compression_size = kwargs.get("minimum_compression_size")
self.deployments = {} self.deployments = {}
self.authorizers = {} self.authorizers = {}
self.gateway_responses = {}
self.stages = {} self.stages = {}
self.resources = {} self.resources = {}
self.models = {} self.models = {}
@ -972,6 +997,8 @@ class RestAPI(CloudFormationModel):
tags=None, tags=None,
tracing_enabled=None, tracing_enabled=None,
): ):
if name in self.stages:
raise ConflictException("Stage already exists")
if variables is None: if variables is None:
variables = {} variables = {}
stage = Stage( stage = Stage(
@ -994,9 +1021,10 @@ class RestAPI(CloudFormationModel):
deployment_id = create_id() deployment_id = create_id()
deployment = Deployment(deployment_id, name, description) deployment = Deployment(deployment_id, name, description)
self.deployments[deployment_id] = deployment self.deployments[deployment_id] = deployment
self.stages[name] = Stage( if name:
name=name, deployment_id=deployment_id, variables=stage_variables self.stages[name] = Stage(
) name=name, deployment_id=deployment_id, variables=stage_variables
)
self.update_integration_mocks(name) self.update_integration_mocks(name)
return deployment return deployment
@ -1014,6 +1042,13 @@ class RestAPI(CloudFormationModel):
return list(self.deployments.values()) return list(self.deployments.values())
def delete_deployment(self, deployment_id): def delete_deployment(self, deployment_id):
if deployment_id not in self.deployments:
raise DeploymentNotFoundException()
deployment = self.deployments[deployment_id]
if deployment["stageName"] and deployment["stageName"] in self.stages:
# Stage is still active
raise StageStillActive()
return self.deployments.pop(deployment_id) return self.deployments.pop(deployment_id)
def create_request_validator( def create_request_validator(
@ -1046,6 +1081,29 @@ class RestAPI(CloudFormationModel):
self.request_validators[validator_id].apply_patch_operations(patch_operations) self.request_validators[validator_id].apply_patch_operations(patch_operations)
return self.request_validators[validator_id] return self.request_validators[validator_id]
def put_gateway_response(
self, response_type, status_code, response_parameters, response_templates
):
response = GatewayResponse(
response_type=response_type,
status_code=status_code,
response_parameters=response_parameters,
response_templates=response_templates,
)
self.gateway_responses[response_type] = response
return response
def get_gateway_response(self, response_type):
if response_type not in self.gateway_responses:
raise GatewayResponseNotFound()
return self.gateway_responses[response_type]
def get_gateway_responses(self):
return list(self.gateway_responses.values())
def delete_gateway_response(self, response_type):
self.gateway_responses.pop(response_type, None)
class DomainName(BaseModel, dict): class DomainName(BaseModel, dict):
def __init__(self, domain_name, **kwargs): def __init__(self, domain_name, **kwargs):
@ -1136,6 +1194,21 @@ class BasePathMapping(BaseModel, dict):
self["stage"] = value self["stage"] = value
class GatewayResponse(BaseModel, dict):
def __init__(
self, response_type, status_code, response_parameters, response_templates
):
super().__init__()
self["responseType"] = response_type
if status_code is not None:
self["statusCode"] = status_code
if response_parameters is not None:
self["responseParameters"] = response_parameters
if response_templates is not None:
self["responseTemplates"] = response_templates
self["defaultResponse"] = False
class APIGatewayBackend(BaseBackend): class APIGatewayBackend(BaseBackend):
""" """
API Gateway mock. API Gateway mock.
@ -1423,6 +1496,7 @@ class APIGatewayBackend(BaseBackend):
request_templates=None, request_templates=None,
tls_config=None, tls_config=None,
cache_namespace=None, cache_namespace=None,
timeout_in_millis=None,
): ):
resource = self.get_resource(function_id, resource_id) resource = self.get_resource(function_id, resource_id)
if credentials and not re.match( if credentials and not re.match(
@ -1462,6 +1536,7 @@ class APIGatewayBackend(BaseBackend):
request_templates=request_templates, request_templates=request_templates,
tls_config=tls_config, tls_config=tls_config,
cache_namespace=cache_namespace, cache_namespace=cache_namespace,
timeout_in_millis=timeout_in_millis,
) )
return integration return integration
@ -1915,5 +1990,37 @@ class APIGatewayBackend(BaseBackend):
""" """
return list(self.vpc_links.values()) return list(self.vpc_links.values())
def put_gateway_response(
self,
rest_api_id,
response_type,
status_code,
response_parameters,
response_templates,
):
api = self.get_rest_api(rest_api_id)
response = api.put_gateway_response(
response_type,
status_code=status_code,
response_parameters=response_parameters,
response_templates=response_templates,
)
return response
def get_gateway_response(self, rest_api_id, response_type):
api = self.get_rest_api(rest_api_id)
return api.get_gateway_response(response_type)
def get_gateway_responses(self, rest_api_id):
"""
Pagination is not yet implemented
"""
api = self.get_rest_api(rest_api_id)
return api.get_gateway_responses()
def delete_gateway_response(self, rest_api_id, response_type):
api = self.get_rest_api(rest_api_id)
api.delete_gateway_response(response_type)
apigateway_backends = BackendDict(APIGatewayBackend, "apigateway") apigateway_backends = BackendDict(APIGatewayBackend, "apigateway")

View File

@ -1,31 +1,13 @@
import json import json
from functools import wraps
from urllib.parse import unquote from urllib.parse import unquote
from moto.utilities.utils import merge_multiple_dicts from moto.utilities.utils import merge_multiple_dicts
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import apigateway_backends from .models import apigateway_backends
from .exceptions import ( from .exceptions import (
ApiKeyNotFoundException, ApiGatewayException,
UsagePlanNotFoundException,
BadRequestException,
CrossAccountNotAllowed,
AuthorizerNotFoundException,
StageNotFoundException,
ApiKeyAlreadyExists,
DomainNameNotFound,
InvalidDomainName,
InvalidRestApiId,
InvalidModelName,
RestAPINotFound,
ModelNotFound,
ApiKeyValueMinLength,
InvalidRequestInput, InvalidRequestInput,
NoIntegrationDefined,
NoIntegrationResponseDefined,
NotFoundException,
ConflictException,
InvalidRestApiIdForBasePathMappingException,
InvalidStageException,
) )
API_KEY_SOURCES = ["AUTHORIZER", "HEADER"] API_KEY_SOURCES = ["AUTHORIZER", "HEADER"]
@ -33,6 +15,17 @@ AUTHORIZER_TYPES = ["TOKEN", "REQUEST", "COGNITO_USER_POOLS"]
ENDPOINT_CONFIGURATION_TYPES = ["PRIVATE", "EDGE", "REGIONAL"] ENDPOINT_CONFIGURATION_TYPES = ["PRIVATE", "EDGE", "REGIONAL"]
def error_handler(f):
@wraps(f)
def _wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except ApiGatewayException as e:
return e.code, e.get_headers(), e.get_body()
return _wrapper
class APIGatewayResponse(BaseResponse): class APIGatewayResponse(BaseResponse):
def error(self, type_, message, status=400): def error(self, type_, message, status=400):
headers = self.response_headers or {} headers = self.response_headers or {}
@ -117,6 +110,7 @@ class APIGatewayResponse(BaseResponse):
value = op["value"] value = op["value"]
return self.__validate_api_key_source(value) return self.__validate_api_key_source(value)
@error_handler
def restapis_individual(self, request, full_url, headers): def restapis_individual(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
function_id = self.path.replace("/restapis/", "", 1).split("/")[0] function_id = self.path.replace("/restapis/", "", 1).split("/")[0]
@ -130,16 +124,7 @@ class APIGatewayResponse(BaseResponse):
response = self.__validte_rest_patch_operations(patch_operations) response = self.__validte_rest_patch_operations(patch_operations)
if response is not None: if response is not None:
return response return response
try: rest_api = self.backend.update_rest_api(function_id, patch_operations)
rest_api = self.backend.update_rest_api(function_id, patch_operations)
except RestAPINotFound as error:
return (
error.code,
{},
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
)
return 200, {}, json.dumps(rest_api.to_dict()) return 200, {}, json.dumps(rest_api.to_dict())
@ -155,25 +140,37 @@ class APIGatewayResponse(BaseResponse):
json.dumps({"item": [resource.to_dict() for resource in resources]}), json.dumps({"item": [resource.to_dict() for resource in resources]}),
) )
@error_handler
def gateway_response(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == "PUT":
return self.put_gateway_response()
elif request.method == "GET":
return self.get_gateway_response()
elif request.method == "DELETE":
return self.delete_gateway_response()
def gateway_responses(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == "GET":
return self.get_gateway_responses()
@error_handler
def resource_individual(self, request, full_url, headers): def resource_individual(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
function_id = self.path.replace("/restapis/", "", 1).split("/")[0] function_id = self.path.replace("/restapis/", "", 1).split("/")[0]
resource_id = self.path.split("/")[-1] resource_id = self.path.split("/")[-1]
try: if self.method == "GET":
if self.method == "GET": resource = self.backend.get_resource(function_id, resource_id)
resource = self.backend.get_resource(function_id, resource_id) elif self.method == "POST":
elif self.method == "POST": path_part = self._get_param("pathPart")
path_part = self._get_param("pathPart") resource = self.backend.create_resource(function_id, resource_id, path_part)
resource = self.backend.create_resource( elif self.method == "DELETE":
function_id, resource_id, path_part resource = self.backend.delete_resource(function_id, resource_id)
) return 200, {}, json.dumps(resource.to_dict())
elif self.method == "DELETE":
resource = self.backend.delete_resource(function_id, resource_id)
return 200, {}, json.dumps(resource.to_dict())
except BadRequestException as e:
return self.error("BadRequestException", e.message)
@error_handler
def resource_methods(self, request, full_url, headers): def resource_methods(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/") url_path_parts = self.path.split("/")
@ -182,11 +179,8 @@ class APIGatewayResponse(BaseResponse):
method_type = url_path_parts[6] method_type = url_path_parts[6]
if self.method == "GET": if self.method == "GET":
try: method = self.backend.get_method(function_id, resource_id, method_type)
method = self.backend.get_method(function_id, resource_id, method_type) return 200, {}, json.dumps(method)
return 200, {}, json.dumps(method)
except NotFoundException as nfe:
return self.error("NotFoundException", nfe.message)
elif self.method == "PUT": elif self.method == "PUT":
authorization_type = self._get_param("authorizationType") authorization_type = self._get_param("authorizationType")
api_key_required = self._get_param("apiKeyRequired") api_key_required = self._get_param("apiKeyRequired")
@ -308,54 +302,48 @@ class APIGatewayResponse(BaseResponse):
return 200, {}, json.dumps(authorizer_response) return 200, {}, json.dumps(authorizer_response)
@error_handler
def request_validators(self, request, full_url, headers): def request_validators(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/") url_path_parts = self.path.split("/")
restapi_id = url_path_parts[2] restapi_id = url_path_parts[2]
try:
if self.method == "GET": if self.method == "GET":
validators = self.backend.get_request_validators(restapi_id) validators = self.backend.get_request_validators(restapi_id)
res = json.dumps( res = json.dumps(
{"item": [validator.to_dict() for validator in validators]} {"item": [validator.to_dict() for validator in validators]}
) )
return 200, {}, res return 200, {}, res
if self.method == "POST": if self.method == "POST":
name = self._get_param("name") name = self._get_param("name")
body = self._get_bool_param("validateRequestBody") body = self._get_bool_param("validateRequestBody")
params = self._get_bool_param("validateRequestParameters") params = self._get_bool_param("validateRequestParameters")
validator = self.backend.create_request_validator( validator = self.backend.create_request_validator(
restapi_id, name, body, params restapi_id, name, body, params
) )
return 200, {}, json.dumps(validator) return 200, {}, json.dumps(validator)
except BadRequestException as e:
return self.error("BadRequestException", e.message)
except CrossAccountNotAllowed as e:
return self.error("AccessDeniedException", e.message)
@error_handler
def request_validator_individual(self, request, full_url, headers): def request_validator_individual(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/") url_path_parts = self.path.split("/")
restapi_id = url_path_parts[2] restapi_id = url_path_parts[2]
validator_id = url_path_parts[4] validator_id = url_path_parts[4]
try:
if self.method == "GET":
validator = self.backend.get_request_validator(restapi_id, validator_id)
return 200, {}, json.dumps(validator)
if self.method == "DELETE":
self.backend.delete_request_validator(restapi_id, validator_id)
return 202, {}, ""
if self.method == "PATCH":
patch_operations = self._get_param("patchOperations")
validator = self.backend.update_request_validator(
restapi_id, validator_id, patch_operations
)
return 200, {}, json.dumps(validator)
except BadRequestException as e:
return self.error("BadRequestException", e.message)
except CrossAccountNotAllowed as e:
return self.error("AccessDeniedException", e.message)
if self.method == "GET":
validator = self.backend.get_request_validator(restapi_id, validator_id)
return 200, {}, json.dumps(validator)
if self.method == "DELETE":
self.backend.delete_request_validator(restapi_id, validator_id)
return 202, {}, ""
if self.method == "PATCH":
patch_operations = self._get_param("patchOperations")
validator = self.backend.update_request_validator(
restapi_id, validator_id, patch_operations
)
return 200, {}, json.dumps(validator)
@error_handler
def authorizers(self, request, full_url, headers): def authorizers(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/") url_path_parts = self.path.split("/")
@ -363,18 +351,7 @@ class APIGatewayResponse(BaseResponse):
authorizer_id = url_path_parts[4] authorizer_id = url_path_parts[4]
if self.method == "GET": if self.method == "GET":
try: authorizer_response = self.backend.get_authorizer(restapi_id, authorizer_id)
authorizer_response = self.backend.get_authorizer(
restapi_id, authorizer_id
)
except AuthorizerNotFoundException as error:
return (
error.code,
{},
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
)
elif self.method == "PATCH": elif self.method == "PATCH":
patch_operations = self._get_param("patchOperations") patch_operations = self._get_param("patchOperations")
authorizer_response = self.backend.update_authorizer( authorizer_response = self.backend.update_authorizer(
@ -385,6 +362,7 @@ class APIGatewayResponse(BaseResponse):
return 202, {}, "{}" return 202, {}, "{}"
return 200, {}, json.dumps(authorizer_response) return 200, {}, json.dumps(authorizer_response)
@error_handler
def restapis_stages(self, request, full_url, headers): def restapis_stages(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/") url_path_parts = self.path.split("/")
@ -435,28 +413,27 @@ class APIGatewayResponse(BaseResponse):
stage["tags"].pop(tag, None) stage["tags"].pop(tag, None)
return 200, {}, json.dumps({"item": ""}) return 200, {}, json.dumps({"item": ""})
@error_handler
def stages(self, request, full_url, headers): def stages(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/") url_path_parts = self.path.split("/")
function_id = url_path_parts[2] function_id = url_path_parts[2]
stage_name = url_path_parts[4] stage_name = url_path_parts[4]
try: if self.method == "GET":
if self.method == "GET": stage_response = self.backend.get_stage(function_id, stage_name)
stage_response = self.backend.get_stage(function_id, stage_name)
elif self.method == "PATCH": elif self.method == "PATCH":
patch_operations = self._get_param("patchOperations") patch_operations = self._get_param("patchOperations")
stage_response = self.backend.update_stage( stage_response = self.backend.update_stage(
function_id, stage_name, patch_operations function_id, stage_name, patch_operations
) )
elif self.method == "DELETE": elif self.method == "DELETE":
self.backend.delete_stage(function_id, stage_name) self.backend.delete_stage(function_id, stage_name)
return 202, {}, "{}" return 202, {}, "{}"
return 200, {}, json.dumps(stage_response) return 200, {}, json.dumps(stage_response)
except StageNotFoundException as error:
return error.code, {}, error.get_body()
@error_handler
def integrations(self, request, full_url, headers): def integrations(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/") url_path_parts = self.path.split("/")
@ -464,50 +441,47 @@ class APIGatewayResponse(BaseResponse):
resource_id = url_path_parts[4] resource_id = url_path_parts[4]
method_type = url_path_parts[6] method_type = url_path_parts[6]
try: integration_response = {}
integration_response = {}
if self.method == "GET": if self.method == "GET":
integration_response = self.backend.get_integration( integration_response = self.backend.get_integration(
function_id, resource_id, method_type function_id, resource_id, method_type
) )
elif self.method == "PUT": elif self.method == "PUT":
integration_type = self._get_param("type") integration_type = self._get_param("type")
uri = self._get_param("uri") uri = self._get_param("uri")
credentials = self._get_param("credentials") credentials = self._get_param("credentials")
request_templates = self._get_param("requestTemplates") request_templates = self._get_param("requestTemplates")
tls_config = self._get_param("tlsConfig") tls_config = self._get_param("tlsConfig")
cache_namespace = self._get_param("cacheNamespace") cache_namespace = self._get_param("cacheNamespace")
self.backend.get_method(function_id, resource_id, method_type) timeout_in_millis = self._get_param("timeoutInMillis")
self.backend.get_method(function_id, resource_id, method_type)
integration_http_method = self._get_param( integration_http_method = self._get_param(
"httpMethod" "httpMethod"
) # default removed because it's a required parameter ) # default removed because it's a required parameter
integration_response = self.backend.put_integration( integration_response = self.backend.put_integration(
function_id, function_id,
resource_id, resource_id,
method_type, method_type,
integration_type, integration_type,
uri, uri,
credentials=credentials, credentials=credentials,
integration_method=integration_http_method, integration_method=integration_http_method,
request_templates=request_templates, request_templates=request_templates,
tls_config=tls_config, tls_config=tls_config,
cache_namespace=cache_namespace, cache_namespace=cache_namespace,
) timeout_in_millis=timeout_in_millis,
elif self.method == "DELETE": )
integration_response = self.backend.delete_integration( elif self.method == "DELETE":
function_id, resource_id, method_type integration_response = self.backend.delete_integration(
) function_id, resource_id, method_type
)
return 200, {}, json.dumps(integration_response) return 200, {}, json.dumps(integration_response)
except BadRequestException as e:
return self.error("BadRequestException", e.message)
except CrossAccountNotAllowed as e:
return self.error("AccessDeniedException", e.message)
@error_handler
def integration_responses(self, request, full_url, headers): def integration_responses(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/") url_path_parts = self.path.split("/")
@ -516,94 +490,69 @@ class APIGatewayResponse(BaseResponse):
method_type = url_path_parts[6] method_type = url_path_parts[6]
status_code = url_path_parts[9] status_code = url_path_parts[9]
try: if self.method == "GET":
if self.method == "GET": integration_response = self.backend.get_integration_response(
integration_response = self.backend.get_integration_response( function_id, resource_id, method_type, status_code
function_id, resource_id, method_type, status_code )
) elif self.method == "PUT":
elif self.method == "PUT": if not self.body:
if not self.body: raise InvalidRequestInput()
raise InvalidRequestInput()
selection_pattern = self._get_param("selectionPattern") selection_pattern = self._get_param("selectionPattern")
response_templates = self._get_param("responseTemplates") response_templates = self._get_param("responseTemplates")
content_handling = self._get_param("contentHandling") content_handling = self._get_param("contentHandling")
integration_response = self.backend.put_integration_response( integration_response = self.backend.put_integration_response(
function_id, function_id,
resource_id, resource_id,
method_type, method_type,
status_code, status_code,
selection_pattern, selection_pattern,
response_templates, response_templates,
content_handling, content_handling,
) )
elif self.method == "DELETE": elif self.method == "DELETE":
integration_response = self.backend.delete_integration_response( integration_response = self.backend.delete_integration_response(
function_id, resource_id, method_type, status_code function_id, resource_id, method_type, status_code
) )
return 200, {}, json.dumps(integration_response) return 200, {}, json.dumps(integration_response)
except BadRequestException as e:
return self.error("BadRequestException", e.message)
except (NoIntegrationDefined, NoIntegrationResponseDefined) as e:
return self.error("NotFoundException", e.message)
@error_handler
def deployments(self, request, full_url, headers): def deployments(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
function_id = self.path.replace("/restapis/", "", 1).split("/")[0] function_id = self.path.replace("/restapis/", "", 1).split("/")[0]
try: if self.method == "GET":
if self.method == "GET": deployments = self.backend.get_deployments(function_id)
deployments = self.backend.get_deployments(function_id) return 200, {}, json.dumps({"item": deployments})
return 200, {}, json.dumps({"item": deployments}) elif self.method == "POST":
elif self.method == "POST": name = self._get_param("stageName")
name = self._get_param("stageName") description = self._get_param("description")
description = self._get_param("description", if_none="") stage_variables = self._get_param("variables", if_none={})
stage_variables = self._get_param("variables", if_none={}) deployment = self.backend.create_deployment(
deployment = self.backend.create_deployment( function_id, name, description, stage_variables
function_id, name, description, stage_variables )
) return 200, {}, json.dumps(deployment)
return 200, {}, json.dumps(deployment)
except BadRequestException as e:
return self.error("BadRequestException", e.message)
except NotFoundException as e:
return self.error("NotFoundException", e.message)
@error_handler
def individual_deployment(self, request, full_url, headers): def individual_deployment(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/") url_path_parts = self.path.split("/")
function_id = url_path_parts[2] function_id = url_path_parts[2]
deployment_id = url_path_parts[4] deployment_id = url_path_parts[4]
deployment = None
if self.method == "GET": if self.method == "GET":
deployment = self.backend.get_deployment(function_id, deployment_id) deployment = self.backend.get_deployment(function_id, deployment_id)
return 200, {}, json.dumps(deployment)
elif self.method == "DELETE": elif self.method == "DELETE":
deployment = self.backend.delete_deployment(function_id, deployment_id) deployment = self.backend.delete_deployment(function_id, deployment_id)
return 200, {}, json.dumps(deployment) return 202, {}, json.dumps(deployment)
@error_handler
def apikeys(self, request, full_url, headers): def apikeys(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == "POST": if self.method == "POST":
try: apikey_response = self.backend.create_api_key(json.loads(self.body))
apikey_response = self.backend.create_api_key(json.loads(self.body))
except ApiKeyAlreadyExists as error:
return (
error.code,
{},
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
)
except ApiKeyValueMinLength as error:
return (
error.code,
{},
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
)
return 201, {}, json.dumps(apikey_response) return 201, {}, json.dumps(apikey_response)
elif self.method == "GET": elif self.method == "GET":
@ -611,6 +560,7 @@ class APIGatewayResponse(BaseResponse):
apikeys_response = self.backend.get_api_keys(include_values=include_values) apikeys_response = self.backend.get_api_keys(include_values=include_values)
return 200, {}, json.dumps({"item": apikeys_response}) return 200, {}, json.dumps({"item": apikeys_response})
@error_handler
def apikey_individual(self, request, full_url, headers): def apikey_individual(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -620,12 +570,9 @@ class APIGatewayResponse(BaseResponse):
status_code = 200 status_code = 200
if self.method == "GET": if self.method == "GET":
include_value = self._get_bool_param("includeValue") include_value = self._get_bool_param("includeValue")
try: apikey_response = self.backend.get_api_key(
apikey_response = self.backend.get_api_key( apikey, include_value=include_value
apikey, include_value=include_value )
)
except ApiKeyNotFoundException as e:
return self.error("NotFoundException", e.message)
elif self.method == "PATCH": elif self.method == "PATCH":
patch_operations = self._get_param("patchOperations") patch_operations = self._get_param("patchOperations")
apikey_response = self.backend.update_api_key(apikey, patch_operations) apikey_response = self.backend.update_api_key(apikey, patch_operations)
@ -645,6 +592,7 @@ class APIGatewayResponse(BaseResponse):
return 200, {}, json.dumps({"item": usage_plans_response}) return 200, {}, json.dumps({"item": usage_plans_response})
return 200, {}, json.dumps(usage_plan_response) return 200, {}, json.dumps(usage_plan_response)
@error_handler
def usage_plan_individual(self, request, full_url, headers): def usage_plan_individual(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -652,16 +600,7 @@ class APIGatewayResponse(BaseResponse):
usage_plan = url_path_parts[2] usage_plan = url_path_parts[2]
if self.method == "GET": if self.method == "GET":
try: usage_plan_response = self.backend.get_usage_plan(usage_plan)
usage_plan_response = self.backend.get_usage_plan(usage_plan)
except (UsagePlanNotFoundException) as error:
return (
error.code,
{},
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
)
elif self.method == "DELETE": elif self.method == "DELETE":
usage_plan_response = self.backend.delete_usage_plan(usage_plan) usage_plan_response = self.backend.delete_usage_plan(usage_plan)
elif self.method == "PATCH": elif self.method == "PATCH":
@ -671,6 +610,7 @@ class APIGatewayResponse(BaseResponse):
) )
return 200, {}, json.dumps(usage_plan_response) return 200, {}, json.dumps(usage_plan_response)
@error_handler
def usage_plan_keys(self, request, full_url, headers): def usage_plan_keys(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -678,23 +618,15 @@ class APIGatewayResponse(BaseResponse):
usage_plan_id = url_path_parts[2] usage_plan_id = url_path_parts[2]
if self.method == "POST": if self.method == "POST":
try: usage_plan_response = self.backend.create_usage_plan_key(
usage_plan_response = self.backend.create_usage_plan_key( usage_plan_id, json.loads(self.body)
usage_plan_id, json.loads(self.body) )
)
except ApiKeyNotFoundException as error:
return (
error.code,
{},
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
)
return 201, {}, json.dumps(usage_plan_response) return 201, {}, json.dumps(usage_plan_response)
elif self.method == "GET": elif self.method == "GET":
usage_plans_response = self.backend.get_usage_plan_keys(usage_plan_id) usage_plans_response = self.backend.get_usage_plan_keys(usage_plan_id)
return 200, {}, json.dumps({"item": usage_plans_response}) return 200, {}, json.dumps({"item": usage_plans_response})
@error_handler
def usage_plan_key_individual(self, request, full_url, headers): def usage_plan_key_individual(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -703,183 +635,133 @@ class APIGatewayResponse(BaseResponse):
key_id = url_path_parts[4] key_id = url_path_parts[4]
if self.method == "GET": if self.method == "GET":
try: usage_plan_response = self.backend.get_usage_plan_key(usage_plan_id, key_id)
usage_plan_response = self.backend.get_usage_plan_key(
usage_plan_id, key_id
)
except (UsagePlanNotFoundException, ApiKeyNotFoundException) as error:
return (
error.code,
{},
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
)
elif self.method == "DELETE": elif self.method == "DELETE":
usage_plan_response = self.backend.delete_usage_plan_key( usage_plan_response = self.backend.delete_usage_plan_key(
usage_plan_id, key_id usage_plan_id, key_id
) )
return 200, {}, json.dumps(usage_plan_response) return 200, {}, json.dumps(usage_plan_response)
@error_handler
def domain_names(self, request, full_url, headers): def domain_names(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
try: if self.method == "GET":
if self.method == "GET": domain_names = self.backend.get_domain_names()
domain_names = self.backend.get_domain_names() return 200, {}, json.dumps({"item": domain_names})
return 200, {}, json.dumps({"item": domain_names})
elif self.method == "POST": elif self.method == "POST":
domain_name = self._get_param("domainName") domain_name = self._get_param("domainName")
certificate_name = self._get_param("certificateName") certificate_name = self._get_param("certificateName")
tags = self._get_param("tags") tags = self._get_param("tags")
certificate_arn = self._get_param("certificateArn") certificate_arn = self._get_param("certificateArn")
certificate_body = self._get_param("certificateBody") certificate_body = self._get_param("certificateBody")
certificate_private_key = self._get_param("certificatePrivateKey") certificate_private_key = self._get_param("certificatePrivateKey")
certificate_chain = self._get_param("certificateChain") certificate_chain = self._get_param("certificateChain")
regional_certificate_name = self._get_param("regionalCertificateName") regional_certificate_name = self._get_param("regionalCertificateName")
regional_certificate_arn = self._get_param("regionalCertificateArn") regional_certificate_arn = self._get_param("regionalCertificateArn")
endpoint_configuration = self._get_param("endpointConfiguration") endpoint_configuration = self._get_param("endpointConfiguration")
security_policy = self._get_param("securityPolicy") security_policy = self._get_param("securityPolicy")
generate_cli_skeleton = self._get_param("generateCliSkeleton") generate_cli_skeleton = self._get_param("generateCliSkeleton")
domain_name_resp = self.backend.create_domain_name( domain_name_resp = self.backend.create_domain_name(
domain_name, domain_name,
certificate_name, certificate_name,
tags, tags,
certificate_arn, certificate_arn,
certificate_body, certificate_body,
certificate_private_key, certificate_private_key,
certificate_chain, certificate_chain,
regional_certificate_name, regional_certificate_name,
regional_certificate_arn, regional_certificate_arn,
endpoint_configuration, endpoint_configuration,
security_policy, security_policy,
generate_cli_skeleton, generate_cli_skeleton,
)
return 200, {}, json.dumps(domain_name_resp)
except InvalidDomainName as error:
return (
error.code,
{},
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
) )
return 200, {}, json.dumps(domain_name_resp)
@error_handler
def domain_name_induvidual(self, request, full_url, headers): def domain_name_induvidual(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/") url_path_parts = self.path.split("/")
domain_name = url_path_parts[2] domain_name = url_path_parts[2]
domain_names = {} domain_names = {}
try:
if self.method == "GET":
if domain_name is not None:
domain_names = self.backend.get_domain_name(domain_name)
elif self.method == "DELETE":
if domain_name is not None:
self.backend.delete_domain_name(domain_name)
elif self.method == "PATCH":
if domain_name is not None:
patch_operations = self._get_param("patchOperations")
self.backend.update_domain_name(domain_name, patch_operations)
else:
msg = (
'Method "%s" for API GW domain names not implemented' % self.method
)
return 404, {}, json.dumps({"error": msg})
return 200, {}, json.dumps(domain_names)
except DomainNameNotFound as error:
return self.error("NotFoundException", error.message)
if self.method == "GET":
if domain_name is not None:
domain_names = self.backend.get_domain_name(domain_name)
elif self.method == "DELETE":
if domain_name is not None:
self.backend.delete_domain_name(domain_name)
elif self.method == "PATCH":
if domain_name is not None:
patch_operations = self._get_param("patchOperations")
self.backend.update_domain_name(domain_name, patch_operations)
else:
msg = 'Method "%s" for API GW domain names not implemented' % self.method
return 404, {}, json.dumps({"error": msg})
return 200, {}, json.dumps(domain_names)
@error_handler
def models(self, request, full_url, headers): def models(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
rest_api_id = self.path.replace("/restapis/", "", 1).split("/")[0] rest_api_id = self.path.replace("/restapis/", "", 1).split("/")[0]
try: if self.method == "GET":
if self.method == "GET": models = self.backend.get_models(rest_api_id)
models = self.backend.get_models(rest_api_id) return 200, {}, json.dumps({"item": models})
return 200, {}, json.dumps({"item": models})
elif self.method == "POST": elif self.method == "POST":
name = self._get_param("name") name = self._get_param("name")
description = self._get_param("description") description = self._get_param("description")
schema = self._get_param("schema") schema = self._get_param("schema")
content_type = self._get_param("contentType") content_type = self._get_param("contentType")
cli_input_json = self._get_param("cliInputJson") cli_input_json = self._get_param("cliInputJson")
generate_cli_skeleton = self._get_param("generateCliSkeleton") generate_cli_skeleton = self._get_param("generateCliSkeleton")
model = self.backend.create_model( model = self.backend.create_model(
rest_api_id, rest_api_id,
name, name,
content_type, content_type,
description, description,
schema, schema,
cli_input_json, cli_input_json,
generate_cli_skeleton, generate_cli_skeleton,
)
return 200, {}, json.dumps(model)
except (InvalidRestApiId, InvalidModelName, RestAPINotFound) as error:
return (
error.code,
{},
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
) )
return 200, {}, json.dumps(model)
@error_handler
def model_induvidual(self, request, full_url, headers): def model_induvidual(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/") url_path_parts = self.path.split("/")
rest_api_id = url_path_parts[2] rest_api_id = url_path_parts[2]
model_name = url_path_parts[4] model_name = url_path_parts[4]
model_info = {} model_info = {}
try: if self.method == "GET":
if self.method == "GET": model_info = self.backend.get_model(rest_api_id, model_name)
model_info = self.backend.get_model(rest_api_id, model_name) return 200, {}, json.dumps(model_info)
return 200, {}, json.dumps(model_info)
except (
ModelNotFound,
RestAPINotFound,
InvalidRestApiId,
InvalidModelName,
) as error:
return (
error.code,
{},
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
)
@error_handler
def base_path_mappings(self, request, full_url, headers): def base_path_mappings(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/") url_path_parts = self.path.split("/")
domain_name = url_path_parts[2] domain_name = url_path_parts[2]
try: if self.method == "GET":
if self.method == "GET": base_path_mappings = self.backend.get_base_path_mappings(domain_name)
base_path_mappings = self.backend.get_base_path_mappings(domain_name) return 200, {}, json.dumps({"item": base_path_mappings})
return 200, {}, json.dumps({"item": base_path_mappings}) elif self.method == "POST":
elif self.method == "POST": base_path = self._get_param("basePath")
base_path = self._get_param("basePath") rest_api_id = self._get_param("restApiId")
rest_api_id = self._get_param("restApiId") stage = self._get_param("stage")
stage = self._get_param("stage")
base_path_mapping_resp = self.backend.create_base_path_mapping( base_path_mapping_resp = self.backend.create_base_path_mapping(
domain_name, rest_api_id, base_path, stage, domain_name, rest_api_id, base_path, stage,
) )
return 201, {}, json.dumps(base_path_mapping_resp) return 201, {}, json.dumps(base_path_mapping_resp)
except BadRequestException as e:
return self.error("BadRequestException", e.message)
except NotFoundException as e:
return self.error("NotFoundException", e.message, 404)
except ConflictException as e:
return self.error("ConflictException", e.message, 409)
@error_handler
def base_path_mapping_individual(self, request, full_url, headers): def base_path_mapping_individual(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -888,42 +770,33 @@ class APIGatewayResponse(BaseResponse):
domain_name = url_path_parts[2] domain_name = url_path_parts[2]
base_path = unquote(url_path_parts[4]) base_path = unquote(url_path_parts[4])
try: if self.method == "GET":
if self.method == "GET": base_path_mapping = self.backend.get_base_path_mapping(
base_path_mapping = self.backend.get_base_path_mapping( domain_name, base_path
domain_name, base_path )
)
return 200, {}, json.dumps(base_path_mapping)
elif self.method == "DELETE":
self.backend.delete_base_path_mapping(domain_name, base_path)
return 202, {}, ""
elif self.method == "PATCH":
patch_operations = self._get_param("patchOperations")
base_path_mapping = self.backend.update_base_path_mapping(
domain_name, base_path, patch_operations
)
return 200, {}, json.dumps(base_path_mapping) return 200, {}, json.dumps(base_path_mapping)
except NotFoundException as e: elif self.method == "DELETE":
return self.error("NotFoundException", e.message, 404) self.backend.delete_base_path_mapping(domain_name, base_path)
except InvalidRestApiIdForBasePathMappingException as e: return 202, {}, ""
return self.error("BadRequestException", e.message) elif self.method == "PATCH":
except InvalidStageException as e: patch_operations = self._get_param("patchOperations")
return self.error("BadRequestException", e.message) base_path_mapping = self.backend.update_base_path_mapping(
domain_name, base_path, patch_operations
)
return 200, {}, json.dumps(base_path_mapping)
@error_handler
def vpc_link(self, request, full_url, headers): def vpc_link(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/") url_path_parts = self.path.split("/")
vpc_link_id = url_path_parts[-1] vpc_link_id = url_path_parts[-1]
try: if self.method == "DELETE":
if self.method == "DELETE": self.backend.delete_vpc_link(vpc_link_id=vpc_link_id)
self.backend.delete_vpc_link(vpc_link_id=vpc_link_id) return 200, {}, "{}"
return 200, {}, "{}" if self.method == "GET":
if self.method == "GET": vpc_link = self.backend.get_vpc_link(vpc_link_id=vpc_link_id)
vpc_link = self.backend.get_vpc_link(vpc_link_id=vpc_link_id) return 200, {}, json.dumps(vpc_link)
return 200, {}, json.dumps(vpc_link)
except NotFoundException as e:
return self.error("NotFoundException", e.message, 404)
def vpc_links(self, request, full_url, headers): def vpc_links(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -940,3 +813,40 @@ class APIGatewayResponse(BaseResponse):
name=name, description=description, target_arns=target_arns, tags=tags name=name, description=description, target_arns=target_arns, tags=tags
) )
return 200, {}, json.dumps(vpc_link) return 200, {}, json.dumps(vpc_link)
def put_gateway_response(self):
rest_api_id = self.path.split("/")[-3]
response_type = self.path.split("/")[-1]
params = json.loads(self.body)
status_code = params.get("statusCode")
response_parameters = params.get("responseParameters")
response_templates = params.get("responseTemplates")
response = self.backend.put_gateway_response(
rest_api_id=rest_api_id,
response_type=response_type,
status_code=status_code,
response_parameters=response_parameters,
response_templates=response_templates,
)
return 200, {}, json.dumps(response)
def get_gateway_response(self):
rest_api_id = self.path.split("/")[-3]
response_type = self.path.split("/")[-1]
response = self.backend.get_gateway_response(
rest_api_id=rest_api_id, response_type=response_type,
)
return 200, {}, json.dumps(response)
def get_gateway_responses(self):
rest_api_id = self.path.split("/")[-2]
responses = self.backend.get_gateway_responses(rest_api_id=rest_api_id,)
return 200, {}, json.dumps(dict(item=responses))
def delete_gateway_response(self):
rest_api_id = self.path.split("/")[-3]
response_type = self.path.split("/")[-1]
self.backend.delete_gateway_response(
rest_api_id=rest_api_id, response_type=response_type,
)
return 202, {}, json.dumps(dict())

View File

@ -35,6 +35,8 @@ url_paths = {
"{0}/usageplans/(?P<usage_plan_id>[^/]+)/keys/(?P<api_key_id>[^/]+)/?$": response.usage_plan_key_individual, "{0}/usageplans/(?P<usage_plan_id>[^/]+)/keys/(?P<api_key_id>[^/]+)/?$": response.usage_plan_key_individual,
"{0}/restapis/(?P<function_id>[^/]+)/requestvalidators$": response.request_validators, "{0}/restapis/(?P<function_id>[^/]+)/requestvalidators$": response.request_validators,
"{0}/restapis/(?P<api_id>[^/]+)/requestvalidators/(?P<validator_id>[^/]+)/?$": response.request_validator_individual, "{0}/restapis/(?P<api_id>[^/]+)/requestvalidators/(?P<validator_id>[^/]+)/?$": response.request_validator_individual,
"{0}/restapis/(?P<api_id>[^/]+)/gatewayresponses/?$": response.gateway_responses,
"{0}/restapis/(?P<api_id>[^/]+)/gatewayresponses/(?P<response_type>[^/]+)/?$": response.gateway_response,
"{0}/vpclinks$": response.vpc_links, "{0}/vpclinks$": response.vpc_links,
"{0}/vpclinks/(?P<vpclink_id>[^/]+)": response.vpc_link, "{0}/vpclinks/(?P<vpclink_id>[^/]+)": response.vpc_link,
} }

View File

@ -151,7 +151,7 @@ class DeliveryStream(
del self.destinations[0][destination_name]["S3Configuration"] del self.destinations[0][destination_name]["S3Configuration"]
self.delivery_stream_status = "ACTIVE" self.delivery_stream_status = "ACTIVE"
self.delivery_stream_arn = f"arn:aws:firehose:{region}:{ACCOUNT_ID}:/delivery_stream/{delivery_stream_name}" self.delivery_stream_arn = f"arn:aws:firehose:{region}:{ACCOUNT_ID}:deliverystream/{delivery_stream_name}"
self.create_timestamp = datetime.now(timezone.utc).isoformat() self.create_timestamp = datetime.now(timezone.utc).isoformat()
self.version_id = "1" # Used to track updates of destination configs self.version_id = "1" # Used to track updates of destination configs

View File

@ -3,7 +3,6 @@ TestAccAWSEc2TransitGatewayPeeringAttachmentAccepter
TestAccAWSEc2TransitGatewayRouteTableAssociation TestAccAWSEc2TransitGatewayRouteTableAssociation
TestAccAWSEc2TransitGatewayVpcAttachment TestAccAWSEc2TransitGatewayVpcAttachment
TestAccAWSFms TestAccAWSFms
TestAccAWSIAMRolePolicy
TestAccAWSSecurityGroup_forceRevokeRules_ TestAccAWSSecurityGroup_forceRevokeRules_
TestAccAWSDefaultSecurityGroup_Classic_ TestAccAWSDefaultSecurityGroup_Classic_
TestAccDataSourceAwsNetworkInterface_CarrierIPAssociation TestAccDataSourceAwsNetworkInterface_CarrierIPAssociation

View File

@ -94,9 +94,8 @@ TestAccAWSUserGroupMembership
TestAccAWSUserPolicyAttachment TestAccAWSUserPolicyAttachment
TestAccAWSUserSSHKey TestAccAWSUserSSHKey
TestAccAWSVpc_ TestAccAWSVpc_
TestAccAWSAPIGatewayStage_basic TestAccAWSAPIGatewayGatewayResponse
TestAccAWSAPIGatewayStage_accessLogSettings_kinesis TestAccAWSAPIGatewayStage
TestAccAWSAPIGatewayStage_accessLogSettings
TestAccAWSSsmDocumentDataSource TestAccAWSSsmDocumentDataSource
TestAccAwsEc2ManagedPrefixList TestAccAwsEc2ManagedPrefixList
TestAccAWSEgressOnlyInternetGateway TestAccAWSEgressOnlyInternetGateway

View File

@ -5,7 +5,7 @@ from freezegun import freeze_time
import sure # noqa # pylint: disable=unused-import import sure # noqa # pylint: disable=unused-import
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
from moto import mock_apigateway, mock_cognitoidp, settings from moto import mock_apigateway, mock_cognitoidp
from moto.core import ACCOUNT_ID from moto.core import ACCOUNT_ID
import pytest import pytest
@ -576,17 +576,15 @@ def test_integrations():
uri=test_uri, uri=test_uri,
requestTemplates=templates, requestTemplates=templates,
integrationHttpMethod="POST", integrationHttpMethod="POST",
timeoutInMillis=29000,
) )
# this is hard to match against, so remove it
response["ResponseMetadata"].pop("HTTPHeaders", None)
response["ResponseMetadata"].pop("RetryAttempts", None)
response["ResponseMetadata"].should.equal({"HTTPStatusCode": 200})
response = client.get_integration( response = client.get_integration(
restApiId=api_id, resourceId=root_id, httpMethod="POST" restApiId=api_id, resourceId=root_id, httpMethod="POST"
) )
response["uri"].should.equal(test_uri) response["uri"].should.equal(test_uri)
response["requestTemplates"].should.equal(templates) response["requestTemplates"].should.equal(templates)
response.should.have.key("timeoutInMillis").equals(29000)
@mock_apigateway @mock_apigateway
@ -980,354 +978,6 @@ def test_delete_authorizer():
) )
@mock_apigateway
def test_update_stage_configuration():
client = boto3.client("apigateway", region_name="us-west-2")
stage_name = "staging"
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
create_method_integration(client, api_id)
response = client.create_deployment(
restApiId=api_id, stageName=stage_name, description="1.0.1"
)
deployment_id = response["id"]
response = client.get_deployment(restApiId=api_id, deploymentId=deployment_id)
# createdDate is hard to match against, remove it
response.pop("createdDate", None)
# this is hard to match against, so remove it
response["ResponseMetadata"].pop("HTTPHeaders", None)
response["ResponseMetadata"].pop("RetryAttempts", None)
response.should.equal(
{
"id": deployment_id,
"ResponseMetadata": {"HTTPStatusCode": 200},
"description": "1.0.1",
}
)
response = client.create_deployment(
restApiId=api_id, stageName=stage_name, description="1.0.2"
)
deployment_id2 = response["id"]
stage = client.get_stage(restApiId=api_id, stageName=stage_name)
stage["stageName"].should.equal(stage_name)
stage["deploymentId"].should.equal(deployment_id2)
stage.shouldnt.have.key("cacheClusterSize")
client.update_stage(
restApiId=api_id,
stageName=stage_name,
patchOperations=[
{"op": "replace", "path": "/cacheClusterEnabled", "value": "True"}
],
)
stage = client.get_stage(restApiId=api_id, stageName=stage_name)
stage.should.have.key("cacheClusterSize").which.should.equal("0.5")
client.update_stage(
restApiId=api_id,
stageName=stage_name,
patchOperations=[
{"op": "replace", "path": "/cacheClusterSize", "value": "1.6"}
],
)
stage = client.get_stage(restApiId=api_id, stageName=stage_name)
stage.should.have.key("cacheClusterSize").which.should.equal("1.6")
client.update_stage(
restApiId=api_id,
stageName=stage_name,
patchOperations=[
{"op": "replace", "path": "/deploymentId", "value": deployment_id},
{"op": "replace", "path": "/variables/environment", "value": "dev"},
{"op": "replace", "path": "/variables/region", "value": "eu-west-1"},
{"op": "replace", "path": "/*/*/caching/dataEncrypted", "value": "True"},
{"op": "replace", "path": "/cacheClusterEnabled", "value": "True"},
{
"op": "replace",
"path": "/description",
"value": "stage description update",
},
{"op": "replace", "path": "/cacheClusterSize", "value": "1.6"},
],
)
client.update_stage(
restApiId=api_id,
stageName=stage_name,
patchOperations=[
{"op": "remove", "path": "/variables/region", "value": "eu-west-1"}
],
)
stage = client.get_stage(restApiId=api_id, stageName=stage_name)
stage["description"].should.match("stage description update")
stage["cacheClusterSize"].should.equal("1.6")
stage["variables"]["environment"].should.match("dev")
stage["variables"].should_not.have.key("region")
stage["cacheClusterEnabled"].should.be.true
stage["deploymentId"].should.match(deployment_id)
stage["methodSettings"].should.have.key("*/*")
stage["methodSettings"]["*/*"].should.have.key(
"cacheDataEncrypted"
).which.should.be.true
try:
client.update_stage(
restApiId=api_id,
stageName=stage_name,
patchOperations=[
{"op": "add", "path": "/notasetting", "value": "eu-west-1"}
],
)
assert False.should.be.ok # Fail, should not be here
except Exception:
assert True.should.be.ok
@mock_apigateway
def test_non_existent_stage():
client = boto3.client("apigateway", region_name="us-west-2")
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
client.get_stage.when.called_with(restApiId=api_id, stageName="xxx").should.throw(
ClientError
)
@mock_apigateway
def test_create_stage():
client = boto3.client("apigateway", region_name="us-west-2")
stage_name = "staging"
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
create_method_integration(client, api_id)
response = client.create_deployment(restApiId=api_id, stageName=stage_name)
deployment_id = response["id"]
response = client.get_deployment(restApiId=api_id, deploymentId=deployment_id)
# createdDate is hard to match against, remove it
response.pop("createdDate", None)
# this is hard to match against, so remove it
response["ResponseMetadata"].pop("HTTPHeaders", None)
response["ResponseMetadata"].pop("RetryAttempts", None)
response.should.equal(
{
"id": deployment_id,
"ResponseMetadata": {"HTTPStatusCode": 200},
"description": "",
}
)
response = client.create_deployment(restApiId=api_id, stageName=stage_name)
deployment_id2 = response["id"]
response = client.get_deployments(restApiId=api_id)
# this is hard to match against, so remove it
response["ResponseMetadata"].pop("HTTPHeaders", None)
response["ResponseMetadata"].pop("RetryAttempts", None)
response["items"][0].pop("createdDate")
response["items"][1].pop("createdDate")
response["items"][0]["id"].should.match(
r"{0}|{1}".format(deployment_id2, deployment_id)
)
response["items"][1]["id"].should.match(
r"{0}|{1}".format(deployment_id2, deployment_id)
)
new_stage_name = "current"
response = client.create_stage(
restApiId=api_id, stageName=new_stage_name, deploymentId=deployment_id2
)
# this is hard to match against, so remove it
response["ResponseMetadata"].pop("HTTPHeaders", None)
response["ResponseMetadata"].pop("RetryAttempts", None)
response.should.equal(
{
"stageName": new_stage_name,
"deploymentId": deployment_id2,
"methodSettings": {},
"variables": {},
"ResponseMetadata": {"HTTPStatusCode": 200},
"description": "",
"cacheClusterEnabled": False,
}
)
stage = client.get_stage(restApiId=api_id, stageName=new_stage_name)
stage["stageName"].should.equal(new_stage_name)
stage["deploymentId"].should.equal(deployment_id2)
new_stage_name_with_vars = "stage_with_vars"
response = client.create_stage(
restApiId=api_id,
stageName=new_stage_name_with_vars,
deploymentId=deployment_id2,
variables={"env": "dev"},
)
# this is hard to match against, so remove it
response["ResponseMetadata"].pop("HTTPHeaders", None)
response["ResponseMetadata"].pop("RetryAttempts", None)
response.should.equal(
{
"stageName": new_stage_name_with_vars,
"deploymentId": deployment_id2,
"methodSettings": {},
"variables": {"env": "dev"},
"ResponseMetadata": {"HTTPStatusCode": 200},
"description": "",
"cacheClusterEnabled": False,
}
)
stage = client.get_stage(restApiId=api_id, stageName=new_stage_name_with_vars)
stage["stageName"].should.equal(new_stage_name_with_vars)
stage["deploymentId"].should.equal(deployment_id2)
stage["variables"].should.have.key("env").which.should.match("dev")
new_stage_name = "stage_with_vars_and_cache_settings"
response = client.create_stage(
restApiId=api_id,
stageName=new_stage_name,
deploymentId=deployment_id2,
variables={"env": "dev"},
cacheClusterEnabled=True,
description="hello moto",
)
# this is hard to match against, so remove it
response["ResponseMetadata"].pop("HTTPHeaders", None)
response["ResponseMetadata"].pop("RetryAttempts", None)
response.should.equal(
{
"stageName": new_stage_name,
"deploymentId": deployment_id2,
"methodSettings": {},
"variables": {"env": "dev"},
"ResponseMetadata": {"HTTPStatusCode": 200},
"description": "hello moto",
"cacheClusterEnabled": True,
"cacheClusterSize": "0.5",
}
)
stage = client.get_stage(restApiId=api_id, stageName=new_stage_name)
stage["cacheClusterSize"].should.equal("0.5")
new_stage_name = "stage_with_vars_and_cache_settings_and_size"
response = client.create_stage(
restApiId=api_id,
stageName=new_stage_name,
deploymentId=deployment_id2,
variables={"env": "dev"},
cacheClusterEnabled=True,
cacheClusterSize="1.6",
description="hello moto",
)
# this is hard to match against, so remove it
response["ResponseMetadata"].pop("HTTPHeaders", None)
response["ResponseMetadata"].pop("RetryAttempts", None)
response.should.equal(
{
"stageName": new_stage_name,
"deploymentId": deployment_id2,
"methodSettings": {},
"variables": {"env": "dev"},
"ResponseMetadata": {"HTTPStatusCode": 200},
"description": "hello moto",
"cacheClusterEnabled": True,
"cacheClusterSize": "1.6",
}
)
stage = client.get_stage(restApiId=api_id, stageName=new_stage_name)
stage["stageName"].should.equal(new_stage_name)
stage["deploymentId"].should.equal(deployment_id2)
stage["variables"].should.have.key("env").which.should.match("dev")
stage["cacheClusterSize"].should.equal("1.6")
@mock_apigateway
def test_create_deployment_requires_REST_methods():
client = boto3.client("apigateway", region_name="us-west-2")
stage_name = "staging"
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
with pytest.raises(ClientError) as ex:
client.create_deployment(restApiId=api_id, stageName=stage_name)["id"]
ex.value.response["Error"]["Code"].should.equal("BadRequestException")
ex.value.response["Error"]["Message"].should.equal(
"The REST API doesn't contain any methods"
)
@mock_apigateway
def test_create_deployment_requires_REST_method_integrations():
client = boto3.client("apigateway", region_name="us-west-2")
stage_name = "staging"
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
resources = client.get_resources(restApiId=api_id)
root_id = [resource for resource in resources["items"] if resource["path"] == "/"][
0
]["id"]
client.put_method(
restApiId=api_id, resourceId=root_id, httpMethod="GET", authorizationType="NONE"
)
with pytest.raises(ClientError) as ex:
client.create_deployment(restApiId=api_id, stageName=stage_name)["id"]
ex.value.response["Error"]["Code"].should.equal("NotFoundException")
ex.value.response["Error"]["Message"].should.equal(
"No integration defined for method"
)
@mock_apigateway
def test_create_simple_deployment_with_get_method():
client = boto3.client("apigateway", region_name="us-west-2")
stage_name = "staging"
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
create_method_integration(client, api_id)
deployment = client.create_deployment(restApiId=api_id, stageName=stage_name)
assert "id" in deployment
@mock_apigateway
def test_create_simple_deployment_with_post_method():
client = boto3.client("apigateway", region_name="us-west-2")
stage_name = "staging"
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
create_method_integration(client, api_id, httpMethod="POST")
deployment = client.create_deployment(restApiId=api_id, stageName=stage_name)
assert "id" in deployment
@mock_apigateway @mock_apigateway
def test_put_integration_response_with_response_template(): def test_put_integration_response_with_response_template():
client = boto3.client("apigateway", region_name="us-west-2") client = boto3.client("apigateway", region_name="us-west-2")
@ -1553,101 +1203,6 @@ def test_put_integration_validation():
) )
@mock_apigateway
def test_delete_stage():
client = boto3.client("apigateway", region_name="us-west-2")
stage_name = "staging"
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
create_method_integration(client, api_id)
deployment_id1 = client.create_deployment(restApiId=api_id, stageName=stage_name)[
"id"
]
deployment_id2 = client.create_deployment(restApiId=api_id, stageName=stage_name)[
"id"
]
new_stage_name = "current"
client.create_stage(
restApiId=api_id, stageName=new_stage_name, deploymentId=deployment_id1
)
new_stage_name_with_vars = "stage_with_vars"
client.create_stage(
restApiId=api_id,
stageName=new_stage_name_with_vars,
deploymentId=deployment_id2,
variables={"env": "dev"},
)
stages = client.get_stages(restApiId=api_id)["item"]
sorted([stage["stageName"] for stage in stages]).should.equal(
sorted([new_stage_name, new_stage_name_with_vars, stage_name])
)
# delete stage
response = client.delete_stage(restApiId=api_id, stageName=new_stage_name_with_vars)
response["ResponseMetadata"]["HTTPStatusCode"].should.equal(202)
# verify other stage still exists
stages = client.get_stages(restApiId=api_id)["item"]
sorted([stage["stageName"] for stage in stages]).should.equal(
sorted([new_stage_name, stage_name])
)
@mock_apigateway
def test_deployment():
client = boto3.client("apigateway", region_name="us-west-2")
stage_name = "staging"
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
create_method_integration(client, api_id)
response = client.create_deployment(restApiId=api_id, stageName=stage_name)
deployment_id = response["id"]
response = client.get_deployment(restApiId=api_id, deploymentId=deployment_id)
# createdDate is hard to match against, remove it
response.pop("createdDate", None)
# this is hard to match against, so remove it
response["ResponseMetadata"].pop("HTTPHeaders", None)
response["ResponseMetadata"].pop("RetryAttempts", None)
response.should.equal(
{
"id": deployment_id,
"ResponseMetadata": {"HTTPStatusCode": 200},
"description": "",
}
)
response = client.get_deployments(restApiId=api_id)
response["items"][0].pop("createdDate")
response["items"].should.equal([{"id": deployment_id, "description": ""}])
client.delete_deployment(restApiId=api_id, deploymentId=deployment_id)
response = client.get_deployments(restApiId=api_id)
len(response["items"]).should.equal(0)
# test deployment stages
stage = client.get_stage(restApiId=api_id, stageName=stage_name)
stage["stageName"].should.equal(stage_name)
stage["deploymentId"].should.equal(deployment_id)
client.update_stage(
restApiId=api_id,
stageName=stage_name,
patchOperations=[
{"op": "replace", "path": "/description", "value": "_new_description_"}
],
)
stage = client.get_stage(restApiId=api_id, stageName=stage_name)
stage["stageName"].should.equal(stage_name)
stage["deploymentId"].should.equal(deployment_id)
stage["description"].should.equal("_new_description_")
@mock_apigateway @mock_apigateway
def test_create_domain_names(): def test_create_domain_names():
client = boto3.client("apigateway", region_name="us-west-2") client = boto3.client("apigateway", region_name="us-west-2")
@ -1908,11 +1463,9 @@ def test_create_api_key():
response = client.get_api_keys() response = client.get_api_keys()
len(response["items"]).should.equal(1) len(response["items"]).should.equal(1)
client.create_api_key.when.called_with(**payload).should.throw(ClientError)
@mock_apigateway @mock_apigateway
def test_create_api_headers(): def test_create_api_key_twice():
region_name = "us-west-2" region_name = "us-west-2"
client = boto3.client("apigateway", region_name=region_name) client = boto3.client("apigateway", region_name=region_name)
@ -1924,8 +1477,6 @@ def test_create_api_headers():
with pytest.raises(ClientError) as ex: with pytest.raises(ClientError) as ex:
client.create_api_key(**payload) client.create_api_key(**payload)
ex.value.response["Error"]["Code"].should.equal("ConflictException") ex.value.response["Error"]["Code"].should.equal("ConflictException")
if not settings.TEST_SERVER_MODE:
ex.value.response["ResponseMetadata"]["HTTPHeaders"].should.equal({})
@mock_apigateway @mock_apigateway
@ -2242,20 +1793,6 @@ def test_get_integration_response_unknown_response():
ex.value.response["Error"]["Code"].should.equal("NotFoundException") ex.value.response["Error"]["Code"].should.equal("NotFoundException")
@mock_apigateway
def test_delete_stage_unknown_stage():
client = boto3.client("apigateway", region_name="us-west-2")
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
with pytest.raises(ClientError) as ex:
client.delete_stage(restApiId=api_id, stageName="unknown")
ex.value.response["Error"]["Message"].should.equal(
"Invalid stage identifier specified"
)
ex.value.response["Error"]["Code"].should.equal("NotFoundException")
@mock_apigateway @mock_apigateway
def test_get_api_key_unknown_apikey(): def test_get_api_key_unknown_apikey():
client = boto3.client("apigateway", region_name="us-east-1") client = boto3.client("apigateway", region_name="us-east-1")

View File

@ -0,0 +1,215 @@
import boto3
import sure # noqa # pylint: disable=unused-import
from botocore.exceptions import ClientError
from moto import mock_apigateway
import pytest
from .test_apigateway import create_method_integration
@mock_apigateway
def test_create_deployment_requires_REST_methods():
client = boto3.client("apigateway", region_name="us-west-2")
stage_name = "staging"
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
with pytest.raises(ClientError) as ex:
client.create_deployment(restApiId=api_id, stageName=stage_name)["id"]
ex.value.response["Error"]["Code"].should.equal("BadRequestException")
ex.value.response["Error"]["Message"].should.equal(
"The REST API doesn't contain any methods"
)
@mock_apigateway
def test_create_deployment_requires_REST_method_integrations():
client = boto3.client("apigateway", region_name="us-west-2")
stage_name = "staging"
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
resources = client.get_resources(restApiId=api_id)
root_id = [r for r in resources["items"] if r["path"] == "/"][0]["id"]
client.put_method(
restApiId=api_id, resourceId=root_id, httpMethod="GET", authorizationType="NONE"
)
with pytest.raises(ClientError) as ex:
client.create_deployment(restApiId=api_id, stageName=stage_name)["id"]
ex.value.response["Error"]["Code"].should.equal("NotFoundException")
ex.value.response["Error"]["Message"].should.equal(
"No integration defined for method"
)
@mock_apigateway
def test_create_simple_deployment_with_get_method():
client = boto3.client("apigateway", region_name="us-west-2")
stage_name = "staging"
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
create_method_integration(client, api_id)
deployment = client.create_deployment(restApiId=api_id, stageName=stage_name)
deployment.should.have.key("id")
@mock_apigateway
def test_create_simple_deployment_with_post_method():
client = boto3.client("apigateway", region_name="us-west-2")
stage_name = "staging"
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
create_method_integration(client, api_id, httpMethod="POST")
deployment = client.create_deployment(restApiId=api_id, stageName=stage_name)
deployment.should.have.key("id")
@mock_apigateway
def test_create_deployment_minimal():
client = boto3.client("apigateway", region_name="us-west-2")
stage_name = "staging"
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
create_method_integration(client, api_id)
response = client.create_deployment(restApiId=api_id, stageName=stage_name)
deployment_id = response["id"]
response = client.get_deployment(restApiId=api_id, deploymentId=deployment_id)
response.should.have.key("id").equals(deployment_id)
response.should.have.key("ResponseMetadata").should.have.key(
"HTTPStatusCode"
).equals(200)
@mock_apigateway
def test_create_deployment_with_empty_stage():
client = boto3.client("apigateway", region_name="us-west-2")
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
create_method_integration(client, api_id)
response = client.create_deployment(restApiId=api_id, stageName="")
deployment_id = response["id"]
response = client.get_deployment(restApiId=api_id, deploymentId=deployment_id)
response.should.have.key("id")
response.should.have.key("createdDate")
response.shouldnt.have.key("stageName")
# This should not create an empty stage
stages = client.get_stages(restApiId=api_id)["item"]
stages.should.equal([])
@mock_apigateway
def test_get_deployments():
client = boto3.client("apigateway", region_name="us-west-2")
stage_name = "staging"
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
create_method_integration(client, api_id)
response = client.create_deployment(restApiId=api_id, stageName=stage_name)
deployment_id = response["id"]
response = client.get_deployments(restApiId=api_id)
response.should.have.key("items").length_of(1)
response["items"][0].pop("createdDate")
response["items"].should.equal([{"id": deployment_id}])
@mock_apigateway
def test_create_multiple_deployments():
client = boto3.client("apigateway", region_name="us-west-2")
stage_name = "staging"
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
create_method_integration(client, api_id)
response = client.create_deployment(restApiId=api_id, stageName=stage_name)
deployment_id = response["id"]
response = client.get_deployment(restApiId=api_id, deploymentId=deployment_id)
response = client.create_deployment(restApiId=api_id, stageName=stage_name)
deployment_id2 = response["id"]
response = client.get_deployments(restApiId=api_id)
response["items"][0]["id"].should.match(
r"{0}|{1}".format(deployment_id2, deployment_id)
)
response["items"][1]["id"].should.match(
r"{0}|{1}".format(deployment_id2, deployment_id)
)
@mock_apigateway
def test_delete_deployment__requires_stage_to_be_deleted():
client = boto3.client("apigateway", region_name="us-west-2")
stage_name = "staging"
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
create_method_integration(client, api_id)
response = client.create_deployment(restApiId=api_id, stageName=stage_name)
deployment_id = response["id"]
# Can't delete deployment immediately
with pytest.raises(ClientError) as exc:
client.delete_deployment(restApiId=api_id, deploymentId=deployment_id)
err = exc.value.response["Error"]
err["Code"].should.equal("BadRequestException")
err["Message"].should.equal(
"Active stages pointing to this deployment must be moved or deleted"
)
# Deployment still exists
deployments = client.get_deployments(restApiId=api_id)["items"]
deployments.should.have.length_of(1)
# Stage still exists
stages = client.get_stages(restApiId=api_id)["item"]
stages.should.have.length_of(1)
# Delete stage first
resp = client.delete_stage(restApiId=api_id, stageName=stage_name)
resp["ResponseMetadata"].should.have.key("HTTPStatusCode").equals(202)
# Deployment still exists
deployments = client.get_deployments(restApiId=api_id)["items"]
print(deployments)
deployments.should.have.length_of(1)
# Now delete deployment
resp = client.delete_deployment(restApiId=api_id, deploymentId=deployment_id)
resp["ResponseMetadata"].should.have.key("HTTPStatusCode").equals(202)
# Deployment is gone
deployments = client.get_deployments(restApiId=api_id)["items"]
deployments.should.have.length_of(0)
# Stage is gone
stages = client.get_stages(restApiId=api_id)["item"]
stages.should.have.length_of(0)
@mock_apigateway
def test_delete_unknown_deployment():
client = boto3.client("apigateway", region_name="us-west-2")
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
create_method_integration(client, api_id)
with pytest.raises(ClientError) as exc:
client.delete_deployment(restApiId=api_id, deploymentId="unknown")
err = exc.value.response["Error"]
err["Code"].should.equal("NotFoundException")
err["Message"].should.equal("Invalid Deployment identifier specified")

View File

@ -0,0 +1,144 @@
import boto3
import pytest
from botocore.exceptions import ClientError
from moto import mock_apigateway
@mock_apigateway
def test_put_gateway_response_minimal():
client = boto3.client("apigateway", region_name="us-east-2")
api_id = client.create_rest_api(name="my_api", description="d")["id"]
resp = client.put_gateway_response(restApiId=api_id, responseType="DEFAULT_4XX")
resp.should.have.key("responseType").equals("DEFAULT_4XX")
resp.should.have.key("defaultResponse").equals(False)
@mock_apigateway
def test_put_gateway_response():
client = boto3.client("apigateway", region_name="us-east-2")
api_id = client.create_rest_api(name="my_api", description="d")["id"]
resp = client.put_gateway_response(
restApiId=api_id,
responseType="DEFAULT_4XX",
statusCode="401",
responseParameters={"gatewayresponse.header.Authorization": "'Basic'"},
responseTemplates={
"application/xml": "#set($inputRoot = $input.path('$'))\n{ }"
},
)
resp.should.have.key("responseType").equals("DEFAULT_4XX")
resp.should.have.key("defaultResponse").equals(False)
resp.should.have.key("statusCode").equals("401")
resp.should.have.key("responseParameters").equals(
{"gatewayresponse.header.Authorization": "'Basic'"}
)
resp.should.have.key("responseTemplates").equals(
{"application/xml": "#set($inputRoot = $input.path('$'))\n{ }"}
)
@mock_apigateway
def test_get_gateway_response_minimal():
client = boto3.client("apigateway", region_name="ap-southeast-1")
api_id = client.create_rest_api(name="my_api", description="d")["id"]
client.put_gateway_response(restApiId=api_id, responseType="DEFAULT_4XX")
resp = client.get_gateway_response(restApiId=api_id, responseType="DEFAULT_4XX")
resp.should.have.key("responseType").equals("DEFAULT_4XX")
resp.should.have.key("defaultResponse").equals(False)
@mock_apigateway
def test_get_gateway_response():
client = boto3.client("apigateway", region_name="us-east-2")
api_id = client.create_rest_api(name="my_api", description="d")["id"]
client.put_gateway_response(
restApiId=api_id,
responseType="DEFAULT_4XX",
statusCode="401",
responseParameters={"gatewayresponse.header.Authorization": "'Basic'"},
responseTemplates={
"application/xml": "#set($inputRoot = $input.path('$'))\n{ }"
},
)
resp = client.get_gateway_response(restApiId=api_id, responseType="DEFAULT_4XX")
resp.should.have.key("responseType").equals("DEFAULT_4XX")
resp.should.have.key("defaultResponse").equals(False)
resp.should.have.key("statusCode").equals("401")
resp.should.have.key("responseParameters").equals(
{"gatewayresponse.header.Authorization": "'Basic'"}
)
resp.should.have.key("responseTemplates").equals(
{"application/xml": "#set($inputRoot = $input.path('$'))\n{ }"}
)
@mock_apigateway
def test_get_gateway_response_unknown():
client = boto3.client("apigateway", region_name="us-east-2")
api_id = client.create_rest_api(name="my_api", description="d")["id"]
with pytest.raises(ClientError) as exc:
client.get_gateway_response(restApiId=api_id, responseType="DEFAULT_4XX")
err = exc.value.response["Error"]
err["Code"].should.equal("NotFoundException")
@mock_apigateway
def test_get_gateway_responses_empty():
client = boto3.client("apigateway", region_name="ap-southeast-1")
api_id = client.create_rest_api(name="my_api", description="d")["id"]
resp = client.get_gateway_responses(restApiId=api_id)
resp.should.have.key("items").equals([])
@mock_apigateway
def test_get_gateway_responses():
client = boto3.client("apigateway", region_name="ap-southeast-1")
api_id = client.create_rest_api(name="my_api", description="d")["id"]
client.put_gateway_response(restApiId=api_id, responseType="DEFAULT_4XX")
client.put_gateway_response(
restApiId=api_id, responseType="DEFAULT_5XX", statusCode="503"
)
resp = client.get_gateway_responses(restApiId=api_id)
resp.should.have.key("items").length_of(2)
resp["items"].should.contain(
{"responseType": "DEFAULT_4XX", "defaultResponse": False}
)
resp["items"].should.contain(
{"responseType": "DEFAULT_5XX", "defaultResponse": False, "statusCode": "503"}
)
@mock_apigateway
def test_delete_gateway_response():
client = boto3.client("apigateway", region_name="ap-southeast-1")
api_id = client.create_rest_api(name="my_api", description="d")["id"]
client.put_gateway_response(restApiId=api_id, responseType="DEFAULT_4XX")
client.put_gateway_response(
restApiId=api_id, responseType="DEFAULT_5XX", statusCode="503"
)
resp = client.get_gateway_responses(restApiId=api_id)
resp.should.have.key("items").length_of(2)
resp = client.delete_gateway_response(restApiId=api_id, responseType="DEFAULT_5XX")
resp = client.get_gateway_responses(restApiId=api_id)
resp.should.have.key("items").length_of(1)
resp["items"].should.contain(
{"responseType": "DEFAULT_4XX", "defaultResponse": False}
)

View File

@ -0,0 +1,509 @@
import boto3
import pytest
import sure # noqa # pylint: disable=unused-import
from botocore.exceptions import ClientError
from moto import mock_apigateway
from .test_apigateway import create_method_integration
@mock_apigateway
def test_create_stage_minimal():
client = boto3.client("apigateway", region_name="us-west-2")
stage_name = "staging"
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
create_method_integration(client, api_id)
response = client.create_deployment(restApiId=api_id, stageName=stage_name)
deployment_id = response["id"]
new_stage_name = "current"
response = client.create_stage(
restApiId=api_id, stageName=new_stage_name, deploymentId=deployment_id
)
response.should.have.key("stageName").equals(new_stage_name)
response.should.have.key("deploymentId").equals(deployment_id)
response.should.have.key("methodSettings").equals({})
response.should.have.key("variables").equals({})
response.should.have.key("ResponseMetadata").should.have.key(
"HTTPStatusCode"
).equals(200)
response.should.have.key("description").equals("")
response.shouldnt.have.key("cacheClusterStatus")
response.should.have.key("cacheClusterEnabled").equals(False)
stage = client.get_stage(restApiId=api_id, stageName=new_stage_name)
stage["stageName"].should.equal(new_stage_name)
stage["deploymentId"].should.equal(deployment_id)
@mock_apigateway
def test_create_stage_with_env_vars():
client = boto3.client("apigateway", region_name="us-west-2")
stage_name = "staging"
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
create_method_integration(client, api_id)
response = client.create_deployment(restApiId=api_id, stageName=stage_name)
deployment_id = response["id"]
new_stage_name_with_vars = "stage_with_vars"
response = client.create_stage(
restApiId=api_id,
stageName=new_stage_name_with_vars,
deploymentId=deployment_id,
variables={"env": "dev"},
)
response.should.have.key("stageName").equals(new_stage_name_with_vars)
response.should.have.key("deploymentId").equals(deployment_id)
response.should.have.key("methodSettings").equals({})
response.should.have.key("variables").equals({"env": "dev"})
response.should.have.key("ResponseMetadata").should.have.key(
"HTTPStatusCode"
).equals(200)
response.should.have.key("description").equals("")
response.shouldnt.have.key("cacheClusterStatus")
response.should.have.key("cacheClusterEnabled").equals(False)
stage = client.get_stage(restApiId=api_id, stageName=new_stage_name_with_vars)
stage["stageName"].should.equal(new_stage_name_with_vars)
stage["deploymentId"].should.equal(deployment_id)
stage["variables"].should.have.key("env").which.should.match("dev")
@mock_apigateway
def test_create_stage_with_vars_and_cache():
client = boto3.client("apigateway", region_name="us-west-2")
stage_name = "staging"
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
create_method_integration(client, api_id)
response = client.create_deployment(restApiId=api_id, stageName=stage_name)
deployment_id = response["id"]
new_stage_name = "stage_with_vars_and_cache_settings"
response = client.create_stage(
restApiId=api_id,
stageName=new_stage_name,
deploymentId=deployment_id,
variables={"env": "dev"},
cacheClusterEnabled=True,
description="hello moto",
)
response.should.have.key("stageName").equals(new_stage_name)
response.should.have.key("deploymentId").equals(deployment_id)
response.should.have.key("methodSettings").equals({})
response.should.have.key("variables").equals({"env": "dev"})
response.should.have.key("ResponseMetadata").should.have.key(
"HTTPStatusCode"
).equals(200)
response.should.have.key("description").equals("hello moto")
response.should.have.key("cacheClusterStatus").equals("AVAILABLE")
response.should.have.key("cacheClusterEnabled").equals(True)
response.should.have.key("cacheClusterSize").equals("0.5")
stage = client.get_stage(restApiId=api_id, stageName=new_stage_name)
stage["cacheClusterSize"].should.equal("0.5")
@mock_apigateway
def test_create_stage_with_cache_settings():
client = boto3.client("apigateway", region_name="us-west-2")
stage_name = "staging"
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
create_method_integration(client, api_id)
response = client.create_deployment(restApiId=api_id, stageName=stage_name)
deployment_id = response["id"]
new_stage_name = "stage_with_vars_and_cache_settings_and_size"
response = client.create_stage(
restApiId=api_id,
stageName=new_stage_name,
deploymentId=deployment_id,
variables={"env": "dev"},
cacheClusterEnabled=True,
cacheClusterSize="1.6",
tracingEnabled=True,
description="hello moto",
)
response.should.have.key("stageName").equals(new_stage_name)
response.should.have.key("deploymentId").equals(deployment_id)
response.should.have.key("methodSettings").equals({})
response.should.have.key("variables").equals({"env": "dev"})
response.should.have.key("ResponseMetadata").should.have.key(
"HTTPStatusCode"
).equals(200)
response.should.have.key("description").equals("hello moto")
response.should.have.key("cacheClusterStatus").equals("AVAILABLE")
response.should.have.key("cacheClusterEnabled").equals(True)
response.should.have.key("cacheClusterSize").equals("1.6")
response.should.have.key("tracingEnabled").equals(True)
stage = client.get_stage(restApiId=api_id, stageName=new_stage_name)
stage["stageName"].should.equal(new_stage_name)
stage["deploymentId"].should.equal(deployment_id)
stage["variables"].should.have.key("env").which.should.match("dev")
stage["cacheClusterSize"].should.equal("1.6")
@mock_apigateway
def test_recreate_stage_from_deployment():
client = boto3.client("apigateway", region_name="us-west-2")
stage_name = "staging"
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
create_method_integration(client, api_id)
depl_id1 = client.create_deployment(restApiId=api_id, stageName=stage_name)["id"]
with pytest.raises(ClientError) as exc:
client.create_stage(
restApiId=api_id, stageName=stage_name, deploymentId=depl_id1
)
err = exc.value.response["Error"]
err["Code"].should.equal("ConflictException")
err["Message"].should.equal("Stage already exists")
@mock_apigateway
def test_create_stage_twice():
client = boto3.client("apigateway", region_name="us-west-2")
stage_name = "staging"
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
create_method_integration(client, api_id)
depl_id1 = client.create_deployment(restApiId=api_id, stageName=stage_name)["id"]
new_stage_name = "current"
client.create_stage(
restApiId=api_id, stageName=new_stage_name, deploymentId=depl_id1
)
with pytest.raises(ClientError) as exc:
client.create_stage(
restApiId=api_id, stageName=new_stage_name, deploymentId=depl_id1
)
err = exc.value.response["Error"]
err["Code"].should.equal("ConflictException")
err["Message"].should.equal("Stage already exists")
@mock_apigateway
def test_delete_stage():
client = boto3.client("apigateway", region_name="us-west-2")
stage_name = "staging"
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
create_method_integration(client, api_id)
depl_id1 = client.create_deployment(restApiId=api_id, stageName=stage_name)["id"]
depl_id2 = client.create_deployment(restApiId=api_id, stageName=stage_name)["id"]
new_stage_name = "current"
client.create_stage(
restApiId=api_id, stageName=new_stage_name, deploymentId=depl_id1
)
new_stage_name_with_vars = "stage_with_vars"
client.create_stage(
restApiId=api_id,
stageName=new_stage_name_with_vars,
deploymentId=depl_id2,
variables={"env": "dev"},
)
stages = client.get_stages(restApiId=api_id)["item"]
stage_names = [stage["stageName"] for stage in stages]
stage_names.should.have.length_of(3)
stage_names.should.contain(stage_name)
stage_names.should.contain(new_stage_name)
stage_names.should.contain(new_stage_name_with_vars)
# delete stage
response = client.delete_stage(restApiId=api_id, stageName=new_stage_name_with_vars)
response["ResponseMetadata"]["HTTPStatusCode"].should.equal(202)
# verify other stage still exists
stages = client.get_stages(restApiId=api_id)["item"]
stage_names = [stage["stageName"] for stage in stages]
stage_names.should.have.length_of(2)
stage_names.shouldnt.contain(new_stage_name_with_vars)
@mock_apigateway
def test_delete_stage_created_by_deployment():
client = boto3.client("apigateway", region_name="us-west-2")
stage_name = "staging"
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
create_method_integration(client, api_id)
depl_id1 = client.create_deployment(restApiId=api_id, stageName=stage_name)["id"]
# Sanity check that the deployment exists
depls = client.get_deployments(restApiId=api_id)["items"]
depls.should.have.length_of(1)
set(depls[0].keys()).should.equal({"id", "createdDate"})
# Sanity check that the stage exists
stage = client.get_stages(restApiId=api_id)["item"][0]
stage.should.have.key("deploymentId").equals(depl_id1)
stage.should.have.key("stageName").equals(stage_name)
# delete stage
response = client.delete_stage(restApiId=api_id, stageName=stage_name)
response["ResponseMetadata"]["HTTPStatusCode"].should.equal(202)
# verify no stage exists
stages = client.get_stages(restApiId=api_id)
stages.should.have.key("item").equals([])
# verify deployment still exists, unchanged
depls = client.get_deployments(restApiId=api_id)["items"]
depls.should.have.length_of(1)
set(depls[0].keys()).should.equal({"id", "createdDate"})
@mock_apigateway
def test_delete_stage_unknown_stage():
client = boto3.client("apigateway", region_name="us-west-2")
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
with pytest.raises(ClientError) as exc:
client.delete_stage(restApiId=api_id, stageName="unknown")
err = exc.value.response["Error"]
err["Message"].should.equal("Invalid stage identifier specified")
err["Code"].should.equal("NotFoundException")
@mock_apigateway
def test_update_stage_configuration():
client = boto3.client("apigateway", region_name="us-west-2")
stage_name = "staging"
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
create_method_integration(client, api_id)
response = client.create_deployment(
restApiId=api_id, stageName=stage_name, description="1.0.1"
)
deployment_id = response["id"]
response = client.get_deployment(restApiId=api_id, deploymentId=deployment_id)
response.should.have.key("id").equals(deployment_id)
response.should.have.key("ResponseMetadata").should.have.key(
"HTTPStatusCode"
).equals(200)
response.should.have.key("description").equals("1.0.1")
response = client.create_deployment(
restApiId=api_id, stageName=stage_name, description="1.0.2"
)
deployment_id2 = response["id"]
stage = client.get_stage(restApiId=api_id, stageName=stage_name)
stage["stageName"].should.equal(stage_name)
stage["deploymentId"].should.equal(deployment_id2)
stage.shouldnt.have.key("cacheClusterSize")
stage.shouldnt.have.key("cacheClusterStatus")
client.update_stage(
restApiId=api_id,
stageName=stage_name,
patchOperations=[
{"op": "replace", "path": "/cacheClusterEnabled", "value": "True"}
],
)
stage = client.get_stage(restApiId=api_id, stageName=stage_name)
stage.should.have.key("cacheClusterSize").which.should.equal("0.5")
stage.should.have.key("cacheClusterStatus").equals("AVAILABLE")
client.update_stage(
restApiId=api_id,
stageName=stage_name,
patchOperations=[
{"op": "replace", "path": "/cacheClusterSize", "value": "1.6"}
],
)
stage = client.get_stage(restApiId=api_id, stageName=stage_name)
stage.should.have.key("cacheClusterSize").which.should.equal("1.6")
client.update_stage(
restApiId=api_id,
stageName=stage_name,
patchOperations=[
{"op": "replace", "path": "/deploymentId", "value": deployment_id},
{"op": "replace", "path": "/variables/environment", "value": "dev"},
{"op": "replace", "path": "/variables/region", "value": "eu-west-1"},
{"op": "replace", "path": "/*/*/caching/dataEncrypted", "value": "True"},
{"op": "replace", "path": "/cacheClusterEnabled", "value": "True"},
{
"op": "replace",
"path": "/description",
"value": "stage description update",
},
{"op": "replace", "path": "/cacheClusterSize", "value": "1.6"},
],
)
client.update_stage(
restApiId=api_id,
stageName=stage_name,
patchOperations=[
{"op": "remove", "path": "/variables/region", "value": "eu-west-1"}
],
)
stage = client.get_stage(restApiId=api_id, stageName=stage_name)
stage["description"].should.match("stage description update")
stage["cacheClusterSize"].should.equal("1.6")
stage["variables"]["environment"].should.match("dev")
stage["variables"].should_not.have.key("region")
stage["cacheClusterEnabled"].should.be.true
stage["deploymentId"].should.match(deployment_id)
stage["methodSettings"].should.have.key("*/*")
stage["methodSettings"]["*/*"].should.have.key(
"cacheDataEncrypted"
).which.should.be.true
@mock_apigateway
def test_update_stage_add_access_log_settings():
client = boto3.client("apigateway", region_name="us-west-2")
stage_name = "staging"
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
create_method_integration(client, api_id)
client.create_deployment(
restApiId=api_id, stageName=stage_name, description="1.0.1"
)
client.update_stage(
restApiId=api_id,
stageName=stage_name,
patchOperations=[
{
"op": "replace",
"path": "/accessLogSettings/destinationArn",
"value": "arn:aws:logs:us-east-1:123456789012:log-group:foo-bar-x0hyv",
},
{
"op": "replace",
"path": "/accessLogSettings/format",
"value": "$context.identity.sourceIp msg",
},
],
)
stage = client.get_stage(restApiId=api_id, stageName=stage_name)
stage.should.have.key("accessLogSettings").equals(
{
"format": "$context.identity.sourceIp msg",
"destinationArn": "arn:aws:logs:us-east-1:123456789012:log-group:foo-bar-x0hyv",
}
)
@mock_apigateway
def test_update_stage_tracing_disabled():
client = boto3.client("apigateway", region_name="us-west-2")
stage_name = "staging"
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
create_method_integration(client, api_id)
response = client.create_deployment(restApiId=api_id, stageName=stage_name)
client.update_stage(
restApiId=api_id,
stageName=stage_name,
patchOperations=[
{"op": "replace", "path": "/tracingEnabled", "value": "false"}
],
)
stage = client.get_stage(restApiId=api_id, stageName=stage_name)
stage.should.have.key("tracingEnabled").equals(False)
client.update_stage(
restApiId=api_id,
stageName=stage_name,
patchOperations=[{"op": "replace", "path": "/tracingEnabled", "value": "true"}],
)
stage = client.get_stage(restApiId=api_id, stageName=stage_name)
stage.should.have.key("tracingEnabled").equals(True)
@mock_apigateway
def test_update_stage_remove_access_log_settings():
client = boto3.client("apigateway", region_name="us-west-2")
stage_name = "staging"
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
create_method_integration(client, api_id)
client.create_deployment(
restApiId=api_id, stageName=stage_name, description="1.0.1"
)
client.update_stage(
restApiId=api_id,
stageName=stage_name,
patchOperations=[{"op": "remove", "path": "/accessLogSettings"}],
)
stage = client.get_stage(restApiId=api_id, stageName=stage_name)
stage.shouldnt.have.key("accessLogSettings")
@mock_apigateway
def test_update_stage_configuration_unknown_operation():
client = boto3.client("apigateway", region_name="us-west-2")
stage_name = "staging"
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
create_method_integration(client, api_id)
client.create_deployment(
restApiId=api_id, stageName=stage_name, description="1.0.1"
)
with pytest.raises(ClientError) as exc:
client.update_stage(
restApiId=api_id,
stageName=stage_name,
patchOperations=[
{"op": "unknown_op", "path": "/notasetting", "value": "eu-west-1"}
],
)
err = exc.value.response["Error"]
err["Code"].should.equal("ValidationException")
err["Message"].should.equal(
"Member must satisfy enum value set: [add, remove, move, test, replace, copy]"
)
@mock_apigateway
def test_non_existent_stage():
client = boto3.client("apigateway", region_name="us-west-2")
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
with pytest.raises(ClientError) as exc:
client.get_stage(restApiId=api_id, stageName="xxx")
err = exc.value.response["Error"]
err["Code"].should.equal("NotFoundException")

View File

@ -248,7 +248,7 @@ def test_describe_delivery_stream():
assert description["DeliveryStreamName"] == stream_name assert description["DeliveryStreamName"] == stream_name
assert ( assert (
description["DeliveryStreamARN"] description["DeliveryStreamARN"]
== f"arn:aws:firehose:{TEST_REGION}:{ACCOUNT_ID}:/delivery_stream/{stream_name}" == f"arn:aws:firehose:{TEST_REGION}:{ACCOUNT_ID}:deliverystream/{stream_name}"
) )
assert description["DeliveryStreamStatus"] == "ACTIVE" assert description["DeliveryStreamStatus"] == "ACTIVE"
assert description["DeliveryStreamType"] == "KinesisStreamAsSource" assert description["DeliveryStreamType"] == "KinesisStreamAsSource"
@ -281,7 +281,7 @@ def test_describe_delivery_stream():
assert description["DeliveryStreamName"] == stream_name assert description["DeliveryStreamName"] == stream_name
assert ( assert (
description["DeliveryStreamARN"] description["DeliveryStreamARN"]
== f"arn:aws:firehose:{TEST_REGION}:{ACCOUNT_ID}:/delivery_stream/{stream_name}" == f"arn:aws:firehose:{TEST_REGION}:{ACCOUNT_ID}:deliverystream/{stream_name}"
) )
assert description["DeliveryStreamStatus"] == "ACTIVE" assert description["DeliveryStreamStatus"] == "ACTIVE"
assert description["DeliveryStreamType"] == "KinesisStreamAsSource" assert description["DeliveryStreamType"] == "KinesisStreamAsSource"