Techdebt: MyPy Athena (#5578)
This commit is contained in:
parent
fd77cd4dc2
commit
6a07abbb30
@ -3,7 +3,7 @@ from moto.core.exceptions import JsonRESTError
|
|||||||
|
|
||||||
|
|
||||||
class AthenaClientError(JsonRESTError):
|
class AthenaClientError(JsonRESTError):
|
||||||
def __init__(self, code, message):
|
def __init__(self, code: str, message: str):
|
||||||
super().__init__(error_type="InvalidRequestException", message=message)
|
super().__init__(error_type="InvalidRequestException", message=message)
|
||||||
self.description = json.dumps(
|
self.description = json.dumps(
|
||||||
{
|
{
|
||||||
|
@ -3,25 +3,32 @@ import time
|
|||||||
from moto.core import BaseBackend, BaseModel
|
from moto.core import BaseBackend, BaseModel
|
||||||
from moto.core.utils import BackendDict
|
from moto.core.utils import BackendDict
|
||||||
from moto.moto_api._internal import mock_random
|
from moto.moto_api._internal import mock_random
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
class TaggableResourceMixin(object):
|
class TaggableResourceMixin(object):
|
||||||
# This mixing was copied from Redshift when initially implementing
|
# This mixing was copied from Redshift when initially implementing
|
||||||
# Athena. TBD if it's worth the overhead.
|
# Athena. TBD if it's worth the overhead.
|
||||||
|
|
||||||
def __init__(self, account_id, region_name, resource_name, tags):
|
def __init__(
|
||||||
|
self,
|
||||||
|
account_id: str,
|
||||||
|
region_name: str,
|
||||||
|
resource_name: str,
|
||||||
|
tags: List[Dict[str, str]],
|
||||||
|
):
|
||||||
self.region = region_name
|
self.region = region_name
|
||||||
self.resource_name = resource_name
|
self.resource_name = resource_name
|
||||||
self.tags = tags or []
|
self.tags = tags or []
|
||||||
self.arn = f"arn:aws:athena:{region_name}:{account_id}:{resource_name}"
|
self.arn = f"arn:aws:athena:{region_name}:{account_id}:{resource_name}"
|
||||||
|
|
||||||
def create_tags(self, tags):
|
def create_tags(self, tags: List[Dict[str, str]]) -> List[Dict[str, str]]:
|
||||||
new_keys = [tag_set["Key"] for tag_set in tags]
|
new_keys = [tag_set["Key"] for tag_set in tags]
|
||||||
self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys]
|
self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys]
|
||||||
self.tags.extend(tags)
|
self.tags.extend(tags)
|
||||||
return self.tags
|
return self.tags
|
||||||
|
|
||||||
def delete_tags(self, tag_keys):
|
def delete_tags(self, tag_keys: List[str]) -> List[Dict[str, str]]:
|
||||||
self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys]
|
self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys]
|
||||||
return self.tags
|
return self.tags
|
||||||
|
|
||||||
@ -31,7 +38,14 @@ class WorkGroup(TaggableResourceMixin, BaseModel):
|
|||||||
resource_type = "workgroup"
|
resource_type = "workgroup"
|
||||||
state = "ENABLED"
|
state = "ENABLED"
|
||||||
|
|
||||||
def __init__(self, athena_backend, name, configuration, description, tags):
|
def __init__(
|
||||||
|
self,
|
||||||
|
athena_backend: "AthenaBackend",
|
||||||
|
name: str,
|
||||||
|
configuration: str,
|
||||||
|
description: str,
|
||||||
|
tags: List[Dict[str, str]],
|
||||||
|
):
|
||||||
self.region_name = athena_backend.region_name
|
self.region_name = athena_backend.region_name
|
||||||
super().__init__(
|
super().__init__(
|
||||||
athena_backend.account_id,
|
athena_backend.account_id,
|
||||||
@ -47,7 +61,13 @@ class WorkGroup(TaggableResourceMixin, BaseModel):
|
|||||||
|
|
||||||
class DataCatalog(TaggableResourceMixin, BaseModel):
|
class DataCatalog(TaggableResourceMixin, BaseModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, athena_backend, name, catalog_type, description, parameters, tags
|
self,
|
||||||
|
athena_backend: "AthenaBackend",
|
||||||
|
name: str,
|
||||||
|
catalog_type: str,
|
||||||
|
description: str,
|
||||||
|
parameters: str,
|
||||||
|
tags: List[Dict[str, str]],
|
||||||
):
|
):
|
||||||
self.region_name = athena_backend.region_name
|
self.region_name = athena_backend.region_name
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@ -64,7 +84,7 @@ class DataCatalog(TaggableResourceMixin, BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class Execution(BaseModel):
|
class Execution(BaseModel):
|
||||||
def __init__(self, query, context, config, workgroup):
|
def __init__(self, query: str, context: str, config: str, workgroup: WorkGroup):
|
||||||
self.id = str(mock_random.uuid4())
|
self.id = str(mock_random.uuid4())
|
||||||
self.query = query
|
self.query = query
|
||||||
self.context = context
|
self.context = context
|
||||||
@ -75,7 +95,14 @@ class Execution(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class NamedQuery(BaseModel):
|
class NamedQuery(BaseModel):
|
||||||
def __init__(self, name, description, database, query_string, workgroup):
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
database: str,
|
||||||
|
query_string: str,
|
||||||
|
workgroup: WorkGroup,
|
||||||
|
):
|
||||||
self.id = str(mock_random.uuid4())
|
self.id = str(mock_random.uuid4())
|
||||||
self.name = name
|
self.name = name
|
||||||
self.description = description
|
self.description = description
|
||||||
@ -85,30 +112,36 @@ class NamedQuery(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class AthenaBackend(BaseBackend):
|
class AthenaBackend(BaseBackend):
|
||||||
region_name = None
|
def __init__(self, region_name: str, account_id: str):
|
||||||
|
|
||||||
def __init__(self, region_name, account_id):
|
|
||||||
super().__init__(region_name, account_id)
|
super().__init__(region_name, account_id)
|
||||||
self.work_groups = {}
|
self.work_groups: Dict[str, WorkGroup] = {}
|
||||||
self.executions = {}
|
self.executions: Dict[str, Execution] = {}
|
||||||
self.named_queries = {}
|
self.named_queries: Dict[str, NamedQuery] = {}
|
||||||
self.data_catalogs = {}
|
self.data_catalogs: Dict[str, DataCatalog] = {}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def default_vpc_endpoint_service(service_region, zones):
|
def default_vpc_endpoint_service(
|
||||||
|
service_region: str, zones: List[str]
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
"""Default VPC endpoint service."""
|
"""Default VPC endpoint service."""
|
||||||
return BaseBackend.default_vpc_endpoint_service_factory(
|
return BaseBackend.default_vpc_endpoint_service_factory(
|
||||||
service_region, zones, "athena"
|
service_region, zones, "athena"
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_work_group(self, name, configuration, description, tags):
|
def create_work_group(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
configuration: str,
|
||||||
|
description: str,
|
||||||
|
tags: List[Dict[str, str]],
|
||||||
|
) -> Optional[WorkGroup]:
|
||||||
if name in self.work_groups:
|
if name in self.work_groups:
|
||||||
return None
|
return None
|
||||||
work_group = WorkGroup(self, name, configuration, description, tags)
|
work_group = WorkGroup(self, name, configuration, description, tags)
|
||||||
self.work_groups[name] = work_group
|
self.work_groups[name] = work_group
|
||||||
return work_group
|
return work_group
|
||||||
|
|
||||||
def list_work_groups(self):
|
def list_work_groups(self) -> List[Dict[str, Any]]:
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"Name": wg.name,
|
"Name": wg.name,
|
||||||
@ -119,7 +152,7 @@ class AthenaBackend(BaseBackend):
|
|||||||
for wg in self.work_groups.values()
|
for wg in self.work_groups.values()
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_work_group(self, name):
|
def get_work_group(self, name: str) -> Optional[Dict[str, Any]]:
|
||||||
if name not in self.work_groups:
|
if name not in self.work_groups:
|
||||||
return None
|
return None
|
||||||
wg = self.work_groups[name]
|
wg = self.work_groups[name]
|
||||||
@ -131,21 +164,30 @@ class AthenaBackend(BaseBackend):
|
|||||||
"CreationTime": time.time(),
|
"CreationTime": time.time(),
|
||||||
}
|
}
|
||||||
|
|
||||||
def start_query_execution(self, query, context, config, workgroup):
|
def start_query_execution(
|
||||||
|
self, query: str, context: str, config: str, workgroup: WorkGroup
|
||||||
|
) -> str:
|
||||||
execution = Execution(
|
execution = Execution(
|
||||||
query=query, context=context, config=config, workgroup=workgroup
|
query=query, context=context, config=config, workgroup=workgroup
|
||||||
)
|
)
|
||||||
self.executions[execution.id] = execution
|
self.executions[execution.id] = execution
|
||||||
return execution.id
|
return execution.id
|
||||||
|
|
||||||
def get_execution(self, exec_id):
|
def get_execution(self, exec_id: str) -> Execution:
|
||||||
return self.executions[exec_id]
|
return self.executions[exec_id]
|
||||||
|
|
||||||
def stop_query_execution(self, exec_id):
|
def stop_query_execution(self, exec_id: str) -> None:
|
||||||
execution = self.executions[exec_id]
|
execution = self.executions[exec_id]
|
||||||
execution.status = "CANCELLED"
|
execution.status = "CANCELLED"
|
||||||
|
|
||||||
def create_named_query(self, name, description, database, query_string, workgroup):
|
def create_named_query(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
database: str,
|
||||||
|
query_string: str,
|
||||||
|
workgroup: WorkGroup,
|
||||||
|
) -> str:
|
||||||
nq = NamedQuery(
|
nq = NamedQuery(
|
||||||
name=name,
|
name=name,
|
||||||
description=description,
|
description=description,
|
||||||
@ -156,16 +198,16 @@ class AthenaBackend(BaseBackend):
|
|||||||
self.named_queries[nq.id] = nq
|
self.named_queries[nq.id] = nq
|
||||||
return nq.id
|
return nq.id
|
||||||
|
|
||||||
def get_named_query(self, query_id):
|
def get_named_query(self, query_id: str) -> Optional[NamedQuery]:
|
||||||
return self.named_queries[query_id] if query_id in self.named_queries else None
|
return self.named_queries[query_id] if query_id in self.named_queries else None
|
||||||
|
|
||||||
def list_data_catalogs(self):
|
def list_data_catalogs(self) -> List[Dict[str, str]]:
|
||||||
return [
|
return [
|
||||||
{"CatalogName": dc.name, "Type": dc.type}
|
{"CatalogName": dc.name, "Type": dc.type}
|
||||||
for dc in self.data_catalogs.values()
|
for dc in self.data_catalogs.values()
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_data_catalog(self, name):
|
def get_data_catalog(self, name: str) -> Optional[Dict[str, str]]:
|
||||||
if name not in self.data_catalogs:
|
if name not in self.data_catalogs:
|
||||||
return None
|
return None
|
||||||
dc = self.data_catalogs[name]
|
dc = self.data_catalogs[name]
|
||||||
@ -176,7 +218,14 @@ class AthenaBackend(BaseBackend):
|
|||||||
"Parameters": dc.parameters,
|
"Parameters": dc.parameters,
|
||||||
}
|
}
|
||||||
|
|
||||||
def create_data_catalog(self, name, catalog_type, description, parameters, tags):
|
def create_data_catalog(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
catalog_type: str,
|
||||||
|
description: str,
|
||||||
|
parameters: str,
|
||||||
|
tags: List[Dict[str, str]],
|
||||||
|
) -> Optional[DataCatalog]:
|
||||||
if name in self.data_catalogs:
|
if name in self.data_catalogs:
|
||||||
return None
|
return None
|
||||||
data_catalog = DataCatalog(
|
data_catalog = DataCatalog(
|
||||||
|
@ -1,18 +1,19 @@
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
from moto.core.responses import BaseResponse
|
from moto.core.responses import BaseResponse
|
||||||
from .models import athena_backends
|
from .models import athena_backends, AthenaBackend
|
||||||
|
from typing import Dict, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
class AthenaResponse(BaseResponse):
|
class AthenaResponse(BaseResponse):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__(service_name="athena")
|
super().__init__(service_name="athena")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def athena_backend(self):
|
def athena_backend(self) -> AthenaBackend:
|
||||||
return athena_backends[self.current_account][self.region]
|
return athena_backends[self.current_account][self.region]
|
||||||
|
|
||||||
def create_work_group(self):
|
def create_work_group(self) -> Union[Tuple[str, Dict[str, int]], str]:
|
||||||
name = self._get_param("Name")
|
name = self._get_param("Name")
|
||||||
description = self._get_param("Description")
|
description = self._get_param("Description")
|
||||||
configuration = self._get_param("Configuration")
|
configuration = self._get_param("Configuration")
|
||||||
@ -32,14 +33,14 @@ class AthenaResponse(BaseResponse):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def list_work_groups(self):
|
def list_work_groups(self) -> str:
|
||||||
return json.dumps({"WorkGroups": self.athena_backend.list_work_groups()})
|
return json.dumps({"WorkGroups": self.athena_backend.list_work_groups()})
|
||||||
|
|
||||||
def get_work_group(self):
|
def get_work_group(self) -> str:
|
||||||
name = self._get_param("WorkGroup")
|
name = self._get_param("WorkGroup")
|
||||||
return json.dumps({"WorkGroup": self.athena_backend.get_work_group(name)})
|
return json.dumps({"WorkGroup": self.athena_backend.get_work_group(name)})
|
||||||
|
|
||||||
def start_query_execution(self):
|
def start_query_execution(self) -> Union[Tuple[str, Dict[str, int]], str]:
|
||||||
query = self._get_param("QueryString")
|
query = self._get_param("QueryString")
|
||||||
context = self._get_param("QueryExecutionContext")
|
context = self._get_param("QueryExecutionContext")
|
||||||
config = self._get_param("ResultConfiguration")
|
config = self._get_param("ResultConfiguration")
|
||||||
@ -51,7 +52,7 @@ class AthenaResponse(BaseResponse):
|
|||||||
)
|
)
|
||||||
return json.dumps({"QueryExecutionId": q_exec_id})
|
return json.dumps({"QueryExecutionId": q_exec_id})
|
||||||
|
|
||||||
def get_query_execution(self):
|
def get_query_execution(self) -> str:
|
||||||
exec_id = self._get_param("QueryExecutionId")
|
exec_id = self._get_param("QueryExecutionId")
|
||||||
execution = self.athena_backend.get_execution(exec_id)
|
execution = self.athena_backend.get_execution(exec_id)
|
||||||
result = {
|
result = {
|
||||||
@ -78,18 +79,18 @@ class AthenaResponse(BaseResponse):
|
|||||||
}
|
}
|
||||||
return json.dumps(result)
|
return json.dumps(result)
|
||||||
|
|
||||||
def stop_query_execution(self):
|
def stop_query_execution(self) -> str:
|
||||||
exec_id = self._get_param("QueryExecutionId")
|
exec_id = self._get_param("QueryExecutionId")
|
||||||
self.athena_backend.stop_query_execution(exec_id)
|
self.athena_backend.stop_query_execution(exec_id)
|
||||||
return json.dumps({})
|
return json.dumps({})
|
||||||
|
|
||||||
def error(self, msg, status):
|
def error(self, msg: str, status: int) -> Tuple[str, Dict[str, int]]:
|
||||||
return (
|
return (
|
||||||
json.dumps({"__type": "InvalidRequestException", "Message": msg}),
|
json.dumps({"__type": "InvalidRequestException", "Message": msg}),
|
||||||
dict(status=status),
|
dict(status=status),
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_named_query(self):
|
def create_named_query(self) -> Union[Tuple[str, Dict[str, int]], str]:
|
||||||
name = self._get_param("Name")
|
name = self._get_param("Name")
|
||||||
description = self._get_param("Description")
|
description = self._get_param("Description")
|
||||||
database = self._get_param("Database")
|
database = self._get_param("Database")
|
||||||
@ -102,32 +103,32 @@ class AthenaResponse(BaseResponse):
|
|||||||
)
|
)
|
||||||
return json.dumps({"NamedQueryId": query_id})
|
return json.dumps({"NamedQueryId": query_id})
|
||||||
|
|
||||||
def get_named_query(self):
|
def get_named_query(self) -> str:
|
||||||
query_id = self._get_param("NamedQueryId")
|
query_id = self._get_param("NamedQueryId")
|
||||||
nq = self.athena_backend.get_named_query(query_id)
|
nq = self.athena_backend.get_named_query(query_id)
|
||||||
return json.dumps(
|
return json.dumps(
|
||||||
{
|
{
|
||||||
"NamedQuery": {
|
"NamedQuery": {
|
||||||
"Name": nq.name,
|
"Name": nq.name, # type: ignore[union-attr]
|
||||||
"Description": nq.description,
|
"Description": nq.description, # type: ignore[union-attr]
|
||||||
"Database": nq.database,
|
"Database": nq.database, # type: ignore[union-attr]
|
||||||
"QueryString": nq.query_string,
|
"QueryString": nq.query_string, # type: ignore[union-attr]
|
||||||
"NamedQueryId": nq.id,
|
"NamedQueryId": nq.id, # type: ignore[union-attr]
|
||||||
"WorkGroup": nq.workgroup,
|
"WorkGroup": nq.workgroup, # type: ignore[union-attr]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def list_data_catalogs(self):
|
def list_data_catalogs(self) -> str:
|
||||||
return json.dumps(
|
return json.dumps(
|
||||||
{"DataCatalogsSummary": self.athena_backend.list_data_catalogs()}
|
{"DataCatalogsSummary": self.athena_backend.list_data_catalogs()}
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_data_catalog(self):
|
def get_data_catalog(self) -> str:
|
||||||
name = self._get_param("Name")
|
name = self._get_param("Name")
|
||||||
return json.dumps({"DataCatalog": self.athena_backend.get_data_catalog(name)})
|
return json.dumps({"DataCatalog": self.athena_backend.get_data_catalog(name)})
|
||||||
|
|
||||||
def create_data_catalog(self):
|
def create_data_catalog(self) -> Union[Tuple[str, Dict[str, int]], str]:
|
||||||
name = self._get_param("Name")
|
name = self._get_param("Name")
|
||||||
catalog_type = self._get_param("Type")
|
catalog_type = self._get_param("Type")
|
||||||
description = self._get_param("Description")
|
description = self._get_param("Description")
|
||||||
|
@ -18,7 +18,7 @@ disable = W,C,R,E
|
|||||||
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
|
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
|
||||||
|
|
||||||
[mypy]
|
[mypy]
|
||||||
files= moto/acm,moto/amp,moto/apigateway,moto/apigatewayv2,moto/applicationautoscaling/,moto/appsync
|
files= moto/acm,moto/amp,moto/apigateway,moto/apigatewayv2,moto/applicationautoscaling/,moto/appsync,moto/athena
|
||||||
show_column_numbers=True
|
show_column_numbers=True
|
||||||
show_error_codes = True
|
show_error_codes = True
|
||||||
disable_error_code=abstract
|
disable_error_code=abstract
|
||||||
|
Loading…
Reference in New Issue
Block a user