diff --git a/moto/apigateway/exceptions.py b/moto/apigateway/exceptions.py index 5df362a20..fc6da5f1f 100644 --- a/moto/apigateway/exceptions.py +++ b/moto/apigateway/exceptions.py @@ -192,6 +192,15 @@ class RestAPINotFound(NotFoundException): ) +class RequestValidatorNotFound(BadRequestException): + code = 400 + + def __init__(self): + super(RequestValidatorNotFound, self).__init__( + "NotFoundException", "Invalid Request Validator Id specified" + ) + + class ModelNotFound(NotFoundException): code = 404 diff --git a/moto/apigateway/models.py b/moto/apigateway/models.py index cebf6df5a..e55fdbafd 100644 --- a/moto/apigateway/models.py +++ b/moto/apigateway/models.py @@ -17,7 +17,7 @@ except ImportError: from urllib.parse import urlparse import responses from moto.core import ACCOUNT_ID, BaseBackend, BaseModel, CloudFormationModel -from .utils import create_id +from .utils import create_id, to_path from moto.core.utils import path_url from .exceptions import ( ApiKeyNotFoundException, @@ -41,6 +41,7 @@ from .exceptions import ( InvalidRestApiId, InvalidModelName, RestAPINotFound, + RequestValidatorNotFound, ModelNotFound, ApiKeyValueMinLength, ) @@ -652,6 +653,52 @@ class UsagePlan(BaseModel, dict): self["throttle"]["burstLimit"] = value +class RequestValidator(BaseModel, dict): + PROP_ID = "id" + PROP_NAME = "name" + PROP_VALIDATE_REQUEST_BODY = "validateRequestBody" + PROP_VALIDATE_REQUEST_PARAMETERS = "validateRequestParameters" + + # operations + OP_PATH = "path" + OP_VALUE = "value" + OP_REPLACE = "replace" + OP_OP = "op" + + def __init__(self, id, name, validateRequestBody, validateRequestParameters): + super(RequestValidator, self).__init__() + self[RequestValidator.PROP_ID] = id + self[RequestValidator.PROP_NAME] = name + self[RequestValidator.PROP_VALIDATE_REQUEST_BODY] = validateRequestBody + self[ + RequestValidator.PROP_VALIDATE_REQUEST_PARAMETERS + ] = validateRequestParameters + + def apply_patch_operations(self, operations): + for operation in operations: + path = operation[RequestValidator.OP_PATH] + value = operation[RequestValidator.OP_VALUE] + if operation[RequestValidator.OP_OP] == RequestValidator.OP_REPLACE: + if to_path(RequestValidator.PROP_NAME) in path: + self[RequestValidator.PROP_NAME] = value + if to_path(RequestValidator.PROP_VALIDATE_REQUEST_BODY) in path: + self[ + RequestValidator.PROP_VALIDATE_REQUEST_BODY + ] = value.lower() in ("true") + if to_path(RequestValidator.PROP_VALIDATE_REQUEST_PARAMETERS) in path: + self[ + RequestValidator.PROP_VALIDATE_REQUEST_PARAMETERS + ] = value.lower() in ("true") + + def to_dict(self): + return { + "id": self["id"], + "name": self["name"], + "validateRequestBody": self["validateRequestBody"], + "validateRequestParameters": self["validateRequestParameters"], + } + + class UsagePlanKey(BaseModel, dict): def __init__(self, id, type, name, value): super(UsagePlanKey, self).__init__() @@ -708,6 +755,7 @@ class RestAPI(CloudFormationModel): self.stages = {} self.resources = {} self.models = {} + self.request_validators = {} self.add_child("/") # Add default child def __repr__(self): @@ -948,6 +996,36 @@ class RestAPI(CloudFormationModel): def delete_deployment(self, deployment_id): return self.deployments.pop(deployment_id) + def create_request_validator( + self, name, validateRequestBody, validateRequestParameters + ): + validator_id = create_id() + request_validator = RequestValidator( + id=validator_id, + name=name, + validateRequestBody=validateRequestBody, + validateRequestParameters=validateRequestParameters, + ) + self.request_validators[validator_id] = request_validator + return request_validator + + def get_request_validators(self): + return list(self.request_validators.values()) + + def get_request_validator(self, validator_id): + reqeust_validator = self.request_validators.get(validator_id) + if reqeust_validator is None: + raise RequestValidatorNotFound() + return reqeust_validator + + def delete_request_validator(self, validator_id): + reqeust_validator = self.request_validators.pop(validator_id) + return reqeust_validator + + def update_request_validator(self, validator_id, patch_operations): + self.request_validators[validator_id].apply_patch_operations(patch_operations) + return self.request_validators[validator_id] + class DomainName(BaseModel, dict): def __init__(self, domain_name, **kwargs): diff --git a/moto/apigateway/responses.py b/moto/apigateway/responses.py index ad24706e2..4b1af8e66 100644 --- a/moto/apigateway/responses.py +++ b/moto/apigateway/responses.py @@ -306,6 +306,58 @@ class APIGatewayResponse(BaseResponse): return 200, {}, json.dumps(authorizer_response) + def request_validators(self, request, full_url, headers): + self.setup_class(request, full_url, headers) + url_path_parts = self.path.split("/") + restapi_id = url_path_parts[2] + try: + restApi = self.backend.get_rest_api(restapi_id) + if self.method == "GET": + validators = restApi.get_request_validators() + res = json.dumps( + {"item": [validator.to_dict() for validator in validators]} + ) + return 200, {}, res + if self.method == "POST": + name = self._get_param("name") + validateRequestBody = self._get_bool_param("validateRequestBody") + validateRequestParameters = self._get_bool_param( + "validateRequestParameters" + ) + validator = restApi.create_request_validator( + name=name, + validateRequestBody=validateRequestBody, + validateRequestParameters=validateRequestParameters, + ) + return 200, {}, json.dumps(validator) + except BadRequestException as e: + return self.error("BadRequestException", e.message) + except CrossAccountNotAllowed as e: + return self.error("AccessDeniedException", e.message) + + def request_validator_individual(self, request, full_url, headers): + self.setup_class(request, full_url, headers) + url_path_parts = self.path.split("/") + restapi_id = url_path_parts[2] + validator_id = url_path_parts[4] + try: + restApi = self.backend.get_rest_api(restapi_id) + if self.method == "GET": + return 200, {}, json.dumps(restApi.get_request_validator(validator_id)) + if self.method == "DELETE": + restApi.delete_request_validator(validator_id) + return 202, {}, "" + if self.method == "PATCH": + patch_operations = self._get_param("patchOperations") + validator = restApi.update_request_validator( + validator_id, patch_operations + ) + return 200, {}, json.dumps(validator) + except BadRequestException as e: + return self.error("BadRequestException", e.message) + except CrossAccountNotAllowed as e: + return self.error("AccessDeniedException", e.message) + def authorizers(self, request, full_url, headers): self.setup_class(request, full_url, headers) url_path_parts = self.path.split("/") diff --git a/moto/apigateway/urls.py b/moto/apigateway/urls.py index b6cc567a3..a6d1e87c7 100644 --- a/moto/apigateway/urls.py +++ b/moto/apigateway/urls.py @@ -31,4 +31,6 @@ url_paths = { "{0}/usageplans/(?P[^/]+)/?$": response.usage_plan_individual, "{0}/usageplans/(?P[^/]+)/keys$": response.usage_plan_keys, "{0}/usageplans/(?P[^/]+)/keys/(?P[^/]+)/?$": response.usage_plan_key_individual, + "{0}/restapis/(?P[^/]+)/requestvalidators$": response.request_validators, + "{0}/restapis/(?P[^/]+)/requestvalidators/(?P[^/]+)/?$": response.request_validator_individual, } diff --git a/moto/apigateway/utils.py b/moto/apigateway/utils.py index d583f64a1..acd492525 100644 --- a/moto/apigateway/utils.py +++ b/moto/apigateway/utils.py @@ -7,3 +7,7 @@ def create_id(): size = 10 chars = list(range(10)) + list(string.ascii_lowercase) return "".join(str(random.choice(chars)) for x in range(size)) + + +def to_path(prop): + return "/" + prop diff --git a/moto/core/models.py b/moto/core/models.py index 3857908f7..3a487f34e 100644 --- a/moto/core/models.py +++ b/moto/core/models.py @@ -375,7 +375,6 @@ class BotocoreStubber: def __call__(self, event_name, request, **kwargs): if not self.enabled: return None - response = None response_callback = None found_index = None diff --git a/tests/test_apigateway/test_apigateway_validators.py b/tests/test_apigateway/test_apigateway_validators.py new file mode 100644 index 000000000..4413907ea --- /dev/null +++ b/tests/test_apigateway/test_apigateway_validators.py @@ -0,0 +1,170 @@ +import boto3 +import sure # noqa +from moto import mock_apigateway +from moto.apigateway.exceptions import RequestValidatorNotFound +from botocore.exceptions import ClientError +import pytest + +ID = "id" +NAME = "name" +VALIDATE_REQUEST_BODY = "validateRequestBody" +VALIDATE_REQUEST_PARAMETERS = "validateRequestParameters" +PARAM_NAME = "my-validator" +RESPONSE_METADATA = "ResponseMetadata" + + +@mock_apigateway +def test_create_request_validator(): + client = create_client() + api_id = create_rest_api_id(client) + response = create_validator(client, api_id) + response.pop(RESPONSE_METADATA) + response.pop(ID) + response.should.equal( + { + NAME: PARAM_NAME, + VALIDATE_REQUEST_BODY: True, + VALIDATE_REQUEST_PARAMETERS: True, + } + ) + + +@mock_apigateway +def test_get_request_validators(): + + client = create_client() + api_id = create_rest_api_id(client) + response = client.get_request_validators(restApiId=api_id) + + validators = response["items"] + validators.should.have.length_of(0) + + response.pop(RESPONSE_METADATA) + response.should.equal({"items": []}) + + response = create_validator(client, api_id) + validator_id1 = response[ID] + response = create_validator(client, api_id) + validator_id2 = response[ID] + response = client.get_request_validators(restApiId=api_id) + + validators = response["items"] + validators.should.have.length_of(2) + + response.pop(RESPONSE_METADATA) + response.should.equal( + { + "items": [ + { + ID: validator_id1, + NAME: PARAM_NAME, + VALIDATE_REQUEST_BODY: True, + VALIDATE_REQUEST_PARAMETERS: True, + }, + { + ID: validator_id2, + NAME: PARAM_NAME, + VALIDATE_REQUEST_BODY: True, + VALIDATE_REQUEST_PARAMETERS: True, + }, + ] + } + ) + + +@mock_apigateway +def test_get_request_validator(): + client = create_client() + api_id = create_rest_api_id(client) + response = create_validator(client, api_id) + validator_id = response[ID] + response = client.get_request_validator( + restApiId=api_id, requestValidatorId=validator_id + ) + response.pop(RESPONSE_METADATA) + response.should.equal( + { + ID: validator_id, + NAME: PARAM_NAME, + VALIDATE_REQUEST_BODY: True, + VALIDATE_REQUEST_PARAMETERS: True, + } + ) + + +@mock_apigateway +def test_delete_request_validator(): + client = create_client() + api_id = create_rest_api_id(client) + response = create_validator(client, api_id) + # test get single validator by + validator_id = response[ID] + response = client.get_request_validator( + restApiId=api_id, requestValidatorId=validator_id + ) + + response.pop(RESPONSE_METADATA) + response.should.equal( + { + ID: validator_id, + NAME: PARAM_NAME, + VALIDATE_REQUEST_BODY: True, + VALIDATE_REQUEST_PARAMETERS: True, + } + ) + + # delete validator + response = client.delete_request_validator( + restApiId=api_id, requestValidatorId=validator_id + ) + with pytest.raises(ClientError) as ex: + client.get_request_validator(restApiId=api_id, requestValidatorId=validator_id) + err = ex.value.response["Error"] + err["Code"].should.equal("BadRequestException") + err["Message"].should.equal("Invalid Request Validator Id specified") + + +@mock_apigateway +def test_update_request_validator(): + client = create_client() + api_id = create_rest_api_id(client) + response = create_validator(client, api_id) + + validator_id = response[ID] + response = client.update_request_validator( + restApiId=api_id, + requestValidatorId=validator_id, + patchOperations=[ + {"op": "replace", "path": "/name", "value": PARAM_NAME + PARAM_NAME}, + {"op": "replace", "path": "/validateRequestBody", "value": "False"}, + {"op": "replace", "path": "/validateRequestParameters", "value": "False"}, + ], + ) + response.pop(RESPONSE_METADATA) + response.should.equal( + { + ID: validator_id, + NAME: PARAM_NAME + PARAM_NAME, + VALIDATE_REQUEST_BODY: False, + VALIDATE_REQUEST_PARAMETERS: False, + } + ) + + +def create_validator(client, api_id): + response = client.create_request_validator( + restApiId=api_id, + name=PARAM_NAME, + validateRequestBody=True, + validateRequestParameters=True, + ) + return response + + +def create_client(): + return boto3.client("apigateway", region_name="us-west-2") + + +def create_rest_api_id(client): + response = client.create_rest_api(name="my_api", description="this is my api") + return response[ID]