2019-10-22 21:37:29 +00:00
|
|
|
import time
|
|
|
|
|
2020-05-16 14:00:06 +00:00
|
|
|
from moto.core import BaseBackend, BaseModel, ACCOUNT_ID
|
2021-12-24 21:02:45 +00:00
|
|
|
from moto.core.utils import BackendDict
|
2019-12-26 16:12:22 +00:00
|
|
|
|
2020-05-16 14:00:06 +00:00
|
|
|
from uuid import uuid4
|
2019-10-22 21:37:29 +00:00
|
|
|
|
|
|
|
|
|
|
|
class TaggableResourceMixin(object):
|
|
|
|
# This mixing was copied from Redshift when initially implementing
|
|
|
|
# Athena. TBD if it's worth the overhead.
|
|
|
|
|
|
|
|
def __init__(self, region_name, resource_name, tags):
|
|
|
|
self.region = region_name
|
|
|
|
self.resource_name = resource_name
|
|
|
|
self.tags = tags or []
|
|
|
|
|
|
|
|
@property
|
|
|
|
def arn(self):
|
|
|
|
return "arn:aws:athena:{region}:{account_id}:{resource_name}".format(
|
2019-10-31 15:44:26 +00:00
|
|
|
region=self.region, account_id=ACCOUNT_ID, resource_name=self.resource_name
|
|
|
|
)
|
2019-10-22 21:37:29 +00:00
|
|
|
|
|
|
|
def create_tags(self, tags):
|
2019-10-31 15:44:26 +00:00
|
|
|
new_keys = [tag_set["Key"] for tag_set in tags]
|
|
|
|
self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys]
|
2019-10-22 21:37:29 +00:00
|
|
|
self.tags.extend(tags)
|
|
|
|
return self.tags
|
|
|
|
|
|
|
|
def delete_tags(self, tag_keys):
|
2019-10-31 15:44:26 +00:00
|
|
|
self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys]
|
2019-10-22 21:37:29 +00:00
|
|
|
return self.tags
|
|
|
|
|
|
|
|
|
|
|
|
class WorkGroup(TaggableResourceMixin, BaseModel):
|
|
|
|
|
2019-10-31 15:44:26 +00:00
|
|
|
resource_type = "workgroup"
|
|
|
|
state = "ENABLED"
|
2019-10-22 21:37:29 +00:00
|
|
|
|
|
|
|
def __init__(self, athena_backend, name, configuration, description, tags):
|
|
|
|
self.region_name = athena_backend.region_name
|
2021-12-01 23:06:58 +00:00
|
|
|
super().__init__(self.region_name, "workgroup/{}".format(name), tags)
|
2019-10-22 21:37:29 +00:00
|
|
|
self.athena_backend = athena_backend
|
|
|
|
self.name = name
|
|
|
|
self.description = description
|
|
|
|
self.configuration = configuration
|
|
|
|
|
|
|
|
|
2022-02-14 19:11:39 +00:00
|
|
|
class DataCatalog(TaggableResourceMixin, BaseModel):
|
|
|
|
def __init__(
|
|
|
|
self, athena_backend, name, catalog_type, description, parameters, tags
|
|
|
|
):
|
|
|
|
self.region_name = athena_backend.region_name
|
|
|
|
super().__init__(self.region_name, "datacatalog/{}".format(name), tags)
|
|
|
|
self.athena_backend = athena_backend
|
|
|
|
self.name = name
|
|
|
|
self.type = catalog_type
|
|
|
|
self.description = description
|
|
|
|
self.parameters = parameters
|
|
|
|
|
|
|
|
|
2020-05-16 14:00:06 +00:00
|
|
|
class Execution(BaseModel):
|
|
|
|
def __init__(self, query, context, config, workgroup):
|
|
|
|
self.id = str(uuid4())
|
|
|
|
self.query = query
|
|
|
|
self.context = context
|
|
|
|
self.config = config
|
|
|
|
self.workgroup = workgroup
|
|
|
|
self.start_time = time.time()
|
|
|
|
self.status = "QUEUED"
|
|
|
|
|
|
|
|
|
2020-06-11 16:27:29 +00:00
|
|
|
class NamedQuery(BaseModel):
|
|
|
|
def __init__(self, name, description, database, query_string, workgroup):
|
|
|
|
self.id = str(uuid4())
|
|
|
|
self.name = name
|
|
|
|
self.description = description
|
|
|
|
self.database = database
|
|
|
|
self.query_string = query_string
|
|
|
|
self.workgroup = workgroup
|
|
|
|
|
|
|
|
|
2019-10-22 21:37:29 +00:00
|
|
|
class AthenaBackend(BaseBackend):
|
|
|
|
region_name = None
|
|
|
|
|
|
|
|
def __init__(self, region_name=None):
|
|
|
|
if region_name is not None:
|
|
|
|
self.region_name = region_name
|
|
|
|
self.work_groups = {}
|
2020-05-16 14:00:06 +00:00
|
|
|
self.executions = {}
|
2020-06-11 16:27:29 +00:00
|
|
|
self.named_queries = {}
|
2022-02-14 19:11:39 +00:00
|
|
|
self.data_catalogs = {}
|
2019-10-22 21:37:29 +00:00
|
|
|
|
2021-09-24 16:01:09 +00:00
|
|
|
@staticmethod
|
|
|
|
def default_vpc_endpoint_service(service_region, zones):
|
|
|
|
"""Default VPC endpoint service."""
|
|
|
|
return BaseBackend.default_vpc_endpoint_service_factory(
|
|
|
|
service_region, zones, "athena"
|
|
|
|
)
|
|
|
|
|
2019-10-22 21:37:29 +00:00
|
|
|
def create_work_group(self, name, configuration, description, tags):
|
|
|
|
if name in self.work_groups:
|
|
|
|
return None
|
|
|
|
work_group = WorkGroup(self, name, configuration, description, tags)
|
|
|
|
self.work_groups[name] = work_group
|
|
|
|
return work_group
|
|
|
|
|
|
|
|
def list_work_groups(self):
|
2019-10-31 15:44:26 +00:00
|
|
|
return [
|
|
|
|
{
|
|
|
|
"Name": wg.name,
|
|
|
|
"State": wg.state,
|
|
|
|
"Description": wg.description,
|
|
|
|
"CreationTime": time.time(),
|
|
|
|
}
|
|
|
|
for wg in self.work_groups.values()
|
|
|
|
]
|
2019-10-22 21:37:29 +00:00
|
|
|
|
2020-05-16 14:00:06 +00:00
|
|
|
def get_work_group(self, name):
|
|
|
|
if name not in self.work_groups:
|
|
|
|
return None
|
|
|
|
wg = self.work_groups[name]
|
|
|
|
return {
|
|
|
|
"Name": wg.name,
|
|
|
|
"State": wg.state,
|
|
|
|
"Configuration": wg.configuration,
|
|
|
|
"Description": wg.description,
|
2020-05-16 14:03:26 +00:00
|
|
|
"CreationTime": time.time(),
|
2020-05-16 14:00:06 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
def start_query_execution(self, query, context, config, workgroup):
|
2020-05-16 14:03:26 +00:00
|
|
|
execution = Execution(
|
|
|
|
query=query, context=context, config=config, workgroup=workgroup
|
|
|
|
)
|
2020-05-16 14:00:06 +00:00
|
|
|
self.executions[execution.id] = execution
|
|
|
|
return execution.id
|
|
|
|
|
|
|
|
def get_execution(self, exec_id):
|
|
|
|
return self.executions[exec_id]
|
|
|
|
|
|
|
|
def stop_query_execution(self, exec_id):
|
|
|
|
execution = self.executions[exec_id]
|
|
|
|
execution.status = "CANCELLED"
|
|
|
|
|
2020-06-11 16:27:29 +00:00
|
|
|
def create_named_query(self, name, description, database, query_string, workgroup):
|
|
|
|
nq = NamedQuery(
|
|
|
|
name=name,
|
|
|
|
description=description,
|
|
|
|
database=database,
|
|
|
|
query_string=query_string,
|
|
|
|
workgroup=workgroup,
|
|
|
|
)
|
|
|
|
self.named_queries[nq.id] = nq
|
|
|
|
return nq.id
|
|
|
|
|
|
|
|
def get_named_query(self, query_id):
|
|
|
|
return self.named_queries[query_id] if query_id in self.named_queries else None
|
|
|
|
|
2022-02-14 19:11:39 +00:00
|
|
|
def list_data_catalogs(self):
|
|
|
|
return [
|
|
|
|
{"CatalogName": dc.name, "Type": dc.type,}
|
|
|
|
for dc in self.data_catalogs.values()
|
|
|
|
]
|
|
|
|
|
|
|
|
def get_data_catalog(self, name):
|
|
|
|
if name not in self.data_catalogs:
|
|
|
|
return None
|
|
|
|
dc = self.data_catalogs[name]
|
|
|
|
return {
|
|
|
|
"Name": dc.name,
|
|
|
|
"Description": dc.description,
|
|
|
|
"Type": dc.type,
|
|
|
|
"Parameters": dc.parameters,
|
|
|
|
}
|
|
|
|
|
|
|
|
def create_data_catalog(self, name, catalog_type, description, parameters, tags):
|
|
|
|
if name in self.data_catalogs:
|
|
|
|
return None
|
|
|
|
data_catalog = DataCatalog(
|
|
|
|
self, name, catalog_type, description, parameters, tags
|
|
|
|
)
|
|
|
|
self.data_catalogs[name] = data_catalog
|
|
|
|
return data_catalog
|
|
|
|
|
2019-10-22 21:37:29 +00:00
|
|
|
|
2021-12-24 21:02:45 +00:00
|
|
|
athena_backends = BackendDict(AthenaBackend, "athena")
|