Techdebt: Mypy DataSync (#5756)

This commit is contained in:
Bert Blommers 2022-12-11 12:47:26 -01:00 committed by GitHub
parent 56335e2d93
commit 3c7bdcc5ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 60 additions and 47 deletions

View File

@ -1,4 +1,5 @@
from moto.core.exceptions import JsonRESTError from moto.core.exceptions import JsonRESTError
from typing import Optional
class DataSyncClientError(JsonRESTError): class DataSyncClientError(JsonRESTError):
@ -6,6 +7,6 @@ class DataSyncClientError(JsonRESTError):
class InvalidRequestException(DataSyncClientError): class InvalidRequestException(DataSyncClientError):
def __init__(self, msg=None): def __init__(self, msg: Optional[str] = None):
self.code = 400 self.code = 400
super().__init__("InvalidRequestException", msg or "The request is not valid.") super().__init__("InvalidRequestException", msg or "The request is not valid.")

View File

@ -1,4 +1,5 @@
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict, List, Optional
from moto.core import BaseBackend, BackendDict, BaseModel from moto.core import BaseBackend, BackendDict, BaseModel
from .exceptions import InvalidRequestException from .exceptions import InvalidRequestException
@ -6,7 +7,12 @@ from .exceptions import InvalidRequestException
class Location(BaseModel): class Location(BaseModel):
def __init__( def __init__(
self, location_uri, region_name=None, typ=None, metadata=None, arn_counter=0 self,
location_uri: str,
region_name: str,
typ: str,
metadata: Dict[str, Any],
arn_counter: int = 0,
): ):
self.uri = location_uri self.uri = location_uri
self.region_name = region_name self.region_name = region_name
@ -19,12 +25,12 @@ class Location(BaseModel):
class Task(BaseModel): class Task(BaseModel):
def __init__( def __init__(
self, self,
source_location_arn, source_location_arn: str,
destination_location_arn, destination_location_arn: str,
name, name: str,
region_name, region_name: str,
arn_counter=0, metadata: Dict[str, Any],
metadata=None, arn_counter: int = 0,
): ):
self.source_location_arn = source_location_arn self.source_location_arn = source_location_arn
self.destination_location_arn = destination_location_arn self.destination_location_arn = destination_location_arn
@ -32,7 +38,7 @@ class Task(BaseModel):
self.metadata = metadata self.metadata = metadata
# For simplicity Tasks are either available or running # For simplicity Tasks are either available or running
self.status = "AVAILABLE" self.status = "AVAILABLE"
self.current_task_execution_arn = None self.current_task_execution_arn: Optional[str] = None
# Generate ARN # Generate ARN
self.arn = f"arn:aws:datasync:{region_name}:111222333444:task/task-{str(arn_counter).zfill(17)}" self.arn = f"arn:aws:datasync:{region_name}:111222333444:task/task-{str(arn_counter).zfill(17)}"
@ -57,13 +63,13 @@ class TaskExecution(BaseModel):
TASK_EXECUTION_SUCCESS_STATES = ("SUCCESS",) TASK_EXECUTION_SUCCESS_STATES = ("SUCCESS",)
# Also COMPLETED state? # Also COMPLETED state?
def __init__(self, task_arn, arn_counter=0): def __init__(self, task_arn: str, arn_counter: int = 0):
self.task_arn = task_arn self.task_arn = task_arn
self.arn = f"{task_arn}/execution/exec-{str(arn_counter).zfill(17)}" self.arn = f"{task_arn}/execution/exec-{str(arn_counter).zfill(17)}"
self.status = self.TASK_EXECUTION_INTERMEDIATE_STATES[0] self.status = self.TASK_EXECUTION_INTERMEDIATE_STATES[0]
# Simulate a task execution # Simulate a task execution
def iterate_status(self): def iterate_status(self) -> None:
if self.status in self.TASK_EXECUTION_FAILURE_STATES: if self.status in self.TASK_EXECUTION_FAILURE_STATES:
return return
if self.status in self.TASK_EXECUTION_SUCCESS_STATES: if self.status in self.TASK_EXECUTION_SUCCESS_STATES:
@ -78,7 +84,7 @@ class TaskExecution(BaseModel):
return return
raise Exception(f"TaskExecution.iterate_status: Unknown status={self.status}") raise Exception(f"TaskExecution.iterate_status: Unknown status={self.status}")
def cancel(self): def cancel(self) -> None:
if self.status not in self.TASK_EXECUTION_INTERMEDIATE_STATES: if self.status not in self.TASK_EXECUTION_INTERMEDIATE_STATES:
raise InvalidRequestException( raise InvalidRequestException(
f"Sync task cannot be cancelled in its current status: {self.status}" f"Sync task cannot be cancelled in its current status: {self.status}"
@ -87,23 +93,25 @@ class TaskExecution(BaseModel):
class DataSyncBackend(BaseBackend): class DataSyncBackend(BaseBackend):
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
# Always increase when new things are created # Always increase when new things are created
# This ensures uniqueness # This ensures uniqueness
self.arn_counter = 0 self.arn_counter = 0
self.locations = OrderedDict() self.locations: Dict[str, Location] = OrderedDict()
self.tasks = OrderedDict() self.tasks: Dict[str, Task] = OrderedDict()
self.task_executions = OrderedDict() self.task_executions: Dict[str, TaskExecution] = OrderedDict()
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(service_region: str, zones: List[str]) -> List[Dict[str, Any]]: # type: ignore[misc]
"""Default VPC endpoint service.""" """Default VPC endpoint service."""
return BaseBackend.default_vpc_endpoint_service_factory( return BaseBackend.default_vpc_endpoint_service_factory(
service_region, zones, "datasync" service_region, zones, "datasync"
) )
def create_location(self, location_uri, typ=None, metadata=None): def create_location(
self, location_uri: str, typ: str, metadata: Dict[str, Any]
) -> str:
""" """
# AWS DataSync allows for duplicate LocationUris # AWS DataSync allows for duplicate LocationUris
for arn, location in self.locations.items(): for arn, location in self.locations.items():
@ -123,7 +131,7 @@ class DataSyncBackend(BaseBackend):
self.locations[location.arn] = location self.locations[location.arn] = location
return location.arn return location.arn
def _get_location(self, location_arn, typ): def _get_location(self, location_arn: str, typ: str) -> Location:
if location_arn not in self.locations: if location_arn not in self.locations:
raise InvalidRequestException(f"Location {location_arn} is not found.") raise InvalidRequestException(f"Location {location_arn} is not found.")
location = self.locations[location_arn] location = self.locations[location_arn]
@ -131,15 +139,19 @@ class DataSyncBackend(BaseBackend):
raise InvalidRequestException(f"Invalid Location type: {location.typ}") raise InvalidRequestException(f"Invalid Location type: {location.typ}")
return location return location
def delete_location(self, location_arn): def delete_location(self, location_arn: str) -> None:
if location_arn in self.locations: if location_arn in self.locations:
del self.locations[location_arn] del self.locations[location_arn]
else: else:
raise InvalidRequestException raise InvalidRequestException
def create_task( def create_task(
self, source_location_arn, destination_location_arn, name, metadata=None self,
): source_location_arn: str,
destination_location_arn: str,
name: str,
metadata: Dict[str, Any],
) -> str:
if source_location_arn not in self.locations: if source_location_arn not in self.locations:
raise InvalidRequestException(f"Location {source_location_arn} not found.") raise InvalidRequestException(f"Location {source_location_arn} not found.")
if destination_location_arn not in self.locations: if destination_location_arn not in self.locations:
@ -158,13 +170,13 @@ class DataSyncBackend(BaseBackend):
self.tasks[task.arn] = task self.tasks[task.arn] = task
return task.arn return task.arn
def _get_task(self, task_arn): def _get_task(self, task_arn: str) -> Task:
if task_arn in self.tasks: if task_arn in self.tasks:
return self.tasks[task_arn] return self.tasks[task_arn]
else: else:
raise InvalidRequestException raise InvalidRequestException
def update_task(self, task_arn, name, metadata): def update_task(self, task_arn: str, name: str, metadata: Dict[str, Any]) -> None:
if task_arn in self.tasks: if task_arn in self.tasks:
task = self.tasks[task_arn] task = self.tasks[task_arn]
task.name = name task.name = name
@ -172,13 +184,13 @@ class DataSyncBackend(BaseBackend):
else: else:
raise InvalidRequestException(f"Sync task {task_arn} is not found.") raise InvalidRequestException(f"Sync task {task_arn} is not found.")
def delete_task(self, task_arn): def delete_task(self, task_arn: str) -> None:
if task_arn in self.tasks: if task_arn in self.tasks:
del self.tasks[task_arn] del self.tasks[task_arn]
else: else:
raise InvalidRequestException raise InvalidRequestException
def start_task_execution(self, task_arn): def start_task_execution(self, task_arn: str) -> str:
self.arn_counter = self.arn_counter + 1 self.arn_counter = self.arn_counter + 1
if task_arn in self.tasks: if task_arn in self.tasks:
task = self.tasks[task_arn] task = self.tasks[task_arn]
@ -190,13 +202,13 @@ class DataSyncBackend(BaseBackend):
return task_execution.arn return task_execution.arn
raise InvalidRequestException("Invalid request.") raise InvalidRequestException("Invalid request.")
def _get_task_execution(self, task_execution_arn): def _get_task_execution(self, task_execution_arn: str) -> TaskExecution:
if task_execution_arn in self.task_executions: if task_execution_arn in self.task_executions:
return self.task_executions[task_execution_arn] return self.task_executions[task_execution_arn]
else: else:
raise InvalidRequestException raise InvalidRequestException
def cancel_task_execution(self, task_execution_arn): def cancel_task_execution(self, task_execution_arn: str) -> None:
if task_execution_arn in self.task_executions: if task_execution_arn in self.task_executions:
task_execution = self.task_executions[task_execution_arn] task_execution = self.task_executions[task_execution_arn]
task_execution.cancel() task_execution.cancel()

View File

@ -2,27 +2,27 @@ import json
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import datasync_backends from .models import datasync_backends, DataSyncBackend, Location
class DataSyncResponse(BaseResponse): class DataSyncResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="datasync") super().__init__(service_name="datasync")
@property @property
def datasync_backend(self): def datasync_backend(self) -> DataSyncBackend:
return datasync_backends[self.current_account][self.region] return datasync_backends[self.current_account][self.region]
def list_locations(self): def list_locations(self) -> str:
locations = list() locations = list()
for arn, location in self.datasync_backend.locations.items(): for arn, location in self.datasync_backend.locations.items():
locations.append({"LocationArn": arn, "LocationUri": location.uri}) locations.append({"LocationArn": arn, "LocationUri": location.uri})
return json.dumps({"Locations": locations}) return json.dumps({"Locations": locations})
def _get_location(self, location_arn, typ): def _get_location(self, location_arn: str, typ: str) -> Location:
return self.datasync_backend._get_location(location_arn, typ) return self.datasync_backend._get_location(location_arn, typ)
def create_location_s3(self): def create_location_s3(self) -> str:
# s3://bucket_name/folder/ # s3://bucket_name/folder/
s3_bucket_arn = self._get_param("S3BucketArn") s3_bucket_arn = self._get_param("S3BucketArn")
subdirectory = self._get_param("Subdirectory") subdirectory = self._get_param("Subdirectory")
@ -36,7 +36,7 @@ class DataSyncResponse(BaseResponse):
) )
return json.dumps({"LocationArn": arn}) return json.dumps({"LocationArn": arn})
def describe_location_s3(self): def describe_location_s3(self) -> str:
location_arn = self._get_param("LocationArn") location_arn = self._get_param("LocationArn")
location = self._get_location(location_arn, typ="S3") location = self._get_location(location_arn, typ="S3")
return json.dumps( return json.dumps(
@ -47,7 +47,7 @@ class DataSyncResponse(BaseResponse):
} }
) )
def create_location_smb(self): def create_location_smb(self) -> str:
# smb://smb.share.fqdn/AWS_Test/ # smb://smb.share.fqdn/AWS_Test/
subdirectory = self._get_param("Subdirectory") subdirectory = self._get_param("Subdirectory")
server_hostname = self._get_param("ServerHostname") server_hostname = self._get_param("ServerHostname")
@ -64,7 +64,7 @@ class DataSyncResponse(BaseResponse):
) )
return json.dumps({"LocationArn": arn}) return json.dumps({"LocationArn": arn})
def describe_location_smb(self): def describe_location_smb(self) -> str:
location_arn = self._get_param("LocationArn") location_arn = self._get_param("LocationArn")
location = self._get_location(location_arn, typ="SMB") location = self._get_location(location_arn, typ="SMB")
return json.dumps( return json.dumps(
@ -78,12 +78,12 @@ class DataSyncResponse(BaseResponse):
} }
) )
def delete_location(self): def delete_location(self) -> str:
location_arn = self._get_param("LocationArn") location_arn = self._get_param("LocationArn")
self.datasync_backend.delete_location(location_arn) self.datasync_backend.delete_location(location_arn)
return json.dumps({}) return json.dumps({})
def create_task(self): def create_task(self) -> str:
destination_location_arn = self._get_param("DestinationLocationArn") destination_location_arn = self._get_param("DestinationLocationArn")
source_location_arn = self._get_param("SourceLocationArn") source_location_arn = self._get_param("SourceLocationArn")
name = self._get_param("Name") name = self._get_param("Name")
@ -98,7 +98,7 @@ class DataSyncResponse(BaseResponse):
) )
return json.dumps({"TaskArn": arn}) return json.dumps({"TaskArn": arn})
def update_task(self): def update_task(self) -> str:
task_arn = self._get_param("TaskArn") task_arn = self._get_param("TaskArn")
self.datasync_backend.update_task( self.datasync_backend.update_task(
task_arn, task_arn,
@ -112,18 +112,18 @@ class DataSyncResponse(BaseResponse):
) )
return json.dumps({}) return json.dumps({})
def list_tasks(self): def list_tasks(self) -> str:
tasks = list() tasks = list()
for arn, task in self.datasync_backend.tasks.items(): for arn, task in self.datasync_backend.tasks.items():
tasks.append({"Name": task.name, "Status": task.status, "TaskArn": arn}) tasks.append({"Name": task.name, "Status": task.status, "TaskArn": arn})
return json.dumps({"Tasks": tasks}) return json.dumps({"Tasks": tasks})
def delete_task(self): def delete_task(self) -> str:
task_arn = self._get_param("TaskArn") task_arn = self._get_param("TaskArn")
self.datasync_backend.delete_task(task_arn) self.datasync_backend.delete_task(task_arn)
return json.dumps({}) return json.dumps({})
def describe_task(self): def describe_task(self) -> str:
task_arn = self._get_param("TaskArn") task_arn = self._get_param("TaskArn")
task = self.datasync_backend._get_task(task_arn) task = self.datasync_backend._get_task(task_arn)
return json.dumps( return json.dumps(
@ -140,17 +140,17 @@ class DataSyncResponse(BaseResponse):
} }
) )
def start_task_execution(self): def start_task_execution(self) -> str:
task_arn = self._get_param("TaskArn") task_arn = self._get_param("TaskArn")
arn = self.datasync_backend.start_task_execution(task_arn) arn = self.datasync_backend.start_task_execution(task_arn)
return json.dumps({"TaskExecutionArn": arn}) return json.dumps({"TaskExecutionArn": arn})
def cancel_task_execution(self): def cancel_task_execution(self) -> str:
task_execution_arn = self._get_param("TaskExecutionArn") task_execution_arn = self._get_param("TaskExecutionArn")
self.datasync_backend.cancel_task_execution(task_execution_arn) self.datasync_backend.cancel_task_execution(task_execution_arn)
return json.dumps({}) return json.dumps({})
def describe_task_execution(self): def describe_task_execution(self) -> str:
task_execution_arn = self._get_param("TaskExecutionArn") task_execution_arn = self._get_param("TaskExecutionArn")
task_execution = self.datasync_backend._get_task_execution(task_execution_arn) task_execution = self.datasync_backend._get_task_execution(task_execution_arn)
result = json.dumps( result = json.dumps(

View File

@ -18,7 +18,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy] [mypy]
files= moto/a*,moto/b*,moto/c*,moto/databrew,moto/moto_api files= moto/a*,moto/b*,moto/c*,moto/databrew,moto/datapipeline,moto/datasync,moto/moto_api
show_column_numbers=True show_column_numbers=True
show_error_codes = True show_error_codes = True
disable_error_code=abstract disable_error_code=abstract