diff --git a/moto/apigateway/exceptions.py b/moto/apigateway/exceptions.py index c9c90cea5..8f6d21aa0 100644 --- a/moto/apigateway/exceptions.py +++ b/moto/apigateway/exceptions.py @@ -137,3 +137,39 @@ class DomainNameNotFound(RESTError): super(DomainNameNotFound, self).__init__( "NotFoundException", "Invalid Domain Name specified" ) + + +class InvalidRestApiId(BadRequestException): + code = 404 + + def __init__(self): + super(InvalidRestApiId, self).__init__( + "BadRequestException", "No Rest API Id specified" + ) + + +class InvalidModelName(BadRequestException): + code = 404 + + def __init__(self): + super(InvalidModelName, self).__init__( + "BadRequestException", "No Model Name specified" + ) + + +class RestAPINotFound(RESTError): + code = 404 + + def __init__(self): + super(RestAPINotFound, self).__init__( + "NotFoundException", "Invalid Rest API Id specified" + ) + + +class ModelNotFound(RESTError): + code = 404 + + def __init__(self): + super(ModelNotFound, self).__init__( + "NotFoundException", "Invalid Model Name specified" + ) diff --git a/moto/apigateway/models.py b/moto/apigateway/models.py index 16462e278..5ce95742e 100644 --- a/moto/apigateway/models.py +++ b/moto/apigateway/models.py @@ -36,6 +36,10 @@ from .exceptions import ( ApiKeyAlreadyExists, DomainNameNotFound, InvalidDomainName, + InvalidRestApiId, + InvalidModelName, + RestAPINotFound, + ModelNotFound ) STAGE_URL = "https://{api_id}.execute-api.{region_name}.amazonaws.com/{stage_name}" @@ -466,6 +470,7 @@ class RestAPI(BaseModel): self.authorizers = {} self.stages = {} self.resources = {} + self.models = {} self.add_child("/") # Add default child def __repr__(self): @@ -494,6 +499,27 @@ class RestAPI(BaseModel): self.resources[child_id] = child return child + def add_model(self, + name, + description=None, + schema=None, + content_type=None, + cli_input_json=None, + generate_cli_skeleton=None): + model_id = create_id() + new_model = Model( + id=model_id, + name=name, + description=description, + schema=schema, + content_type=content_type, + cli_input_json=cli_input_json, + generate_cli_skeleton=generate_cli_skeleton) + + self.models[name] = new_model + return new_model + + def get_resource_for_path(self, path_after_stage_name): for resource in self.resources.values(): if resource.get_path() == path_after_stage_name: @@ -645,6 +671,24 @@ class DomainName(BaseModel, dict): self["generateCliSkeleton"] = kwargs.get("generate_cli_skeleton") +class Model(BaseModel,dict): + def __init__(self, id, name, **kwargs): + super(Model, self).__init__() + self["id"] = id + self["name"] = name + if kwargs.get("description"): + self["description"] = kwargs.get("description") + if kwargs.get("schema"): + self["schema"] = kwargs.get("schema") + if kwargs.get("content_type"): + self["contentType"] = kwargs.get("content_type") + if kwargs.get("cli_input_json"): + self["cliInputJson"] = kwargs.get("cli_input_json") + if kwargs.get("generate_cli_skeleton"): + self["generateCliSkeleton"] = kwargs.get("generate_cli_skeleton") + + + class APIGatewayBackend(BaseBackend): def __init__(self, region_name): super(APIGatewayBackend, self).__init__() @@ -653,6 +697,7 @@ class APIGatewayBackend(BaseBackend): self.usage_plans = {} self.usage_plan_keys = {} self.domain_names = {} + self.models = {} self.region_name = region_name def reset(self): @@ -682,7 +727,9 @@ class APIGatewayBackend(BaseBackend): return rest_api def get_rest_api(self, function_id): - rest_api = self.apis[function_id] + rest_api = self.apis.get(function_id) + if rest_api is None: + raise RestAPINotFound() return rest_api def list_apis(self): @@ -1085,6 +1132,47 @@ class APIGatewayBackend(BaseBackend): else: return self.domain_names[domain_name] + def create_model(self, + rest_api_id, + name, + content_type, + description=None, + schema=None, + cli_input_json=None, + generate_cli_skeleton=None): + + if not rest_api_id: + raise InvalidRestApiId + if not name: + raise InvalidModelName + + api = self.get_rest_api(rest_api_id) + new_model = api.add_model( + name=name, + description=description, + schema=schema, + content_type=content_type, + cli_input_json=cli_input_json, + generate_cli_skeleton=generate_cli_skeleton) + + return new_model + + def get_models(self, rest_api_id): + if not rest_api_id: + raise InvalidRestApiId + api = self.get_rest_api(rest_api_id) + models = api.models.values() + return list(models) + + def get_model(self, rest_api_id, model_name): + if not rest_api_id: + raise InvalidRestApiId + api = self.get_rest_api(rest_api_id) + model = api.models.get(model_name) + if model is None: + raise ModelNotFound + return model + apigateway_backends = {} for region_name in Session().get_available_regions("apigateway"): diff --git a/moto/apigateway/responses.py b/moto/apigateway/responses.py index e4723f0d4..c18b7f6c4 100644 --- a/moto/apigateway/responses.py +++ b/moto/apigateway/responses.py @@ -13,6 +13,10 @@ from .exceptions import ( ApiKeyAlreadyExists, DomainNameNotFound, InvalidDomainName, + InvalidRestApiId, + InvalidModelName, + RestAPINotFound, + ModelNotFound ) API_KEY_SOURCES = ["AUTHORIZER", "HEADER"] @@ -595,3 +599,67 @@ class APIGatewayResponse(BaseResponse): error.message, error.error_type ), ) + + def models(self,request, full_url, headers): + self.setup_class(request, full_url, headers) + rest_api_id = self.path.replace("/restapis/", "", 1).split("/")[0] + + try: + if self.method == "GET": + models = self.backend.get_models( + rest_api_id + ) + return 200, {}, json.dumps({"item": models}) + + elif self.method == "POST": + name = self._get_param("name") + description = self._get_param("description") + schema = self._get_param("schema") + content_type = self._get_param("contentType") + cli_input_json = self._get_param("cliInputJson") + generate_cli_skeleton = self._get_param( + "generateCliSkeleton" + ) + model = self.backend.create_model( + rest_api_id, + name, + content_type, + description, + schema, + cli_input_json, + generate_cli_skeleton + ) + + return 200, {}, json.dumps(model) + + except (InvalidRestApiId, InvalidModelName,RestAPINotFound) as error: + return ( + error.code, + {}, + '{{"message":"{0}","code":"{1}"}}'.format( + error.message, error.error_type + ), + ) + + def model_induvidual(self, request, full_url, headers): + self.setup_class(request, full_url, headers) + url_path_parts = self.path.split("/") + rest_api_id = url_path_parts[2] + model_name = url_path_parts[4] + model_info = {} + try: + if self.method == "GET": + model_info = self.backend.get_model( + rest_api_id, + model_name + ) + return 200, {}, json.dumps(model_info) + except (ModelNotFound, RestAPINotFound, InvalidRestApiId, + InvalidModelName) as error: + return ( + error.code, + {}, + '{{"message":"{0}","code":"{1}"}}'.format( + error.message, error.error_type + ), + ) \ No newline at end of file diff --git a/moto/apigateway/urls.py b/moto/apigateway/urls.py index 6c3b7f6bb..751d8ae65 100644 --- a/moto/apigateway/urls.py +++ b/moto/apigateway/urls.py @@ -22,6 +22,8 @@ url_paths = { "{0}/apikeys/(?P[^/]+)": APIGatewayResponse().apikey_individual, "{0}/usageplans$": APIGatewayResponse().usage_plans, "{0}/domainnames$": APIGatewayResponse().domain_names, + "{0}/restapis/(?P[^/]+)/models": APIGatewayResponse().models, + "{0}/restapis/(?P[^/]+)/models/(?P[^/]+)/?$": APIGatewayResponse().model_induvidual, "{0}/domainnames/(?P[^/]+)/?$": APIGatewayResponse().domain_name_induvidual, "{0}/usageplans/(?P[^/]+)/?$": APIGatewayResponse().usage_plan_individual, "{0}/usageplans/(?P[^/]+)/keys$": APIGatewayResponse().usage_plan_keys, diff --git a/tests/test_apigateway/test_apigateway.py b/tests/test_apigateway/test_apigateway.py index a1a380974..3a6b75104 100644 --- a/tests/test_apigateway/test_apigateway.py +++ b/tests/test_apigateway/test_apigateway.py @@ -1547,6 +1547,148 @@ def test_get_domain_name(): result["domainNameStatus"].should.equal("AVAILABLE") +@mock_apigateway +def test_create_model(): + client = boto3.client("apigateway", region_name="us-west-2") + response = client.create_rest_api(name="my_api", + description="this is my api" + ) + rest_api_id = response["id"] + dummy_rest_api_id = 'a12b3c4d' + model_name = "testModel" + description = "test model" + content_type = 'application/json' + # success case with valid params + response = client.create_model( + restApiId=rest_api_id, + name=model_name, + description=description, + contentType=content_type + ) + response["name"].should.equal(model_name) + response["description"].should.equal(description) + + # with an invalid rest_api_id it should throw NotFoundException + with assert_raises(ClientError) as ex: + client.create_model( + restApiId=dummy_rest_api_id, + name=model_name, + description=description, + contentType=content_type + ) + ex.exception.response["Error"]["Message"].should.equal( + "Invalid Rest API Id specified" + ) + ex.exception.response["Error"]["Code"].should.equal( + "NotFoundException" + ) + + with assert_raises(ClientError) as ex: + client.create_model( + restApiId=rest_api_id, + name="", + description=description, + contentType=content_type + ) + + ex.exception.response["Error"]["Message"].should.equal( + "No Model Name specified" + ) + ex.exception.response["Error"]["Code"].should.equal( + "BadRequestException" + ) + + +@mock_apigateway +def test_get_api_models(): + client = boto3.client("apigateway", region_name="us-west-2") + response = client.create_rest_api( + name="my_api", + description="this is my api" + ) + rest_api_id = response["id"] + model_name = "testModel" + description = "test model" + content_type = 'application/json' + # when no models are present + result = client.get_models( + restApiId=rest_api_id + ) + result["items"].should.equal([]) + # add a model + client.create_model( + restApiId=rest_api_id, + name=model_name, + description=description, + contentType=content_type + ) + # get models after adding + result = client.get_models( + restApiId=rest_api_id + ) + result["items"][0]["name"] = model_name + result["items"][0]["description"] = description + + +@mock_apigateway +def test_get_model_by_name(): + client = boto3.client("apigateway", region_name="us-west-2") + response = client.create_rest_api( + name="my_api", + description="this is my api" + ) + rest_api_id = response["id"] + dummy_rest_api_id = 'a12b3c4d' + model_name = "testModel" + description = "test model" + content_type = 'application/json' + # add a model + client.create_model( + restApiId=rest_api_id, + name=model_name, + description=description, + contentType=content_type + ) + # get models after adding + result = client.get_model( + restApiId=rest_api_id, modelName=model_name + ) + result["name"] = model_name + result["description"] = description + + with assert_raises(ClientError) as ex: + client.get_model( + restApiId=dummy_rest_api_id, modelName=model_name + ) + ex.exception.response["Error"]["Message"].should.equal( + "Invalid Rest API Id specified" + ) + ex.exception.response["Error"]["Code"].should.equal( + "NotFoundException" + ) + + +@mock_apigateway +def test_get_model_with_invalid_name(): + client = boto3.client("apigateway", region_name="us-west-2") + response = client.create_rest_api( + name="my_api", + description="this is my api" + ) + rest_api_id = response["id"] + # test with an invalid model name + with assert_raises(ClientError) as ex: + client.get_model( + restApiId=rest_api_id, modelName="fake" + ) + ex.exception.response["Error"]["Message"].should.equal( + "Invalid Model Name specified" + ) + ex.exception.response["Error"]["Code"].should.equal( + "NotFoundException" + ) + + @mock_apigateway def test_http_proxying_integration(): responses.add(