Add Athena: create_data_catalog, list_data_catalogs, get_data_catalog (#4854)
This commit is contained in:
parent
13b9c0322c
commit
e75dcf47b7
@ -46,6 +46,19 @@ class WorkGroup(TaggableResourceMixin, BaseModel):
|
||||
self.configuration = configuration
|
||||
|
||||
|
||||
class DataCatalog(TaggableResourceMixin, BaseModel):
|
||||
def __init__(
|
||||
self, athena_backend, name, catalog_type, description, parameters, tags
|
||||
):
|
||||
self.region_name = athena_backend.region_name
|
||||
super().__init__(self.region_name, "datacatalog/{}".format(name), tags)
|
||||
self.athena_backend = athena_backend
|
||||
self.name = name
|
||||
self.type = catalog_type
|
||||
self.description = description
|
||||
self.parameters = parameters
|
||||
|
||||
|
||||
class Execution(BaseModel):
|
||||
def __init__(self, query, context, config, workgroup):
|
||||
self.id = str(uuid4())
|
||||
@ -76,6 +89,7 @@ class AthenaBackend(BaseBackend):
|
||||
self.work_groups = {}
|
||||
self.executions = {}
|
||||
self.named_queries = {}
|
||||
self.data_catalogs = {}
|
||||
|
||||
@staticmethod
|
||||
def default_vpc_endpoint_service(service_region, zones):
|
||||
@ -142,5 +156,31 @@ class AthenaBackend(BaseBackend):
|
||||
def get_named_query(self, query_id):
|
||||
return self.named_queries[query_id] if query_id in self.named_queries else None
|
||||
|
||||
def list_data_catalogs(self):
|
||||
return [
|
||||
{"CatalogName": dc.name, "Type": dc.type,}
|
||||
for dc in self.data_catalogs.values()
|
||||
]
|
||||
|
||||
def get_data_catalog(self, name):
|
||||
if name not in self.data_catalogs:
|
||||
return None
|
||||
dc = self.data_catalogs[name]
|
||||
return {
|
||||
"Name": dc.name,
|
||||
"Description": dc.description,
|
||||
"Type": dc.type,
|
||||
"Parameters": dc.parameters,
|
||||
}
|
||||
|
||||
def create_data_catalog(self, name, catalog_type, description, parameters, tags):
|
||||
if name in self.data_catalogs:
|
||||
return None
|
||||
data_catalog = DataCatalog(
|
||||
self, name, catalog_type, description, parameters, tags
|
||||
)
|
||||
self.data_catalogs[name] = data_catalog
|
||||
return data_catalog
|
||||
|
||||
|
||||
athena_backends = BackendDict(AthenaBackend, "athena")
|
||||
|
@ -114,3 +114,33 @@ class AthenaResponse(BaseResponse):
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
def list_data_catalogs(self):
|
||||
return json.dumps(
|
||||
{"DataCatalogsSummary": self.athena_backend.list_data_catalogs()}
|
||||
)
|
||||
|
||||
def get_data_catalog(self):
|
||||
name = self._get_param("Name")
|
||||
return json.dumps({"DataCatalog": self.athena_backend.get_data_catalog(name)})
|
||||
|
||||
def create_data_catalog(self):
|
||||
name = self._get_param("Name")
|
||||
catalog_type = self._get_param("Type")
|
||||
description = self._get_param("Description")
|
||||
parameters = self._get_param("Parameters")
|
||||
tags = self._get_param("Tags")
|
||||
data_catalog = self.athena_backend.create_data_catalog(
|
||||
name, catalog_type, description, parameters, tags
|
||||
)
|
||||
if not data_catalog:
|
||||
return self.error("DataCatalog already exists", 400)
|
||||
return json.dumps(
|
||||
{
|
||||
"CreateDataCatalogResponse": {
|
||||
"ResponseMetadata": {
|
||||
"RequestId": "384ac68d-3775-11df-8963-01868b7c937a"
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
@ -189,7 +189,6 @@ def test_get_named_query():
|
||||
database = "target_db"
|
||||
query_string = "SELECT * FROM tbl1"
|
||||
description = "description of this query"
|
||||
|
||||
# craete named query
|
||||
res_create = client.create_named_query(
|
||||
Name=query_name,
|
||||
@ -216,3 +215,61 @@ def create_basic_workgroup(client, name):
|
||||
"ResultConfiguration": {"OutputLocation": "s3://bucket-name/prefix/",}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@mock_athena
|
||||
def test_create_data_catalog():
|
||||
client = boto3.client("athena", region_name="us-east-1")
|
||||
response = client.create_data_catalog(
|
||||
Name="athena_datacatalog",
|
||||
Type="GLUE",
|
||||
Description="Test data catalog",
|
||||
Parameters={"catalog-id": "AWS Test account ID"},
|
||||
Tags=[],
|
||||
)
|
||||
|
||||
try:
|
||||
# The second time should throw an error
|
||||
response = client.create_data_catalog(
|
||||
Name="athena_datacatalog",
|
||||
Type="GLUE",
|
||||
Description="Test data catalog",
|
||||
Parameters={"catalog-id": "AWS Test account ID"},
|
||||
Tags=[],
|
||||
)
|
||||
except ClientError as err:
|
||||
err.response["Error"]["Code"].should.equal("InvalidRequestException")
|
||||
err.response["Error"]["Message"].should.equal("DataCatalog already exists")
|
||||
else:
|
||||
raise RuntimeError("Should have raised ResourceNotFoundException")
|
||||
|
||||
# Then test the work group appears in the work group list
|
||||
response = client.list_data_catalogs()
|
||||
|
||||
response["DataCatalogsSummary"].should.have.length_of(1)
|
||||
data_catalog = response["DataCatalogsSummary"][0]
|
||||
data_catalog["CatalogName"].should.equal("athena_datacatalog")
|
||||
data_catalog["Type"].should.equal("GLUE")
|
||||
|
||||
|
||||
@mock_athena
|
||||
def test_create_and_get_data_catalog():
|
||||
client = boto3.client("athena", region_name="us-east-1")
|
||||
|
||||
client.create_data_catalog(
|
||||
Name="athena_datacatalog",
|
||||
Type="GLUE",
|
||||
Description="Test data catalog",
|
||||
Parameters={"catalog-id": "AWS Test account ID"},
|
||||
Tags=[],
|
||||
)
|
||||
|
||||
data_catalog = client.get_data_catalog(Name="athena_datacatalog")
|
||||
data_catalog["DataCatalog"].should.equal(
|
||||
{
|
||||
"Name": "athena_datacatalog",
|
||||
"Description": "Test data catalog",
|
||||
"Type": "GLUE",
|
||||
"Parameters": {"catalog-id": "AWS Test account ID"},
|
||||
}
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user