From 2f50f9cb240197d5375023fb9d22f46b986b55b7 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Mon, 15 Feb 2021 09:39:35 +0000 Subject: [PATCH] Add validation to API Key min length value (#3652) * api gateway - add api key minimum length validation * api gateway - support includeValue query parameter on api key apis * [apigateway] code refactoring * Cleanup - remove duplicate utility methods * APIGateway - Dont send headers with error messsage Co-authored-by: Jovan Zivanov --- moto/apigateway/exceptions.py | 9 +++ moto/apigateway/models.py | 32 +++++++-- moto/apigateway/responses.py | 66 +++++++++--------- tests/test_apigateway/test_apigateway.py | 88 ++++++++++++++++++++++-- 4 files changed, 152 insertions(+), 43 deletions(-) diff --git a/moto/apigateway/exceptions.py b/moto/apigateway/exceptions.py index 4d3475d0e..2ed24ea2f 100644 --- a/moto/apigateway/exceptions.py +++ b/moto/apigateway/exceptions.py @@ -182,3 +182,12 @@ class ModelNotFound(RESTError): super(ModelNotFound, self).__init__( "NotFoundException", "Invalid Model Name specified" ) + + +class ApiKeyValueMinLength(RESTError): + code = 400 + + def __init__(self): + super(ApiKeyValueMinLength, self).__init__( + "BadRequestException", "API Key value should be at least 20 characters" + ) diff --git a/moto/apigateway/models.py b/moto/apigateway/models.py index c8d2ae5f2..b6a8847db 100644 --- a/moto/apigateway/models.py +++ b/moto/apigateway/models.py @@ -4,6 +4,8 @@ from __future__ import unicode_literals import random import string import re +from copy import copy + import requests import time @@ -40,6 +42,7 @@ from .exceptions import ( InvalidModelName, RestAPINotFound, ModelNotFound, + ApiKeyValueMinLength, ) from ..core.models import responses_mock @@ -1026,18 +1029,37 @@ class APIGatewayBackend(BaseBackend): def create_apikey(self, payload): if payload.get("value") is not None: - for api_key in self.get_apikeys(): + if len(payload.get("value", [])) < 20: + raise ApiKeyValueMinLength() + for api_key in self.get_apikeys(include_values=True): if api_key.get("value") == payload["value"]: raise ApiKeyAlreadyExists() key = ApiKey(**payload) self.keys[key["id"]] = key return key - def get_apikeys(self): - return list(self.keys.values()) + def get_apikeys(self, include_values=False): + api_keys = list(self.keys.values()) - def get_apikey(self, api_key_id): - return self.keys[api_key_id] + if not include_values: + keys = [] + for api_key in list(self.keys.values()): + new_key = copy(api_key) + del new_key["value"] + keys.append(new_key) + api_keys = keys + + return api_keys + + def get_apikey(self, api_key_id, include_value=False): + api_key = self.keys[api_key_id] + + if not include_value: + new_key = copy(api_key) + del new_key["value"] + api_key = new_key + + return api_key def update_apikey(self, api_key_id, patch_operations): key = self.keys[api_key_id] diff --git a/moto/apigateway/responses.py b/moto/apigateway/responses.py index 0454ae58e..5d4e9a1a4 100644 --- a/moto/apigateway/responses.py +++ b/moto/apigateway/responses.py @@ -18,6 +18,7 @@ from .exceptions import ( InvalidModelName, RestAPINotFound, ModelNotFound, + ApiKeyValueMinLength, ) API_KEY_SOURCES = ["AUTHORIZER", "HEADER"] @@ -33,17 +34,6 @@ class APIGatewayResponse(BaseResponse): json.dumps({"__type": type_, "message": message}), ) - def _get_param(self, key): - return json.loads(self.body).get(key) if self.body else None - - def _get_param_with_default_value(self, key, default): - jsonbody = json.loads(self.body) - - if key in jsonbody: - return jsonbody.get(key) - else: - return default - @property def backend(self): return apigateway_backends[self.region] @@ -197,18 +187,16 @@ class APIGatewayResponse(BaseResponse): name = self._get_param("name") authorizer_type = self._get_param("type") - provider_arns = self._get_param_with_default_value("providerARNs", None) - auth_type = self._get_param_with_default_value("authType", None) - authorizer_uri = self._get_param_with_default_value("authorizerUri", None) - authorizer_credentials = self._get_param_with_default_value( - "authorizerCredentials", None + provider_arns = self._get_param("providerARNs") + auth_type = self._get_param("authType") + authorizer_uri = self._get_param("authorizerUri") + authorizer_credentials = self._get_param("authorizerCredentials") + identity_source = self._get_param("identitySource") + identiy_validation_expression = self._get_param( + "identityValidationExpression" ) - identity_source = self._get_param_with_default_value("identitySource", None) - identiy_validation_expression = self._get_param_with_default_value( - "identityValidationExpression", None - ) - authorizer_result_ttl = self._get_param_with_default_value( - "authorizerResultTtlInSeconds", 300 + authorizer_result_ttl = self._get_param( + "authorizerResultTtlInSeconds", if_none=300 ) # Param validation @@ -278,14 +266,10 @@ class APIGatewayResponse(BaseResponse): if self.method == "POST": stage_name = self._get_param("stageName") deployment_id = self._get_param("deploymentId") - stage_variables = self._get_param_with_default_value("variables", {}) - description = self._get_param_with_default_value("description", "") - cacheClusterEnabled = self._get_param_with_default_value( - "cacheClusterEnabled", False - ) - cacheClusterSize = self._get_param_with_default_value( - "cacheClusterSize", None - ) + stage_variables = self._get_param("variables", if_none={}) + description = self._get_param("description", if_none="") + cacheClusterEnabled = self._get_param("cacheClusterEnabled", if_none=False) + cacheClusterSize = self._get_param("cacheClusterSize") stage_response = self.backend.create_stage( function_id, @@ -417,8 +401,8 @@ class APIGatewayResponse(BaseResponse): return 200, {}, json.dumps({"item": deployments}) elif self.method == "POST": name = self._get_param("stageName") - description = self._get_param_with_default_value("description", "") - stage_variables = self._get_param_with_default_value("variables", {}) + description = self._get_param("description", if_none="") + stage_variables = self._get_param("variables", if_none={}) deployment = self.backend.create_deployment( function_id, name, description, stage_variables ) @@ -454,9 +438,20 @@ class APIGatewayResponse(BaseResponse): error.message, error.error_type ), ) + + except ApiKeyValueMinLength as error: + return ( + error.code, + {}, + '{{"message":"{0}","code":"{1}"}}'.format( + error.message, error.error_type + ), + ) return 201, {}, json.dumps(apikey_response) + elif self.method == "GET": - apikeys_response = self.backend.get_apikeys() + include_values = self._get_bool_param("includeValues") + apikeys_response = self.backend.get_apikeys(include_values=include_values) return 200, {}, json.dumps({"item": apikeys_response}) def apikey_individual(self, request, full_url, headers): @@ -467,7 +462,10 @@ class APIGatewayResponse(BaseResponse): status_code = 200 if self.method == "GET": - apikey_response = self.backend.get_apikey(apikey) + include_value = self._get_bool_param("includeValue") + apikey_response = self.backend.get_apikey( + apikey, include_value=include_value + ) elif self.method == "PATCH": patch_operations = self._get_param("patchOperations") apikey_response = self.backend.update_apikey(apikey, patch_operations) diff --git a/tests/test_apigateway/test_apigateway.py b/tests/test_apigateway/test_apigateway.py index 3b4c8ca64..7aa7e052d 100644 --- a/tests/test_apigateway/test_apigateway.py +++ b/tests/test_apigateway/test_apigateway.py @@ -1828,12 +1828,92 @@ def test_http_proxying_integration(): requests.get(deploy_url).content.should.equal(b"a fake response") +@mock_apigateway +def test_api_key_value_min_length(): + region_name = "us-east-1" + client = boto3.client("apigateway", region_name=region_name) + + apikey_value = "12345" + apikey_name = "TESTKEY1" + payload = {"value": apikey_value, "name": apikey_name} + + with pytest.raises(ClientError) as e: + client.create_api_key(**payload) + ex = e.value + ex.operation_name.should.equal("CreateApiKey") + ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.response["Error"]["Code"].should.contain("BadRequestException") + ex.response["Error"]["Message"].should.equal( + "API Key value should be at least 20 characters" + ) + + +@mock_apigateway +def test_get_api_key_include_value(): + region_name = "us-west-2" + client = boto3.client("apigateway", region_name=region_name) + + apikey_value = "01234567890123456789" + apikey_name = "TESTKEY1" + payload = {"value": apikey_value, "name": apikey_name} + + response = client.create_api_key(**payload) + api_key_id_one = response["id"] + + response = client.get_api_key(apiKey=api_key_id_one, includeValue=True) + response.should.have.key("value") + + response = client.get_api_key(apiKey=api_key_id_one) + response.should_not.have.key("value") + + response = client.get_api_key(apiKey=api_key_id_one, includeValue=True) + response.should.have.key("value") + + response = client.get_api_key(apiKey=api_key_id_one, includeValue=False) + response.should_not.have.key("value") + + response = client.get_api_key(apiKey=api_key_id_one, includeValue=True) + response.should.have.key("value") + + +@mock_apigateway +def test_get_api_keys_include_values(): + region_name = "us-west-2" + client = boto3.client("apigateway", region_name=region_name) + + apikey_value = "01234567890123456789" + apikey_name = "TESTKEY1" + payload = {"value": apikey_value, "name": apikey_name} + + apikey_value2 = "01234567890123456789123" + apikey_name2 = "TESTKEY1" + payload2 = {"value": apikey_value2, "name": apikey_name2} + + client.create_api_key(**payload) + client.create_api_key(**payload2) + + response = client.get_api_keys() + len(response["items"]).should.equal(2) + for api_key in response["items"]: + api_key.should_not.have.key("value") + + response = client.get_api_keys(includeValues=True) + len(response["items"]).should.equal(2) + for api_key in response["items"]: + api_key.should.have.key("value") + + response = client.get_api_keys(includeValues=False) + len(response["items"]).should.equal(2) + for api_key in response["items"]: + api_key.should_not.have.key("value") + + @mock_apigateway def test_create_api_key(): region_name = "us-west-2" client = boto3.client("apigateway", region_name=region_name) - apikey_value = "12345" + apikey_value = "01234567890123456789" apikey_name = "TESTKEY1" payload = {"value": apikey_value, "name": apikey_name} @@ -1855,7 +1935,7 @@ def test_create_api_headers(): region_name = "us-west-2" client = boto3.client("apigateway", region_name=region_name) - apikey_value = "12345" + apikey_value = "01234567890123456789" apikey_name = "TESTKEY1" payload = {"value": apikey_value, "name": apikey_name} @@ -1874,7 +1954,7 @@ def test_api_keys(): response = client.get_api_keys() len(response["items"]).should.equal(0) - apikey_value = "12345" + apikey_value = "01234567890123456789" apikey_name = "TESTKEY1" payload = { "value": apikey_value, @@ -1883,7 +1963,7 @@ def test_api_keys(): } response = client.create_api_key(**payload) apikey_id = response["id"] - apikey = client.get_api_key(apiKey=response["id"]) + apikey = client.get_api_key(apiKey=response["id"], includeValue=True) apikey["name"].should.equal(apikey_name) apikey["value"].should.equal(apikey_value) apikey["tags"]["tag1"].should.equal("test_tag1")