Implement APIGateway update_base_path_mapping (#4730)
This commit is contained in:
parent
8def040f8d
commit
a9293d62e3
@ -1092,6 +1092,13 @@ class Model(BaseModel, dict):
|
||||
|
||||
|
||||
class BasePathMapping(BaseModel, dict):
|
||||
|
||||
# operations
|
||||
OPERATION_REPLACE = "replace"
|
||||
OPERATION_PATH = "path"
|
||||
OPERATION_VALUE = "value"
|
||||
OPERATION_OP = "op"
|
||||
|
||||
def __init__(self, domain_name, rest_api_id, **kwargs):
|
||||
super().__init__()
|
||||
self["domain_name"] = domain_name
|
||||
@ -1100,10 +1107,23 @@ class BasePathMapping(BaseModel, dict):
|
||||
self["basePath"] = kwargs.get("basePath")
|
||||
else:
|
||||
self["basePath"] = "(none)"
|
||||
|
||||
if kwargs.get("stage"):
|
||||
self["stage"] = kwargs.get("stage")
|
||||
|
||||
def apply_patch_operations(self, patch_operations):
|
||||
|
||||
for op in patch_operations:
|
||||
path = op["path"]
|
||||
value = op["value"]
|
||||
operation = op["op"]
|
||||
if operation == self.OPERATION_REPLACE:
|
||||
if "/basePath" in path:
|
||||
self["basePath"] = value
|
||||
if "/restapiId" in path:
|
||||
self["restApiId"] = value
|
||||
if "/stage" in path:
|
||||
self["stage"] = value
|
||||
|
||||
|
||||
class APIGatewayBackend(BaseBackend):
|
||||
"""
|
||||
@ -1816,5 +1836,51 @@ class APIGatewayBackend(BaseBackend):
|
||||
|
||||
self.base_path_mappings[domain_name].pop(base_path)
|
||||
|
||||
def update_base_path_mapping(self, domain_name, base_path, patch_operations):
|
||||
|
||||
if domain_name not in self.domain_names:
|
||||
raise DomainNameNotFound()
|
||||
|
||||
if base_path not in self.base_path_mappings[domain_name]:
|
||||
raise BasePathNotFoundException()
|
||||
|
||||
base_path_mapping = self.get_base_path_mapping(domain_name, base_path)
|
||||
|
||||
rest_api_ids = [
|
||||
op["value"] for op in patch_operations if op["path"] == "/restapiId"
|
||||
]
|
||||
if len(rest_api_ids) == 0:
|
||||
modified_rest_api_id = base_path_mapping["restApiId"]
|
||||
else:
|
||||
modified_rest_api_id = rest_api_ids[-1]
|
||||
|
||||
stages = [op["value"] for op in patch_operations if op["path"] == "/stage"]
|
||||
if len(stages) == 0:
|
||||
modified_stage = base_path_mapping.get("stage")
|
||||
else:
|
||||
modified_stage = stages[-1]
|
||||
|
||||
base_paths = [
|
||||
op["value"] for op in patch_operations if op["path"] == "/basePath"
|
||||
]
|
||||
if len(base_paths) == 0:
|
||||
modified_base_path = base_path_mapping["basePath"]
|
||||
else:
|
||||
modified_base_path = base_paths[-1]
|
||||
|
||||
rest_api = self.apis.get(modified_rest_api_id)
|
||||
if rest_api is None:
|
||||
raise InvalidRestApiIdForBasePathMappingException()
|
||||
if modified_stage and rest_api.stages.get(modified_stage) is None:
|
||||
raise InvalidStageException()
|
||||
|
||||
base_path_mapping.apply_patch_operations(patch_operations)
|
||||
|
||||
if base_path != modified_base_path:
|
||||
self.base_path_mappings[domain_name].pop(base_path)
|
||||
self.base_path_mappings[domain_name][modified_base_path] = base_path_mapping
|
||||
|
||||
return base_path_mapping
|
||||
|
||||
|
||||
apigateway_backends = BackendDict(APIGatewayBackend, "apigateway")
|
||||
|
@ -24,6 +24,8 @@ from .exceptions import (
|
||||
NoIntegrationResponseDefined,
|
||||
NotFoundException,
|
||||
ConflictException,
|
||||
InvalidRestApiIdForBasePathMappingException,
|
||||
InvalidStageException,
|
||||
)
|
||||
|
||||
API_KEY_SOURCES = ["AUTHORIZER", "HEADER"]
|
||||
@ -895,5 +897,15 @@ class APIGatewayResponse(BaseResponse):
|
||||
elif self.method == "DELETE":
|
||||
self.backend.delete_base_path_mapping(domain_name, base_path)
|
||||
return 202, {}, ""
|
||||
elif self.method == "PATCH":
|
||||
patch_operations = self._get_param("patchOperations")
|
||||
base_path_mapping = self.backend.update_base_path_mapping(
|
||||
domain_name, base_path, patch_operations
|
||||
)
|
||||
return 200, {}, json.dumps(base_path_mapping)
|
||||
except NotFoundException as e:
|
||||
return self.error("NotFoundException", e.message, 404)
|
||||
except InvalidRestApiIdForBasePathMappingException as e:
|
||||
return self.error("BadRequestException", e.message)
|
||||
except InvalidStageException as e:
|
||||
return self.error("BadRequestException", e.message)
|
||||
|
@ -2652,3 +2652,191 @@ def test_delete_base_path_mapping_with_unknown_base_path():
|
||||
)
|
||||
ex.value.response["Error"]["Code"].should.equal("NotFoundException")
|
||||
ex.value.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(404)
|
||||
|
||||
|
||||
@mock_apigateway
|
||||
def test_update_path_mapping():
|
||||
client = boto3.client("apigateway", region_name="us-west-2")
|
||||
domain_name = "testDomain"
|
||||
test_certificate_name = "test.certificate"
|
||||
client.create_domain_name(
|
||||
domainName=domain_name, certificateName=test_certificate_name
|
||||
)
|
||||
|
||||
response = client.create_rest_api(name="my_api", description="this is my api")
|
||||
api_id = response["id"]
|
||||
|
||||
stage_name = "dev"
|
||||
|
||||
client.create_base_path_mapping(domainName=domain_name, restApiId=api_id)
|
||||
|
||||
response = client.create_rest_api(
|
||||
name="new_my_api", description="this is new my api"
|
||||
)
|
||||
new_api_id = response["id"]
|
||||
create_method_integration(client, new_api_id)
|
||||
client.create_deployment(
|
||||
restApiId=new_api_id, stageName=stage_name, description="1.0.1"
|
||||
)
|
||||
|
||||
base_path = "v1"
|
||||
patch_operations = [
|
||||
{"op": "replace", "path": "/stage", "value": stage_name},
|
||||
{"op": "replace", "path": "/basePath", "value": base_path},
|
||||
{"op": "replace", "path": "/restapiId", "value": new_api_id},
|
||||
]
|
||||
response = client.update_base_path_mapping(
|
||||
domainName=domain_name, basePath="(none)", patchOperations=patch_operations
|
||||
)
|
||||
response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
|
||||
response["basePath"].should.equal(base_path)
|
||||
response["restApiId"].should.equal(new_api_id)
|
||||
response["stage"].should.equal(stage_name)
|
||||
|
||||
|
||||
@mock_apigateway
|
||||
def test_update_path_mapping_with_unknown_domain():
|
||||
|
||||
client = boto3.client("apigateway", region_name="us-west-2")
|
||||
with pytest.raises(ClientError) as ex:
|
||||
client.update_base_path_mapping(
|
||||
domainName="unknown-domain", basePath="(none)", patchOperations=[]
|
||||
)
|
||||
|
||||
ex.value.response["Error"]["Message"].should.equal(
|
||||
"Invalid domain name identifier specified"
|
||||
)
|
||||
ex.value.response["Error"]["Code"].should.equal("NotFoundException")
|
||||
ex.value.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(404)
|
||||
|
||||
|
||||
@mock_apigateway
|
||||
def test_update_path_mapping_with_unknown_base_path():
|
||||
client = boto3.client("apigateway", region_name="us-west-2")
|
||||
domain_name = "testDomain"
|
||||
test_certificate_name = "test.certificate"
|
||||
client.create_domain_name(
|
||||
domainName=domain_name, certificateName=test_certificate_name
|
||||
)
|
||||
|
||||
response = client.create_rest_api(name="my_api", description="this is my api")
|
||||
api_id = response["id"]
|
||||
client.create_base_path_mapping(
|
||||
domainName=domain_name, restApiId=api_id, basePath="v1"
|
||||
)
|
||||
|
||||
with pytest.raises(ClientError) as ex:
|
||||
client.update_base_path_mapping(
|
||||
domainName=domain_name, basePath="unknown", patchOperations=[]
|
||||
)
|
||||
|
||||
ex.value.response["Error"]["Message"].should.equal(
|
||||
"Invalid base path mapping identifier specified"
|
||||
)
|
||||
ex.value.response["Error"]["Code"].should.equal("NotFoundException")
|
||||
ex.value.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(404)
|
||||
|
||||
|
||||
@mock_apigateway
|
||||
def test_update_path_mapping_to_same_base_path():
|
||||
client = boto3.client("apigateway", region_name="us-west-2")
|
||||
domain_name = "testDomain"
|
||||
test_certificate_name = "test.certificate"
|
||||
client.create_domain_name(
|
||||
domainName=domain_name, certificateName=test_certificate_name
|
||||
)
|
||||
|
||||
response = client.create_rest_api(name="my_api", description="this is my api")
|
||||
api_id_1 = response["id"]
|
||||
response = client.create_rest_api(name="my_api", description="this is my api")
|
||||
api_id_2 = response["id"]
|
||||
|
||||
client.create_base_path_mapping(
|
||||
domainName=domain_name, restApiId=api_id_1, basePath="v1"
|
||||
)
|
||||
client.create_base_path_mapping(
|
||||
domainName=domain_name, restApiId=api_id_2, basePath="v2"
|
||||
)
|
||||
|
||||
response = client.get_base_path_mappings(domainName=domain_name)
|
||||
items = response["items"]
|
||||
len(items).should.equal(2)
|
||||
|
||||
patch_operations = [
|
||||
{"op": "replace", "path": "/basePath", "value": "v2"},
|
||||
]
|
||||
response = client.update_base_path_mapping(
|
||||
domainName=domain_name, basePath="v1", patchOperations=patch_operations
|
||||
)
|
||||
response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
|
||||
response["basePath"].should.equal("v2")
|
||||
response["restApiId"].should.equal(api_id_1)
|
||||
|
||||
response = client.get_base_path_mappings(domainName=domain_name)
|
||||
items = response["items"]
|
||||
len(items).should.equal(1)
|
||||
items[0]["basePath"].should.equal("v2")
|
||||
items[0]["restApiId"].should.equal(api_id_1)
|
||||
items[0].should_not.have.key("stage")
|
||||
|
||||
|
||||
@mock_apigateway
|
||||
def test_update_path_mapping_with_unknown_api():
|
||||
client = boto3.client("apigateway", region_name="us-west-2")
|
||||
domain_name = "testDomain"
|
||||
test_certificate_name = "test.certificate"
|
||||
client.create_domain_name(
|
||||
domainName=domain_name, certificateName=test_certificate_name
|
||||
)
|
||||
|
||||
response = client.create_rest_api(name="my_api", description="this is my api")
|
||||
api_id = response["id"]
|
||||
base_path = "v1"
|
||||
client.create_base_path_mapping(
|
||||
domainName=domain_name, restApiId=api_id, basePath=base_path
|
||||
)
|
||||
|
||||
with pytest.raises(ClientError) as ex:
|
||||
client.update_base_path_mapping(
|
||||
domainName=domain_name,
|
||||
basePath=base_path,
|
||||
patchOperations=[
|
||||
{"op": "replace", "path": "/restapiId", "value": "unknown"},
|
||||
],
|
||||
)
|
||||
|
||||
ex.value.response["Error"]["Message"].should.equal(
|
||||
"Invalid REST API identifier specified"
|
||||
)
|
||||
ex.value.response["Error"]["Code"].should.equal("BadRequestException")
|
||||
ex.value.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400)
|
||||
|
||||
|
||||
@mock_apigateway
|
||||
def test_update_path_mapping_with_unknown_stage():
|
||||
client = boto3.client("apigateway", region_name="us-west-2")
|
||||
domain_name = "testDomain"
|
||||
test_certificate_name = "test.certificate"
|
||||
client.create_domain_name(
|
||||
domainName=domain_name, certificateName=test_certificate_name
|
||||
)
|
||||
|
||||
response = client.create_rest_api(name="my_api", description="this is my api")
|
||||
api_id = response["id"]
|
||||
base_path = "v1"
|
||||
client.create_base_path_mapping(
|
||||
domainName=domain_name, restApiId=api_id, basePath=base_path
|
||||
)
|
||||
|
||||
with pytest.raises(ClientError) as ex:
|
||||
client.update_base_path_mapping(
|
||||
domainName=domain_name,
|
||||
basePath=base_path,
|
||||
patchOperations=[{"op": "replace", "path": "/stage", "value": "unknown"},],
|
||||
)
|
||||
|
||||
ex.value.response["Error"]["Message"].should.equal(
|
||||
"Invalid stage identifier specified"
|
||||
)
|
||||
ex.value.response["Error"]["Code"].should.equal("BadRequestException")
|
||||
ex.value.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400)
|
||||
|
Loading…
Reference in New Issue
Block a user