Techdebt: MyPy DMS (#5785)

This commit is contained in:
Bert Blommers 2022-12-18 20:16:46 -01:00 committed by GitHub
parent fab91edd8d
commit 92396bce4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 65 additions and 52 deletions

View File

@ -6,15 +6,15 @@ class DmsClientError(JsonRESTError):
class ResourceNotFoundFault(DmsClientError): class ResourceNotFoundFault(DmsClientError):
def __init__(self, message): def __init__(self, message: str):
super().__init__("ResourceNotFoundFault", message) super().__init__("ResourceNotFoundFault", message)
class InvalidResourceStateFault(DmsClientError): class InvalidResourceStateFault(DmsClientError):
def __init__(self, message): def __init__(self, message: str):
super().__init__("InvalidResourceStateFault", message) super().__init__("InvalidResourceStateFault", message)
class ResourceAlreadyExistsFault(DmsClientError): class ResourceAlreadyExistsFault(DmsClientError):
def __init__(self, message): def __init__(self, message: str):
super().__init__("ResourceAlreadyExistsFault", message) super().__init__("ResourceAlreadyExistsFault", message)

View File

@ -1,6 +1,7 @@
import json import json
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Iterable, Optional
from moto.core import BaseBackend, BackendDict, BaseModel from moto.core import BaseBackend, BackendDict, BaseModel
from .exceptions import ( from .exceptions import (
@ -12,12 +13,12 @@ from .utils import filter_tasks
class DatabaseMigrationServiceBackend(BaseBackend): class DatabaseMigrationServiceBackend(BaseBackend):
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.replication_tasks = {} self.replication_tasks: Dict[str, "FakeReplicationTask"] = {}
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(service_region: str, zones: List[str]) -> List[Dict[str, Any]]: # type: ignore[misc]
"""Default VPC endpoint service.""" """Default VPC endpoint service."""
return BaseBackend.default_vpc_endpoint_service_factory( return BaseBackend.default_vpc_endpoint_service_factory(
service_region, zones, "dms" service_region, zones, "dms"
@ -25,14 +26,14 @@ class DatabaseMigrationServiceBackend(BaseBackend):
def create_replication_task( def create_replication_task(
self, self,
replication_task_identifier, replication_task_identifier: str,
source_endpoint_arn, source_endpoint_arn: str,
target_endpoint_arn, target_endpoint_arn: str,
replication_instance_arn, replication_instance_arn: str,
migration_type, migration_type: str,
table_mappings, table_mappings: str,
replication_task_settings, replication_task_settings: str,
): ) -> "FakeReplicationTask":
""" """
The following parameters are not yet implemented: The following parameters are not yet implemented:
CDCStartTime, CDCStartPosition, CDCStopPosition, Tags, TaskData, ResourceIdentifier CDCStartTime, CDCStartPosition, CDCStopPosition, Tags, TaskData, ResourceIdentifier
@ -58,7 +59,9 @@ class DatabaseMigrationServiceBackend(BaseBackend):
return replication_task return replication_task
def start_replication_task(self, replication_task_arn): def start_replication_task(
self, replication_task_arn: str
) -> "FakeReplicationTask":
""" """
The following parameters have not yet been implemented: The following parameters have not yet been implemented:
StartReplicationTaskType, CDCStartTime, CDCStartPosition, CDCStopPosition StartReplicationTaskType, CDCStartTime, CDCStartPosition, CDCStopPosition
@ -68,13 +71,15 @@ class DatabaseMigrationServiceBackend(BaseBackend):
return self.replication_tasks[replication_task_arn].start() return self.replication_tasks[replication_task_arn].start()
def stop_replication_task(self, replication_task_arn): def stop_replication_task(self, replication_task_arn: str) -> "FakeReplicationTask":
if not self.replication_tasks.get(replication_task_arn): if not self.replication_tasks.get(replication_task_arn):
raise ResourceNotFoundFault("Replication task could not be found.") raise ResourceNotFoundFault("Replication task could not be found.")
return self.replication_tasks[replication_task_arn].stop() return self.replication_tasks[replication_task_arn].stop()
def delete_replication_task(self, replication_task_arn): def delete_replication_task(
self, replication_task_arn: str
) -> "FakeReplicationTask":
if not self.replication_tasks.get(replication_task_arn): if not self.replication_tasks.get(replication_task_arn):
raise ResourceNotFoundFault("Replication task could not be found.") raise ResourceNotFoundFault("Replication task could not be found.")
@ -84,7 +89,9 @@ class DatabaseMigrationServiceBackend(BaseBackend):
return task return task
def describe_replication_tasks(self, filters, max_records): def describe_replication_tasks(
self, filters: List[Dict[str, Any]], max_records: int
) -> Iterable["FakeReplicationTask"]:
""" """
The parameter WithoutSettings has not yet been implemented The parameter WithoutSettings has not yet been implemented
""" """
@ -99,15 +106,15 @@ class DatabaseMigrationServiceBackend(BaseBackend):
class FakeReplicationTask(BaseModel): class FakeReplicationTask(BaseModel):
def __init__( def __init__(
self, self,
replication_task_identifier, replication_task_identifier: str,
migration_type, migration_type: str,
replication_instance_arn, replication_instance_arn: str,
source_endpoint_arn, source_endpoint_arn: str,
target_endpoint_arn, target_endpoint_arn: str,
table_mappings, table_mappings: str,
replication_task_settings, replication_task_settings: str,
account_id, account_id: str,
region_name, region_name: str,
): ):
self.id = replication_task_identifier self.id = replication_task_identifier
self.region = region_name self.region = region_name
@ -122,10 +129,10 @@ class FakeReplicationTask(BaseModel):
self.status = "creating" self.status = "creating"
self.creation_date = datetime.utcnow() self.creation_date = datetime.utcnow()
self.start_date = None self.start_date: Optional[datetime] = None
self.stop_date = None self.stop_date: Optional[datetime] = None
def to_dict(self): def to_dict(self) -> Dict[str, Any]:
start_date = self.start_date.isoformat() if self.start_date else None start_date = self.start_date.isoformat() if self.start_date else None
stop_date = self.stop_date.isoformat() if self.stop_date else None stop_date = self.stop_date.isoformat() if self.stop_date else None
@ -156,17 +163,17 @@ class FakeReplicationTask(BaseModel):
}, },
} }
def ready(self): def ready(self) -> "FakeReplicationTask":
self.status = "ready" self.status = "ready"
return self return self
def start(self): def start(self) -> "FakeReplicationTask":
self.status = "starting" self.status = "starting"
self.start_date = datetime.utcnow() self.start_date = datetime.utcnow()
self.run() self.run()
return self return self
def stop(self): def stop(self) -> "FakeReplicationTask":
if self.status != "running": if self.status != "running":
raise InvalidResourceStateFault("Replication task is not running") raise InvalidResourceStateFault("Replication task is not running")
@ -174,11 +181,11 @@ class FakeReplicationTask(BaseModel):
self.stop_date = datetime.utcnow() self.stop_date = datetime.utcnow()
return self return self
def delete(self): def delete(self) -> "FakeReplicationTask":
self.status = "deleting" self.status = "deleting"
return self return self
def run(self): def run(self) -> "FakeReplicationTask":
self.status = "running" self.status = "running"
return self return self

View File

@ -1,17 +1,17 @@
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import dms_backends from .models import dms_backends, DatabaseMigrationServiceBackend
import json import json
class DatabaseMigrationServiceResponse(BaseResponse): class DatabaseMigrationServiceResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="dms") super().__init__(service_name="dms")
@property @property
def dms_backend(self): def dms_backend(self) -> DatabaseMigrationServiceBackend:
return dms_backends[self.current_account][self.region] return dms_backends[self.current_account][self.region]
def create_replication_task(self): def create_replication_task(self) -> str:
replication_task_identifier = self._get_param("ReplicationTaskIdentifier") replication_task_identifier = self._get_param("ReplicationTaskIdentifier")
source_endpoint_arn = self._get_param("SourceEndpointArn") source_endpoint_arn = self._get_param("SourceEndpointArn")
target_endpoint_arn = self._get_param("TargetEndpointArn") target_endpoint_arn = self._get_param("TargetEndpointArn")
@ -31,7 +31,7 @@ class DatabaseMigrationServiceResponse(BaseResponse):
return json.dumps({"ReplicationTask": replication_task.to_dict()}) return json.dumps({"ReplicationTask": replication_task.to_dict()})
def start_replication_task(self): def start_replication_task(self) -> str:
replication_task_arn = self._get_param("ReplicationTaskArn") replication_task_arn = self._get_param("ReplicationTaskArn")
replication_task = self.dms_backend.start_replication_task( replication_task = self.dms_backend.start_replication_task(
replication_task_arn=replication_task_arn replication_task_arn=replication_task_arn
@ -39,7 +39,7 @@ class DatabaseMigrationServiceResponse(BaseResponse):
return json.dumps({"ReplicationTask": replication_task.to_dict()}) return json.dumps({"ReplicationTask": replication_task.to_dict()})
def stop_replication_task(self): def stop_replication_task(self) -> str:
replication_task_arn = self._get_param("ReplicationTaskArn") replication_task_arn = self._get_param("ReplicationTaskArn")
replication_task = self.dms_backend.stop_replication_task( replication_task = self.dms_backend.stop_replication_task(
replication_task_arn=replication_task_arn replication_task_arn=replication_task_arn
@ -47,7 +47,7 @@ class DatabaseMigrationServiceResponse(BaseResponse):
return json.dumps({"ReplicationTask": replication_task.to_dict()}) return json.dumps({"ReplicationTask": replication_task.to_dict()})
def delete_replication_task(self): def delete_replication_task(self) -> str:
replication_task_arn = self._get_param("ReplicationTaskArn") replication_task_arn = self._get_param("ReplicationTaskArn")
replication_task = self.dms_backend.delete_replication_task( replication_task = self.dms_backend.delete_replication_task(
replication_task_arn=replication_task_arn replication_task_arn=replication_task_arn
@ -55,7 +55,7 @@ class DatabaseMigrationServiceResponse(BaseResponse):
return json.dumps({"ReplicationTask": replication_task.to_dict()}) return json.dumps({"ReplicationTask": replication_task.to_dict()})
def describe_replication_tasks(self): def describe_replication_tasks(self) -> str:
filters = self._get_list_prefix("Filters.member") filters = self._get_list_prefix("Filters.member")
max_records = self._get_int_param("MaxRecords") max_records = self._get_int_param("MaxRecords")
replication_tasks = self.dms_backend.describe_replication_tasks( replication_tasks = self.dms_backend.describe_replication_tasks(

View File

@ -1,23 +1,28 @@
def match_task_arn(task, arns): from typing import Any, Dict, List, Iterable
def match_task_arn(task: Dict[str, Any], arns: List[str]) -> bool:
return task["ReplicationTaskArn"] in arns return task["ReplicationTaskArn"] in arns
def match_task_id(task, ids): def match_task_id(task: Dict[str, Any], ids: List[str]) -> bool:
return task["ReplicationTaskIdentifier"] in ids return task["ReplicationTaskIdentifier"] in ids
def match_task_migration_type(task, migration_types): def match_task_migration_type(task: Dict[str, Any], migration_types: List[str]) -> bool:
return task["MigrationType"] in migration_types return task["MigrationType"] in migration_types
def match_task_endpoint_arn(task, endpoint_arns): def match_task_endpoint_arn(task: Dict[str, Any], endpoint_arns: List[str]) -> bool:
return ( return (
task["SourceEndpointArn"] in endpoint_arns task["SourceEndpointArn"] in endpoint_arns
or task["TargetEndpointArn"] in endpoint_arns or task["TargetEndpointArn"] in endpoint_arns
) )
def match_task_replication_instance_arn(task, replication_instance_arns): def match_task_replication_instance_arn(
task: Dict[str, Any], replication_instance_arns: List[str]
) -> bool:
return task["ReplicationInstanceArn"] in replication_instance_arns return task["ReplicationInstanceArn"] in replication_instance_arns
@ -30,17 +35,18 @@ task_filter_functions = {
} }
def filter_tasks(tasks, filters): def filter_tasks(tasks: Iterable[Any], filters: List[Dict[str, Any]]) -> Any:
matching_tasks = tasks matching_tasks = tasks
for f in filters: for f in filters:
filter_function = task_filter_functions[f["Name"]] filter_function = task_filter_functions.get(f["Name"])
if not filter_function: if not filter_function:
continue continue
# https://github.com/python/mypy/issues/12682
matching_tasks = filter( matching_tasks = filter(
lambda task: filter_function(task, f["Values"]), matching_tasks lambda task: filter_function(task, f["Values"]), matching_tasks # type: ignore[arg-type]
) )
return matching_tasks return matching_tasks

View File

@ -18,7 +18,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy] [mypy]
files= moto/a*,moto/b*,moto/c*,moto/databrew,moto/datapipeline,moto/datasync,moto/dax,moto/moto_api files= moto/a*,moto/b*,moto/c*,moto/databrew,moto/datapipeline,moto/datasync,moto/dax,moto/dms,moto/moto_api
show_column_numbers=True show_column_numbers=True
show_error_codes = True show_error_codes = True
disable_error_code=abstract disable_error_code=abstract