From 92396bce4f036cf879da5897c62cb96d48ed9d07 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Sun, 18 Dec 2022 20:16:46 -0100 Subject: [PATCH] Techdebt: MyPy DMS (#5785) --- moto/dms/exceptions.py | 6 ++-- moto/dms/models.py | 71 +++++++++++++++++++++++------------------- moto/dms/responses.py | 16 +++++----- moto/dms/utils.py | 22 ++++++++----- setup.cfg | 2 +- 5 files changed, 65 insertions(+), 52 deletions(-) diff --git a/moto/dms/exceptions.py b/moto/dms/exceptions.py index f0c4dc29b..6ade6bae2 100644 --- a/moto/dms/exceptions.py +++ b/moto/dms/exceptions.py @@ -6,15 +6,15 @@ class DmsClientError(JsonRESTError): class ResourceNotFoundFault(DmsClientError): - def __init__(self, message): + def __init__(self, message: str): super().__init__("ResourceNotFoundFault", message) class InvalidResourceStateFault(DmsClientError): - def __init__(self, message): + def __init__(self, message: str): super().__init__("InvalidResourceStateFault", message) class ResourceAlreadyExistsFault(DmsClientError): - def __init__(self, message): + def __init__(self, message: str): super().__init__("ResourceAlreadyExistsFault", message) diff --git a/moto/dms/models.py b/moto/dms/models.py index 8456ab382..2a174a7d3 100644 --- a/moto/dms/models.py +++ b/moto/dms/models.py @@ -1,6 +1,7 @@ import json from datetime import datetime +from typing import Any, Dict, List, Iterable, Optional from moto.core import BaseBackend, BackendDict, BaseModel from .exceptions import ( @@ -12,12 +13,12 @@ from .utils import filter_tasks class DatabaseMigrationServiceBackend(BaseBackend): - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.replication_tasks = {} + self.replication_tasks: Dict[str, "FakeReplicationTask"] = {} @staticmethod - def default_vpc_endpoint_service(service_region, zones): + def default_vpc_endpoint_service(service_region: str, zones: List[str]) -> List[Dict[str, Any]]: # type: ignore[misc] """Default VPC endpoint service.""" return BaseBackend.default_vpc_endpoint_service_factory( service_region, zones, "dms" @@ -25,14 +26,14 @@ class DatabaseMigrationServiceBackend(BaseBackend): def create_replication_task( self, - replication_task_identifier, - source_endpoint_arn, - target_endpoint_arn, - replication_instance_arn, - migration_type, - table_mappings, - replication_task_settings, - ): + replication_task_identifier: str, + source_endpoint_arn: str, + target_endpoint_arn: str, + replication_instance_arn: str, + migration_type: str, + table_mappings: str, + replication_task_settings: str, + ) -> "FakeReplicationTask": """ The following parameters are not yet implemented: CDCStartTime, CDCStartPosition, CDCStopPosition, Tags, TaskData, ResourceIdentifier @@ -58,7 +59,9 @@ class DatabaseMigrationServiceBackend(BaseBackend): return replication_task - def start_replication_task(self, replication_task_arn): + def start_replication_task( + self, replication_task_arn: str + ) -> "FakeReplicationTask": """ The following parameters have not yet been implemented: StartReplicationTaskType, CDCStartTime, CDCStartPosition, CDCStopPosition @@ -68,13 +71,15 @@ class DatabaseMigrationServiceBackend(BaseBackend): return self.replication_tasks[replication_task_arn].start() - def stop_replication_task(self, replication_task_arn): + def stop_replication_task(self, replication_task_arn: str) -> "FakeReplicationTask": if not self.replication_tasks.get(replication_task_arn): raise ResourceNotFoundFault("Replication task could not be found.") return self.replication_tasks[replication_task_arn].stop() - def delete_replication_task(self, replication_task_arn): + def delete_replication_task( + self, replication_task_arn: str + ) -> "FakeReplicationTask": if not self.replication_tasks.get(replication_task_arn): raise ResourceNotFoundFault("Replication task could not be found.") @@ -84,7 +89,9 @@ class DatabaseMigrationServiceBackend(BaseBackend): return task - def describe_replication_tasks(self, filters, max_records): + def describe_replication_tasks( + self, filters: List[Dict[str, Any]], max_records: int + ) -> Iterable["FakeReplicationTask"]: """ The parameter WithoutSettings has not yet been implemented """ @@ -99,15 +106,15 @@ class DatabaseMigrationServiceBackend(BaseBackend): class FakeReplicationTask(BaseModel): def __init__( self, - replication_task_identifier, - migration_type, - replication_instance_arn, - source_endpoint_arn, - target_endpoint_arn, - table_mappings, - replication_task_settings, - account_id, - region_name, + replication_task_identifier: str, + migration_type: str, + replication_instance_arn: str, + source_endpoint_arn: str, + target_endpoint_arn: str, + table_mappings: str, + replication_task_settings: str, + account_id: str, + region_name: str, ): self.id = replication_task_identifier self.region = region_name @@ -122,10 +129,10 @@ class FakeReplicationTask(BaseModel): self.status = "creating" self.creation_date = datetime.utcnow() - self.start_date = None - self.stop_date = None + self.start_date: Optional[datetime] = None + self.stop_date: Optional[datetime] = None - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: start_date = self.start_date.isoformat() if self.start_date else None stop_date = self.stop_date.isoformat() if self.stop_date else None @@ -156,17 +163,17 @@ class FakeReplicationTask(BaseModel): }, } - def ready(self): + def ready(self) -> "FakeReplicationTask": self.status = "ready" return self - def start(self): + def start(self) -> "FakeReplicationTask": self.status = "starting" self.start_date = datetime.utcnow() self.run() return self - def stop(self): + def stop(self) -> "FakeReplicationTask": if self.status != "running": raise InvalidResourceStateFault("Replication task is not running") @@ -174,11 +181,11 @@ class FakeReplicationTask(BaseModel): self.stop_date = datetime.utcnow() return self - def delete(self): + def delete(self) -> "FakeReplicationTask": self.status = "deleting" return self - def run(self): + def run(self) -> "FakeReplicationTask": self.status = "running" return self diff --git a/moto/dms/responses.py b/moto/dms/responses.py index c51741c2a..cdb99e386 100644 --- a/moto/dms/responses.py +++ b/moto/dms/responses.py @@ -1,17 +1,17 @@ from moto.core.responses import BaseResponse -from .models import dms_backends +from .models import dms_backends, DatabaseMigrationServiceBackend import json class DatabaseMigrationServiceResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="dms") @property - def dms_backend(self): + def dms_backend(self) -> DatabaseMigrationServiceBackend: return dms_backends[self.current_account][self.region] - def create_replication_task(self): + def create_replication_task(self) -> str: replication_task_identifier = self._get_param("ReplicationTaskIdentifier") source_endpoint_arn = self._get_param("SourceEndpointArn") target_endpoint_arn = self._get_param("TargetEndpointArn") @@ -31,7 +31,7 @@ class DatabaseMigrationServiceResponse(BaseResponse): return json.dumps({"ReplicationTask": replication_task.to_dict()}) - def start_replication_task(self): + def start_replication_task(self) -> str: replication_task_arn = self._get_param("ReplicationTaskArn") replication_task = self.dms_backend.start_replication_task( replication_task_arn=replication_task_arn @@ -39,7 +39,7 @@ class DatabaseMigrationServiceResponse(BaseResponse): return json.dumps({"ReplicationTask": replication_task.to_dict()}) - def stop_replication_task(self): + def stop_replication_task(self) -> str: replication_task_arn = self._get_param("ReplicationTaskArn") replication_task = self.dms_backend.stop_replication_task( replication_task_arn=replication_task_arn @@ -47,7 +47,7 @@ class DatabaseMigrationServiceResponse(BaseResponse): return json.dumps({"ReplicationTask": replication_task.to_dict()}) - def delete_replication_task(self): + def delete_replication_task(self) -> str: replication_task_arn = self._get_param("ReplicationTaskArn") replication_task = self.dms_backend.delete_replication_task( replication_task_arn=replication_task_arn @@ -55,7 +55,7 @@ class DatabaseMigrationServiceResponse(BaseResponse): return json.dumps({"ReplicationTask": replication_task.to_dict()}) - def describe_replication_tasks(self): + def describe_replication_tasks(self) -> str: filters = self._get_list_prefix("Filters.member") max_records = self._get_int_param("MaxRecords") replication_tasks = self.dms_backend.describe_replication_tasks( diff --git a/moto/dms/utils.py b/moto/dms/utils.py index cff278e12..3cbf33e2d 100644 --- a/moto/dms/utils.py +++ b/moto/dms/utils.py @@ -1,23 +1,28 @@ -def match_task_arn(task, arns): +from typing import Any, Dict, List, Iterable + + +def match_task_arn(task: Dict[str, Any], arns: List[str]) -> bool: return task["ReplicationTaskArn"] in arns -def match_task_id(task, ids): +def match_task_id(task: Dict[str, Any], ids: List[str]) -> bool: return task["ReplicationTaskIdentifier"] in ids -def match_task_migration_type(task, migration_types): +def match_task_migration_type(task: Dict[str, Any], migration_types: List[str]) -> bool: return task["MigrationType"] in migration_types -def match_task_endpoint_arn(task, endpoint_arns): +def match_task_endpoint_arn(task: Dict[str, Any], endpoint_arns: List[str]) -> bool: return ( task["SourceEndpointArn"] in endpoint_arns or task["TargetEndpointArn"] in endpoint_arns ) -def match_task_replication_instance_arn(task, replication_instance_arns): +def match_task_replication_instance_arn( + task: Dict[str, Any], replication_instance_arns: List[str] +) -> bool: return task["ReplicationInstanceArn"] in replication_instance_arns @@ -30,17 +35,18 @@ task_filter_functions = { } -def filter_tasks(tasks, filters): +def filter_tasks(tasks: Iterable[Any], filters: List[Dict[str, Any]]) -> Any: matching_tasks = tasks for f in filters: - filter_function = task_filter_functions[f["Name"]] + filter_function = task_filter_functions.get(f["Name"]) if not filter_function: continue + # https://github.com/python/mypy/issues/12682 matching_tasks = filter( - lambda task: filter_function(task, f["Values"]), matching_tasks + lambda task: filter_function(task, f["Values"]), matching_tasks # type: ignore[arg-type] ) return matching_tasks diff --git a/setup.cfg b/setup.cfg index 8febb1afc..9e857dfdb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,7 +18,7 @@ disable = W,C,R,E enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import [mypy] -files= moto/a*,moto/b*,moto/c*,moto/databrew,moto/datapipeline,moto/datasync,moto/dax,moto/moto_api +files= moto/a*,moto/b*,moto/c*,moto/databrew,moto/datapipeline,moto/datasync,moto/dax,moto/dms,moto/moto_api show_column_numbers=True show_error_codes = True disable_error_code=abstract