Techdebt: MyPy Athena (#5578)

This commit is contained in:
Bert Blommers 2022-10-18 12:57:37 +00:00 committed by GitHub
parent fd77cd4dc2
commit 6a07abbb30
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 99 additions and 49 deletions

View File

@ -3,7 +3,7 @@ from moto.core.exceptions import JsonRESTError
class AthenaClientError(JsonRESTError): class AthenaClientError(JsonRESTError):
def __init__(self, code, message): def __init__(self, code: str, message: str):
super().__init__(error_type="InvalidRequestException", message=message) super().__init__(error_type="InvalidRequestException", message=message)
self.description = json.dumps( self.description = json.dumps(
{ {

View File

@ -3,25 +3,32 @@ import time
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.utils import BackendDict from moto.core.utils import BackendDict
from moto.moto_api._internal import mock_random from moto.moto_api._internal import mock_random
from typing import Any, Dict, List, Optional
class TaggableResourceMixin(object): class TaggableResourceMixin(object):
# This mixing was copied from Redshift when initially implementing # This mixing was copied from Redshift when initially implementing
# Athena. TBD if it's worth the overhead. # Athena. TBD if it's worth the overhead.
def __init__(self, account_id, region_name, resource_name, tags): def __init__(
self,
account_id: str,
region_name: str,
resource_name: str,
tags: List[Dict[str, str]],
):
self.region = region_name self.region = region_name
self.resource_name = resource_name self.resource_name = resource_name
self.tags = tags or [] self.tags = tags or []
self.arn = f"arn:aws:athena:{region_name}:{account_id}:{resource_name}" self.arn = f"arn:aws:athena:{region_name}:{account_id}:{resource_name}"
def create_tags(self, tags): def create_tags(self, tags: List[Dict[str, str]]) -> List[Dict[str, str]]:
new_keys = [tag_set["Key"] for tag_set in tags] new_keys = [tag_set["Key"] for tag_set in tags]
self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys] self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys]
self.tags.extend(tags) self.tags.extend(tags)
return self.tags return self.tags
def delete_tags(self, tag_keys): def delete_tags(self, tag_keys: List[str]) -> List[Dict[str, str]]:
self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys]
return self.tags return self.tags
@ -31,7 +38,14 @@ class WorkGroup(TaggableResourceMixin, BaseModel):
resource_type = "workgroup" resource_type = "workgroup"
state = "ENABLED" state = "ENABLED"
def __init__(self, athena_backend, name, configuration, description, tags): def __init__(
self,
athena_backend: "AthenaBackend",
name: str,
configuration: str,
description: str,
tags: List[Dict[str, str]],
):
self.region_name = athena_backend.region_name self.region_name = athena_backend.region_name
super().__init__( super().__init__(
athena_backend.account_id, athena_backend.account_id,
@ -47,7 +61,13 @@ class WorkGroup(TaggableResourceMixin, BaseModel):
class DataCatalog(TaggableResourceMixin, BaseModel): class DataCatalog(TaggableResourceMixin, BaseModel):
def __init__( def __init__(
self, athena_backend, name, catalog_type, description, parameters, tags self,
athena_backend: "AthenaBackend",
name: str,
catalog_type: str,
description: str,
parameters: str,
tags: List[Dict[str, str]],
): ):
self.region_name = athena_backend.region_name self.region_name = athena_backend.region_name
super().__init__( super().__init__(
@ -64,7 +84,7 @@ class DataCatalog(TaggableResourceMixin, BaseModel):
class Execution(BaseModel): class Execution(BaseModel):
def __init__(self, query, context, config, workgroup): def __init__(self, query: str, context: str, config: str, workgroup: WorkGroup):
self.id = str(mock_random.uuid4()) self.id = str(mock_random.uuid4())
self.query = query self.query = query
self.context = context self.context = context
@ -75,7 +95,14 @@ class Execution(BaseModel):
class NamedQuery(BaseModel): class NamedQuery(BaseModel):
def __init__(self, name, description, database, query_string, workgroup): def __init__(
self,
name: str,
description: str,
database: str,
query_string: str,
workgroup: WorkGroup,
):
self.id = str(mock_random.uuid4()) self.id = str(mock_random.uuid4())
self.name = name self.name = name
self.description = description self.description = description
@ -85,30 +112,36 @@ class NamedQuery(BaseModel):
class AthenaBackend(BaseBackend): class AthenaBackend(BaseBackend):
region_name = None def __init__(self, region_name: str, account_id: str):
def __init__(self, region_name, account_id):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.work_groups = {} self.work_groups: Dict[str, WorkGroup] = {}
self.executions = {} self.executions: Dict[str, Execution] = {}
self.named_queries = {} self.named_queries: Dict[str, NamedQuery] = {}
self.data_catalogs = {} self.data_catalogs: Dict[str, DataCatalog] = {}
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(
service_region: str, zones: List[str]
) -> List[Dict[str, str]]:
"""Default VPC endpoint service.""" """Default VPC endpoint service."""
return BaseBackend.default_vpc_endpoint_service_factory( return BaseBackend.default_vpc_endpoint_service_factory(
service_region, zones, "athena" service_region, zones, "athena"
) )
def create_work_group(self, name, configuration, description, tags): def create_work_group(
self,
name: str,
configuration: str,
description: str,
tags: List[Dict[str, str]],
) -> Optional[WorkGroup]:
if name in self.work_groups: if name in self.work_groups:
return None return None
work_group = WorkGroup(self, name, configuration, description, tags) work_group = WorkGroup(self, name, configuration, description, tags)
self.work_groups[name] = work_group self.work_groups[name] = work_group
return work_group return work_group
def list_work_groups(self): def list_work_groups(self) -> List[Dict[str, Any]]:
return [ return [
{ {
"Name": wg.name, "Name": wg.name,
@ -119,7 +152,7 @@ class AthenaBackend(BaseBackend):
for wg in self.work_groups.values() for wg in self.work_groups.values()
] ]
def get_work_group(self, name): def get_work_group(self, name: str) -> Optional[Dict[str, Any]]:
if name not in self.work_groups: if name not in self.work_groups:
return None return None
wg = self.work_groups[name] wg = self.work_groups[name]
@ -131,21 +164,30 @@ class AthenaBackend(BaseBackend):
"CreationTime": time.time(), "CreationTime": time.time(),
} }
def start_query_execution(self, query, context, config, workgroup): def start_query_execution(
self, query: str, context: str, config: str, workgroup: WorkGroup
) -> str:
execution = Execution( execution = Execution(
query=query, context=context, config=config, workgroup=workgroup query=query, context=context, config=config, workgroup=workgroup
) )
self.executions[execution.id] = execution self.executions[execution.id] = execution
return execution.id return execution.id
def get_execution(self, exec_id): def get_execution(self, exec_id: str) -> Execution:
return self.executions[exec_id] return self.executions[exec_id]
def stop_query_execution(self, exec_id): def stop_query_execution(self, exec_id: str) -> None:
execution = self.executions[exec_id] execution = self.executions[exec_id]
execution.status = "CANCELLED" execution.status = "CANCELLED"
def create_named_query(self, name, description, database, query_string, workgroup): def create_named_query(
self,
name: str,
description: str,
database: str,
query_string: str,
workgroup: WorkGroup,
) -> str:
nq = NamedQuery( nq = NamedQuery(
name=name, name=name,
description=description, description=description,
@ -156,16 +198,16 @@ class AthenaBackend(BaseBackend):
self.named_queries[nq.id] = nq self.named_queries[nq.id] = nq
return nq.id return nq.id
def get_named_query(self, query_id): def get_named_query(self, query_id: str) -> Optional[NamedQuery]:
return self.named_queries[query_id] if query_id in self.named_queries else None return self.named_queries[query_id] if query_id in self.named_queries else None
def list_data_catalogs(self): def list_data_catalogs(self) -> List[Dict[str, str]]:
return [ return [
{"CatalogName": dc.name, "Type": dc.type} {"CatalogName": dc.name, "Type": dc.type}
for dc in self.data_catalogs.values() for dc in self.data_catalogs.values()
] ]
def get_data_catalog(self, name): def get_data_catalog(self, name: str) -> Optional[Dict[str, str]]:
if name not in self.data_catalogs: if name not in self.data_catalogs:
return None return None
dc = self.data_catalogs[name] dc = self.data_catalogs[name]
@ -176,7 +218,14 @@ class AthenaBackend(BaseBackend):
"Parameters": dc.parameters, "Parameters": dc.parameters,
} }
def create_data_catalog(self, name, catalog_type, description, parameters, tags): def create_data_catalog(
self,
name: str,
catalog_type: str,
description: str,
parameters: str,
tags: List[Dict[str, str]],
) -> Optional[DataCatalog]:
if name in self.data_catalogs: if name in self.data_catalogs:
return None return None
data_catalog = DataCatalog( data_catalog = DataCatalog(

View File

@ -1,18 +1,19 @@
import json import json
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import athena_backends from .models import athena_backends, AthenaBackend
from typing import Dict, Tuple, Union
class AthenaResponse(BaseResponse): class AthenaResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="athena") super().__init__(service_name="athena")
@property @property
def athena_backend(self): def athena_backend(self) -> AthenaBackend:
return athena_backends[self.current_account][self.region] return athena_backends[self.current_account][self.region]
def create_work_group(self): def create_work_group(self) -> Union[Tuple[str, Dict[str, int]], str]:
name = self._get_param("Name") name = self._get_param("Name")
description = self._get_param("Description") description = self._get_param("Description")
configuration = self._get_param("Configuration") configuration = self._get_param("Configuration")
@ -32,14 +33,14 @@ class AthenaResponse(BaseResponse):
} }
) )
def list_work_groups(self): def list_work_groups(self) -> str:
return json.dumps({"WorkGroups": self.athena_backend.list_work_groups()}) return json.dumps({"WorkGroups": self.athena_backend.list_work_groups()})
def get_work_group(self): def get_work_group(self) -> str:
name = self._get_param("WorkGroup") name = self._get_param("WorkGroup")
return json.dumps({"WorkGroup": self.athena_backend.get_work_group(name)}) return json.dumps({"WorkGroup": self.athena_backend.get_work_group(name)})
def start_query_execution(self): def start_query_execution(self) -> Union[Tuple[str, Dict[str, int]], str]:
query = self._get_param("QueryString") query = self._get_param("QueryString")
context = self._get_param("QueryExecutionContext") context = self._get_param("QueryExecutionContext")
config = self._get_param("ResultConfiguration") config = self._get_param("ResultConfiguration")
@ -51,7 +52,7 @@ class AthenaResponse(BaseResponse):
) )
return json.dumps({"QueryExecutionId": q_exec_id}) return json.dumps({"QueryExecutionId": q_exec_id})
def get_query_execution(self): def get_query_execution(self) -> str:
exec_id = self._get_param("QueryExecutionId") exec_id = self._get_param("QueryExecutionId")
execution = self.athena_backend.get_execution(exec_id) execution = self.athena_backend.get_execution(exec_id)
result = { result = {
@ -78,18 +79,18 @@ class AthenaResponse(BaseResponse):
} }
return json.dumps(result) return json.dumps(result)
def stop_query_execution(self): def stop_query_execution(self) -> str:
exec_id = self._get_param("QueryExecutionId") exec_id = self._get_param("QueryExecutionId")
self.athena_backend.stop_query_execution(exec_id) self.athena_backend.stop_query_execution(exec_id)
return json.dumps({}) return json.dumps({})
def error(self, msg, status): def error(self, msg: str, status: int) -> Tuple[str, Dict[str, int]]:
return ( return (
json.dumps({"__type": "InvalidRequestException", "Message": msg}), json.dumps({"__type": "InvalidRequestException", "Message": msg}),
dict(status=status), dict(status=status),
) )
def create_named_query(self): def create_named_query(self) -> Union[Tuple[str, Dict[str, int]], str]:
name = self._get_param("Name") name = self._get_param("Name")
description = self._get_param("Description") description = self._get_param("Description")
database = self._get_param("Database") database = self._get_param("Database")
@ -102,32 +103,32 @@ class AthenaResponse(BaseResponse):
) )
return json.dumps({"NamedQueryId": query_id}) return json.dumps({"NamedQueryId": query_id})
def get_named_query(self): def get_named_query(self) -> str:
query_id = self._get_param("NamedQueryId") query_id = self._get_param("NamedQueryId")
nq = self.athena_backend.get_named_query(query_id) nq = self.athena_backend.get_named_query(query_id)
return json.dumps( return json.dumps(
{ {
"NamedQuery": { "NamedQuery": {
"Name": nq.name, "Name": nq.name, # type: ignore[union-attr]
"Description": nq.description, "Description": nq.description, # type: ignore[union-attr]
"Database": nq.database, "Database": nq.database, # type: ignore[union-attr]
"QueryString": nq.query_string, "QueryString": nq.query_string, # type: ignore[union-attr]
"NamedQueryId": nq.id, "NamedQueryId": nq.id, # type: ignore[union-attr]
"WorkGroup": nq.workgroup, "WorkGroup": nq.workgroup, # type: ignore[union-attr]
} }
} }
) )
def list_data_catalogs(self): def list_data_catalogs(self) -> str:
return json.dumps( return json.dumps(
{"DataCatalogsSummary": self.athena_backend.list_data_catalogs()} {"DataCatalogsSummary": self.athena_backend.list_data_catalogs()}
) )
def get_data_catalog(self): def get_data_catalog(self) -> str:
name = self._get_param("Name") name = self._get_param("Name")
return json.dumps({"DataCatalog": self.athena_backend.get_data_catalog(name)}) return json.dumps({"DataCatalog": self.athena_backend.get_data_catalog(name)})
def create_data_catalog(self): def create_data_catalog(self) -> Union[Tuple[str, Dict[str, int]], str]:
name = self._get_param("Name") name = self._get_param("Name")
catalog_type = self._get_param("Type") catalog_type = self._get_param("Type")
description = self._get_param("Description") description = self._get_param("Description")

View File

@ -18,7 +18,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy] [mypy]
files= moto/acm,moto/amp,moto/apigateway,moto/apigatewayv2,moto/applicationautoscaling/,moto/appsync files= moto/acm,moto/amp,moto/apigateway,moto/apigatewayv2,moto/applicationautoscaling/,moto/appsync,moto/athena
show_column_numbers=True show_column_numbers=True
show_error_codes = True show_error_codes = True
disable_error_code=abstract disable_error_code=abstract