Techdebt: MyPy DMS (#5785)

This commit is contained in:
Bert Blommers 2022-12-18 20:16:46 -01:00 committed by GitHub
parent fab91edd8d
commit 92396bce4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 65 additions and 52 deletions

View File

@ -6,15 +6,15 @@ class DmsClientError(JsonRESTError):
class ResourceNotFoundFault(DmsClientError):
def __init__(self, message):
def __init__(self, message: str):
super().__init__("ResourceNotFoundFault", message)
class InvalidResourceStateFault(DmsClientError):
def __init__(self, message):
def __init__(self, message: str):
super().__init__("InvalidResourceStateFault", message)
class ResourceAlreadyExistsFault(DmsClientError):
def __init__(self, message):
def __init__(self, message: str):
super().__init__("ResourceAlreadyExistsFault", message)

View File

@ -1,6 +1,7 @@
import json
from datetime import datetime
from typing import Any, Dict, List, Iterable, Optional
from moto.core import BaseBackend, BackendDict, BaseModel
from .exceptions import (
@ -12,12 +13,12 @@ from .utils import filter_tasks
class DatabaseMigrationServiceBackend(BaseBackend):
def __init__(self, region_name, account_id):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.replication_tasks = {}
self.replication_tasks: Dict[str, "FakeReplicationTask"] = {}
@staticmethod
def default_vpc_endpoint_service(service_region, zones):
def default_vpc_endpoint_service(service_region: str, zones: List[str]) -> List[Dict[str, Any]]: # type: ignore[misc]
"""Default VPC endpoint service."""
return BaseBackend.default_vpc_endpoint_service_factory(
service_region, zones, "dms"
@ -25,14 +26,14 @@ class DatabaseMigrationServiceBackend(BaseBackend):
def create_replication_task(
self,
replication_task_identifier,
source_endpoint_arn,
target_endpoint_arn,
replication_instance_arn,
migration_type,
table_mappings,
replication_task_settings,
):
replication_task_identifier: str,
source_endpoint_arn: str,
target_endpoint_arn: str,
replication_instance_arn: str,
migration_type: str,
table_mappings: str,
replication_task_settings: str,
) -> "FakeReplicationTask":
"""
The following parameters are not yet implemented:
CDCStartTime, CDCStartPosition, CDCStopPosition, Tags, TaskData, ResourceIdentifier
@ -58,7 +59,9 @@ class DatabaseMigrationServiceBackend(BaseBackend):
return replication_task
def start_replication_task(self, replication_task_arn):
def start_replication_task(
self, replication_task_arn: str
) -> "FakeReplicationTask":
"""
The following parameters have not yet been implemented:
StartReplicationTaskType, CDCStartTime, CDCStartPosition, CDCStopPosition
@ -68,13 +71,15 @@ class DatabaseMigrationServiceBackend(BaseBackend):
return self.replication_tasks[replication_task_arn].start()
def stop_replication_task(self, replication_task_arn):
def stop_replication_task(self, replication_task_arn: str) -> "FakeReplicationTask":
if not self.replication_tasks.get(replication_task_arn):
raise ResourceNotFoundFault("Replication task could not be found.")
return self.replication_tasks[replication_task_arn].stop()
def delete_replication_task(self, replication_task_arn):
def delete_replication_task(
self, replication_task_arn: str
) -> "FakeReplicationTask":
if not self.replication_tasks.get(replication_task_arn):
raise ResourceNotFoundFault("Replication task could not be found.")
@ -84,7 +89,9 @@ class DatabaseMigrationServiceBackend(BaseBackend):
return task
def describe_replication_tasks(self, filters, max_records):
def describe_replication_tasks(
self, filters: List[Dict[str, Any]], max_records: int
) -> Iterable["FakeReplicationTask"]:
"""
The parameter WithoutSettings has not yet been implemented
"""
@ -99,15 +106,15 @@ class DatabaseMigrationServiceBackend(BaseBackend):
class FakeReplicationTask(BaseModel):
def __init__(
self,
replication_task_identifier,
migration_type,
replication_instance_arn,
source_endpoint_arn,
target_endpoint_arn,
table_mappings,
replication_task_settings,
account_id,
region_name,
replication_task_identifier: str,
migration_type: str,
replication_instance_arn: str,
source_endpoint_arn: str,
target_endpoint_arn: str,
table_mappings: str,
replication_task_settings: str,
account_id: str,
region_name: str,
):
self.id = replication_task_identifier
self.region = region_name
@ -122,10 +129,10 @@ class FakeReplicationTask(BaseModel):
self.status = "creating"
self.creation_date = datetime.utcnow()
self.start_date = None
self.stop_date = None
self.start_date: Optional[datetime] = None
self.stop_date: Optional[datetime] = None
def to_dict(self):
def to_dict(self) -> Dict[str, Any]:
start_date = self.start_date.isoformat() if self.start_date else None
stop_date = self.stop_date.isoformat() if self.stop_date else None
@ -156,17 +163,17 @@ class FakeReplicationTask(BaseModel):
},
}
def ready(self):
def ready(self) -> "FakeReplicationTask":
self.status = "ready"
return self
def start(self):
def start(self) -> "FakeReplicationTask":
self.status = "starting"
self.start_date = datetime.utcnow()
self.run()
return self
def stop(self):
def stop(self) -> "FakeReplicationTask":
if self.status != "running":
raise InvalidResourceStateFault("Replication task is not running")
@ -174,11 +181,11 @@ class FakeReplicationTask(BaseModel):
self.stop_date = datetime.utcnow()
return self
def delete(self):
def delete(self) -> "FakeReplicationTask":
self.status = "deleting"
return self
def run(self):
def run(self) -> "FakeReplicationTask":
self.status = "running"
return self

View File

@ -1,17 +1,17 @@
from moto.core.responses import BaseResponse
from .models import dms_backends
from .models import dms_backends, DatabaseMigrationServiceBackend
import json
class DatabaseMigrationServiceResponse(BaseResponse):
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="dms")
@property
def dms_backend(self):
def dms_backend(self) -> DatabaseMigrationServiceBackend:
return dms_backends[self.current_account][self.region]
def create_replication_task(self):
def create_replication_task(self) -> str:
replication_task_identifier = self._get_param("ReplicationTaskIdentifier")
source_endpoint_arn = self._get_param("SourceEndpointArn")
target_endpoint_arn = self._get_param("TargetEndpointArn")
@ -31,7 +31,7 @@ class DatabaseMigrationServiceResponse(BaseResponse):
return json.dumps({"ReplicationTask": replication_task.to_dict()})
def start_replication_task(self):
def start_replication_task(self) -> str:
replication_task_arn = self._get_param("ReplicationTaskArn")
replication_task = self.dms_backend.start_replication_task(
replication_task_arn=replication_task_arn
@ -39,7 +39,7 @@ class DatabaseMigrationServiceResponse(BaseResponse):
return json.dumps({"ReplicationTask": replication_task.to_dict()})
def stop_replication_task(self):
def stop_replication_task(self) -> str:
replication_task_arn = self._get_param("ReplicationTaskArn")
replication_task = self.dms_backend.stop_replication_task(
replication_task_arn=replication_task_arn
@ -47,7 +47,7 @@ class DatabaseMigrationServiceResponse(BaseResponse):
return json.dumps({"ReplicationTask": replication_task.to_dict()})
def delete_replication_task(self):
def delete_replication_task(self) -> str:
replication_task_arn = self._get_param("ReplicationTaskArn")
replication_task = self.dms_backend.delete_replication_task(
replication_task_arn=replication_task_arn
@ -55,7 +55,7 @@ class DatabaseMigrationServiceResponse(BaseResponse):
return json.dumps({"ReplicationTask": replication_task.to_dict()})
def describe_replication_tasks(self):
def describe_replication_tasks(self) -> str:
filters = self._get_list_prefix("Filters.member")
max_records = self._get_int_param("MaxRecords")
replication_tasks = self.dms_backend.describe_replication_tasks(

View File

@ -1,23 +1,28 @@
def match_task_arn(task, arns):
from typing import Any, Dict, List, Iterable
def match_task_arn(task: Dict[str, Any], arns: List[str]) -> bool:
return task["ReplicationTaskArn"] in arns
def match_task_id(task, ids):
def match_task_id(task: Dict[str, Any], ids: List[str]) -> bool:
return task["ReplicationTaskIdentifier"] in ids
def match_task_migration_type(task, migration_types):
def match_task_migration_type(task: Dict[str, Any], migration_types: List[str]) -> bool:
return task["MigrationType"] in migration_types
def match_task_endpoint_arn(task, endpoint_arns):
def match_task_endpoint_arn(task: Dict[str, Any], endpoint_arns: List[str]) -> bool:
return (
task["SourceEndpointArn"] in endpoint_arns
or task["TargetEndpointArn"] in endpoint_arns
)
def match_task_replication_instance_arn(task, replication_instance_arns):
def match_task_replication_instance_arn(
task: Dict[str, Any], replication_instance_arns: List[str]
) -> bool:
return task["ReplicationInstanceArn"] in replication_instance_arns
@ -30,17 +35,18 @@ task_filter_functions = {
}
def filter_tasks(tasks, filters):
def filter_tasks(tasks: Iterable[Any], filters: List[Dict[str, Any]]) -> Any:
matching_tasks = tasks
for f in filters:
filter_function = task_filter_functions[f["Name"]]
filter_function = task_filter_functions.get(f["Name"])
if not filter_function:
continue
# https://github.com/python/mypy/issues/12682
matching_tasks = filter(
lambda task: filter_function(task, f["Values"]), matching_tasks
lambda task: filter_function(task, f["Values"]), matching_tasks # type: ignore[arg-type]
)
return matching_tasks

View File

@ -18,7 +18,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy]
files= moto/a*,moto/b*,moto/c*,moto/databrew,moto/datapipeline,moto/datasync,moto/dax,moto/moto_api
files= moto/a*,moto/b*,moto/c*,moto/databrew,moto/datapipeline,moto/datasync,moto/dax,moto/dms,moto/moto_api
show_column_numbers=True
show_error_codes = True
disable_error_code=abstract