Techdebt: MyPy RedshiftData (#6218)

This commit is contained in:
Bert Blommers 2023-04-17 15:20:19 +00:00 committed by GitHub
parent f2b6384f28
commit b999d658e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 59 additions and 40 deletions

View File

@ -2,10 +2,10 @@ from moto.core.exceptions import JsonRESTError
class ResourceNotFoundException(JsonRESTError): class ResourceNotFoundException(JsonRESTError):
def __init__(self): def __init__(self) -> None:
super().__init__("ResourceNotFoundException", "Query does not exist.") super().__init__("ResourceNotFoundException", "Query does not exist.")
class ValidationException(JsonRESTError): class ValidationException(JsonRESTError):
def __init__(self, message): def __init__(self, message: str):
super().__init__("ValidationException", message) super().__init__("ValidationException", message)

View File

@ -1,5 +1,6 @@
import re import re
from datetime import datetime from datetime import datetime
from typing import Any, Dict, Iterable, Iterator, List, Tuple, Optional
from moto.core import BaseBackend, BackendDict from moto.core import BaseBackend, BackendDict
from moto.core.utils import iso_8601_datetime_without_milliseconds from moto.core.utils import iso_8601_datetime_without_milliseconds
@ -7,15 +8,15 @@ from moto.moto_api._internal import mock_random as random
from moto.redshiftdata.exceptions import ValidationException, ResourceNotFoundException from moto.redshiftdata.exceptions import ValidationException, ResourceNotFoundException
class Statement: class Statement(Iterable[Tuple[str, Any]]):
def __init__( def __init__(
self, self,
cluster_identifier, cluster_identifier: str,
database, database: str,
db_user, db_user: str,
query_parameters, query_parameters: List[Dict[str, str]],
query_string, query_string: str,
secret_arn, secret_arn: str,
): ):
now = iso_8601_datetime_without_milliseconds(datetime.now()) now = iso_8601_datetime_without_milliseconds(datetime.now())
@ -34,10 +35,10 @@ class Statement:
self.result_size = -1 self.result_size = -1
self.secret_arn = secret_arn self.secret_arn = secret_arn
self.status = "STARTED" self.status = "STARTED"
self.sub_statements = [] self.sub_statements: List[str] = []
self.updated_at = now self.updated_at = now
def __iter__(self): def __iter__(self) -> Iterator[Tuple[str, Any]]:
yield "Id", self.id yield "Id", self.id
yield "ClusterIdentifier", self.cluster_identifier yield "ClusterIdentifier", self.cluster_identifier
yield "CreatedAt", self.created_at yield "CreatedAt", self.created_at
@ -57,29 +58,42 @@ class Statement:
yield "UpdatedAt", self.updated_at yield "UpdatedAt", self.updated_at
class StatementResult: class StatementResult(Iterable[Tuple[str, Any]]):
def __init__(self, column_metadata, records, total_number_rows, next_token=None): def __init__(
self,
column_metadata: List[Dict[str, Any]],
records: List[List[Dict[str, Any]]],
total_number_rows: int,
next_token: Optional[str] = None,
):
self.column_metadata = column_metadata self.column_metadata = column_metadata
self.records = records self.records = records
self.total_number_rows = total_number_rows self.total_number_rows = total_number_rows
self.next_token = next_token self.next_token = next_token
def __iter__(self): def __iter__(self) -> Iterator[Tuple[str, Any]]:
yield "ColumnMetadata", self.column_metadata yield "ColumnMetadata", self.column_metadata
yield "Records", self.records yield "Records", self.records
yield "TotalNumberRows", self.total_number_rows yield "TotalNumberRows", self.total_number_rows
yield "NextToken", self.next_token yield "NextToken", self.next_token
class ColumnMetadata: class ColumnMetadata(Iterable[Tuple[str, Any]]):
def __init__(self, column_default, is_case_sensitive, is_signed, name, nullable): def __init__(
self,
column_default: Optional[str],
is_case_sensitive: bool,
is_signed: bool,
name: str,
nullable: int,
):
self.column_default = column_default self.column_default = column_default
self.is_case_sensitive = is_case_sensitive self.is_case_sensitive = is_case_sensitive
self.is_signed = is_signed self.is_signed = is_signed
self.name = name self.name = name
self.nullable = nullable self.nullable = nullable
def __iter__(self): def __iter__(self) -> Iterator[Tuple[str, Any]]:
yield "columnDefault", self.column_default yield "columnDefault", self.column_default
yield "isCaseSensitive", self.is_case_sensitive yield "isCaseSensitive", self.is_case_sensitive
yield "isSigned", self.is_signed yield "isSigned", self.is_signed
@ -87,11 +101,11 @@ class ColumnMetadata:
yield "nullable", self.nullable yield "nullable", self.nullable
class Record: class Record(Iterable[Tuple[str, Any]]):
def __init__(self, **kwargs): def __init__(self, **kwargs: Any):
self.kwargs = kwargs self.kwargs = kwargs
def __iter__(self): def __iter__(self) -> Iterator[Tuple[str, Any]]:
if "long_value" in self.kwargs: if "long_value" in self.kwargs:
yield "longValue", self.kwargs["long_value"] yield "longValue", self.kwargs["long_value"]
elif "string_value" in self.kwargs: elif "string_value" in self.kwargs:
@ -99,11 +113,11 @@ class Record:
class RedshiftDataAPIServiceBackend(BaseBackend): class RedshiftDataAPIServiceBackend(BaseBackend):
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.statements = {} self.statements: Dict[str, Statement] = {}
def cancel_statement(self, statement_id): def cancel_statement(self, statement_id: str) -> None:
_validate_uuid(statement_id) _validate_uuid(statement_id)
try: try:
@ -122,9 +136,7 @@ class RedshiftDataAPIServiceBackend(BaseBackend):
# Statement does not exist. # Statement does not exist.
raise ResourceNotFoundException() raise ResourceNotFoundException()
return True def describe_statement(self, statement_id: str) -> Statement:
def describe_statement(self, statement_id):
_validate_uuid(statement_id) _validate_uuid(statement_id)
try: try:
@ -135,8 +147,14 @@ class RedshiftDataAPIServiceBackend(BaseBackend):
raise ResourceNotFoundException() raise ResourceNotFoundException()
def execute_statement( def execute_statement(
self, cluster_identifier, database, db_user, parameters, secret_arn, sql self,
): cluster_identifier: str,
database: str,
db_user: str,
parameters: List[Dict[str, str]],
secret_arn: str,
sql: str,
) -> Statement:
""" """
Runs an SQL statement Runs an SQL statement
Validation of parameters is very limited because there is no redshift integration Validation of parameters is very limited because there is no redshift integration
@ -152,7 +170,7 @@ class RedshiftDataAPIServiceBackend(BaseBackend):
self.statements[statement.id] = statement self.statements[statement.id] = statement
return statement return statement
def get_statement_result(self, statement_id): def get_statement_result(self, statement_id: str) -> StatementResult:
""" """
Return static statement result Return static statement result
StatementResult is the result of the SQL query "sql" passed as parameter when calling "execute_statement" StatementResult is the result of the SQL query "sql" passed as parameter when calling "execute_statement"
@ -190,7 +208,7 @@ class RedshiftDataAPIServiceBackend(BaseBackend):
) )
def _validate_uuid(uuid): def _validate_uuid(uuid: str) -> None:
match = re.search(r"^[a-z0-9]{8}(-[a-z0-9]{4}){3}-[a-z0-9]{12}(:\d+)?$", uuid) match = re.search(r"^[a-z0-9]{8}(-[a-z0-9]{4}){3}-[a-z0-9]{12}(:\d+)?$", uuid)
if not match: if not match:
raise ValidationException( raise ValidationException(

View File

@ -1,29 +1,30 @@
import json import json
from moto.core.common_types import TYPE_RESPONSE
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import redshiftdata_backends from .models import redshiftdata_backends, RedshiftDataAPIServiceBackend
class RedshiftDataAPIServiceResponse(BaseResponse): class RedshiftDataAPIServiceResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="redshift-data") super().__init__(service_name="redshift-data")
@property @property
def redshiftdata_backend(self): def redshiftdata_backend(self) -> RedshiftDataAPIServiceBackend:
return redshiftdata_backends[self.current_account][self.region] return redshiftdata_backends[self.current_account][self.region]
def cancel_statement(self): def cancel_statement(self) -> TYPE_RESPONSE:
statement_id = self._get_param("Id") statement_id = self._get_param("Id")
status = self.redshiftdata_backend.cancel_statement(statement_id=statement_id) self.redshiftdata_backend.cancel_statement(statement_id=statement_id)
return 200, {}, json.dumps({"Status": status}) return 200, {}, json.dumps({"Status": True})
def describe_statement(self): def describe_statement(self) -> TYPE_RESPONSE:
statement_id = self._get_param("Id") statement_id = self._get_param("Id")
statement = self.redshiftdata_backend.describe_statement( statement = self.redshiftdata_backend.describe_statement(
statement_id=statement_id statement_id=statement_id
) )
return 200, {}, json.dumps(dict(statement)) return 200, {}, json.dumps(dict(statement))
def execute_statement(self): def execute_statement(self) -> TYPE_RESPONSE:
cluster_identifier = self._get_param("ClusterIdentifier") cluster_identifier = self._get_param("ClusterIdentifier")
database = self._get_param("Database") database = self._get_param("Database")
db_user = self._get_param("DbUser") db_user = self._get_param("DbUser")
@ -54,7 +55,7 @@ class RedshiftDataAPIServiceResponse(BaseResponse):
), ),
) )
def get_statement_result(self): def get_statement_result(self) -> TYPE_RESPONSE:
statement_id = self._get_param("Id") statement_id = self._get_param("Id")
statement_result = self.redshiftdata_backend.get_statement_result( statement_result = self.redshiftdata_backend.get_statement_result(
statement_id=statement_id statement_id=statement_id

View File

@ -239,7 +239,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy] [mypy]
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/ram,moto/rds,moto/rdsdata,moto/redshift,moto/scheduler files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/l*,moto/m*,moto/n*,moto/o*,moto/p*,moto/q*,moto/ram,moto/rds*,moto/redshift*,moto/scheduler
show_column_numbers=True show_column_numbers=True
show_error_codes = True show_error_codes = True
disable_error_code=abstract disable_error_code=abstract