From a9293d62e319fbe89e7c99500cecedb47328b169 Mon Sep 17 00:00:00 2001 From: cm-iwata <38879253+cm-iwata@users.noreply.github.com> Date: Thu, 30 Dec 2021 19:52:57 +0900 Subject: [PATCH] Implement APIGateway update_base_path_mapping (#4730) --- moto/apigateway/models.py | 68 +++++++- moto/apigateway/responses.py | 12 ++ tests/test_apigateway/test_apigateway.py | 188 +++++++++++++++++++++++ 3 files changed, 267 insertions(+), 1 deletion(-) diff --git a/moto/apigateway/models.py b/moto/apigateway/models.py index 12f9005d1..2e1ad4f3f 100644 --- a/moto/apigateway/models.py +++ b/moto/apigateway/models.py @@ -1092,6 +1092,13 @@ class Model(BaseModel, dict): class BasePathMapping(BaseModel, dict): + + # operations + OPERATION_REPLACE = "replace" + OPERATION_PATH = "path" + OPERATION_VALUE = "value" + OPERATION_OP = "op" + def __init__(self, domain_name, rest_api_id, **kwargs): super().__init__() self["domain_name"] = domain_name @@ -1100,10 +1107,23 @@ class BasePathMapping(BaseModel, dict): self["basePath"] = kwargs.get("basePath") else: self["basePath"] = "(none)" - if kwargs.get("stage"): self["stage"] = kwargs.get("stage") + def apply_patch_operations(self, patch_operations): + + for op in patch_operations: + path = op["path"] + value = op["value"] + operation = op["op"] + if operation == self.OPERATION_REPLACE: + if "/basePath" in path: + self["basePath"] = value + if "/restapiId" in path: + self["restApiId"] = value + if "/stage" in path: + self["stage"] = value + class APIGatewayBackend(BaseBackend): """ @@ -1816,5 +1836,51 @@ class APIGatewayBackend(BaseBackend): self.base_path_mappings[domain_name].pop(base_path) + def update_base_path_mapping(self, domain_name, base_path, patch_operations): + + if domain_name not in self.domain_names: + raise DomainNameNotFound() + + if base_path not in self.base_path_mappings[domain_name]: + raise BasePathNotFoundException() + + base_path_mapping = self.get_base_path_mapping(domain_name, base_path) + + rest_api_ids = [ + op["value"] for op in patch_operations if op["path"] == "/restapiId" + ] + if len(rest_api_ids) == 0: + modified_rest_api_id = base_path_mapping["restApiId"] + else: + modified_rest_api_id = rest_api_ids[-1] + + stages = [op["value"] for op in patch_operations if op["path"] == "/stage"] + if len(stages) == 0: + modified_stage = base_path_mapping.get("stage") + else: + modified_stage = stages[-1] + + base_paths = [ + op["value"] for op in patch_operations if op["path"] == "/basePath" + ] + if len(base_paths) == 0: + modified_base_path = base_path_mapping["basePath"] + else: + modified_base_path = base_paths[-1] + + rest_api = self.apis.get(modified_rest_api_id) + if rest_api is None: + raise InvalidRestApiIdForBasePathMappingException() + if modified_stage and rest_api.stages.get(modified_stage) is None: + raise InvalidStageException() + + base_path_mapping.apply_patch_operations(patch_operations) + + if base_path != modified_base_path: + self.base_path_mappings[domain_name].pop(base_path) + self.base_path_mappings[domain_name][modified_base_path] = base_path_mapping + + return base_path_mapping + apigateway_backends = BackendDict(APIGatewayBackend, "apigateway") diff --git a/moto/apigateway/responses.py b/moto/apigateway/responses.py index 11eb3ea1d..c162e8138 100644 --- a/moto/apigateway/responses.py +++ b/moto/apigateway/responses.py @@ -24,6 +24,8 @@ from .exceptions import ( NoIntegrationResponseDefined, NotFoundException, ConflictException, + InvalidRestApiIdForBasePathMappingException, + InvalidStageException, ) API_KEY_SOURCES = ["AUTHORIZER", "HEADER"] @@ -895,5 +897,15 @@ class APIGatewayResponse(BaseResponse): elif self.method == "DELETE": self.backend.delete_base_path_mapping(domain_name, base_path) return 202, {}, "" + elif self.method == "PATCH": + patch_operations = self._get_param("patchOperations") + base_path_mapping = self.backend.update_base_path_mapping( + domain_name, base_path, patch_operations + ) + return 200, {}, json.dumps(base_path_mapping) except NotFoundException as e: return self.error("NotFoundException", e.message, 404) + except InvalidRestApiIdForBasePathMappingException as e: + return self.error("BadRequestException", e.message) + except InvalidStageException as e: + return self.error("BadRequestException", e.message) diff --git a/tests/test_apigateway/test_apigateway.py b/tests/test_apigateway/test_apigateway.py index 91f6d90ca..78c53db1f 100644 --- a/tests/test_apigateway/test_apigateway.py +++ b/tests/test_apigateway/test_apigateway.py @@ -2652,3 +2652,191 @@ def test_delete_base_path_mapping_with_unknown_base_path(): ) ex.value.response["Error"]["Code"].should.equal("NotFoundException") ex.value.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(404) + + +@mock_apigateway +def test_update_path_mapping(): + client = boto3.client("apigateway", region_name="us-west-2") + domain_name = "testDomain" + test_certificate_name = "test.certificate" + client.create_domain_name( + domainName=domain_name, certificateName=test_certificate_name + ) + + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] + + stage_name = "dev" + + client.create_base_path_mapping(domainName=domain_name, restApiId=api_id) + + response = client.create_rest_api( + name="new_my_api", description="this is new my api" + ) + new_api_id = response["id"] + create_method_integration(client, new_api_id) + client.create_deployment( + restApiId=new_api_id, stageName=stage_name, description="1.0.1" + ) + + base_path = "v1" + patch_operations = [ + {"op": "replace", "path": "/stage", "value": stage_name}, + {"op": "replace", "path": "/basePath", "value": base_path}, + {"op": "replace", "path": "/restapiId", "value": new_api_id}, + ] + response = client.update_base_path_mapping( + domainName=domain_name, basePath="(none)", patchOperations=patch_operations + ) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + response["basePath"].should.equal(base_path) + response["restApiId"].should.equal(new_api_id) + response["stage"].should.equal(stage_name) + + +@mock_apigateway +def test_update_path_mapping_with_unknown_domain(): + + client = boto3.client("apigateway", region_name="us-west-2") + with pytest.raises(ClientError) as ex: + client.update_base_path_mapping( + domainName="unknown-domain", basePath="(none)", patchOperations=[] + ) + + ex.value.response["Error"]["Message"].should.equal( + "Invalid domain name identifier specified" + ) + ex.value.response["Error"]["Code"].should.equal("NotFoundException") + ex.value.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(404) + + +@mock_apigateway +def test_update_path_mapping_with_unknown_base_path(): + client = boto3.client("apigateway", region_name="us-west-2") + domain_name = "testDomain" + test_certificate_name = "test.certificate" + client.create_domain_name( + domainName=domain_name, certificateName=test_certificate_name + ) + + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] + client.create_base_path_mapping( + domainName=domain_name, restApiId=api_id, basePath="v1" + ) + + with pytest.raises(ClientError) as ex: + client.update_base_path_mapping( + domainName=domain_name, basePath="unknown", patchOperations=[] + ) + + ex.value.response["Error"]["Message"].should.equal( + "Invalid base path mapping identifier specified" + ) + ex.value.response["Error"]["Code"].should.equal("NotFoundException") + ex.value.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(404) + + +@mock_apigateway +def test_update_path_mapping_to_same_base_path(): + client = boto3.client("apigateway", region_name="us-west-2") + domain_name = "testDomain" + test_certificate_name = "test.certificate" + client.create_domain_name( + domainName=domain_name, certificateName=test_certificate_name + ) + + response = client.create_rest_api(name="my_api", description="this is my api") + api_id_1 = response["id"] + response = client.create_rest_api(name="my_api", description="this is my api") + api_id_2 = response["id"] + + client.create_base_path_mapping( + domainName=domain_name, restApiId=api_id_1, basePath="v1" + ) + client.create_base_path_mapping( + domainName=domain_name, restApiId=api_id_2, basePath="v2" + ) + + response = client.get_base_path_mappings(domainName=domain_name) + items = response["items"] + len(items).should.equal(2) + + patch_operations = [ + {"op": "replace", "path": "/basePath", "value": "v2"}, + ] + response = client.update_base_path_mapping( + domainName=domain_name, basePath="v1", patchOperations=patch_operations + ) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + response["basePath"].should.equal("v2") + response["restApiId"].should.equal(api_id_1) + + response = client.get_base_path_mappings(domainName=domain_name) + items = response["items"] + len(items).should.equal(1) + items[0]["basePath"].should.equal("v2") + items[0]["restApiId"].should.equal(api_id_1) + items[0].should_not.have.key("stage") + + +@mock_apigateway +def test_update_path_mapping_with_unknown_api(): + client = boto3.client("apigateway", region_name="us-west-2") + domain_name = "testDomain" + test_certificate_name = "test.certificate" + client.create_domain_name( + domainName=domain_name, certificateName=test_certificate_name + ) + + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] + base_path = "v1" + client.create_base_path_mapping( + domainName=domain_name, restApiId=api_id, basePath=base_path + ) + + with pytest.raises(ClientError) as ex: + client.update_base_path_mapping( + domainName=domain_name, + basePath=base_path, + patchOperations=[ + {"op": "replace", "path": "/restapiId", "value": "unknown"}, + ], + ) + + ex.value.response["Error"]["Message"].should.equal( + "Invalid REST API identifier specified" + ) + ex.value.response["Error"]["Code"].should.equal("BadRequestException") + ex.value.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + + +@mock_apigateway +def test_update_path_mapping_with_unknown_stage(): + client = boto3.client("apigateway", region_name="us-west-2") + domain_name = "testDomain" + test_certificate_name = "test.certificate" + client.create_domain_name( + domainName=domain_name, certificateName=test_certificate_name + ) + + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] + base_path = "v1" + client.create_base_path_mapping( + domainName=domain_name, restApiId=api_id, basePath=base_path + ) + + with pytest.raises(ClientError) as ex: + client.update_base_path_mapping( + domainName=domain_name, + basePath=base_path, + patchOperations=[{"op": "replace", "path": "/stage", "value": "unknown"},], + ) + + ex.value.response["Error"]["Message"].should.equal( + "Invalid stage identifier specified" + ) + ex.value.response["Error"]["Code"].should.equal("BadRequestException") + ex.value.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400)