Add validation to API Key min length value (#3652)
* api gateway - add api key minimum length validation * api gateway - support includeValue query parameter on api key apis * [apigateway] code refactoring * Cleanup - remove duplicate utility methods * APIGateway - Dont send headers with error messsage Co-authored-by: Jovan Zivanov <j.zivanov@levi9.com>
This commit is contained in:
parent
d7b8419791
commit
2f50f9cb24
@ -182,3 +182,12 @@ class ModelNotFound(RESTError):
|
||||
super(ModelNotFound, self).__init__(
|
||||
"NotFoundException", "Invalid Model Name specified"
|
||||
)
|
||||
|
||||
|
||||
class ApiKeyValueMinLength(RESTError):
|
||||
code = 400
|
||||
|
||||
def __init__(self):
|
||||
super(ApiKeyValueMinLength, self).__init__(
|
||||
"BadRequestException", "API Key value should be at least 20 characters"
|
||||
)
|
||||
|
@ -4,6 +4,8 @@ from __future__ import unicode_literals
|
||||
import random
|
||||
import string
|
||||
import re
|
||||
from copy import copy
|
||||
|
||||
import requests
|
||||
import time
|
||||
|
||||
@ -40,6 +42,7 @@ from .exceptions import (
|
||||
InvalidModelName,
|
||||
RestAPINotFound,
|
||||
ModelNotFound,
|
||||
ApiKeyValueMinLength,
|
||||
)
|
||||
from ..core.models import responses_mock
|
||||
|
||||
@ -1026,18 +1029,37 @@ class APIGatewayBackend(BaseBackend):
|
||||
|
||||
def create_apikey(self, payload):
|
||||
if payload.get("value") is not None:
|
||||
for api_key in self.get_apikeys():
|
||||
if len(payload.get("value", [])) < 20:
|
||||
raise ApiKeyValueMinLength()
|
||||
for api_key in self.get_apikeys(include_values=True):
|
||||
if api_key.get("value") == payload["value"]:
|
||||
raise ApiKeyAlreadyExists()
|
||||
key = ApiKey(**payload)
|
||||
self.keys[key["id"]] = key
|
||||
return key
|
||||
|
||||
def get_apikeys(self):
|
||||
return list(self.keys.values())
|
||||
def get_apikeys(self, include_values=False):
|
||||
api_keys = list(self.keys.values())
|
||||
|
||||
def get_apikey(self, api_key_id):
|
||||
return self.keys[api_key_id]
|
||||
if not include_values:
|
||||
keys = []
|
||||
for api_key in list(self.keys.values()):
|
||||
new_key = copy(api_key)
|
||||
del new_key["value"]
|
||||
keys.append(new_key)
|
||||
api_keys = keys
|
||||
|
||||
return api_keys
|
||||
|
||||
def get_apikey(self, api_key_id, include_value=False):
|
||||
api_key = self.keys[api_key_id]
|
||||
|
||||
if not include_value:
|
||||
new_key = copy(api_key)
|
||||
del new_key["value"]
|
||||
api_key = new_key
|
||||
|
||||
return api_key
|
||||
|
||||
def update_apikey(self, api_key_id, patch_operations):
|
||||
key = self.keys[api_key_id]
|
||||
|
@ -18,6 +18,7 @@ from .exceptions import (
|
||||
InvalidModelName,
|
||||
RestAPINotFound,
|
||||
ModelNotFound,
|
||||
ApiKeyValueMinLength,
|
||||
)
|
||||
|
||||
API_KEY_SOURCES = ["AUTHORIZER", "HEADER"]
|
||||
@ -33,17 +34,6 @@ class APIGatewayResponse(BaseResponse):
|
||||
json.dumps({"__type": type_, "message": message}),
|
||||
)
|
||||
|
||||
def _get_param(self, key):
|
||||
return json.loads(self.body).get(key) if self.body else None
|
||||
|
||||
def _get_param_with_default_value(self, key, default):
|
||||
jsonbody = json.loads(self.body)
|
||||
|
||||
if key in jsonbody:
|
||||
return jsonbody.get(key)
|
||||
else:
|
||||
return default
|
||||
|
||||
@property
|
||||
def backend(self):
|
||||
return apigateway_backends[self.region]
|
||||
@ -197,18 +187,16 @@ class APIGatewayResponse(BaseResponse):
|
||||
name = self._get_param("name")
|
||||
authorizer_type = self._get_param("type")
|
||||
|
||||
provider_arns = self._get_param_with_default_value("providerARNs", None)
|
||||
auth_type = self._get_param_with_default_value("authType", None)
|
||||
authorizer_uri = self._get_param_with_default_value("authorizerUri", None)
|
||||
authorizer_credentials = self._get_param_with_default_value(
|
||||
"authorizerCredentials", None
|
||||
provider_arns = self._get_param("providerARNs")
|
||||
auth_type = self._get_param("authType")
|
||||
authorizer_uri = self._get_param("authorizerUri")
|
||||
authorizer_credentials = self._get_param("authorizerCredentials")
|
||||
identity_source = self._get_param("identitySource")
|
||||
identiy_validation_expression = self._get_param(
|
||||
"identityValidationExpression"
|
||||
)
|
||||
identity_source = self._get_param_with_default_value("identitySource", None)
|
||||
identiy_validation_expression = self._get_param_with_default_value(
|
||||
"identityValidationExpression", None
|
||||
)
|
||||
authorizer_result_ttl = self._get_param_with_default_value(
|
||||
"authorizerResultTtlInSeconds", 300
|
||||
authorizer_result_ttl = self._get_param(
|
||||
"authorizerResultTtlInSeconds", if_none=300
|
||||
)
|
||||
|
||||
# Param validation
|
||||
@ -278,14 +266,10 @@ class APIGatewayResponse(BaseResponse):
|
||||
if self.method == "POST":
|
||||
stage_name = self._get_param("stageName")
|
||||
deployment_id = self._get_param("deploymentId")
|
||||
stage_variables = self._get_param_with_default_value("variables", {})
|
||||
description = self._get_param_with_default_value("description", "")
|
||||
cacheClusterEnabled = self._get_param_with_default_value(
|
||||
"cacheClusterEnabled", False
|
||||
)
|
||||
cacheClusterSize = self._get_param_with_default_value(
|
||||
"cacheClusterSize", None
|
||||
)
|
||||
stage_variables = self._get_param("variables", if_none={})
|
||||
description = self._get_param("description", if_none="")
|
||||
cacheClusterEnabled = self._get_param("cacheClusterEnabled", if_none=False)
|
||||
cacheClusterSize = self._get_param("cacheClusterSize")
|
||||
|
||||
stage_response = self.backend.create_stage(
|
||||
function_id,
|
||||
@ -417,8 +401,8 @@ class APIGatewayResponse(BaseResponse):
|
||||
return 200, {}, json.dumps({"item": deployments})
|
||||
elif self.method == "POST":
|
||||
name = self._get_param("stageName")
|
||||
description = self._get_param_with_default_value("description", "")
|
||||
stage_variables = self._get_param_with_default_value("variables", {})
|
||||
description = self._get_param("description", if_none="")
|
||||
stage_variables = self._get_param("variables", if_none={})
|
||||
deployment = self.backend.create_deployment(
|
||||
function_id, name, description, stage_variables
|
||||
)
|
||||
@ -454,9 +438,20 @@ class APIGatewayResponse(BaseResponse):
|
||||
error.message, error.error_type
|
||||
),
|
||||
)
|
||||
|
||||
except ApiKeyValueMinLength as error:
|
||||
return (
|
||||
error.code,
|
||||
{},
|
||||
'{{"message":"{0}","code":"{1}"}}'.format(
|
||||
error.message, error.error_type
|
||||
),
|
||||
)
|
||||
return 201, {}, json.dumps(apikey_response)
|
||||
|
||||
elif self.method == "GET":
|
||||
apikeys_response = self.backend.get_apikeys()
|
||||
include_values = self._get_bool_param("includeValues")
|
||||
apikeys_response = self.backend.get_apikeys(include_values=include_values)
|
||||
return 200, {}, json.dumps({"item": apikeys_response})
|
||||
|
||||
def apikey_individual(self, request, full_url, headers):
|
||||
@ -467,7 +462,10 @@ class APIGatewayResponse(BaseResponse):
|
||||
|
||||
status_code = 200
|
||||
if self.method == "GET":
|
||||
apikey_response = self.backend.get_apikey(apikey)
|
||||
include_value = self._get_bool_param("includeValue")
|
||||
apikey_response = self.backend.get_apikey(
|
||||
apikey, include_value=include_value
|
||||
)
|
||||
elif self.method == "PATCH":
|
||||
patch_operations = self._get_param("patchOperations")
|
||||
apikey_response = self.backend.update_apikey(apikey, patch_operations)
|
||||
|
@ -1828,12 +1828,92 @@ def test_http_proxying_integration():
|
||||
requests.get(deploy_url).content.should.equal(b"a fake response")
|
||||
|
||||
|
||||
@mock_apigateway
|
||||
def test_api_key_value_min_length():
|
||||
region_name = "us-east-1"
|
||||
client = boto3.client("apigateway", region_name=region_name)
|
||||
|
||||
apikey_value = "12345"
|
||||
apikey_name = "TESTKEY1"
|
||||
payload = {"value": apikey_value, "name": apikey_name}
|
||||
|
||||
with pytest.raises(ClientError) as e:
|
||||
client.create_api_key(**payload)
|
||||
ex = e.value
|
||||
ex.operation_name.should.equal("CreateApiKey")
|
||||
ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400)
|
||||
ex.response["Error"]["Code"].should.contain("BadRequestException")
|
||||
ex.response["Error"]["Message"].should.equal(
|
||||
"API Key value should be at least 20 characters"
|
||||
)
|
||||
|
||||
|
||||
@mock_apigateway
|
||||
def test_get_api_key_include_value():
|
||||
region_name = "us-west-2"
|
||||
client = boto3.client("apigateway", region_name=region_name)
|
||||
|
||||
apikey_value = "01234567890123456789"
|
||||
apikey_name = "TESTKEY1"
|
||||
payload = {"value": apikey_value, "name": apikey_name}
|
||||
|
||||
response = client.create_api_key(**payload)
|
||||
api_key_id_one = response["id"]
|
||||
|
||||
response = client.get_api_key(apiKey=api_key_id_one, includeValue=True)
|
||||
response.should.have.key("value")
|
||||
|
||||
response = client.get_api_key(apiKey=api_key_id_one)
|
||||
response.should_not.have.key("value")
|
||||
|
||||
response = client.get_api_key(apiKey=api_key_id_one, includeValue=True)
|
||||
response.should.have.key("value")
|
||||
|
||||
response = client.get_api_key(apiKey=api_key_id_one, includeValue=False)
|
||||
response.should_not.have.key("value")
|
||||
|
||||
response = client.get_api_key(apiKey=api_key_id_one, includeValue=True)
|
||||
response.should.have.key("value")
|
||||
|
||||
|
||||
@mock_apigateway
|
||||
def test_get_api_keys_include_values():
|
||||
region_name = "us-west-2"
|
||||
client = boto3.client("apigateway", region_name=region_name)
|
||||
|
||||
apikey_value = "01234567890123456789"
|
||||
apikey_name = "TESTKEY1"
|
||||
payload = {"value": apikey_value, "name": apikey_name}
|
||||
|
||||
apikey_value2 = "01234567890123456789123"
|
||||
apikey_name2 = "TESTKEY1"
|
||||
payload2 = {"value": apikey_value2, "name": apikey_name2}
|
||||
|
||||
client.create_api_key(**payload)
|
||||
client.create_api_key(**payload2)
|
||||
|
||||
response = client.get_api_keys()
|
||||
len(response["items"]).should.equal(2)
|
||||
for api_key in response["items"]:
|
||||
api_key.should_not.have.key("value")
|
||||
|
||||
response = client.get_api_keys(includeValues=True)
|
||||
len(response["items"]).should.equal(2)
|
||||
for api_key in response["items"]:
|
||||
api_key.should.have.key("value")
|
||||
|
||||
response = client.get_api_keys(includeValues=False)
|
||||
len(response["items"]).should.equal(2)
|
||||
for api_key in response["items"]:
|
||||
api_key.should_not.have.key("value")
|
||||
|
||||
|
||||
@mock_apigateway
|
||||
def test_create_api_key():
|
||||
region_name = "us-west-2"
|
||||
client = boto3.client("apigateway", region_name=region_name)
|
||||
|
||||
apikey_value = "12345"
|
||||
apikey_value = "01234567890123456789"
|
||||
apikey_name = "TESTKEY1"
|
||||
payload = {"value": apikey_value, "name": apikey_name}
|
||||
|
||||
@ -1855,7 +1935,7 @@ def test_create_api_headers():
|
||||
region_name = "us-west-2"
|
||||
client = boto3.client("apigateway", region_name=region_name)
|
||||
|
||||
apikey_value = "12345"
|
||||
apikey_value = "01234567890123456789"
|
||||
apikey_name = "TESTKEY1"
|
||||
payload = {"value": apikey_value, "name": apikey_name}
|
||||
|
||||
@ -1874,7 +1954,7 @@ def test_api_keys():
|
||||
response = client.get_api_keys()
|
||||
len(response["items"]).should.equal(0)
|
||||
|
||||
apikey_value = "12345"
|
||||
apikey_value = "01234567890123456789"
|
||||
apikey_name = "TESTKEY1"
|
||||
payload = {
|
||||
"value": apikey_value,
|
||||
@ -1883,7 +1963,7 @@ def test_api_keys():
|
||||
}
|
||||
response = client.create_api_key(**payload)
|
||||
apikey_id = response["id"]
|
||||
apikey = client.get_api_key(apiKey=response["id"])
|
||||
apikey = client.get_api_key(apiKey=response["id"], includeValue=True)
|
||||
apikey["name"].should.equal(apikey_name)
|
||||
apikey["value"].should.equal(apikey_value)
|
||||
apikey["tags"]["tag1"].should.equal("test_tag1")
|
||||
|
Loading…
Reference in New Issue
Block a user