Techdebt: MyPy F-services (#6021)

This commit is contained in:
Bert Blommers 2023-03-06 21:41:24 -01:00 committed by GitHub
parent 26904fdb36
commit 33ce02056d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 158 additions and 127 deletions

View File

@ -7,7 +7,7 @@ class ConcurrentModificationException(JsonRESTError):
code = 400 code = 400
def __init__(self, message): def __init__(self, message: str):
super().__init__("ConcurrentModificationException", message) super().__init__("ConcurrentModificationException", message)
@ -16,7 +16,7 @@ class InvalidArgumentException(JsonRESTError):
code = 400 code = 400
def __init__(self, message): def __init__(self, message: str):
super().__init__("InvalidArgumentException", message) super().__init__("InvalidArgumentException", message)
@ -25,7 +25,7 @@ class LimitExceededException(JsonRESTError):
code = 400 code = 400
def __init__(self, message): def __init__(self, message: str):
super().__init__("LimitExceededException", message) super().__init__("LimitExceededException", message)
@ -34,7 +34,7 @@ class ResourceInUseException(JsonRESTError):
code = 400 code = 400
def __init__(self, message): def __init__(self, message: str):
super().__init__("ResourceInUseException", message) super().__init__("ResourceInUseException", message)
@ -43,7 +43,7 @@ class ResourceNotFoundException(JsonRESTError):
code = 400 code = 400
def __init__(self, message): def __init__(self, message: str):
super().__init__("ResourceNotFoundException", message) super().__init__("ResourceNotFoundException", message)
@ -52,5 +52,5 @@ class ValidationException(JsonRESTError):
code = 400 code = 400
def __init__(self, message): def __init__(self, message: str):
super().__init__("ValidationException", message) super().__init__("ValidationException", message)

View File

@ -18,6 +18,7 @@ Incomplete list of unfinished items:
from base64 import b64decode, b64encode from base64 import b64decode, b64encode
from datetime import datetime, timezone from datetime import datetime, timezone
from gzip import GzipFile from gzip import GzipFile
from typing import Any, Dict, List, Optional, Tuple
import io import io
import json import json
from time import time from time import time
@ -50,8 +51,8 @@ DESTINATION_TYPES_TO_NAMES = {
} }
def find_destination_config_in_args(api_args): def find_destination_config_in_args(api_args: Dict[str, Any]) -> Tuple[str, Any]:
"""Return (config_arg, config_name) tuple for destination config. """Return (config_name, config) tuple for destination config.
Determines which destination config(s) have been specified. The Determines which destination config(s) have been specified. The
alternative is to use a bunch of 'if' statements to check each alternative is to use a bunch of 'if' statements to check each
@ -83,7 +84,9 @@ def find_destination_config_in_args(api_args):
return configs[0] return configs[0]
def create_s3_destination_config(extended_s3_destination_config): def create_s3_destination_config(
extended_s3_destination_config: Dict[str, Any]
) -> Dict[str, Any]:
"""Return dict with selected fields copied from ExtendedS3 config. """Return dict with selected fields copied from ExtendedS3 config.
When an ExtendedS3 config is chosen, AWS tacks on a S3 config as When an ExtendedS3 config is chosen, AWS tacks on a S3 config as
@ -115,13 +118,13 @@ class DeliveryStream(
def __init__( def __init__(
self, self,
account_id, account_id: str,
region, region: str,
delivery_stream_name, delivery_stream_name: str,
delivery_stream_type, delivery_stream_type: str,
kinesis_stream_source_configuration, kinesis_stream_source_configuration: Dict[str, Any],
destination_name, destination_name: str,
destination_config, destination_config: Dict[str, Any],
): # pylint: disable=too-many-arguments ): # pylint: disable=too-many-arguments
self.delivery_stream_status = "CREATING" self.delivery_stream_status = "CREATING"
self.delivery_stream_name = delivery_stream_name self.delivery_stream_name = delivery_stream_name
@ -130,7 +133,7 @@ class DeliveryStream(
) )
self.source = kinesis_stream_source_configuration self.source = kinesis_stream_source_configuration
self.destinations = [ self.destinations: List[Dict[str, Any]] = [
{ {
"destination_id": "destinationId-000000000001", "destination_id": "destinationId-000000000001",
destination_name: destination_config, destination_name: destination_config,
@ -162,33 +165,35 @@ class DeliveryStream(
class FirehoseBackend(BaseBackend): class FirehoseBackend(BaseBackend):
"""Implementation of Firehose APIs.""" """Implementation of Firehose APIs."""
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.delivery_streams = {} self.delivery_streams: Dict[str, DeliveryStream] = {}
self.tagger = TaggingService() self.tagger = TaggingService()
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(
service_region: str, zones: List[str]
) -> List[Dict[str, str]]:
"""Default VPC endpoint service.""" """Default VPC endpoint service."""
return BaseBackend.default_vpc_endpoint_service_factory( return BaseBackend.default_vpc_endpoint_service_factory(
service_region, zones, "firehose", special_service_name="kinesis-firehose" service_region, zones, "firehose", special_service_name="kinesis-firehose"
) )
def create_delivery_stream( def create_delivery_stream( # pylint: disable=unused-argument
self, self,
region, region: str,
delivery_stream_name, delivery_stream_name: str,
delivery_stream_type, delivery_stream_type: str,
kinesis_stream_source_configuration, kinesis_stream_source_configuration: Dict[str, Any],
delivery_stream_encryption_configuration_input, delivery_stream_encryption_configuration_input: Dict[str, Any],
s3_destination_configuration, s3_destination_configuration: Dict[str, Any],
extended_s3_destination_configuration, extended_s3_destination_configuration: Dict[str, Any],
redshift_destination_configuration, redshift_destination_configuration: Dict[str, Any],
elasticsearch_destination_configuration, elasticsearch_destination_configuration: Dict[str, Any],
splunk_destination_configuration, splunk_destination_configuration: Dict[str, Any],
http_endpoint_destination_configuration, http_endpoint_destination_configuration: Dict[str, Any],
tags, tags: List[Dict[str, str]],
): # pylint: disable=too-many-arguments,too-many-locals,unused-argument ) -> str:
"""Create a Kinesis Data Firehose delivery stream.""" """Create a Kinesis Data Firehose delivery stream."""
(destination_name, destination_config) = find_destination_config_in_args( (destination_name, destination_config) = find_destination_config_in_args(
locals() locals()
@ -255,9 +260,7 @@ class FirehoseBackend(BaseBackend):
self.delivery_streams[delivery_stream_name] = delivery_stream self.delivery_streams[delivery_stream_name] = delivery_stream
return self.delivery_streams[delivery_stream_name].delivery_stream_arn return self.delivery_streams[delivery_stream_name].delivery_stream_arn
def delete_delivery_stream( def delete_delivery_stream(self, delivery_stream_name: str) -> None:
self, delivery_stream_name, allow_force_delete=False
): # pylint: disable=unused-argument
"""Delete a delivery stream and its data. """Delete a delivery stream and its data.
AllowForceDelete option is ignored as we only superficially AllowForceDelete option is ignored as we only superficially
@ -275,9 +278,7 @@ class FirehoseBackend(BaseBackend):
delivery_stream.delivery_stream_status = "DELETING" delivery_stream.delivery_stream_status = "DELETING"
self.delivery_streams.pop(delivery_stream_name) self.delivery_streams.pop(delivery_stream_name)
def describe_delivery_stream( def describe_delivery_stream(self, delivery_stream_name: str) -> Dict[str, Any]:
self, delivery_stream_name, limit, exclusive_start_destination_id
): # pylint: disable=unused-argument
"""Return description of specified delivery stream and its status. """Return description of specified delivery stream and its status.
Note: the 'limit' and 'exclusive_start_destination_id' parameters Note: the 'limit' and 'exclusive_start_destination_id' parameters
@ -290,7 +291,9 @@ class FirehoseBackend(BaseBackend):
f"not found." f"not found."
) )
result = {"DeliveryStreamDescription": {"HasMoreDestinations": False}} result: Dict[str, Any] = {
"DeliveryStreamDescription": {"HasMoreDestinations": False}
}
for attribute, attribute_value in vars(delivery_stream).items(): for attribute, attribute_value in vars(delivery_stream).items():
if not attribute_value: if not attribute_value:
continue continue
@ -326,8 +329,11 @@ class FirehoseBackend(BaseBackend):
return result return result
def list_delivery_streams( def list_delivery_streams(
self, limit, delivery_stream_type, exclusive_start_delivery_stream_name self,
): limit: Optional[int],
delivery_stream_type: str,
exclusive_start_delivery_stream_name: str,
) -> Dict[str, Any]:
"""Return list of delivery streams in alphabetic order of names.""" """Return list of delivery streams in alphabetic order of names."""
result = {"DeliveryStreamNames": [], "HasMoreDeliveryStreams": False} result = {"DeliveryStreamNames": [], "HasMoreDeliveryStreams": False}
if not self.delivery_streams: if not self.delivery_streams:
@ -335,7 +341,7 @@ class FirehoseBackend(BaseBackend):
# If delivery_stream_type is specified, filter out any stream that's # If delivery_stream_type is specified, filter out any stream that's
# not of that type. # not of that type.
stream_list = self.delivery_streams.keys() stream_list = list(self.delivery_streams.keys())
if delivery_stream_type: if delivery_stream_type:
stream_list = [ stream_list = [
x x
@ -363,8 +369,11 @@ class FirehoseBackend(BaseBackend):
return result return result
def list_tags_for_delivery_stream( def list_tags_for_delivery_stream(
self, delivery_stream_name, exclusive_start_tag_key, limit self,
): delivery_stream_name: str,
exclusive_start_tag_key: str,
limit: Optional[int],
) -> Dict[str, Any]:
"""Return list of tags.""" """Return list of tags."""
result = {"Tags": [], "HasMoreTags": False} result = {"Tags": [], "HasMoreTags": False}
delivery_stream = self.delivery_streams.get(delivery_stream_name) delivery_stream = self.delivery_streams.get(delivery_stream_name)
@ -391,7 +400,9 @@ class FirehoseBackend(BaseBackend):
result["HasMoreTags"] = True result["HasMoreTags"] = True
return result return result
def put_record(self, delivery_stream_name, record): def put_record(
self, delivery_stream_name: str, record: Dict[str, bytes]
) -> Dict[str, Any]:
"""Write a single data record into a Kinesis Data firehose stream.""" """Write a single data record into a Kinesis Data firehose stream."""
result = self.put_record_batch(delivery_stream_name, [record]) result = self.put_record_batch(delivery_stream_name, [record])
return { return {
@ -400,7 +411,7 @@ class FirehoseBackend(BaseBackend):
} }
@staticmethod @staticmethod
def put_http_records(http_destination, records): def put_http_records(http_destination: Dict[str, Any], records: List[Dict[str, bytes]]) -> List[Dict[str, str]]: # type: ignore[misc]
"""Put records to a HTTP destination.""" """Put records to a HTTP destination."""
# Mostly copied from localstack # Mostly copied from localstack
url = http_destination["EndpointConfiguration"]["Url"] url = http_destination["EndpointConfiguration"]["Url"]
@ -420,7 +431,9 @@ class FirehoseBackend(BaseBackend):
return [{"RecordId": str(mock_random.uuid4())} for _ in range(len(records))] return [{"RecordId": str(mock_random.uuid4())} for _ in range(len(records))]
@staticmethod @staticmethod
def _format_s3_object_path(delivery_stream_name, version_id, prefix): def _format_s3_object_path(
delivery_stream_name: str, version_id: str, prefix: str
) -> str:
"""Return a S3 object path in the expected format.""" """Return a S3 object path in the expected format."""
# Taken from LocalStack's firehose logic, with minor changes. # Taken from LocalStack's firehose logic, with minor changes.
# See https://docs.aws.amazon.com/firehose/latest/dev/basic-deliver.html#s3-object-name # See https://docs.aws.amazon.com/firehose/latest/dev/basic-deliver.html#s3-object-name
@ -435,7 +448,13 @@ class FirehoseBackend(BaseBackend):
f"{now.strftime('%Y-%m-%d-%H-%M-%S')}-{str(mock_random.uuid4())}" f"{now.strftime('%Y-%m-%d-%H-%M-%S')}-{str(mock_random.uuid4())}"
) )
def put_s3_records(self, delivery_stream_name, version_id, s3_destination, records): def put_s3_records(
self,
delivery_stream_name: str,
version_id: str,
s3_destination: Dict[str, Any],
records: List[Dict[str, bytes]],
) -> List[Dict[str, str]]:
"""Put records to a ExtendedS3 or S3 destination.""" """Put records to a ExtendedS3 or S3 destination."""
# Taken from LocalStack's firehose logic, with minor changes. # Taken from LocalStack's firehose logic, with minor changes.
bucket_name = s3_destination["BucketARN"].split(":")[-1] bucket_name = s3_destination["BucketARN"].split(":")[-1]
@ -456,7 +475,9 @@ class FirehoseBackend(BaseBackend):
) from exc ) from exc
return [{"RecordId": str(mock_random.uuid4())} for _ in range(len(records))] return [{"RecordId": str(mock_random.uuid4())} for _ in range(len(records))]
def put_record_batch(self, delivery_stream_name, records): def put_record_batch(
self, delivery_stream_name: str, records: List[Dict[str, bytes]]
) -> Dict[str, Any]:
"""Write multiple data records into a Kinesis Data firehose stream.""" """Write multiple data records into a Kinesis Data firehose stream."""
delivery_stream = self.delivery_streams.get(delivery_stream_name) delivery_stream = self.delivery_streams.get(delivery_stream_name)
if not delivery_stream: if not delivery_stream:
@ -502,7 +523,9 @@ class FirehoseBackend(BaseBackend):
"RequestResponses": request_responses, "RequestResponses": request_responses,
} }
def tag_delivery_stream(self, delivery_stream_name, tags): def tag_delivery_stream(
self, delivery_stream_name: str, tags: List[Dict[str, str]]
) -> None:
"""Add/update tags for specified delivery stream.""" """Add/update tags for specified delivery stream."""
delivery_stream = self.delivery_streams.get(delivery_stream_name) delivery_stream = self.delivery_streams.get(delivery_stream_name)
if not delivery_stream: if not delivery_stream:
@ -524,7 +547,9 @@ class FirehoseBackend(BaseBackend):
self.tagger.tag_resource(delivery_stream.delivery_stream_arn, tags) self.tagger.tag_resource(delivery_stream.delivery_stream_arn, tags)
def untag_delivery_stream(self, delivery_stream_name, tag_keys): def untag_delivery_stream(
self, delivery_stream_name: str, tag_keys: List[str]
) -> None:
"""Removes tags from specified delivery stream.""" """Removes tags from specified delivery stream."""
delivery_stream = self.delivery_streams.get(delivery_stream_name) delivery_stream = self.delivery_streams.get(delivery_stream_name)
if not delivery_stream: if not delivery_stream:
@ -537,19 +562,19 @@ class FirehoseBackend(BaseBackend):
delivery_stream.delivery_stream_arn, tag_keys delivery_stream.delivery_stream_arn, tag_keys
) )
def update_destination( def update_destination( # pylint: disable=unused-argument
self, self,
delivery_stream_name, delivery_stream_name: str,
current_delivery_stream_version_id, current_delivery_stream_version_id: str,
destination_id, destination_id: str,
s3_destination_update, s3_destination_update: Dict[str, Any],
extended_s3_destination_update, extended_s3_destination_update: Dict[str, Any],
s3_backup_mode, s3_backup_mode: str,
redshift_destination_update, redshift_destination_update: Dict[str, Any],
elasticsearch_destination_update, elasticsearch_destination_update: Dict[str, Any],
splunk_destination_update, splunk_destination_update: Dict[str, Any],
http_endpoint_destination_update, http_endpoint_destination_update: Dict[str, Any],
): # pylint: disable=unused-argument,too-many-arguments,too-many-locals ) -> None:
"""Updates specified destination of specified delivery stream.""" """Updates specified destination of specified delivery stream."""
(destination_name, destination_config) = find_destination_config_in_args( (destination_name, destination_config) = find_destination_config_in_args(
locals() locals()
@ -628,18 +653,18 @@ class FirehoseBackend(BaseBackend):
# S3 backup if it is disabled. If backup is enabled, you can't update # S3 backup if it is disabled. If backup is enabled, you can't update
# the delivery stream to disable it." # the delivery stream to disable it."
def lookup_name_from_arn(self, arn): def lookup_name_from_arn(self, arn: str) -> Optional[DeliveryStream]:
"""Given an ARN, return the associated delivery stream name.""" """Given an ARN, return the associated delivery stream name."""
return self.delivery_streams.get(arn.split("/")[-1]) return self.delivery_streams.get(arn.split("/")[-1])
def send_log_event( def send_log_event(
self, self,
delivery_stream_arn, delivery_stream_arn: str,
filter_name, filter_name: str,
log_group_name, log_group_name: str,
log_stream_name, log_stream_name: str,
log_events, log_events: List[Dict[str, Any]],
): # pylint: disable=too-many-arguments ) -> None:
"""Send log events to a S3 bucket after encoding and gzipping it.""" """Send log events to a S3 bucket after encoding and gzipping it."""
data = { data = {
"logEvents": log_events, "logEvents": log_events,
@ -653,9 +678,9 @@ class FirehoseBackend(BaseBackend):
output = io.BytesIO() output = io.BytesIO()
with GzipFile(fileobj=output, mode="w") as fhandle: with GzipFile(fileobj=output, mode="w") as fhandle:
fhandle.write(json.dumps(data, separators=(",", ":")).encode("utf-8")) fhandle.write(json.dumps(data, separators=(",", ":")).encode("utf-8"))
gzipped_payload = b64encode(output.getvalue()).decode("utf-8") gzipped_payload = b64encode(output.getvalue())
delivery_stream = self.lookup_name_from_arn(delivery_stream_arn) delivery_stream: DeliveryStream = self.lookup_name_from_arn(delivery_stream_arn) # type: ignore[assignment]
self.put_s3_records( self.put_s3_records(
delivery_stream.delivery_stream_name, delivery_stream.delivery_stream_name,
delivery_stream.version_id, delivery_stream.version_id,

View File

@ -2,21 +2,21 @@
import json import json
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import firehose_backends from .models import firehose_backends, FirehoseBackend
class FirehoseResponse(BaseResponse): class FirehoseResponse(BaseResponse):
"""Handler for Firehose requests and responses.""" """Handler for Firehose requests and responses."""
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="firehose") super().__init__(service_name="firehose")
@property @property
def firehose_backend(self): def firehose_backend(self) -> FirehoseBackend:
"""Return backend instance specific to this region.""" """Return backend instance specific to this region."""
return firehose_backends[self.current_account][self.region] return firehose_backends[self.current_account][self.region]
def create_delivery_stream(self): def create_delivery_stream(self) -> str:
"""Prepare arguments and respond to CreateDeliveryStream request.""" """Prepare arguments and respond to CreateDeliveryStream request."""
delivery_stream_arn = self.firehose_backend.create_delivery_stream( delivery_stream_arn = self.firehose_backend.create_delivery_stream(
self.region, self.region,
@ -34,23 +34,21 @@ class FirehoseResponse(BaseResponse):
) )
return json.dumps({"DeliveryStreamARN": delivery_stream_arn}) return json.dumps({"DeliveryStreamARN": delivery_stream_arn})
def delete_delivery_stream(self): def delete_delivery_stream(self) -> str:
"""Prepare arguments and respond to DeleteDeliveryStream request.""" """Prepare arguments and respond to DeleteDeliveryStream request."""
self.firehose_backend.delete_delivery_stream( self.firehose_backend.delete_delivery_stream(
self._get_param("DeliveryStreamName"), self._get_param("AllowForceDelete") self._get_param("DeliveryStreamName")
) )
return json.dumps({}) return json.dumps({})
def describe_delivery_stream(self): def describe_delivery_stream(self) -> str:
"""Prepare arguments and respond to DescribeDeliveryStream request.""" """Prepare arguments and respond to DescribeDeliveryStream request."""
result = self.firehose_backend.describe_delivery_stream( result = self.firehose_backend.describe_delivery_stream(
self._get_param("DeliveryStreamName"), self._get_param("DeliveryStreamName")
self._get_param("Limit"),
self._get_param("ExclusiveStartDestinationId"),
) )
return json.dumps(result) return json.dumps(result)
def list_delivery_streams(self): def list_delivery_streams(self) -> str:
"""Prepare arguments and respond to ListDeliveryStreams request.""" """Prepare arguments and respond to ListDeliveryStreams request."""
stream_list = self.firehose_backend.list_delivery_streams( stream_list = self.firehose_backend.list_delivery_streams(
self._get_param("Limit"), self._get_param("Limit"),
@ -59,7 +57,7 @@ class FirehoseResponse(BaseResponse):
) )
return json.dumps(stream_list) return json.dumps(stream_list)
def list_tags_for_delivery_stream(self): def list_tags_for_delivery_stream(self) -> str:
"""Prepare arguments and respond to ListTagsForDeliveryStream().""" """Prepare arguments and respond to ListTagsForDeliveryStream()."""
result = self.firehose_backend.list_tags_for_delivery_stream( result = self.firehose_backend.list_tags_for_delivery_stream(
self._get_param("DeliveryStreamName"), self._get_param("DeliveryStreamName"),
@ -68,35 +66,35 @@ class FirehoseResponse(BaseResponse):
) )
return json.dumps(result) return json.dumps(result)
def put_record(self): def put_record(self) -> str:
"""Prepare arguments and response to PutRecord().""" """Prepare arguments and response to PutRecord()."""
result = self.firehose_backend.put_record( result = self.firehose_backend.put_record(
self._get_param("DeliveryStreamName"), self._get_param("Record") self._get_param("DeliveryStreamName"), self._get_param("Record")
) )
return json.dumps(result) return json.dumps(result)
def put_record_batch(self): def put_record_batch(self) -> str:
"""Prepare arguments and response to PutRecordBatch().""" """Prepare arguments and response to PutRecordBatch()."""
result = self.firehose_backend.put_record_batch( result = self.firehose_backend.put_record_batch(
self._get_param("DeliveryStreamName"), self._get_param("Records") self._get_param("DeliveryStreamName"), self._get_param("Records")
) )
return json.dumps(result) return json.dumps(result)
def tag_delivery_stream(self): def tag_delivery_stream(self) -> str:
"""Prepare arguments and respond to TagDeliveryStream request.""" """Prepare arguments and respond to TagDeliveryStream request."""
self.firehose_backend.tag_delivery_stream( self.firehose_backend.tag_delivery_stream(
self._get_param("DeliveryStreamName"), self._get_param("Tags") self._get_param("DeliveryStreamName"), self._get_param("Tags")
) )
return json.dumps({}) return json.dumps({})
def untag_delivery_stream(self): def untag_delivery_stream(self) -> str:
"""Prepare arguments and respond to UntagDeliveryStream().""" """Prepare arguments and respond to UntagDeliveryStream()."""
self.firehose_backend.untag_delivery_stream( self.firehose_backend.untag_delivery_stream(
self._get_param("DeliveryStreamName"), self._get_param("TagKeys") self._get_param("DeliveryStreamName"), self._get_param("TagKeys")
) )
return json.dumps({}) return json.dumps({})
def update_destination(self): def update_destination(self) -> str:
"""Prepare arguments and respond to UpdateDestination().""" """Prepare arguments and respond to UpdateDestination()."""
self.firehose_backend.update_destination( self.firehose_backend.update_destination(
self._get_param("DeliveryStreamName"), self._get_param("DeliveryStreamName"),

View File

@ -1,5 +1,6 @@
import re import re
from datetime import datetime from datetime import datetime
from typing import Dict, List, Optional
from moto.core import BaseBackend, BackendDict from moto.core import BaseBackend, BackendDict
from moto.core.utils import iso_8601_datetime_without_milliseconds from moto.core.utils import iso_8601_datetime_without_milliseconds
@ -26,12 +27,12 @@ class DatasetGroup:
def __init__( def __init__(
self, self,
account_id, account_id: str,
region_name, region_name: str,
dataset_arns, dataset_arns: List[str],
dataset_group_name, dataset_group_name: str,
domain, domain: str,
tags=None, tags: Optional[List[Dict[str, str]]] = None,
): ):
self.creation_date = iso_8601_datetime_without_milliseconds(datetime.now()) self.creation_date = iso_8601_datetime_without_milliseconds(datetime.now())
self.modified_date = self.creation_date self.modified_date = self.creation_date
@ -43,11 +44,11 @@ class DatasetGroup:
self.tags = tags self.tags = tags
self._validate() self._validate()
def update(self, dataset_arns): def update(self, dataset_arns: List[str]) -> None:
self.dataset_arns = dataset_arns self.dataset_arns = dataset_arns
self.last_modified_date = iso_8601_datetime_without_milliseconds(datetime.now()) self.last_modified_date = iso_8601_datetime_without_milliseconds(datetime.now())
def _validate(self): def _validate(self) -> None:
errors = [] errors = []
errors.extend(self._validate_dataset_group_name()) errors.extend(self._validate_dataset_group_name())
@ -62,7 +63,7 @@ class DatasetGroup:
message += "; ".join(errors) message += "; ".join(errors)
raise ValidationException(message) raise ValidationException(message)
def _validate_dataset_group_name(self): def _validate_dataset_group_name(self) -> List[str]:
errors = [] errors = []
if not re.match( if not re.match(
self.accepted_dataset_group_name_format, self.dataset_group_name self.accepted_dataset_group_name_format, self.dataset_group_name
@ -75,7 +76,7 @@ class DatasetGroup:
) )
return errors return errors
def _validate_dataset_group_name_len(self): def _validate_dataset_group_name_len(self) -> List[str]:
errors = [] errors = []
if len(self.dataset_group_name) >= 64: if len(self.dataset_group_name) >= 64:
errors.append( errors.append(
@ -85,7 +86,7 @@ class DatasetGroup:
) )
return errors return errors
def _validate_dataset_group_domain(self): def _validate_dataset_group_domain(self) -> List[str]:
errors = [] errors = []
if self.domain not in self.accepted_dataset_types: if self.domain not in self.accepted_dataset_types:
errors.append( errors.append(
@ -98,12 +99,18 @@ class DatasetGroup:
class ForecastBackend(BaseBackend): class ForecastBackend(BaseBackend):
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.dataset_groups = {} self.dataset_groups: Dict[str, DatasetGroup] = {}
self.datasets = {} self.datasets: Dict[str, str] = {}
def create_dataset_group(self, dataset_group_name, domain, dataset_arns, tags): def create_dataset_group(
self,
dataset_group_name: str,
domain: str,
dataset_arns: List[str],
tags: List[Dict[str, str]],
) -> DatasetGroup:
dataset_group = DatasetGroup( dataset_group = DatasetGroup(
account_id=self.account_id, account_id=self.account_id,
region_name=self.region_name, region_name=self.region_name,
@ -128,20 +135,21 @@ class ForecastBackend(BaseBackend):
self.dataset_groups[dataset_group.arn] = dataset_group self.dataset_groups[dataset_group.arn] = dataset_group
return dataset_group return dataset_group
def describe_dataset_group(self, dataset_group_arn): def describe_dataset_group(self, dataset_group_arn: str) -> DatasetGroup:
try: try:
dataset_group = self.dataset_groups[dataset_group_arn] return self.dataset_groups[dataset_group_arn]
except KeyError: except KeyError:
raise ResourceNotFoundException("No resource found " + dataset_group_arn) raise ResourceNotFoundException("No resource found " + dataset_group_arn)
return dataset_group
def delete_dataset_group(self, dataset_group_arn): def delete_dataset_group(self, dataset_group_arn: str) -> None:
try: try:
del self.dataset_groups[dataset_group_arn] del self.dataset_groups[dataset_group_arn]
except KeyError: except KeyError:
raise ResourceNotFoundException("No resource found " + dataset_group_arn) raise ResourceNotFoundException("No resource found " + dataset_group_arn)
def update_dataset_group(self, dataset_group_arn, dataset_arns): def update_dataset_group(
self, dataset_group_arn: str, dataset_arns: List[str]
) -> None:
try: try:
dsg = self.dataset_groups[dataset_group_arn] dsg = self.dataset_groups[dataset_group_arn]
except KeyError: except KeyError:
@ -155,7 +163,7 @@ class ForecastBackend(BaseBackend):
dsg.update(dataset_arns) dsg.update(dataset_arns)
def list_dataset_groups(self): def list_dataset_groups(self) -> List[DatasetGroup]:
return [v for (_, v) in self.dataset_groups.items()] return [v for (_, v) in self.dataset_groups.items()]

View File

@ -1,20 +1,21 @@
import json import json
from moto.core.common_types import TYPE_RESPONSE
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.utilities.aws_headers import amzn_request_id from moto.utilities.aws_headers import amzn_request_id
from .models import forecast_backends from .models import forecast_backends, ForecastBackend
class ForecastResponse(BaseResponse): class ForecastResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="forecast") super().__init__(service_name="forecast")
@property @property
def forecast_backend(self): def forecast_backend(self) -> ForecastBackend:
return forecast_backends[self.current_account][self.region] return forecast_backends[self.current_account][self.region]
@amzn_request_id @amzn_request_id
def create_dataset_group(self): def create_dataset_group(self) -> TYPE_RESPONSE:
dataset_group_name = self._get_param("DatasetGroupName") dataset_group_name = self._get_param("DatasetGroupName")
domain = self._get_param("Domain") domain = self._get_param("Domain")
dataset_arns = self._get_param("DatasetArns") dataset_arns = self._get_param("DatasetArns")
@ -30,7 +31,7 @@ class ForecastResponse(BaseResponse):
return 200, {}, json.dumps(response) return 200, {}, json.dumps(response)
@amzn_request_id @amzn_request_id
def describe_dataset_group(self): def describe_dataset_group(self) -> TYPE_RESPONSE:
dataset_group_arn = self._get_param("DatasetGroupArn") dataset_group_arn = self._get_param("DatasetGroupArn")
dataset_group = self.forecast_backend.describe_dataset_group( dataset_group = self.forecast_backend.describe_dataset_group(
@ -48,21 +49,20 @@ class ForecastResponse(BaseResponse):
return 200, {}, json.dumps(response) return 200, {}, json.dumps(response)
@amzn_request_id @amzn_request_id
def delete_dataset_group(self): def delete_dataset_group(self) -> TYPE_RESPONSE:
dataset_group_arn = self._get_param("DatasetGroupArn") dataset_group_arn = self._get_param("DatasetGroupArn")
self.forecast_backend.delete_dataset_group(dataset_group_arn) self.forecast_backend.delete_dataset_group(dataset_group_arn)
return 200, {}, None return 200, {}, ""
@amzn_request_id @amzn_request_id
def update_dataset_group(self): def update_dataset_group(self) -> TYPE_RESPONSE:
dataset_group_arn = self._get_param("DatasetGroupArn") dataset_group_arn = self._get_param("DatasetGroupArn")
dataset_arns = self._get_param("DatasetArns") dataset_arns = self._get_param("DatasetArns")
self.forecast_backend.update_dataset_group(dataset_group_arn, dataset_arns) self.forecast_backend.update_dataset_group(dataset_group_arn, dataset_arns)
return 200, {}, None return 200, {}, ""
@amzn_request_id @amzn_request_id
def list_dataset_groups(self): def list_dataset_groups(self) -> TYPE_RESPONSE:
list_all = self.forecast_backend.list_dataset_groups()
list_all = sorted( list_all = sorted(
[ [
{ {
@ -71,9 +71,9 @@ class ForecastResponse(BaseResponse):
"CreationTime": dsg.creation_date, "CreationTime": dsg.creation_date,
"LastModificationTime": dsg.creation_date, "LastModificationTime": dsg.creation_date,
} }
for dsg in list_all for dsg in self.forecast_backend.list_dataset_groups()
], ],
key=lambda x: x["LastModificationTime"], key=lambda x: x["LastModificationTime"], # type: ignore
reverse=True, reverse=True,
) )
response = {"DatasetGroups": list_all} response = {"DatasetGroups": list_all}

View File

@ -86,7 +86,7 @@ class TaggingService:
# If both key and value are provided, match both before deletion # If both key and value are provided, match both before deletion
del current_tags[tag[self.key_name]] del current_tags[tag[self.key_name]]
def extract_tag_names(self, tags: Dict[str, str]) -> None: def extract_tag_names(self, tags: List[Dict[str, str]]) -> List[str]:
"""Return list of key names in list of 'tags' key/value dicts.""" """Return list of key names in list of 'tags' key/value dicts."""
results = [] results = []
if len(tags) == 0: if len(tags) == 0:

View File

@ -229,7 +229,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy] [mypy]
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/moto_api,moto/neptune files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/moto_api,moto/neptune
show_column_numbers=True show_column_numbers=True
show_error_codes = True show_error_codes = True
disable_error_code=abstract disable_error_code=abstract