Add validation to API Key min length value (#3652)

* api gateway - add api key minimum length validation

* api gateway - support includeValue query parameter on api key apis

* [apigateway] code refactoring

* Cleanup - remove duplicate utility methods

* APIGateway - Dont send headers with error messsage

Co-authored-by: Jovan Zivanov <j.zivanov@levi9.com>
This commit is contained in:
Bert Blommers 2021-02-15 09:39:35 +00:00 committed by GitHub
parent d7b8419791
commit 2f50f9cb24
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 152 additions and 43 deletions

View File

@ -182,3 +182,12 @@ class ModelNotFound(RESTError):
super(ModelNotFound, self).__init__(
"NotFoundException", "Invalid Model Name specified"
)
class ApiKeyValueMinLength(RESTError):
code = 400
def __init__(self):
super(ApiKeyValueMinLength, self).__init__(
"BadRequestException", "API Key value should be at least 20 characters"
)

View File

@ -4,6 +4,8 @@ from __future__ import unicode_literals
import random
import string
import re
from copy import copy
import requests
import time
@ -40,6 +42,7 @@ from .exceptions import (
InvalidModelName,
RestAPINotFound,
ModelNotFound,
ApiKeyValueMinLength,
)
from ..core.models import responses_mock
@ -1026,18 +1029,37 @@ class APIGatewayBackend(BaseBackend):
def create_apikey(self, payload):
if payload.get("value") is not None:
for api_key in self.get_apikeys():
if len(payload.get("value", [])) < 20:
raise ApiKeyValueMinLength()
for api_key in self.get_apikeys(include_values=True):
if api_key.get("value") == payload["value"]:
raise ApiKeyAlreadyExists()
key = ApiKey(**payload)
self.keys[key["id"]] = key
return key
def get_apikeys(self):
return list(self.keys.values())
def get_apikeys(self, include_values=False):
api_keys = list(self.keys.values())
def get_apikey(self, api_key_id):
return self.keys[api_key_id]
if not include_values:
keys = []
for api_key in list(self.keys.values()):
new_key = copy(api_key)
del new_key["value"]
keys.append(new_key)
api_keys = keys
return api_keys
def get_apikey(self, api_key_id, include_value=False):
api_key = self.keys[api_key_id]
if not include_value:
new_key = copy(api_key)
del new_key["value"]
api_key = new_key
return api_key
def update_apikey(self, api_key_id, patch_operations):
key = self.keys[api_key_id]

View File

@ -18,6 +18,7 @@ from .exceptions import (
InvalidModelName,
RestAPINotFound,
ModelNotFound,
ApiKeyValueMinLength,
)
API_KEY_SOURCES = ["AUTHORIZER", "HEADER"]
@ -33,17 +34,6 @@ class APIGatewayResponse(BaseResponse):
json.dumps({"__type": type_, "message": message}),
)
def _get_param(self, key):
return json.loads(self.body).get(key) if self.body else None
def _get_param_with_default_value(self, key, default):
jsonbody = json.loads(self.body)
if key in jsonbody:
return jsonbody.get(key)
else:
return default
@property
def backend(self):
return apigateway_backends[self.region]
@ -197,18 +187,16 @@ class APIGatewayResponse(BaseResponse):
name = self._get_param("name")
authorizer_type = self._get_param("type")
provider_arns = self._get_param_with_default_value("providerARNs", None)
auth_type = self._get_param_with_default_value("authType", None)
authorizer_uri = self._get_param_with_default_value("authorizerUri", None)
authorizer_credentials = self._get_param_with_default_value(
"authorizerCredentials", None
provider_arns = self._get_param("providerARNs")
auth_type = self._get_param("authType")
authorizer_uri = self._get_param("authorizerUri")
authorizer_credentials = self._get_param("authorizerCredentials")
identity_source = self._get_param("identitySource")
identiy_validation_expression = self._get_param(
"identityValidationExpression"
)
identity_source = self._get_param_with_default_value("identitySource", None)
identiy_validation_expression = self._get_param_with_default_value(
"identityValidationExpression", None
)
authorizer_result_ttl = self._get_param_with_default_value(
"authorizerResultTtlInSeconds", 300
authorizer_result_ttl = self._get_param(
"authorizerResultTtlInSeconds", if_none=300
)
# Param validation
@ -278,14 +266,10 @@ class APIGatewayResponse(BaseResponse):
if self.method == "POST":
stage_name = self._get_param("stageName")
deployment_id = self._get_param("deploymentId")
stage_variables = self._get_param_with_default_value("variables", {})
description = self._get_param_with_default_value("description", "")
cacheClusterEnabled = self._get_param_with_default_value(
"cacheClusterEnabled", False
)
cacheClusterSize = self._get_param_with_default_value(
"cacheClusterSize", None
)
stage_variables = self._get_param("variables", if_none={})
description = self._get_param("description", if_none="")
cacheClusterEnabled = self._get_param("cacheClusterEnabled", if_none=False)
cacheClusterSize = self._get_param("cacheClusterSize")
stage_response = self.backend.create_stage(
function_id,
@ -417,8 +401,8 @@ class APIGatewayResponse(BaseResponse):
return 200, {}, json.dumps({"item": deployments})
elif self.method == "POST":
name = self._get_param("stageName")
description = self._get_param_with_default_value("description", "")
stage_variables = self._get_param_with_default_value("variables", {})
description = self._get_param("description", if_none="")
stage_variables = self._get_param("variables", if_none={})
deployment = self.backend.create_deployment(
function_id, name, description, stage_variables
)
@ -454,9 +438,20 @@ class APIGatewayResponse(BaseResponse):
error.message, error.error_type
),
)
except ApiKeyValueMinLength as error:
return (
error.code,
{},
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
)
return 201, {}, json.dumps(apikey_response)
elif self.method == "GET":
apikeys_response = self.backend.get_apikeys()
include_values = self._get_bool_param("includeValues")
apikeys_response = self.backend.get_apikeys(include_values=include_values)
return 200, {}, json.dumps({"item": apikeys_response})
def apikey_individual(self, request, full_url, headers):
@ -467,7 +462,10 @@ class APIGatewayResponse(BaseResponse):
status_code = 200
if self.method == "GET":
apikey_response = self.backend.get_apikey(apikey)
include_value = self._get_bool_param("includeValue")
apikey_response = self.backend.get_apikey(
apikey, include_value=include_value
)
elif self.method == "PATCH":
patch_operations = self._get_param("patchOperations")
apikey_response = self.backend.update_apikey(apikey, patch_operations)

View File

@ -1828,12 +1828,92 @@ def test_http_proxying_integration():
requests.get(deploy_url).content.should.equal(b"a fake response")
@mock_apigateway
def test_api_key_value_min_length():
region_name = "us-east-1"
client = boto3.client("apigateway", region_name=region_name)
apikey_value = "12345"
apikey_name = "TESTKEY1"
payload = {"value": apikey_value, "name": apikey_name}
with pytest.raises(ClientError) as e:
client.create_api_key(**payload)
ex = e.value
ex.operation_name.should.equal("CreateApiKey")
ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400)
ex.response["Error"]["Code"].should.contain("BadRequestException")
ex.response["Error"]["Message"].should.equal(
"API Key value should be at least 20 characters"
)
@mock_apigateway
def test_get_api_key_include_value():
region_name = "us-west-2"
client = boto3.client("apigateway", region_name=region_name)
apikey_value = "01234567890123456789"
apikey_name = "TESTKEY1"
payload = {"value": apikey_value, "name": apikey_name}
response = client.create_api_key(**payload)
api_key_id_one = response["id"]
response = client.get_api_key(apiKey=api_key_id_one, includeValue=True)
response.should.have.key("value")
response = client.get_api_key(apiKey=api_key_id_one)
response.should_not.have.key("value")
response = client.get_api_key(apiKey=api_key_id_one, includeValue=True)
response.should.have.key("value")
response = client.get_api_key(apiKey=api_key_id_one, includeValue=False)
response.should_not.have.key("value")
response = client.get_api_key(apiKey=api_key_id_one, includeValue=True)
response.should.have.key("value")
@mock_apigateway
def test_get_api_keys_include_values():
region_name = "us-west-2"
client = boto3.client("apigateway", region_name=region_name)
apikey_value = "01234567890123456789"
apikey_name = "TESTKEY1"
payload = {"value": apikey_value, "name": apikey_name}
apikey_value2 = "01234567890123456789123"
apikey_name2 = "TESTKEY1"
payload2 = {"value": apikey_value2, "name": apikey_name2}
client.create_api_key(**payload)
client.create_api_key(**payload2)
response = client.get_api_keys()
len(response["items"]).should.equal(2)
for api_key in response["items"]:
api_key.should_not.have.key("value")
response = client.get_api_keys(includeValues=True)
len(response["items"]).should.equal(2)
for api_key in response["items"]:
api_key.should.have.key("value")
response = client.get_api_keys(includeValues=False)
len(response["items"]).should.equal(2)
for api_key in response["items"]:
api_key.should_not.have.key("value")
@mock_apigateway
def test_create_api_key():
region_name = "us-west-2"
client = boto3.client("apigateway", region_name=region_name)
apikey_value = "12345"
apikey_value = "01234567890123456789"
apikey_name = "TESTKEY1"
payload = {"value": apikey_value, "name": apikey_name}
@ -1855,7 +1935,7 @@ def test_create_api_headers():
region_name = "us-west-2"
client = boto3.client("apigateway", region_name=region_name)
apikey_value = "12345"
apikey_value = "01234567890123456789"
apikey_name = "TESTKEY1"
payload = {"value": apikey_value, "name": apikey_name}
@ -1874,7 +1954,7 @@ def test_api_keys():
response = client.get_api_keys()
len(response["items"]).should.equal(0)
apikey_value = "12345"
apikey_value = "01234567890123456789"
apikey_name = "TESTKEY1"
payload = {
"value": apikey_value,
@ -1883,7 +1963,7 @@ def test_api_keys():
}
response = client.create_api_key(**payload)
apikey_id = response["id"]
apikey = client.get_api_key(apiKey=response["id"])
apikey = client.get_api_key(apiKey=response["id"], includeValue=True)
apikey["name"].should.equal(apikey_name)
apikey["value"].should.equal(apikey_value)
apikey["tags"]["tag1"].should.equal("test_tag1")