Add Athena: create_data_catalog, list_data_catalogs, get_data_catalog (#4854)

This commit is contained in:
Abdul 2022-02-14 20:11:39 +01:00 committed by GitHub
parent 13b9c0322c
commit e75dcf47b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 128 additions and 1 deletions

View File

@ -46,6 +46,19 @@ class WorkGroup(TaggableResourceMixin, BaseModel):
self.configuration = configuration
class DataCatalog(TaggableResourceMixin, BaseModel):
def __init__(
self, athena_backend, name, catalog_type, description, parameters, tags
):
self.region_name = athena_backend.region_name
super().__init__(self.region_name, "datacatalog/{}".format(name), tags)
self.athena_backend = athena_backend
self.name = name
self.type = catalog_type
self.description = description
self.parameters = parameters
class Execution(BaseModel):
def __init__(self, query, context, config, workgroup):
self.id = str(uuid4())
@ -76,6 +89,7 @@ class AthenaBackend(BaseBackend):
self.work_groups = {}
self.executions = {}
self.named_queries = {}
self.data_catalogs = {}
@staticmethod
def default_vpc_endpoint_service(service_region, zones):
@ -142,5 +156,31 @@ class AthenaBackend(BaseBackend):
def get_named_query(self, query_id):
return self.named_queries[query_id] if query_id in self.named_queries else None
def list_data_catalogs(self):
return [
{"CatalogName": dc.name, "Type": dc.type,}
for dc in self.data_catalogs.values()
]
def get_data_catalog(self, name):
if name not in self.data_catalogs:
return None
dc = self.data_catalogs[name]
return {
"Name": dc.name,
"Description": dc.description,
"Type": dc.type,
"Parameters": dc.parameters,
}
def create_data_catalog(self, name, catalog_type, description, parameters, tags):
if name in self.data_catalogs:
return None
data_catalog = DataCatalog(
self, name, catalog_type, description, parameters, tags
)
self.data_catalogs[name] = data_catalog
return data_catalog
athena_backends = BackendDict(AthenaBackend, "athena")

View File

@ -114,3 +114,33 @@ class AthenaResponse(BaseResponse):
}
}
)
def list_data_catalogs(self):
return json.dumps(
{"DataCatalogsSummary": self.athena_backend.list_data_catalogs()}
)
def get_data_catalog(self):
name = self._get_param("Name")
return json.dumps({"DataCatalog": self.athena_backend.get_data_catalog(name)})
def create_data_catalog(self):
name = self._get_param("Name")
catalog_type = self._get_param("Type")
description = self._get_param("Description")
parameters = self._get_param("Parameters")
tags = self._get_param("Tags")
data_catalog = self.athena_backend.create_data_catalog(
name, catalog_type, description, parameters, tags
)
if not data_catalog:
return self.error("DataCatalog already exists", 400)
return json.dumps(
{
"CreateDataCatalogResponse": {
"ResponseMetadata": {
"RequestId": "384ac68d-3775-11df-8963-01868b7c937a"
}
}
}
)

View File

@ -189,7 +189,6 @@ def test_get_named_query():
database = "target_db"
query_string = "SELECT * FROM tbl1"
description = "description of this query"
# craete named query
res_create = client.create_named_query(
Name=query_name,
@ -216,3 +215,61 @@ def create_basic_workgroup(client, name):
"ResultConfiguration": {"OutputLocation": "s3://bucket-name/prefix/",}
},
)
@mock_athena
def test_create_data_catalog():
client = boto3.client("athena", region_name="us-east-1")
response = client.create_data_catalog(
Name="athena_datacatalog",
Type="GLUE",
Description="Test data catalog",
Parameters={"catalog-id": "AWS Test account ID"},
Tags=[],
)
try:
# The second time should throw an error
response = client.create_data_catalog(
Name="athena_datacatalog",
Type="GLUE",
Description="Test data catalog",
Parameters={"catalog-id": "AWS Test account ID"},
Tags=[],
)
except ClientError as err:
err.response["Error"]["Code"].should.equal("InvalidRequestException")
err.response["Error"]["Message"].should.equal("DataCatalog already exists")
else:
raise RuntimeError("Should have raised ResourceNotFoundException")
# Then test the work group appears in the work group list
response = client.list_data_catalogs()
response["DataCatalogsSummary"].should.have.length_of(1)
data_catalog = response["DataCatalogsSummary"][0]
data_catalog["CatalogName"].should.equal("athena_datacatalog")
data_catalog["Type"].should.equal("GLUE")
@mock_athena
def test_create_and_get_data_catalog():
client = boto3.client("athena", region_name="us-east-1")
client.create_data_catalog(
Name="athena_datacatalog",
Type="GLUE",
Description="Test data catalog",
Parameters={"catalog-id": "AWS Test account ID"},
Tags=[],
)
data_catalog = client.get_data_catalog(Name="athena_datacatalog")
data_catalog["DataCatalog"].should.equal(
{
"Name": "athena_datacatalog",
"Description": "Test data catalog",
"Type": "GLUE",
"Parameters": {"catalog-id": "AWS Test account ID"},
}
)