From 6a07abbb30e3d58c860a60160f4b0a8d7ba8fa8e Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Tue, 18 Oct 2022 12:57:37 +0000 Subject: [PATCH] Techdebt: MyPy Athena (#5578) --- moto/athena/exceptions.py | 2 +- moto/athena/models.py | 101 ++++++++++++++++++++++++++++---------- moto/athena/responses.py | 43 ++++++++-------- setup.cfg | 2 +- 4 files changed, 99 insertions(+), 49 deletions(-) diff --git a/moto/athena/exceptions.py b/moto/athena/exceptions.py index c0a5260e2..bcbc48c3b 100644 --- a/moto/athena/exceptions.py +++ b/moto/athena/exceptions.py @@ -3,7 +3,7 @@ from moto.core.exceptions import JsonRESTError class AthenaClientError(JsonRESTError): - def __init__(self, code, message): + def __init__(self, code: str, message: str): super().__init__(error_type="InvalidRequestException", message=message) self.description = json.dumps( { diff --git a/moto/athena/models.py b/moto/athena/models.py index beb7923e6..0c872d6af 100644 --- a/moto/athena/models.py +++ b/moto/athena/models.py @@ -3,25 +3,32 @@ import time from moto.core import BaseBackend, BaseModel from moto.core.utils import BackendDict from moto.moto_api._internal import mock_random +from typing import Any, Dict, List, Optional class TaggableResourceMixin(object): # This mixing was copied from Redshift when initially implementing # Athena. TBD if it's worth the overhead. - def __init__(self, account_id, region_name, resource_name, tags): + def __init__( + self, + account_id: str, + region_name: str, + resource_name: str, + tags: List[Dict[str, str]], + ): self.region = region_name self.resource_name = resource_name self.tags = tags or [] self.arn = f"arn:aws:athena:{region_name}:{account_id}:{resource_name}" - def create_tags(self, tags): + def create_tags(self, tags: List[Dict[str, str]]) -> List[Dict[str, str]]: new_keys = [tag_set["Key"] for tag_set in tags] self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys] self.tags.extend(tags) return self.tags - def delete_tags(self, tag_keys): + def delete_tags(self, tag_keys: List[str]) -> List[Dict[str, str]]: self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] return self.tags @@ -31,7 +38,14 @@ class WorkGroup(TaggableResourceMixin, BaseModel): resource_type = "workgroup" state = "ENABLED" - def __init__(self, athena_backend, name, configuration, description, tags): + def __init__( + self, + athena_backend: "AthenaBackend", + name: str, + configuration: str, + description: str, + tags: List[Dict[str, str]], + ): self.region_name = athena_backend.region_name super().__init__( athena_backend.account_id, @@ -47,7 +61,13 @@ class WorkGroup(TaggableResourceMixin, BaseModel): class DataCatalog(TaggableResourceMixin, BaseModel): def __init__( - self, athena_backend, name, catalog_type, description, parameters, tags + self, + athena_backend: "AthenaBackend", + name: str, + catalog_type: str, + description: str, + parameters: str, + tags: List[Dict[str, str]], ): self.region_name = athena_backend.region_name super().__init__( @@ -64,7 +84,7 @@ class DataCatalog(TaggableResourceMixin, BaseModel): class Execution(BaseModel): - def __init__(self, query, context, config, workgroup): + def __init__(self, query: str, context: str, config: str, workgroup: WorkGroup): self.id = str(mock_random.uuid4()) self.query = query self.context = context @@ -75,7 +95,14 @@ class Execution(BaseModel): class NamedQuery(BaseModel): - def __init__(self, name, description, database, query_string, workgroup): + def __init__( + self, + name: str, + description: str, + database: str, + query_string: str, + workgroup: WorkGroup, + ): self.id = str(mock_random.uuid4()) self.name = name self.description = description @@ -85,30 +112,36 @@ class NamedQuery(BaseModel): class AthenaBackend(BaseBackend): - region_name = None - - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.work_groups = {} - self.executions = {} - self.named_queries = {} - self.data_catalogs = {} + self.work_groups: Dict[str, WorkGroup] = {} + self.executions: Dict[str, Execution] = {} + self.named_queries: Dict[str, NamedQuery] = {} + self.data_catalogs: Dict[str, DataCatalog] = {} @staticmethod - def default_vpc_endpoint_service(service_region, zones): + def default_vpc_endpoint_service( + service_region: str, zones: List[str] + ) -> List[Dict[str, str]]: """Default VPC endpoint service.""" return BaseBackend.default_vpc_endpoint_service_factory( service_region, zones, "athena" ) - def create_work_group(self, name, configuration, description, tags): + def create_work_group( + self, + name: str, + configuration: str, + description: str, + tags: List[Dict[str, str]], + ) -> Optional[WorkGroup]: if name in self.work_groups: return None work_group = WorkGroup(self, name, configuration, description, tags) self.work_groups[name] = work_group return work_group - def list_work_groups(self): + def list_work_groups(self) -> List[Dict[str, Any]]: return [ { "Name": wg.name, @@ -119,7 +152,7 @@ class AthenaBackend(BaseBackend): for wg in self.work_groups.values() ] - def get_work_group(self, name): + def get_work_group(self, name: str) -> Optional[Dict[str, Any]]: if name not in self.work_groups: return None wg = self.work_groups[name] @@ -131,21 +164,30 @@ class AthenaBackend(BaseBackend): "CreationTime": time.time(), } - def start_query_execution(self, query, context, config, workgroup): + def start_query_execution( + self, query: str, context: str, config: str, workgroup: WorkGroup + ) -> str: execution = Execution( query=query, context=context, config=config, workgroup=workgroup ) self.executions[execution.id] = execution return execution.id - def get_execution(self, exec_id): + def get_execution(self, exec_id: str) -> Execution: return self.executions[exec_id] - def stop_query_execution(self, exec_id): + def stop_query_execution(self, exec_id: str) -> None: execution = self.executions[exec_id] execution.status = "CANCELLED" - def create_named_query(self, name, description, database, query_string, workgroup): + def create_named_query( + self, + name: str, + description: str, + database: str, + query_string: str, + workgroup: WorkGroup, + ) -> str: nq = NamedQuery( name=name, description=description, @@ -156,16 +198,16 @@ class AthenaBackend(BaseBackend): self.named_queries[nq.id] = nq return nq.id - def get_named_query(self, query_id): + def get_named_query(self, query_id: str) -> Optional[NamedQuery]: return self.named_queries[query_id] if query_id in self.named_queries else None - def list_data_catalogs(self): + def list_data_catalogs(self) -> List[Dict[str, str]]: return [ {"CatalogName": dc.name, "Type": dc.type} for dc in self.data_catalogs.values() ] - def get_data_catalog(self, name): + def get_data_catalog(self, name: str) -> Optional[Dict[str, str]]: if name not in self.data_catalogs: return None dc = self.data_catalogs[name] @@ -176,7 +218,14 @@ class AthenaBackend(BaseBackend): "Parameters": dc.parameters, } - def create_data_catalog(self, name, catalog_type, description, parameters, tags): + def create_data_catalog( + self, + name: str, + catalog_type: str, + description: str, + parameters: str, + tags: List[Dict[str, str]], + ) -> Optional[DataCatalog]: if name in self.data_catalogs: return None data_catalog = DataCatalog( diff --git a/moto/athena/responses.py b/moto/athena/responses.py index b47f5ed7b..09000727b 100644 --- a/moto/athena/responses.py +++ b/moto/athena/responses.py @@ -1,18 +1,19 @@ import json from moto.core.responses import BaseResponse -from .models import athena_backends +from .models import athena_backends, AthenaBackend +from typing import Dict, Tuple, Union class AthenaResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="athena") @property - def athena_backend(self): + def athena_backend(self) -> AthenaBackend: return athena_backends[self.current_account][self.region] - def create_work_group(self): + def create_work_group(self) -> Union[Tuple[str, Dict[str, int]], str]: name = self._get_param("Name") description = self._get_param("Description") configuration = self._get_param("Configuration") @@ -32,14 +33,14 @@ class AthenaResponse(BaseResponse): } ) - def list_work_groups(self): + def list_work_groups(self) -> str: return json.dumps({"WorkGroups": self.athena_backend.list_work_groups()}) - def get_work_group(self): + def get_work_group(self) -> str: name = self._get_param("WorkGroup") return json.dumps({"WorkGroup": self.athena_backend.get_work_group(name)}) - def start_query_execution(self): + def start_query_execution(self) -> Union[Tuple[str, Dict[str, int]], str]: query = self._get_param("QueryString") context = self._get_param("QueryExecutionContext") config = self._get_param("ResultConfiguration") @@ -51,7 +52,7 @@ class AthenaResponse(BaseResponse): ) return json.dumps({"QueryExecutionId": q_exec_id}) - def get_query_execution(self): + def get_query_execution(self) -> str: exec_id = self._get_param("QueryExecutionId") execution = self.athena_backend.get_execution(exec_id) result = { @@ -78,18 +79,18 @@ class AthenaResponse(BaseResponse): } return json.dumps(result) - def stop_query_execution(self): + def stop_query_execution(self) -> str: exec_id = self._get_param("QueryExecutionId") self.athena_backend.stop_query_execution(exec_id) return json.dumps({}) - def error(self, msg, status): + def error(self, msg: str, status: int) -> Tuple[str, Dict[str, int]]: return ( json.dumps({"__type": "InvalidRequestException", "Message": msg}), dict(status=status), ) - def create_named_query(self): + def create_named_query(self) -> Union[Tuple[str, Dict[str, int]], str]: name = self._get_param("Name") description = self._get_param("Description") database = self._get_param("Database") @@ -102,32 +103,32 @@ class AthenaResponse(BaseResponse): ) return json.dumps({"NamedQueryId": query_id}) - def get_named_query(self): + def get_named_query(self) -> str: query_id = self._get_param("NamedQueryId") nq = self.athena_backend.get_named_query(query_id) return json.dumps( { "NamedQuery": { - "Name": nq.name, - "Description": nq.description, - "Database": nq.database, - "QueryString": nq.query_string, - "NamedQueryId": nq.id, - "WorkGroup": nq.workgroup, + "Name": nq.name, # type: ignore[union-attr] + "Description": nq.description, # type: ignore[union-attr] + "Database": nq.database, # type: ignore[union-attr] + "QueryString": nq.query_string, # type: ignore[union-attr] + "NamedQueryId": nq.id, # type: ignore[union-attr] + "WorkGroup": nq.workgroup, # type: ignore[union-attr] } } ) - def list_data_catalogs(self): + def list_data_catalogs(self) -> str: return json.dumps( {"DataCatalogsSummary": self.athena_backend.list_data_catalogs()} ) - def get_data_catalog(self): + def get_data_catalog(self) -> str: name = self._get_param("Name") return json.dumps({"DataCatalog": self.athena_backend.get_data_catalog(name)}) - def create_data_catalog(self): + def create_data_catalog(self) -> Union[Tuple[str, Dict[str, int]], str]: name = self._get_param("Name") catalog_type = self._get_param("Type") description = self._get_param("Description") diff --git a/setup.cfg b/setup.cfg index 9911d789d..e4e853db1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,7 +18,7 @@ disable = W,C,R,E enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import [mypy] -files= moto/acm,moto/amp,moto/apigateway,moto/apigatewayv2,moto/applicationautoscaling/,moto/appsync +files= moto/acm,moto/amp,moto/apigateway,moto/apigatewayv2,moto/applicationautoscaling/,moto/appsync,moto/athena show_column_numbers=True show_error_codes = True disable_error_code=abstract