Add Athena: create_data_catalog, list_data_catalogs, get_data_catalog (#4854)
This commit is contained in:
parent
13b9c0322c
commit
e75dcf47b7
@ -46,6 +46,19 @@ class WorkGroup(TaggableResourceMixin, BaseModel):
|
|||||||
self.configuration = configuration
|
self.configuration = configuration
|
||||||
|
|
||||||
|
|
||||||
|
class DataCatalog(TaggableResourceMixin, BaseModel):
|
||||||
|
def __init__(
|
||||||
|
self, athena_backend, name, catalog_type, description, parameters, tags
|
||||||
|
):
|
||||||
|
self.region_name = athena_backend.region_name
|
||||||
|
super().__init__(self.region_name, "datacatalog/{}".format(name), tags)
|
||||||
|
self.athena_backend = athena_backend
|
||||||
|
self.name = name
|
||||||
|
self.type = catalog_type
|
||||||
|
self.description = description
|
||||||
|
self.parameters = parameters
|
||||||
|
|
||||||
|
|
||||||
class Execution(BaseModel):
|
class Execution(BaseModel):
|
||||||
def __init__(self, query, context, config, workgroup):
|
def __init__(self, query, context, config, workgroup):
|
||||||
self.id = str(uuid4())
|
self.id = str(uuid4())
|
||||||
@ -76,6 +89,7 @@ class AthenaBackend(BaseBackend):
|
|||||||
self.work_groups = {}
|
self.work_groups = {}
|
||||||
self.executions = {}
|
self.executions = {}
|
||||||
self.named_queries = {}
|
self.named_queries = {}
|
||||||
|
self.data_catalogs = {}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def default_vpc_endpoint_service(service_region, zones):
|
def default_vpc_endpoint_service(service_region, zones):
|
||||||
@ -142,5 +156,31 @@ class AthenaBackend(BaseBackend):
|
|||||||
def get_named_query(self, query_id):
|
def get_named_query(self, query_id):
|
||||||
return self.named_queries[query_id] if query_id in self.named_queries else None
|
return self.named_queries[query_id] if query_id in self.named_queries else None
|
||||||
|
|
||||||
|
def list_data_catalogs(self):
|
||||||
|
return [
|
||||||
|
{"CatalogName": dc.name, "Type": dc.type,}
|
||||||
|
for dc in self.data_catalogs.values()
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_data_catalog(self, name):
|
||||||
|
if name not in self.data_catalogs:
|
||||||
|
return None
|
||||||
|
dc = self.data_catalogs[name]
|
||||||
|
return {
|
||||||
|
"Name": dc.name,
|
||||||
|
"Description": dc.description,
|
||||||
|
"Type": dc.type,
|
||||||
|
"Parameters": dc.parameters,
|
||||||
|
}
|
||||||
|
|
||||||
|
def create_data_catalog(self, name, catalog_type, description, parameters, tags):
|
||||||
|
if name in self.data_catalogs:
|
||||||
|
return None
|
||||||
|
data_catalog = DataCatalog(
|
||||||
|
self, name, catalog_type, description, parameters, tags
|
||||||
|
)
|
||||||
|
self.data_catalogs[name] = data_catalog
|
||||||
|
return data_catalog
|
||||||
|
|
||||||
|
|
||||||
athena_backends = BackendDict(AthenaBackend, "athena")
|
athena_backends = BackendDict(AthenaBackend, "athena")
|
||||||
|
@ -114,3 +114,33 @@ class AthenaResponse(BaseResponse):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def list_data_catalogs(self):
|
||||||
|
return json.dumps(
|
||||||
|
{"DataCatalogsSummary": self.athena_backend.list_data_catalogs()}
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_data_catalog(self):
|
||||||
|
name = self._get_param("Name")
|
||||||
|
return json.dumps({"DataCatalog": self.athena_backend.get_data_catalog(name)})
|
||||||
|
|
||||||
|
def create_data_catalog(self):
|
||||||
|
name = self._get_param("Name")
|
||||||
|
catalog_type = self._get_param("Type")
|
||||||
|
description = self._get_param("Description")
|
||||||
|
parameters = self._get_param("Parameters")
|
||||||
|
tags = self._get_param("Tags")
|
||||||
|
data_catalog = self.athena_backend.create_data_catalog(
|
||||||
|
name, catalog_type, description, parameters, tags
|
||||||
|
)
|
||||||
|
if not data_catalog:
|
||||||
|
return self.error("DataCatalog already exists", 400)
|
||||||
|
return json.dumps(
|
||||||
|
{
|
||||||
|
"CreateDataCatalogResponse": {
|
||||||
|
"ResponseMetadata": {
|
||||||
|
"RequestId": "384ac68d-3775-11df-8963-01868b7c937a"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@ -189,7 +189,6 @@ def test_get_named_query():
|
|||||||
database = "target_db"
|
database = "target_db"
|
||||||
query_string = "SELECT * FROM tbl1"
|
query_string = "SELECT * FROM tbl1"
|
||||||
description = "description of this query"
|
description = "description of this query"
|
||||||
|
|
||||||
# craete named query
|
# craete named query
|
||||||
res_create = client.create_named_query(
|
res_create = client.create_named_query(
|
||||||
Name=query_name,
|
Name=query_name,
|
||||||
@ -216,3 +215,61 @@ def create_basic_workgroup(client, name):
|
|||||||
"ResultConfiguration": {"OutputLocation": "s3://bucket-name/prefix/",}
|
"ResultConfiguration": {"OutputLocation": "s3://bucket-name/prefix/",}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mock_athena
|
||||||
|
def test_create_data_catalog():
|
||||||
|
client = boto3.client("athena", region_name="us-east-1")
|
||||||
|
response = client.create_data_catalog(
|
||||||
|
Name="athena_datacatalog",
|
||||||
|
Type="GLUE",
|
||||||
|
Description="Test data catalog",
|
||||||
|
Parameters={"catalog-id": "AWS Test account ID"},
|
||||||
|
Tags=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# The second time should throw an error
|
||||||
|
response = client.create_data_catalog(
|
||||||
|
Name="athena_datacatalog",
|
||||||
|
Type="GLUE",
|
||||||
|
Description="Test data catalog",
|
||||||
|
Parameters={"catalog-id": "AWS Test account ID"},
|
||||||
|
Tags=[],
|
||||||
|
)
|
||||||
|
except ClientError as err:
|
||||||
|
err.response["Error"]["Code"].should.equal("InvalidRequestException")
|
||||||
|
err.response["Error"]["Message"].should.equal("DataCatalog already exists")
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Should have raised ResourceNotFoundException")
|
||||||
|
|
||||||
|
# Then test the work group appears in the work group list
|
||||||
|
response = client.list_data_catalogs()
|
||||||
|
|
||||||
|
response["DataCatalogsSummary"].should.have.length_of(1)
|
||||||
|
data_catalog = response["DataCatalogsSummary"][0]
|
||||||
|
data_catalog["CatalogName"].should.equal("athena_datacatalog")
|
||||||
|
data_catalog["Type"].should.equal("GLUE")
|
||||||
|
|
||||||
|
|
||||||
|
@mock_athena
|
||||||
|
def test_create_and_get_data_catalog():
|
||||||
|
client = boto3.client("athena", region_name="us-east-1")
|
||||||
|
|
||||||
|
client.create_data_catalog(
|
||||||
|
Name="athena_datacatalog",
|
||||||
|
Type="GLUE",
|
||||||
|
Description="Test data catalog",
|
||||||
|
Parameters={"catalog-id": "AWS Test account ID"},
|
||||||
|
Tags=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
data_catalog = client.get_data_catalog(Name="athena_datacatalog")
|
||||||
|
data_catalog["DataCatalog"].should.equal(
|
||||||
|
{
|
||||||
|
"Name": "athena_datacatalog",
|
||||||
|
"Description": "Test data catalog",
|
||||||
|
"Type": "GLUE",
|
||||||
|
"Parameters": {"catalog-id": "AWS Test account ID"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user