Implement APIGateway update_base_path_mapping (#4730)

This commit is contained in:
cm-iwata 2021-12-30 19:52:57 +09:00 committed by GitHub
parent 8def040f8d
commit a9293d62e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 267 additions and 1 deletions

View File

@ -1092,6 +1092,13 @@ class Model(BaseModel, dict):
class BasePathMapping(BaseModel, dict):
# operations
OPERATION_REPLACE = "replace"
OPERATION_PATH = "path"
OPERATION_VALUE = "value"
OPERATION_OP = "op"
def __init__(self, domain_name, rest_api_id, **kwargs):
super().__init__()
self["domain_name"] = domain_name
@ -1100,10 +1107,23 @@ class BasePathMapping(BaseModel, dict):
self["basePath"] = kwargs.get("basePath")
else:
self["basePath"] = "(none)"
if kwargs.get("stage"):
self["stage"] = kwargs.get("stage")
def apply_patch_operations(self, patch_operations):
for op in patch_operations:
path = op["path"]
value = op["value"]
operation = op["op"]
if operation == self.OPERATION_REPLACE:
if "/basePath" in path:
self["basePath"] = value
if "/restapiId" in path:
self["restApiId"] = value
if "/stage" in path:
self["stage"] = value
class APIGatewayBackend(BaseBackend):
"""
@ -1816,5 +1836,51 @@ class APIGatewayBackend(BaseBackend):
self.base_path_mappings[domain_name].pop(base_path)
def update_base_path_mapping(self, domain_name, base_path, patch_operations):
if domain_name not in self.domain_names:
raise DomainNameNotFound()
if base_path not in self.base_path_mappings[domain_name]:
raise BasePathNotFoundException()
base_path_mapping = self.get_base_path_mapping(domain_name, base_path)
rest_api_ids = [
op["value"] for op in patch_operations if op["path"] == "/restapiId"
]
if len(rest_api_ids) == 0:
modified_rest_api_id = base_path_mapping["restApiId"]
else:
modified_rest_api_id = rest_api_ids[-1]
stages = [op["value"] for op in patch_operations if op["path"] == "/stage"]
if len(stages) == 0:
modified_stage = base_path_mapping.get("stage")
else:
modified_stage = stages[-1]
base_paths = [
op["value"] for op in patch_operations if op["path"] == "/basePath"
]
if len(base_paths) == 0:
modified_base_path = base_path_mapping["basePath"]
else:
modified_base_path = base_paths[-1]
rest_api = self.apis.get(modified_rest_api_id)
if rest_api is None:
raise InvalidRestApiIdForBasePathMappingException()
if modified_stage and rest_api.stages.get(modified_stage) is None:
raise InvalidStageException()
base_path_mapping.apply_patch_operations(patch_operations)
if base_path != modified_base_path:
self.base_path_mappings[domain_name].pop(base_path)
self.base_path_mappings[domain_name][modified_base_path] = base_path_mapping
return base_path_mapping
apigateway_backends = BackendDict(APIGatewayBackend, "apigateway")

View File

@ -24,6 +24,8 @@ from .exceptions import (
NoIntegrationResponseDefined,
NotFoundException,
ConflictException,
InvalidRestApiIdForBasePathMappingException,
InvalidStageException,
)
API_KEY_SOURCES = ["AUTHORIZER", "HEADER"]
@ -895,5 +897,15 @@ class APIGatewayResponse(BaseResponse):
elif self.method == "DELETE":
self.backend.delete_base_path_mapping(domain_name, base_path)
return 202, {}, ""
elif self.method == "PATCH":
patch_operations = self._get_param("patchOperations")
base_path_mapping = self.backend.update_base_path_mapping(
domain_name, base_path, patch_operations
)
return 200, {}, json.dumps(base_path_mapping)
except NotFoundException as e:
return self.error("NotFoundException", e.message, 404)
except InvalidRestApiIdForBasePathMappingException as e:
return self.error("BadRequestException", e.message)
except InvalidStageException as e:
return self.error("BadRequestException", e.message)

View File

@ -2652,3 +2652,191 @@ def test_delete_base_path_mapping_with_unknown_base_path():
)
ex.value.response["Error"]["Code"].should.equal("NotFoundException")
ex.value.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(404)
@mock_apigateway
def test_update_path_mapping():
client = boto3.client("apigateway", region_name="us-west-2")
domain_name = "testDomain"
test_certificate_name = "test.certificate"
client.create_domain_name(
domainName=domain_name, certificateName=test_certificate_name
)
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
stage_name = "dev"
client.create_base_path_mapping(domainName=domain_name, restApiId=api_id)
response = client.create_rest_api(
name="new_my_api", description="this is new my api"
)
new_api_id = response["id"]
create_method_integration(client, new_api_id)
client.create_deployment(
restApiId=new_api_id, stageName=stage_name, description="1.0.1"
)
base_path = "v1"
patch_operations = [
{"op": "replace", "path": "/stage", "value": stage_name},
{"op": "replace", "path": "/basePath", "value": base_path},
{"op": "replace", "path": "/restapiId", "value": new_api_id},
]
response = client.update_base_path_mapping(
domainName=domain_name, basePath="(none)", patchOperations=patch_operations
)
response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
response["basePath"].should.equal(base_path)
response["restApiId"].should.equal(new_api_id)
response["stage"].should.equal(stage_name)
@mock_apigateway
def test_update_path_mapping_with_unknown_domain():
client = boto3.client("apigateway", region_name="us-west-2")
with pytest.raises(ClientError) as ex:
client.update_base_path_mapping(
domainName="unknown-domain", basePath="(none)", patchOperations=[]
)
ex.value.response["Error"]["Message"].should.equal(
"Invalid domain name identifier specified"
)
ex.value.response["Error"]["Code"].should.equal("NotFoundException")
ex.value.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(404)
@mock_apigateway
def test_update_path_mapping_with_unknown_base_path():
client = boto3.client("apigateway", region_name="us-west-2")
domain_name = "testDomain"
test_certificate_name = "test.certificate"
client.create_domain_name(
domainName=domain_name, certificateName=test_certificate_name
)
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
client.create_base_path_mapping(
domainName=domain_name, restApiId=api_id, basePath="v1"
)
with pytest.raises(ClientError) as ex:
client.update_base_path_mapping(
domainName=domain_name, basePath="unknown", patchOperations=[]
)
ex.value.response["Error"]["Message"].should.equal(
"Invalid base path mapping identifier specified"
)
ex.value.response["Error"]["Code"].should.equal("NotFoundException")
ex.value.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(404)
@mock_apigateway
def test_update_path_mapping_to_same_base_path():
client = boto3.client("apigateway", region_name="us-west-2")
domain_name = "testDomain"
test_certificate_name = "test.certificate"
client.create_domain_name(
domainName=domain_name, certificateName=test_certificate_name
)
response = client.create_rest_api(name="my_api", description="this is my api")
api_id_1 = response["id"]
response = client.create_rest_api(name="my_api", description="this is my api")
api_id_2 = response["id"]
client.create_base_path_mapping(
domainName=domain_name, restApiId=api_id_1, basePath="v1"
)
client.create_base_path_mapping(
domainName=domain_name, restApiId=api_id_2, basePath="v2"
)
response = client.get_base_path_mappings(domainName=domain_name)
items = response["items"]
len(items).should.equal(2)
patch_operations = [
{"op": "replace", "path": "/basePath", "value": "v2"},
]
response = client.update_base_path_mapping(
domainName=domain_name, basePath="v1", patchOperations=patch_operations
)
response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
response["basePath"].should.equal("v2")
response["restApiId"].should.equal(api_id_1)
response = client.get_base_path_mappings(domainName=domain_name)
items = response["items"]
len(items).should.equal(1)
items[0]["basePath"].should.equal("v2")
items[0]["restApiId"].should.equal(api_id_1)
items[0].should_not.have.key("stage")
@mock_apigateway
def test_update_path_mapping_with_unknown_api():
client = boto3.client("apigateway", region_name="us-west-2")
domain_name = "testDomain"
test_certificate_name = "test.certificate"
client.create_domain_name(
domainName=domain_name, certificateName=test_certificate_name
)
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
base_path = "v1"
client.create_base_path_mapping(
domainName=domain_name, restApiId=api_id, basePath=base_path
)
with pytest.raises(ClientError) as ex:
client.update_base_path_mapping(
domainName=domain_name,
basePath=base_path,
patchOperations=[
{"op": "replace", "path": "/restapiId", "value": "unknown"},
],
)
ex.value.response["Error"]["Message"].should.equal(
"Invalid REST API identifier specified"
)
ex.value.response["Error"]["Code"].should.equal("BadRequestException")
ex.value.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400)
@mock_apigateway
def test_update_path_mapping_with_unknown_stage():
client = boto3.client("apigateway", region_name="us-west-2")
domain_name = "testDomain"
test_certificate_name = "test.certificate"
client.create_domain_name(
domainName=domain_name, certificateName=test_certificate_name
)
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
base_path = "v1"
client.create_base_path_mapping(
domainName=domain_name, restApiId=api_id, basePath=base_path
)
with pytest.raises(ClientError) as ex:
client.update_base_path_mapping(
domainName=domain_name,
basePath=base_path,
patchOperations=[{"op": "replace", "path": "/stage", "value": "unknown"},],
)
ex.value.response["Error"]["Message"].should.equal(
"Invalid stage identifier specified"
)
ex.value.response["Error"]["Code"].should.equal("BadRequestException")
ex.value.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400)