From e75dcf47b767e7be3ddce51292f930e8ed812710 Mon Sep 17 00:00:00 2001 From: Abdul <45050211+amoeedm@users.noreply.github.com> Date: Mon, 14 Feb 2022 20:11:39 +0100 Subject: [PATCH] Add Athena: create_data_catalog, list_data_catalogs, get_data_catalog (#4854) --- moto/athena/models.py | 40 ++++++++++++++++++++++ moto/athena/responses.py | 30 ++++++++++++++++ tests/test_athena/test_athena.py | 59 +++++++++++++++++++++++++++++++- 3 files changed, 128 insertions(+), 1 deletion(-) diff --git a/moto/athena/models.py b/moto/athena/models.py index e083683a7..fac88d6b0 100644 --- a/moto/athena/models.py +++ b/moto/athena/models.py @@ -46,6 +46,19 @@ class WorkGroup(TaggableResourceMixin, BaseModel): self.configuration = configuration +class DataCatalog(TaggableResourceMixin, BaseModel): + def __init__( + self, athena_backend, name, catalog_type, description, parameters, tags + ): + self.region_name = athena_backend.region_name + super().__init__(self.region_name, "datacatalog/{}".format(name), tags) + self.athena_backend = athena_backend + self.name = name + self.type = catalog_type + self.description = description + self.parameters = parameters + + class Execution(BaseModel): def __init__(self, query, context, config, workgroup): self.id = str(uuid4()) @@ -76,6 +89,7 @@ class AthenaBackend(BaseBackend): self.work_groups = {} self.executions = {} self.named_queries = {} + self.data_catalogs = {} @staticmethod def default_vpc_endpoint_service(service_region, zones): @@ -142,5 +156,31 @@ class AthenaBackend(BaseBackend): def get_named_query(self, query_id): return self.named_queries[query_id] if query_id in self.named_queries else None + def list_data_catalogs(self): + return [ + {"CatalogName": dc.name, "Type": dc.type,} + for dc in self.data_catalogs.values() + ] + + def get_data_catalog(self, name): + if name not in self.data_catalogs: + return None + dc = self.data_catalogs[name] + return { + "Name": dc.name, + "Description": dc.description, + "Type": dc.type, + "Parameters": dc.parameters, + } + + def create_data_catalog(self, name, catalog_type, description, parameters, tags): + if name in self.data_catalogs: + return None + data_catalog = DataCatalog( + self, name, catalog_type, description, parameters, tags + ) + self.data_catalogs[name] = data_catalog + return data_catalog + athena_backends = BackendDict(AthenaBackend, "athena") diff --git a/moto/athena/responses.py b/moto/athena/responses.py index d4a85bb4a..218e55471 100644 --- a/moto/athena/responses.py +++ b/moto/athena/responses.py @@ -114,3 +114,33 @@ class AthenaResponse(BaseResponse): } } ) + + def list_data_catalogs(self): + return json.dumps( + {"DataCatalogsSummary": self.athena_backend.list_data_catalogs()} + ) + + def get_data_catalog(self): + name = self._get_param("Name") + return json.dumps({"DataCatalog": self.athena_backend.get_data_catalog(name)}) + + def create_data_catalog(self): + name = self._get_param("Name") + catalog_type = self._get_param("Type") + description = self._get_param("Description") + parameters = self._get_param("Parameters") + tags = self._get_param("Tags") + data_catalog = self.athena_backend.create_data_catalog( + name, catalog_type, description, parameters, tags + ) + if not data_catalog: + return self.error("DataCatalog already exists", 400) + return json.dumps( + { + "CreateDataCatalogResponse": { + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + } + } + } + ) diff --git a/tests/test_athena/test_athena.py b/tests/test_athena/test_athena.py index dfa63375c..fd496b3ba 100644 --- a/tests/test_athena/test_athena.py +++ b/tests/test_athena/test_athena.py @@ -189,7 +189,6 @@ def test_get_named_query(): database = "target_db" query_string = "SELECT * FROM tbl1" description = "description of this query" - # craete named query res_create = client.create_named_query( Name=query_name, @@ -216,3 +215,61 @@ def create_basic_workgroup(client, name): "ResultConfiguration": {"OutputLocation": "s3://bucket-name/prefix/",} }, ) + + +@mock_athena +def test_create_data_catalog(): + client = boto3.client("athena", region_name="us-east-1") + response = client.create_data_catalog( + Name="athena_datacatalog", + Type="GLUE", + Description="Test data catalog", + Parameters={"catalog-id": "AWS Test account ID"}, + Tags=[], + ) + + try: + # The second time should throw an error + response = client.create_data_catalog( + Name="athena_datacatalog", + Type="GLUE", + Description="Test data catalog", + Parameters={"catalog-id": "AWS Test account ID"}, + Tags=[], + ) + except ClientError as err: + err.response["Error"]["Code"].should.equal("InvalidRequestException") + err.response["Error"]["Message"].should.equal("DataCatalog already exists") + else: + raise RuntimeError("Should have raised ResourceNotFoundException") + + # Then test the work group appears in the work group list + response = client.list_data_catalogs() + + response["DataCatalogsSummary"].should.have.length_of(1) + data_catalog = response["DataCatalogsSummary"][0] + data_catalog["CatalogName"].should.equal("athena_datacatalog") + data_catalog["Type"].should.equal("GLUE") + + +@mock_athena +def test_create_and_get_data_catalog(): + client = boto3.client("athena", region_name="us-east-1") + + client.create_data_catalog( + Name="athena_datacatalog", + Type="GLUE", + Description="Test data catalog", + Parameters={"catalog-id": "AWS Test account ID"}, + Tags=[], + ) + + data_catalog = client.get_data_catalog(Name="athena_datacatalog") + data_catalog["DataCatalog"].should.equal( + { + "Name": "athena_datacatalog", + "Description": "Test data catalog", + "Type": "GLUE", + "Parameters": {"catalog-id": "AWS Test account ID"}, + } + )