Techdebt: Mypy DataSync (#5756)

This commit is contained in:
Bert Blommers 2022-12-11 12:47:26 -01:00 committed by GitHub
parent 56335e2d93
commit 3c7bdcc5ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 60 additions and 47 deletions

View File

@ -1,4 +1,5 @@
from moto.core.exceptions import JsonRESTError
from typing import Optional
class DataSyncClientError(JsonRESTError):
@ -6,6 +7,6 @@ class DataSyncClientError(JsonRESTError):
class InvalidRequestException(DataSyncClientError):
def __init__(self, msg=None):
def __init__(self, msg: Optional[str] = None):
self.code = 400
super().__init__("InvalidRequestException", msg or "The request is not valid.")

View File

@ -1,4 +1,5 @@
from collections import OrderedDict
from typing import Any, Dict, List, Optional
from moto.core import BaseBackend, BackendDict, BaseModel
from .exceptions import InvalidRequestException
@ -6,7 +7,12 @@ from .exceptions import InvalidRequestException
class Location(BaseModel):
def __init__(
self, location_uri, region_name=None, typ=None, metadata=None, arn_counter=0
self,
location_uri: str,
region_name: str,
typ: str,
metadata: Dict[str, Any],
arn_counter: int = 0,
):
self.uri = location_uri
self.region_name = region_name
@ -19,12 +25,12 @@ class Location(BaseModel):
class Task(BaseModel):
def __init__(
self,
source_location_arn,
destination_location_arn,
name,
region_name,
arn_counter=0,
metadata=None,
source_location_arn: str,
destination_location_arn: str,
name: str,
region_name: str,
metadata: Dict[str, Any],
arn_counter: int = 0,
):
self.source_location_arn = source_location_arn
self.destination_location_arn = destination_location_arn
@ -32,7 +38,7 @@ class Task(BaseModel):
self.metadata = metadata
# For simplicity Tasks are either available or running
self.status = "AVAILABLE"
self.current_task_execution_arn = None
self.current_task_execution_arn: Optional[str] = None
# Generate ARN
self.arn = f"arn:aws:datasync:{region_name}:111222333444:task/task-{str(arn_counter).zfill(17)}"
@ -57,13 +63,13 @@ class TaskExecution(BaseModel):
TASK_EXECUTION_SUCCESS_STATES = ("SUCCESS",)
# Also COMPLETED state?
def __init__(self, task_arn, arn_counter=0):
def __init__(self, task_arn: str, arn_counter: int = 0):
self.task_arn = task_arn
self.arn = f"{task_arn}/execution/exec-{str(arn_counter).zfill(17)}"
self.status = self.TASK_EXECUTION_INTERMEDIATE_STATES[0]
# Simulate a task execution
def iterate_status(self):
def iterate_status(self) -> None:
if self.status in self.TASK_EXECUTION_FAILURE_STATES:
return
if self.status in self.TASK_EXECUTION_SUCCESS_STATES:
@ -78,7 +84,7 @@ class TaskExecution(BaseModel):
return
raise Exception(f"TaskExecution.iterate_status: Unknown status={self.status}")
def cancel(self):
def cancel(self) -> None:
if self.status not in self.TASK_EXECUTION_INTERMEDIATE_STATES:
raise InvalidRequestException(
f"Sync task cannot be cancelled in its current status: {self.status}"
@ -87,23 +93,25 @@ class TaskExecution(BaseModel):
class DataSyncBackend(BaseBackend):
def __init__(self, region_name, account_id):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
# Always increase when new things are created
# This ensures uniqueness
self.arn_counter = 0
self.locations = OrderedDict()
self.tasks = OrderedDict()
self.task_executions = OrderedDict()
self.locations: Dict[str, Location] = OrderedDict()
self.tasks: Dict[str, Task] = OrderedDict()
self.task_executions: Dict[str, TaskExecution] = OrderedDict()
@staticmethod
def default_vpc_endpoint_service(service_region, zones):
def default_vpc_endpoint_service(service_region: str, zones: List[str]) -> List[Dict[str, Any]]: # type: ignore[misc]
"""Default VPC endpoint service."""
return BaseBackend.default_vpc_endpoint_service_factory(
service_region, zones, "datasync"
)
def create_location(self, location_uri, typ=None, metadata=None):
def create_location(
self, location_uri: str, typ: str, metadata: Dict[str, Any]
) -> str:
"""
# AWS DataSync allows for duplicate LocationUris
for arn, location in self.locations.items():
@ -123,7 +131,7 @@ class DataSyncBackend(BaseBackend):
self.locations[location.arn] = location
return location.arn
def _get_location(self, location_arn, typ):
def _get_location(self, location_arn: str, typ: str) -> Location:
if location_arn not in self.locations:
raise InvalidRequestException(f"Location {location_arn} is not found.")
location = self.locations[location_arn]
@ -131,15 +139,19 @@ class DataSyncBackend(BaseBackend):
raise InvalidRequestException(f"Invalid Location type: {location.typ}")
return location
def delete_location(self, location_arn):
def delete_location(self, location_arn: str) -> None:
if location_arn in self.locations:
del self.locations[location_arn]
else:
raise InvalidRequestException
def create_task(
self, source_location_arn, destination_location_arn, name, metadata=None
):
self,
source_location_arn: str,
destination_location_arn: str,
name: str,
metadata: Dict[str, Any],
) -> str:
if source_location_arn not in self.locations:
raise InvalidRequestException(f"Location {source_location_arn} not found.")
if destination_location_arn not in self.locations:
@ -158,13 +170,13 @@ class DataSyncBackend(BaseBackend):
self.tasks[task.arn] = task
return task.arn
def _get_task(self, task_arn):
def _get_task(self, task_arn: str) -> Task:
if task_arn in self.tasks:
return self.tasks[task_arn]
else:
raise InvalidRequestException
def update_task(self, task_arn, name, metadata):
def update_task(self, task_arn: str, name: str, metadata: Dict[str, Any]) -> None:
if task_arn in self.tasks:
task = self.tasks[task_arn]
task.name = name
@ -172,13 +184,13 @@ class DataSyncBackend(BaseBackend):
else:
raise InvalidRequestException(f"Sync task {task_arn} is not found.")
def delete_task(self, task_arn):
def delete_task(self, task_arn: str) -> None:
if task_arn in self.tasks:
del self.tasks[task_arn]
else:
raise InvalidRequestException
def start_task_execution(self, task_arn):
def start_task_execution(self, task_arn: str) -> str:
self.arn_counter = self.arn_counter + 1
if task_arn in self.tasks:
task = self.tasks[task_arn]
@ -190,13 +202,13 @@ class DataSyncBackend(BaseBackend):
return task_execution.arn
raise InvalidRequestException("Invalid request.")
def _get_task_execution(self, task_execution_arn):
def _get_task_execution(self, task_execution_arn: str) -> TaskExecution:
if task_execution_arn in self.task_executions:
return self.task_executions[task_execution_arn]
else:
raise InvalidRequestException
def cancel_task_execution(self, task_execution_arn):
def cancel_task_execution(self, task_execution_arn: str) -> None:
if task_execution_arn in self.task_executions:
task_execution = self.task_executions[task_execution_arn]
task_execution.cancel()

View File

@ -2,27 +2,27 @@ import json
from moto.core.responses import BaseResponse
from .models import datasync_backends
from .models import datasync_backends, DataSyncBackend, Location
class DataSyncResponse(BaseResponse):
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="datasync")
@property
def datasync_backend(self):
def datasync_backend(self) -> DataSyncBackend:
return datasync_backends[self.current_account][self.region]
def list_locations(self):
def list_locations(self) -> str:
locations = list()
for arn, location in self.datasync_backend.locations.items():
locations.append({"LocationArn": arn, "LocationUri": location.uri})
return json.dumps({"Locations": locations})
def _get_location(self, location_arn, typ):
def _get_location(self, location_arn: str, typ: str) -> Location:
return self.datasync_backend._get_location(location_arn, typ)
def create_location_s3(self):
def create_location_s3(self) -> str:
# s3://bucket_name/folder/
s3_bucket_arn = self._get_param("S3BucketArn")
subdirectory = self._get_param("Subdirectory")
@ -36,7 +36,7 @@ class DataSyncResponse(BaseResponse):
)
return json.dumps({"LocationArn": arn})
def describe_location_s3(self):
def describe_location_s3(self) -> str:
location_arn = self._get_param("LocationArn")
location = self._get_location(location_arn, typ="S3")
return json.dumps(
@ -47,7 +47,7 @@ class DataSyncResponse(BaseResponse):
}
)
def create_location_smb(self):
def create_location_smb(self) -> str:
# smb://smb.share.fqdn/AWS_Test/
subdirectory = self._get_param("Subdirectory")
server_hostname = self._get_param("ServerHostname")
@ -64,7 +64,7 @@ class DataSyncResponse(BaseResponse):
)
return json.dumps({"LocationArn": arn})
def describe_location_smb(self):
def describe_location_smb(self) -> str:
location_arn = self._get_param("LocationArn")
location = self._get_location(location_arn, typ="SMB")
return json.dumps(
@ -78,12 +78,12 @@ class DataSyncResponse(BaseResponse):
}
)
def delete_location(self):
def delete_location(self) -> str:
location_arn = self._get_param("LocationArn")
self.datasync_backend.delete_location(location_arn)
return json.dumps({})
def create_task(self):
def create_task(self) -> str:
destination_location_arn = self._get_param("DestinationLocationArn")
source_location_arn = self._get_param("SourceLocationArn")
name = self._get_param("Name")
@ -98,7 +98,7 @@ class DataSyncResponse(BaseResponse):
)
return json.dumps({"TaskArn": arn})
def update_task(self):
def update_task(self) -> str:
task_arn = self._get_param("TaskArn")
self.datasync_backend.update_task(
task_arn,
@ -112,18 +112,18 @@ class DataSyncResponse(BaseResponse):
)
return json.dumps({})
def list_tasks(self):
def list_tasks(self) -> str:
tasks = list()
for arn, task in self.datasync_backend.tasks.items():
tasks.append({"Name": task.name, "Status": task.status, "TaskArn": arn})
return json.dumps({"Tasks": tasks})
def delete_task(self):
def delete_task(self) -> str:
task_arn = self._get_param("TaskArn")
self.datasync_backend.delete_task(task_arn)
return json.dumps({})
def describe_task(self):
def describe_task(self) -> str:
task_arn = self._get_param("TaskArn")
task = self.datasync_backend._get_task(task_arn)
return json.dumps(
@ -140,17 +140,17 @@ class DataSyncResponse(BaseResponse):
}
)
def start_task_execution(self):
def start_task_execution(self) -> str:
task_arn = self._get_param("TaskArn")
arn = self.datasync_backend.start_task_execution(task_arn)
return json.dumps({"TaskExecutionArn": arn})
def cancel_task_execution(self):
def cancel_task_execution(self) -> str:
task_execution_arn = self._get_param("TaskExecutionArn")
self.datasync_backend.cancel_task_execution(task_execution_arn)
return json.dumps({})
def describe_task_execution(self):
def describe_task_execution(self) -> str:
task_execution_arn = self._get_param("TaskExecutionArn")
task_execution = self.datasync_backend._get_task_execution(task_execution_arn)
result = json.dumps(

View File

@ -18,7 +18,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy]
files= moto/a*,moto/b*,moto/c*,moto/databrew,moto/moto_api
files= moto/a*,moto/b*,moto/c*,moto/databrew,moto/datapipeline,moto/datasync,moto/moto_api
show_column_numbers=True
show_error_codes = True
disable_error_code=abstract