From c3460b8a1aeafb76c8b3ae99f46ed3ffc67c1eef Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Thu, 23 Mar 2023 09:17:40 -0100 Subject: [PATCH] Techdebt: MyPy K (#6111) --- moto/kinesis/exceptions.py | 29 +- moto/kinesis/models.py | 305 ++++++++++++-------- moto/kinesis/responses.py | 224 +++++++------- moto/kinesis/utils.py | 16 +- moto/kinesisvideo/exceptions.py | 4 +- moto/kinesisvideo/models.py | 62 ++-- moto/kinesisvideo/responses.py | 20 +- moto/kinesisvideoarchivedmedia/models.py | 25 +- moto/kinesisvideoarchivedmedia/responses.py | 13 +- moto/kms/exceptions.py | 16 +- moto/kms/models.py | 236 ++++++++------- moto/kms/policy_validator.py | 11 +- moto/kms/responses.py | 252 ++++++++-------- moto/kms/utils.py | 28 +- moto/sts/utils.py | 2 +- setup.cfg | 2 +- 16 files changed, 680 insertions(+), 565 deletions(-) diff --git a/moto/kinesis/exceptions.py b/moto/kinesis/exceptions.py index c1e61ba23..e4c96c23a 100644 --- a/moto/kinesis/exceptions.py +++ b/moto/kinesis/exceptions.py @@ -1,46 +1,47 @@ from moto.core.exceptions import JsonRESTError +from typing import Optional class ResourceNotFoundError(JsonRESTError): - def __init__(self, message): + def __init__(self, message: str): super().__init__(error_type="ResourceNotFoundException", message=message) class ResourceInUseError(JsonRESTError): - def __init__(self, message): + def __init__(self, message: str): super().__init__(error_type="ResourceInUseException", message=message) class StreamNotFoundError(ResourceNotFoundError): - def __init__(self, stream_name, account_id): + def __init__(self, stream_name: str, account_id: str): super().__init__(f"Stream {stream_name} under account {account_id} not found.") class StreamCannotBeUpdatedError(JsonRESTError): - def __init__(self, stream_name, account_id): + def __init__(self, stream_name: str, account_id: str): message = f"Request is invalid. Stream {stream_name} under account {account_id} is in On-Demand mode." super().__init__(error_type="ValidationException", message=message) class ShardNotFoundError(ResourceNotFoundError): - def __init__(self, shard_id, stream, account_id): + def __init__(self, shard_id: str, stream: str, account_id: str): super().__init__( f"Could not find shard {shard_id} in stream {stream} under account {account_id}." ) class ConsumerNotFound(ResourceNotFoundError): - def __init__(self, consumer, account_id): + def __init__(self, consumer: str, account_id: str): super().__init__(f"Consumer {consumer}, account {account_id} not found.") class InvalidArgumentError(JsonRESTError): - def __init__(self, message): + def __init__(self, message: str): super().__init__(error_type="InvalidArgumentException", message=message) class InvalidRetentionPeriod(InvalidArgumentError): - def __init__(self, hours, too_short): + def __init__(self, hours: int, too_short: bool): if too_short: msg = f"Minimum allowed retention period is 24 hours. Requested retention period ({hours} hours) is too short." else: @@ -49,31 +50,31 @@ class InvalidRetentionPeriod(InvalidArgumentError): class InvalidDecreaseRetention(InvalidArgumentError): - def __init__(self, name, requested, existing): + def __init__(self, name: Optional[str], requested: int, existing: int): msg = f"Requested retention period ({requested} hours) for stream {name} can not be longer than existing retention period ({existing} hours). Use IncreaseRetentionPeriod API." super().__init__(msg) class InvalidIncreaseRetention(InvalidArgumentError): - def __init__(self, name, requested, existing): + def __init__(self, name: Optional[str], requested: int, existing: int): msg = f"Requested retention period ({requested} hours) for stream {name} can not be shorter than existing retention period ({existing} hours). Use DecreaseRetentionPeriod API." super().__init__(msg) class ValidationException(JsonRESTError): - def __init__(self, value, position, regex_to_match): + def __init__(self, value: str, position: str, regex_to_match: str): msg = f"1 validation error detected: Value '{value}' at '{position}' failed to satisfy constraint: Member must satisfy regular expression pattern: {regex_to_match}" super().__init__(error_type="ValidationException", message=msg) class RecordSizeExceedsLimit(JsonRESTError): - def __init__(self, position): + def __init__(self, position: int): msg = f"1 validation error detected: Value at 'records.{position}.member.data' failed to satisfy constraint: Member must have length less than or equal to 1048576" super().__init__(error_type="ValidationException", message=msg) class TotalRecordsSizeExceedsLimit(JsonRESTError): - def __init__(self): + def __init__(self) -> None: super().__init__( error_type="InvalidArgumentException", message="Records size exceeds 5 MB limit", @@ -81,6 +82,6 @@ class TotalRecordsSizeExceedsLimit(JsonRESTError): class TooManyRecords(JsonRESTError): - def __init__(self): + def __init__(self) -> None: msg = "1 validation error detected: Value at 'records' failed to satisfy constraint: Member must have length less than or equal to 500" super().__init__(error_type="ValidationException", message=msg) diff --git a/moto/kinesis/models.py b/moto/kinesis/models.py index 8efedbdb9..77689eb62 100644 --- a/moto/kinesis/models.py +++ b/moto/kinesis/models.py @@ -4,7 +4,7 @@ import re import itertools from operator import attrgetter -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Iterable from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel from moto.core.utils import unix_time @@ -35,14 +35,16 @@ from .utils import ( class Consumer(BaseModel): - def __init__(self, consumer_name, account_id, region_name, stream_arn): + def __init__( + self, consumer_name: str, account_id: str, region_name: str, stream_arn: str + ): self.consumer_name = consumer_name self.created = unix_time() self.stream_arn = stream_arn stream_name = stream_arn.split("/")[-1] self.consumer_arn = f"arn:aws:kinesis:{region_name}:{account_id}:stream/{stream_name}/consumer/{consumer_name}" - def to_json(self, include_stream_arn=False): + def to_json(self, include_stream_arn: bool = False) -> Dict[str, Any]: resp = { "ConsumerName": self.consumer_name, "ConsumerARN": self.consumer_arn, @@ -55,7 +57,13 @@ class Consumer(BaseModel): class Record(BaseModel): - def __init__(self, partition_key, data, sequence_number, explicit_hash_key): + def __init__( + self, + partition_key: str, + data: str, + sequence_number: int, + explicit_hash_key: str, + ): self.partition_key = partition_key self.data = data self.sequence_number = sequence_number @@ -63,7 +71,7 @@ class Record(BaseModel): self.created_at_datetime = datetime.datetime.utcnow() self.created_at = unix_time(self.created_at_datetime) - def to_json(self): + def to_json(self) -> Dict[str, Any]: return { "Data": self.data, "PartitionKey": self.partition_key, @@ -74,29 +82,36 @@ class Record(BaseModel): class Shard(BaseModel): def __init__( - self, shard_id, starting_hash, ending_hash, parent=None, adjacent_parent=None + self, + shard_id: int, + starting_hash: int, + ending_hash: int, + parent: Optional[str] = None, + adjacent_parent: Optional[str] = None, ): self._shard_id = shard_id self.starting_hash = starting_hash self.ending_hash = ending_hash - self.records = OrderedDict() + self.records: Dict[int, Record] = OrderedDict() self.is_open = True self.parent = parent self.adjacent_parent = adjacent_parent @property - def shard_id(self): + def shard_id(self) -> str: return f"shardId-{str(self._shard_id).zfill(12)}" - def get_records(self, last_sequence_id, limit): - last_sequence_id = int(last_sequence_id) + def get_records( + self, last_sequence_id: str, limit: Optional[int] + ) -> Tuple[List[Record], int, int]: + last_sequence_int = int(last_sequence_id) results = [] - secs_behind_latest = 0 + secs_behind_latest = 0.0 for sequence_number, record in self.records.items(): - if sequence_number > last_sequence_id: + if sequence_number > last_sequence_int: results.append(record) - last_sequence_id = sequence_number + last_sequence_int = sequence_number very_last_record = self.records[next(reversed(self.records))] secs_behind_latest = very_last_record.created_at - record.created_at @@ -105,9 +120,9 @@ class Shard(BaseModel): break millis_behind_latest = int(secs_behind_latest * 1000) - return results, last_sequence_id, millis_behind_latest + return results, last_sequence_int, millis_behind_latest - def put_record(self, partition_key, data, explicit_hash_key): + def put_record(self, partition_key: str, data: str, explicit_hash_key: str) -> str: # Note: this function is not safe for concurrency if self.records: last_sequence_number = self.get_max_sequence_number() @@ -119,17 +134,17 @@ class Shard(BaseModel): ) return str(sequence_number) - def get_min_sequence_number(self): + def get_min_sequence_number(self) -> int: if self.records: return list(self.records.keys())[0] return 0 - def get_max_sequence_number(self): + def get_max_sequence_number(self) -> int: if self.records: return list(self.records.keys())[-1] return 0 - def get_sequence_number_at(self, at_timestamp): + def get_sequence_number_at(self, at_timestamp: float) -> int: if not self.records or at_timestamp < list(self.records.values())[0].created_at: return 0 else: @@ -143,10 +158,10 @@ class Shard(BaseModel): ), None, ) - return r.sequence_number + return r.sequence_number # type: ignore - def to_json(self): - response = { + def to_json(self) -> Dict[str, Any]: + response: Dict[str, Any] = { "HashKeyRange": { "EndingHashKey": str(self.ending_hash), "StartingHashKey": str(self.starting_hash), @@ -170,12 +185,12 @@ class Shard(BaseModel): class Stream(CloudFormationModel): def __init__( self, - stream_name, - shard_count, - stream_mode, - retention_period_hours, - account_id, - region_name, + stream_name: str, + shard_count: int, + stream_mode: Optional[Dict[str, str]], + retention_period_hours: Optional[int], + account_id: str, + region_name: str, ): self.stream_name = stream_name self.creation_datetime = datetime.datetime.now().strftime( @@ -184,27 +199,27 @@ class Stream(CloudFormationModel): self.region = region_name self.account_id = account_id self.arn = f"arn:aws:kinesis:{region_name}:{account_id}:stream/{stream_name}" - self.shards = {} - self.tags = {} + self.shards: Dict[str, Shard] = {} + self.tags: Dict[str, str] = {} self.status = "ACTIVE" - self.shard_count = None + self.shard_count: Optional[int] = None self.stream_mode = stream_mode or {"StreamMode": "PROVISIONED"} if self.stream_mode.get("StreamMode", "") == "ON_DEMAND": shard_count = 4 self.init_shards(shard_count) self.retention_period_hours = retention_period_hours or 24 - self.shard_level_metrics = [] + self.shard_level_metrics: List[str] = [] self.encryption_type = "NONE" - self.key_id = None - self.consumers = [] + self.key_id: Optional[str] = None + self.consumers: List[Consumer] = [] - def delete_consumer(self, consumer_arn): + def delete_consumer(self, consumer_arn: str) -> None: self.consumers = [c for c in self.consumers if c.consumer_arn != consumer_arn] - def get_consumer_by_arn(self, consumer_arn): + def get_consumer_by_arn(self, consumer_arn: str) -> Optional[Consumer]: return next((c for c in self.consumers if c.consumer_arn == consumer_arn), None) - def init_shards(self, shard_count): + def init_shards(self, shard_count: int) -> None: self.shard_count = shard_count step = 2**128 // shard_count @@ -216,16 +231,16 @@ class Stream(CloudFormationModel): shard = Shard(index, start, end) self.shards[shard.shard_id] = shard - def split_shard(self, shard_to_split, new_starting_hash_key): - new_starting_hash_key = int(new_starting_hash_key) + def split_shard(self, shard_to_split: str, new_starting_hash_key: str) -> None: + new_starting_hash_int = int(new_starting_hash_key) shard = self.shards[shard_to_split] - if shard.starting_hash < new_starting_hash_key < shard.ending_hash: + if shard.starting_hash < new_starting_hash_int < shard.ending_hash: pass else: raise InvalidArgumentError( - message=f"NewStartingHashKey {new_starting_hash_key} used in SplitShard() on shard {shard_to_split} in stream {self.stream_name} under account {self.account_id} is not both greater than one plus the shard's StartingHashKey {shard.starting_hash} and less than the shard's EndingHashKey {(shard.ending_hash - 1)}." + message=f"NewStartingHashKey {new_starting_hash_int} used in SplitShard() on shard {shard_to_split} in stream {self.stream_name} under account {self.account_id} is not both greater than one plus the shard's StartingHashKey {shard.starting_hash} and less than the shard's EndingHashKey {(shard.ending_hash - 1)}." ) if not shard.is_open: @@ -241,12 +256,12 @@ class Stream(CloudFormationModel): new_shard_1 = Shard( last_id + 1, starting_hash=shard.starting_hash, - ending_hash=new_starting_hash_key - 1, + ending_hash=new_starting_hash_int - 1, parent=shard.shard_id, ) new_shard_2 = Shard( last_id + 2, - starting_hash=new_starting_hash_key, + starting_hash=new_starting_hash_int, ending_hash=shard.ending_hash, parent=shard.shard_id, ) @@ -261,7 +276,7 @@ class Stream(CloudFormationModel): record = records[index] self.put_record(record.partition_key, record.explicit_hash_key, record.data) - def merge_shards(self, shard_to_merge, adjacent_shard_to_merge): + def merge_shards(self, shard_to_merge: str, adjacent_shard_to_merge: str) -> None: shard1 = self.shards[shard_to_merge] shard2 = self.shards[adjacent_shard_to_merge] @@ -300,8 +315,8 @@ class Stream(CloudFormationModel): record.partition_key, record.data, record.explicit_hash_key ) - def update_shard_count(self, target_shard_count): - if self.stream_mode.get("StreamMode", "") == "ON_DEMAND": + def update_shard_count(self, target_shard_count: int) -> None: + if self.stream_mode.get("StreamMode", "") == "ON_DEMAND": # type: ignore raise StreamCannotBeUpdatedError( stream_name=self.stream_name, account_id=self.account_id ) @@ -351,13 +366,15 @@ class Stream(CloudFormationModel): self.shard_count = target_shard_count - def get_shard(self, shard_id): + def get_shard(self, shard_id: str) -> Shard: if shard_id in self.shards: return self.shards[shard_id] else: raise ShardNotFoundError(shard_id, stream="", account_id=self.account_id) - def get_shard_for_key(self, partition_key, explicit_hash_key): + def get_shard_for_key( + self, partition_key: str, explicit_hash_key: str + ) -> Optional[Shard]: if not isinstance(partition_key, str): raise InvalidArgumentError("partition_key") if len(partition_key) > 256: @@ -367,25 +384,28 @@ class Stream(CloudFormationModel): if not isinstance(explicit_hash_key, str): raise InvalidArgumentError("explicit_hash_key") - key = int(explicit_hash_key) + int_key = int(explicit_hash_key) - if key >= 2**128: + if int_key >= 2**128: raise InvalidArgumentError("explicit_hash_key") else: - key = int(md5_hash(partition_key.encode("utf-8")).hexdigest(), 16) + int_key = int(md5_hash(partition_key.encode("utf-8")).hexdigest(), 16) for shard in self.shards.values(): - if shard.starting_hash <= key < shard.ending_hash: + if shard.starting_hash <= int_key < shard.ending_hash: return shard + return None - def put_record(self, partition_key, explicit_hash_key, data): + def put_record( + self, partition_key: str, explicit_hash_key: str, data: str + ) -> Tuple[str, str]: shard = self.get_shard_for_key(partition_key, explicit_hash_key) - sequence_number = shard.put_record(partition_key, data, explicit_hash_key) - return sequence_number, shard.shard_id + sequence_number = shard.put_record(partition_key, data, explicit_hash_key) # type: ignore + return sequence_number, shard.shard_id # type: ignore - def to_json(self, shard_limit=None): + def to_json(self, shard_limit: Optional[int] = None) -> Dict[str, Any]: all_shards = list(self.shards.values()) requested_shards = all_shards[0 : shard_limit or len(all_shards)] return { @@ -403,7 +423,7 @@ class Stream(CloudFormationModel): } } - def to_json_summary(self): + def to_json_summary(self) -> Dict[str, Any]: return { "StreamDescriptionSummary": { "StreamARN": self.arn, @@ -420,18 +440,23 @@ class Stream(CloudFormationModel): } @staticmethod - def cloudformation_name_type(): + def cloudformation_name_type() -> str: return "Name" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-kinesis-stream.html return "AWS::Kinesis::Stream" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "Stream": properties = cloudformation_json.get("Properties", {}) shard_count = properties.get("ShardCount", 1) retention_period_hours = properties.get("RetentionPeriodHours", resource_name) @@ -451,14 +476,14 @@ class Stream(CloudFormationModel): return stream @classmethod - def update_from_cloudformation_json( + def update_from_cloudformation_json( # type: ignore[misc] cls, - original_resource, - new_resource_name, - cloudformation_json, - account_id, - region_name, - ): + original_resource: Any, + new_resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + ) -> "Stream": properties = cloudformation_json["Properties"] if Stream.is_replacement_update(properties): @@ -466,11 +491,17 @@ class Stream(CloudFormationModel): if resource_name_property not in properties: properties[resource_name_property] = new_resource_name new_resource = cls.create_from_cloudformation_json( - properties[resource_name_property], cloudformation_json, region_name + resource_name=properties[resource_name_property], + cloudformation_json=cloudformation_json, + account_id=account_id, + region_name=region_name, ) properties[resource_name_property] = original_resource.name cls.delete_from_cloudformation_json( - original_resource.name, cloudformation_json, region_name + resource_name=original_resource.name, + cloudformation_json=cloudformation_json, + account_id=account_id, + region_name=region_name, ) return new_resource @@ -489,14 +520,18 @@ class Stream(CloudFormationModel): return original_resource @classmethod - def delete_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name - ): + def delete_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + ) -> None: backend: KinesisBackend = kinesis_backends[account_id][region_name] backend.delete_stream(stream_arn=None, stream_name=resource_name) @staticmethod - def is_replacement_update(properties): + def is_replacement_update(properties: List[str]) -> bool: properties_requiring_replacement_update = ["BucketName", "ObjectLockEnabled"] return any( [ @@ -506,10 +541,10 @@ class Stream(CloudFormationModel): ) @classmethod - def has_cfn_attr(cls, attr): + def has_cfn_attr(cls, attr: str) -> bool: return attr in ["Arn"] - def get_cfn_attribute(self, attribute_name): + def get_cfn_attribute(self, attribute_name: str) -> str: from moto.cloudformation.exceptions import UnformattedGetAttTemplateException if attribute_name == "Arn": @@ -517,25 +552,31 @@ class Stream(CloudFormationModel): raise UnformattedGetAttTemplateException() @property - def physical_resource_id(self): + def physical_resource_id(self) -> str: return self.stream_name class KinesisBackend(BaseBackend): - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) self.streams: Dict[str, Stream] = OrderedDict() @staticmethod - def default_vpc_endpoint_service(service_region, zones): + def default_vpc_endpoint_service( + service_region: str, zones: List[str] + ) -> List[Dict[str, str]]: """Default VPC endpoint service.""" return BaseBackend.default_vpc_endpoint_service_factory( service_region, zones, "kinesis", special_service_name="kinesis-streams" ) def create_stream( - self, stream_name, shard_count, stream_mode=None, retention_period_hours=None - ): + self, + stream_name: str, + shard_count: int, + stream_mode: Optional[Dict[str, str]] = None, + retention_period_hours: Optional[int] = None, + ) -> Stream: if stream_name in self.streams: raise ResourceInUseError(stream_name) stream = Stream( @@ -560,14 +601,14 @@ class KinesisBackend(BaseBackend): return stream if stream_arn: stream_name = stream_arn.split("/")[1] - raise StreamNotFoundError(stream_name, self.account_id) + raise StreamNotFoundError(stream_name, self.account_id) # type: ignore def describe_stream_summary( self, stream_arn: Optional[str], stream_name: Optional[str] ) -> Stream: return self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) - def list_streams(self): + def list_streams(self) -> Iterable[Stream]: return self.streams.values() def delete_stream( @@ -582,9 +623,9 @@ class KinesisBackend(BaseBackend): stream_name: Optional[str], shard_id: str, shard_iterator_type: str, - starting_sequence_number: str, - at_timestamp: str, - ): + starting_sequence_number: int, + at_timestamp: datetime.datetime, + ) -> str: # Validate params stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) try: @@ -605,31 +646,31 @@ class KinesisBackend(BaseBackend): def get_records( self, stream_arn: Optional[str], shard_iterator: str, limit: Optional[int] - ): + ) -> Tuple[str, List[Record], int]: decomposed = decompose_shard_iterator(shard_iterator) stream_name, shard_id, last_sequence_id = decomposed stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) shard = stream.get_shard(shard_id) - records, last_sequence_id, millis_behind_latest = shard.get_records( + records, last_sequence_id, millis_behind_latest = shard.get_records( # type: ignore last_sequence_id, limit ) next_shard_iterator = compose_shard_iterator( - stream_name, shard, last_sequence_id + stream_name, shard, last_sequence_id # type: ignore ) return next_shard_iterator, records, millis_behind_latest def put_record( self, - stream_arn, - stream_name, - partition_key, - explicit_hash_key, - data, - ): + stream_arn: str, + stream_name: str, + partition_key: str, + explicit_hash_key: str, + data: str, + ) -> Tuple[str, str]: stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) sequence_number, shard_id = stream.put_record( @@ -638,10 +679,12 @@ class KinesisBackend(BaseBackend): return sequence_number, shard_id - def put_records(self, stream_arn, stream_name, records): + def put_records( + self, stream_arn: str, stream_name: str, records: List[Dict[str, Any]] + ) -> Dict[str, Any]: stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) - response = {"FailedRecordCount": 0, "Records": []} + response: Dict[str, Any] = {"FailedRecordCount": 0, "Records": []} if len(records) > 500: raise TooManyRecords @@ -660,7 +703,7 @@ class KinesisBackend(BaseBackend): data = record.get("Data") sequence_number, shard_id = stream.put_record( - partition_key, explicit_hash_key, data + partition_key, explicit_hash_key, data # type: ignore[arg-type] ) response["Records"].append( {"SequenceNumber": sequence_number, "ShardId": shard_id} @@ -669,8 +712,12 @@ class KinesisBackend(BaseBackend): return response def split_shard( - self, stream_arn, stream_name, shard_to_split, new_starting_hash_key - ): + self, + stream_arn: str, + stream_name: str, + shard_to_split: str, + new_starting_hash_key: str, + ) -> None: stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) if not re.match("[a-zA-Z0-9_.-]+", shard_to_split): @@ -695,8 +742,12 @@ class KinesisBackend(BaseBackend): stream.split_shard(shard_to_split, new_starting_hash_key) def merge_shards( - self, stream_arn, stream_name, shard_to_merge, adjacent_shard_to_merge - ): + self, + stream_arn: str, + stream_name: str, + shard_to_merge: str, + adjacent_shard_to_merge: str, + ) -> None: stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) if shard_to_merge not in stream.shards: @@ -713,7 +764,9 @@ class KinesisBackend(BaseBackend): stream.merge_shards(shard_to_merge, adjacent_shard_to_merge) - def update_shard_count(self, stream_arn, stream_name, target_shard_count): + def update_shard_count( + self, stream_arn: str, stream_name: str, target_shard_count: int + ) -> int: stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) current_shard_count = len([s for s in stream.shards.values() if s.is_open]) @@ -722,7 +775,7 @@ class KinesisBackend(BaseBackend): return current_shard_count @paginate(pagination_model=PAGINATION_MODEL) - def list_shards(self, stream_arn: Optional[str], stream_name: Optional[str]): + def list_shards(self, stream_arn: Optional[str], stream_name: Optional[str]) -> List[Dict[str, Any]]: # type: ignore stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) shards = sorted(stream.shards.values(), key=lambda x: x.shard_id) return [shard.to_json() for shard in shards] @@ -766,12 +819,16 @@ class KinesisBackend(BaseBackend): stream.retention_period_hours = retention_period_hours def list_tags_for_stream( - self, stream_arn, stream_name, exclusive_start_tag_key=None, limit=None - ): + self, + stream_arn: str, + stream_name: str, + exclusive_start_tag_key: Optional[str] = None, + limit: Optional[int] = None, + ) -> Dict[str, Any]: stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) - tags = [] - result = {"HasMoreTags": False, "Tags": tags} + tags: List[Dict[str, str]] = [] + result: Dict[str, Any] = {"HasMoreTags": False, "Tags": tags} for key, val in sorted(stream.tags.items(), key=lambda x: x[0]): if limit and len(tags) >= limit: result["HasMoreTags"] = True @@ -805,7 +862,7 @@ class KinesisBackend(BaseBackend): stream_arn: Optional[str], stream_name: Optional[str], shard_level_metrics: List[str], - ) -> Tuple[str, str, Dict[str, Any], Dict[str, Any]]: + ) -> Tuple[str, str, List[str], List[str]]: stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) current_shard_level_metrics = stream.shard_level_metrics desired_metrics = list(set(current_shard_level_metrics + shard_level_metrics)) @@ -822,7 +879,7 @@ class KinesisBackend(BaseBackend): stream_arn: Optional[str], stream_name: Optional[str], to_be_disabled: List[str], - ) -> Tuple[str, str, Dict[str, Any], Dict[str, Any]]: + ) -> Tuple[str, str, List[str], List[str]]: stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) current_metrics = stream.shard_level_metrics if "ALL" in to_be_disabled: @@ -834,19 +891,19 @@ class KinesisBackend(BaseBackend): stream.shard_level_metrics = desired_metrics return stream.arn, stream.stream_name, current_metrics, desired_metrics - def _find_stream_by_arn(self, stream_arn): + def _find_stream_by_arn(self, stream_arn: str) -> Stream: # type: ignore[return] for stream in self.streams.values(): if stream.arn == stream_arn: return stream - def list_stream_consumers(self, stream_arn): + def list_stream_consumers(self, stream_arn: str) -> List[Consumer]: """ Pagination is not yet implemented """ stream = self._find_stream_by_arn(stream_arn) return stream.consumers - def register_stream_consumer(self, stream_arn, consumer_name): + def register_stream_consumer(self, stream_arn: str, consumer_name: str) -> Consumer: consumer = Consumer( consumer_name, self.account_id, self.region_name, stream_arn ) @@ -854,7 +911,9 @@ class KinesisBackend(BaseBackend): stream.consumers.append(consumer) return consumer - def describe_stream_consumer(self, stream_arn, consumer_name, consumer_arn): + def describe_stream_consumer( + self, stream_arn: str, consumer_name: str, consumer_arn: str + ) -> Consumer: if stream_arn: stream = self._find_stream_by_arn(stream_arn) for consumer in stream.consumers: @@ -862,14 +921,16 @@ class KinesisBackend(BaseBackend): return consumer if consumer_arn: for stream in self.streams.values(): - consumer = stream.get_consumer_by_arn(consumer_arn) - if consumer: - return consumer + _consumer = stream.get_consumer_by_arn(consumer_arn) + if _consumer: + return _consumer raise ConsumerNotFound( consumer=consumer_name or consumer_arn, account_id=self.account_id ) - def deregister_stream_consumer(self, stream_arn, consumer_name, consumer_arn): + def deregister_stream_consumer( + self, stream_arn: str, consumer_name: str, consumer_arn: str + ) -> None: if stream_arn: stream = self._find_stream_by_arn(stream_arn) stream.consumers = [ @@ -881,17 +942,19 @@ class KinesisBackend(BaseBackend): # It will be a noop for other streams stream.delete_consumer(consumer_arn) - def start_stream_encryption(self, stream_arn, stream_name, encryption_type, key_id): + def start_stream_encryption( + self, stream_arn: str, stream_name: str, encryption_type: str, key_id: str + ) -> None: stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) stream.encryption_type = encryption_type stream.key_id = key_id - def stop_stream_encryption(self, stream_arn, stream_name): + def stop_stream_encryption(self, stream_arn: str, stream_name: str) -> None: stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) stream.encryption_type = "NONE" stream.key_id = None - def update_stream_mode(self, stream_arn, stream_mode): + def update_stream_mode(self, stream_arn: str, stream_mode: Dict[str, str]) -> None: stream = self._find_stream_by_arn(stream_arn) stream.stream_mode = stream_mode diff --git a/moto/kinesis/responses.py b/moto/kinesis/responses.py index 92f8501c1..167be3f89 100644 --- a/moto/kinesis/responses.py +++ b/moto/kinesis/responses.py @@ -5,45 +5,41 @@ from .models import kinesis_backends, KinesisBackend class KinesisResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="kinesis") - @property - def parameters(self): - return json.loads(self.body) - @property def kinesis_backend(self) -> KinesisBackend: return kinesis_backends[self.current_account][self.region] - def create_stream(self): - stream_name = self.parameters.get("StreamName") - shard_count = self.parameters.get("ShardCount") - stream_mode = self.parameters.get("StreamModeDetails") + def create_stream(self) -> str: + stream_name = self._get_param("StreamName") + shard_count = self._get_param("ShardCount") + stream_mode = self._get_param("StreamModeDetails") self.kinesis_backend.create_stream( stream_name, shard_count, stream_mode=stream_mode ) return "" - def describe_stream(self): - stream_name = self.parameters.get("StreamName") - stream_arn = self.parameters.get("StreamARN") - limit = self.parameters.get("Limit") + def describe_stream(self) -> str: + stream_name = self._get_param("StreamName") + stream_arn = self._get_param("StreamARN") + limit = self._get_param("Limit") stream = self.kinesis_backend.describe_stream(stream_arn, stream_name) return json.dumps(stream.to_json(shard_limit=limit)) - def describe_stream_summary(self): - stream_arn = self.parameters.get("StreamARN") - stream_name = self.parameters.get("StreamName") + def describe_stream_summary(self) -> str: + stream_arn = self._get_param("StreamARN") + stream_name = self._get_param("StreamName") stream = self.kinesis_backend.describe_stream_summary(stream_arn, stream_name) return json.dumps(stream.to_json_summary()) - def list_streams(self): + def list_streams(self) -> str: streams = self.kinesis_backend.list_streams() stream_names = [stream.stream_name for stream in streams] max_streams = self._get_param("Limit", 10) try: - token = self.parameters.get("ExclusiveStartStreamName") + token = self._get_param("ExclusiveStartStreamName") except ValueError: token = self._get_param("ExclusiveStartStreamName") if token: @@ -59,19 +55,19 @@ class KinesisResponse(BaseResponse): {"HasMoreStreams": has_more_streams, "StreamNames": streams_resp} ) - def delete_stream(self): - stream_arn = self.parameters.get("StreamARN") - stream_name = self.parameters.get("StreamName") + def delete_stream(self) -> str: + stream_arn = self._get_param("StreamARN") + stream_name = self._get_param("StreamName") self.kinesis_backend.delete_stream(stream_arn, stream_name) return "" - def get_shard_iterator(self): - stream_arn = self.parameters.get("StreamARN") - stream_name = self.parameters.get("StreamName") - shard_id = self.parameters.get("ShardId") - shard_iterator_type = self.parameters.get("ShardIteratorType") - starting_sequence_number = self.parameters.get("StartingSequenceNumber") - at_timestamp = self.parameters.get("Timestamp") + def get_shard_iterator(self) -> str: + stream_arn = self._get_param("StreamARN") + stream_name = self._get_param("StreamName") + shard_id = self._get_param("ShardId") + shard_iterator_type = self._get_param("ShardIteratorType") + starting_sequence_number = self._get_param("StartingSequenceNumber") + at_timestamp = self._get_param("Timestamp") shard_iterator = self.kinesis_backend.get_shard_iterator( stream_arn, @@ -84,10 +80,10 @@ class KinesisResponse(BaseResponse): return json.dumps({"ShardIterator": shard_iterator}) - def get_records(self): - stream_arn = self.parameters.get("StreamARN") - shard_iterator = self.parameters.get("ShardIterator") - limit = self.parameters.get("Limit") + def get_records(self) -> str: + stream_arn = self._get_param("StreamARN") + shard_iterator = self._get_param("ShardIterator") + limit = self._get_param("Limit") ( next_shard_iterator, @@ -103,12 +99,12 @@ class KinesisResponse(BaseResponse): } ) - def put_record(self): - stream_arn = self.parameters.get("StreamARN") - stream_name = self.parameters.get("StreamName") - partition_key = self.parameters.get("PartitionKey") - explicit_hash_key = self.parameters.get("ExplicitHashKey") - data = self.parameters.get("Data") + def put_record(self) -> str: + stream_arn = self._get_param("StreamARN") + stream_name = self._get_param("StreamName") + partition_key = self._get_param("PartitionKey") + explicit_hash_key = self._get_param("ExplicitHashKey") + data = self._get_param("Data") sequence_number, shard_id = self.kinesis_backend.put_record( stream_arn, @@ -120,40 +116,40 @@ class KinesisResponse(BaseResponse): return json.dumps({"SequenceNumber": sequence_number, "ShardId": shard_id}) - def put_records(self): - stream_arn = self.parameters.get("StreamARN") - stream_name = self.parameters.get("StreamName") - records = self.parameters.get("Records") + def put_records(self) -> str: + stream_arn = self._get_param("StreamARN") + stream_name = self._get_param("StreamName") + records = self._get_param("Records") response = self.kinesis_backend.put_records(stream_arn, stream_name, records) return json.dumps(response) - def split_shard(self): - stream_arn = self.parameters.get("StreamARN") - stream_name = self.parameters.get("StreamName") - shard_to_split = self.parameters.get("ShardToSplit") - new_starting_hash_key = self.parameters.get("NewStartingHashKey") + def split_shard(self) -> str: + stream_arn = self._get_param("StreamARN") + stream_name = self._get_param("StreamName") + shard_to_split = self._get_param("ShardToSplit") + new_starting_hash_key = self._get_param("NewStartingHashKey") self.kinesis_backend.split_shard( stream_arn, stream_name, shard_to_split, new_starting_hash_key ) return "" - def merge_shards(self): - stream_arn = self.parameters.get("StreamARN") - stream_name = self.parameters.get("StreamName") - shard_to_merge = self.parameters.get("ShardToMerge") - adjacent_shard_to_merge = self.parameters.get("AdjacentShardToMerge") + def merge_shards(self) -> str: + stream_arn = self._get_param("StreamARN") + stream_name = self._get_param("StreamName") + shard_to_merge = self._get_param("ShardToMerge") + adjacent_shard_to_merge = self._get_param("AdjacentShardToMerge") self.kinesis_backend.merge_shards( stream_arn, stream_name, shard_to_merge, adjacent_shard_to_merge ) return "" - def list_shards(self): - stream_arn = self.parameters.get("StreamARN") - stream_name = self.parameters.get("StreamName") - next_token = self.parameters.get("NextToken") - max_results = self.parameters.get("MaxResults", 10000) + def list_shards(self) -> str: + stream_arn = self._get_param("StreamARN") + stream_name = self._get_param("StreamName") + next_token = self._get_param("NextToken") + max_results = self._get_param("MaxResults", 10000) shards, token = self.kinesis_backend.list_shards( stream_arn=stream_arn, stream_name=stream_name, @@ -165,10 +161,10 @@ class KinesisResponse(BaseResponse): res["NextToken"] = token return json.dumps(res) - def update_shard_count(self): - stream_arn = self.parameters.get("StreamARN") - stream_name = self.parameters.get("StreamName") - target_shard_count = self.parameters.get("TargetShardCount") + def update_shard_count(self) -> str: + stream_arn = self._get_param("StreamARN") + stream_name = self._get_param("StreamName") + target_shard_count = self._get_param("TargetShardCount") current_shard_count = self.kinesis_backend.update_shard_count( stream_arn=stream_arn, stream_name=stream_name, @@ -182,52 +178,52 @@ class KinesisResponse(BaseResponse): ) ) - def increase_stream_retention_period(self): - stream_arn = self.parameters.get("StreamARN") - stream_name = self.parameters.get("StreamName") - retention_period_hours = self.parameters.get("RetentionPeriodHours") + def increase_stream_retention_period(self) -> str: + stream_arn = self._get_param("StreamARN") + stream_name = self._get_param("StreamName") + retention_period_hours = self._get_param("RetentionPeriodHours") self.kinesis_backend.increase_stream_retention_period( stream_arn, stream_name, retention_period_hours ) return "" - def decrease_stream_retention_period(self): - stream_arn = self.parameters.get("StreamARN") - stream_name = self.parameters.get("StreamName") - retention_period_hours = self.parameters.get("RetentionPeriodHours") + def decrease_stream_retention_period(self) -> str: + stream_arn = self._get_param("StreamARN") + stream_name = self._get_param("StreamName") + retention_period_hours = self._get_param("RetentionPeriodHours") self.kinesis_backend.decrease_stream_retention_period( stream_arn, stream_name, retention_period_hours ) return "" - def add_tags_to_stream(self): - stream_arn = self.parameters.get("StreamARN") - stream_name = self.parameters.get("StreamName") - tags = self.parameters.get("Tags") + def add_tags_to_stream(self) -> str: + stream_arn = self._get_param("StreamARN") + stream_name = self._get_param("StreamName") + tags = self._get_param("Tags") self.kinesis_backend.add_tags_to_stream(stream_arn, stream_name, tags) return json.dumps({}) - def list_tags_for_stream(self): - stream_arn = self.parameters.get("StreamARN") - stream_name = self.parameters.get("StreamName") - exclusive_start_tag_key = self.parameters.get("ExclusiveStartTagKey") - limit = self.parameters.get("Limit") + def list_tags_for_stream(self) -> str: + stream_arn = self._get_param("StreamARN") + stream_name = self._get_param("StreamName") + exclusive_start_tag_key = self._get_param("ExclusiveStartTagKey") + limit = self._get_param("Limit") response = self.kinesis_backend.list_tags_for_stream( stream_arn, stream_name, exclusive_start_tag_key, limit ) return json.dumps(response) - def remove_tags_from_stream(self): - stream_arn = self.parameters.get("StreamARN") - stream_name = self.parameters.get("StreamName") - tag_keys = self.parameters.get("TagKeys") + def remove_tags_from_stream(self) -> str: + stream_arn = self._get_param("StreamARN") + stream_name = self._get_param("StreamName") + tag_keys = self._get_param("TagKeys") self.kinesis_backend.remove_tags_from_stream(stream_arn, stream_name, tag_keys) return json.dumps({}) - def enable_enhanced_monitoring(self): - stream_arn = self.parameters.get("StreamARN") - stream_name = self.parameters.get("StreamName") - shard_level_metrics = self.parameters.get("ShardLevelMetrics") + def enable_enhanced_monitoring(self) -> str: + stream_arn = self._get_param("StreamARN") + stream_name = self._get_param("StreamName") + shard_level_metrics = self._get_param("ShardLevelMetrics") arn, name, current, desired = self.kinesis_backend.enable_enhanced_monitoring( stream_arn=stream_arn, stream_name=stream_name, @@ -242,10 +238,10 @@ class KinesisResponse(BaseResponse): ) ) - def disable_enhanced_monitoring(self): - stream_arn = self.parameters.get("StreamARN") - stream_name = self.parameters.get("StreamName") - shard_level_metrics = self.parameters.get("ShardLevelMetrics") + def disable_enhanced_monitoring(self) -> str: + stream_arn = self._get_param("StreamARN") + stream_name = self._get_param("StreamName") + shard_level_metrics = self._get_param("ShardLevelMetrics") arn, name, current, desired = self.kinesis_backend.disable_enhanced_monitoring( stream_arn=stream_arn, stream_name=stream_name, @@ -260,23 +256,23 @@ class KinesisResponse(BaseResponse): ) ) - def list_stream_consumers(self): - stream_arn = self.parameters.get("StreamARN") + def list_stream_consumers(self) -> str: + stream_arn = self._get_param("StreamARN") consumers = self.kinesis_backend.list_stream_consumers(stream_arn=stream_arn) return json.dumps(dict(Consumers=[c.to_json() for c in consumers])) - def register_stream_consumer(self): - stream_arn = self.parameters.get("StreamARN") - consumer_name = self.parameters.get("ConsumerName") + def register_stream_consumer(self) -> str: + stream_arn = self._get_param("StreamARN") + consumer_name = self._get_param("ConsumerName") consumer = self.kinesis_backend.register_stream_consumer( stream_arn=stream_arn, consumer_name=consumer_name ) return json.dumps(dict(Consumer=consumer.to_json())) - def describe_stream_consumer(self): - stream_arn = self.parameters.get("StreamARN") - consumer_name = self.parameters.get("ConsumerName") - consumer_arn = self.parameters.get("ConsumerARN") + def describe_stream_consumer(self) -> str: + stream_arn = self._get_param("StreamARN") + consumer_name = self._get_param("ConsumerName") + consumer_arn = self._get_param("ConsumerARN") consumer = self.kinesis_backend.describe_stream_consumer( stream_arn=stream_arn, consumer_name=consumer_name, @@ -286,10 +282,10 @@ class KinesisResponse(BaseResponse): dict(ConsumerDescription=consumer.to_json(include_stream_arn=True)) ) - def deregister_stream_consumer(self): - stream_arn = self.parameters.get("StreamARN") - consumer_name = self.parameters.get("ConsumerName") - consumer_arn = self.parameters.get("ConsumerARN") + def deregister_stream_consumer(self) -> str: + stream_arn = self._get_param("StreamARN") + consumer_name = self._get_param("ConsumerName") + consumer_arn = self._get_param("ConsumerARN") self.kinesis_backend.deregister_stream_consumer( stream_arn=stream_arn, consumer_name=consumer_name, @@ -297,11 +293,11 @@ class KinesisResponse(BaseResponse): ) return json.dumps(dict()) - def start_stream_encryption(self): - stream_arn = self.parameters.get("StreamARN") - stream_name = self.parameters.get("StreamName") - encryption_type = self.parameters.get("EncryptionType") - key_id = self.parameters.get("KeyId") + def start_stream_encryption(self) -> str: + stream_arn = self._get_param("StreamARN") + stream_name = self._get_param("StreamName") + encryption_type = self._get_param("EncryptionType") + key_id = self._get_param("KeyId") self.kinesis_backend.start_stream_encryption( stream_arn=stream_arn, stream_name=stream_name, @@ -310,16 +306,16 @@ class KinesisResponse(BaseResponse): ) return json.dumps(dict()) - def stop_stream_encryption(self): - stream_arn = self.parameters.get("StreamARN") - stream_name = self.parameters.get("StreamName") + def stop_stream_encryption(self) -> str: + stream_arn = self._get_param("StreamARN") + stream_name = self._get_param("StreamName") self.kinesis_backend.stop_stream_encryption( stream_arn=stream_arn, stream_name=stream_name ) return json.dumps(dict()) - def update_stream_mode(self): - stream_arn = self.parameters.get("StreamARN") - stream_mode = self.parameters.get("StreamModeDetails") + def update_stream_mode(self) -> str: + stream_arn = self._get_param("StreamARN") + stream_mode = self._get_param("StreamModeDetails") self.kinesis_backend.update_stream_mode(stream_arn, stream_mode) return "{}" diff --git a/moto/kinesis/utils.py b/moto/kinesis/utils.py index 849d2a2e8..56b0dfa67 100644 --- a/moto/kinesis/utils.py +++ b/moto/kinesis/utils.py @@ -1,4 +1,6 @@ import base64 +from datetime import datetime +from typing import Any, Optional, List from .exceptions import InvalidArgumentError @@ -30,8 +32,12 @@ PAGINATION_MODEL = { def compose_new_shard_iterator( - stream_name, shard, shard_iterator_type, starting_sequence_number, at_timestamp -): + stream_name: Optional[str], + shard: Any, + shard_iterator_type: str, + starting_sequence_number: int, + at_timestamp: datetime, +) -> str: if shard_iterator_type == "AT_SEQUENCE_NUMBER": last_sequence_id = int(starting_sequence_number) - 1 elif shard_iterator_type == "AFTER_SEQUENCE_NUMBER": @@ -47,11 +53,13 @@ def compose_new_shard_iterator( return compose_shard_iterator(stream_name, shard, last_sequence_id) -def compose_shard_iterator(stream_name, shard, last_sequence_id): +def compose_shard_iterator( + stream_name: Optional[str], shard: Any, last_sequence_id: int +) -> str: return encode_method( f"{stream_name}:{shard.shard_id}:{last_sequence_id}".encode("utf-8") ).decode("utf-8") -def decompose_shard_iterator(shard_iterator): +def decompose_shard_iterator(shard_iterator: str) -> List[str]: return decode_method(shard_iterator.encode("utf-8")).decode("utf-8").split(":") diff --git a/moto/kinesisvideo/exceptions.py b/moto/kinesisvideo/exceptions.py index 423b10556..13429ef9d 100644 --- a/moto/kinesisvideo/exceptions.py +++ b/moto/kinesisvideo/exceptions.py @@ -6,7 +6,7 @@ class KinesisvideoClientError(RESTError): class ResourceNotFoundException(KinesisvideoClientError): - def __init__(self): + def __init__(self) -> None: self.code = 404 super().__init__( "ResourceNotFoundException", @@ -15,6 +15,6 @@ class ResourceNotFoundException(KinesisvideoClientError): class ResourceInUseException(KinesisvideoClientError): - def __init__(self, message): + def __init__(self, message: str): self.code = 400 super().__init__("ResourceInUseException", message) diff --git a/moto/kinesisvideo/models.py b/moto/kinesisvideo/models.py index 8653016e1..7edb15137 100644 --- a/moto/kinesisvideo/models.py +++ b/moto/kinesisvideo/models.py @@ -1,5 +1,6 @@ -from moto.core import BaseBackend, BackendDict, BaseModel from datetime import datetime +from typing import Any, Dict, List +from moto.core import BaseBackend, BackendDict, BaseModel from .exceptions import ResourceNotFoundException, ResourceInUseException from moto.moto_api._internal import mock_random as random @@ -7,14 +8,14 @@ from moto.moto_api._internal import mock_random as random class Stream(BaseModel): def __init__( self, - account_id, - region_name, - device_name, - stream_name, - media_type, - kms_key_id, - data_retention_in_hours, - tags, + account_id: str, + region_name: str, + device_name: str, + stream_name: str, + media_type: str, + kms_key_id: str, + data_retention_in_hours: int, + tags: Dict[str, str], ): self.region_name = region_name self.stream_name = stream_name @@ -30,11 +31,11 @@ class Stream(BaseModel): self.data_endpoint_number = random.get_random_hex() self.arn = stream_arn - def get_data_endpoint(self, api_name): + def get_data_endpoint(self, api_name: str) -> str: data_endpoint_prefix = "s-" if api_name in ("PUT_MEDIA", "GET_MEDIA") else "b-" return f"https://{data_endpoint_prefix}{self.data_endpoint_number}.kinesisvideo.{self.region_name}.amazonaws.com" - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return { "DeviceName": self.device_name, "StreamName": self.stream_name, @@ -49,19 +50,19 @@ class Stream(BaseModel): class KinesisVideoBackend(BaseBackend): - def __init__(self, region_name, account_id): + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) - self.streams = {} + self.streams: Dict[str, Stream] = {} def create_stream( self, - device_name, - stream_name, - media_type, - kms_key_id, - data_retention_in_hours, - tags, - ): + device_name: str, + stream_name: str, + media_type: str, + kms_key_id: str, + data_retention_in_hours: int, + tags: Dict[str, str], + ) -> str: streams = [_ for _ in self.streams.values() if _.stream_name == stream_name] if len(streams) > 0: raise ResourceInUseException(f"The stream {stream_name} already exists.") @@ -78,7 +79,7 @@ class KinesisVideoBackend(BaseBackend): self.streams[stream.arn] = stream return stream.arn - def _get_stream(self, stream_name, stream_arn): + def _get_stream(self, stream_name: str, stream_arn: str) -> Stream: if stream_name: streams = [_ for _ in self.streams.values() if _.stream_name == stream_name] if len(streams) == 0: @@ -90,20 +91,17 @@ class KinesisVideoBackend(BaseBackend): raise ResourceNotFoundException() return stream - def describe_stream(self, stream_name, stream_arn): + def describe_stream(self, stream_name: str, stream_arn: str) -> Dict[str, Any]: stream = self._get_stream(stream_name, stream_arn) - stream_info = stream.to_dict() - return stream_info + return stream.to_dict() - def list_streams(self): + def list_streams(self) -> List[Dict[str, Any]]: """ Pagination and the StreamNameCondition-parameter are not yet implemented """ - stream_info_list = [_.to_dict() for _ in self.streams.values()] - next_token = None - return stream_info_list, next_token + return [_.to_dict() for _ in self.streams.values()] - def delete_stream(self, stream_arn): + def delete_stream(self, stream_arn: str) -> None: """ The CurrentVersion-parameter is not yet implemented """ @@ -112,11 +110,11 @@ class KinesisVideoBackend(BaseBackend): raise ResourceNotFoundException() del self.streams[stream_arn] - def get_data_endpoint(self, stream_name, stream_arn, api_name): + def get_data_endpoint( + self, stream_name: str, stream_arn: str, api_name: str + ) -> str: stream = self._get_stream(stream_name, stream_arn) return stream.get_data_endpoint(api_name) - # add methods from here - kinesisvideo_backends = BackendDict(KinesisVideoBackend, "kinesisvideo") diff --git a/moto/kinesisvideo/responses.py b/moto/kinesisvideo/responses.py index 98a596817..f9cdf25c4 100644 --- a/moto/kinesisvideo/responses.py +++ b/moto/kinesisvideo/responses.py @@ -1,17 +1,17 @@ from moto.core.responses import BaseResponse -from .models import kinesisvideo_backends +from .models import kinesisvideo_backends, KinesisVideoBackend import json class KinesisVideoResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="kinesisvideo") @property - def kinesisvideo_backend(self): + def kinesisvideo_backend(self) -> KinesisVideoBackend: return kinesisvideo_backends[self.current_account][self.region] - def create_stream(self): + def create_stream(self) -> str: device_name = self._get_param("DeviceName") stream_name = self._get_param("StreamName") media_type = self._get_param("MediaType") @@ -28,7 +28,7 @@ class KinesisVideoResponse(BaseResponse): ) return json.dumps(dict(StreamARN=stream_arn)) - def describe_stream(self): + def describe_stream(self) -> str: stream_name = self._get_param("StreamName") stream_arn = self._get_param("StreamARN") stream_info = self.kinesisvideo_backend.describe_stream( @@ -36,16 +36,16 @@ class KinesisVideoResponse(BaseResponse): ) return json.dumps(dict(StreamInfo=stream_info)) - def list_streams(self): - stream_info_list, next_token = self.kinesisvideo_backend.list_streams() - return json.dumps(dict(StreamInfoList=stream_info_list, NextToken=next_token)) + def list_streams(self) -> str: + stream_info_list = self.kinesisvideo_backend.list_streams() + return json.dumps(dict(StreamInfoList=stream_info_list, NextToken=None)) - def delete_stream(self): + def delete_stream(self) -> str: stream_arn = self._get_param("StreamARN") self.kinesisvideo_backend.delete_stream(stream_arn=stream_arn) return json.dumps(dict()) - def get_data_endpoint(self): + def get_data_endpoint(self) -> str: stream_name = self._get_param("StreamName") stream_arn = self._get_param("StreamARN") api_name = self._get_param("APIName") diff --git a/moto/kinesisvideoarchivedmedia/models.py b/moto/kinesisvideoarchivedmedia/models.py index 20ca92770..3d7ade895 100644 --- a/moto/kinesisvideoarchivedmedia/models.py +++ b/moto/kinesisvideoarchivedmedia/models.py @@ -1,14 +1,17 @@ +from typing import Tuple from moto.core import BaseBackend, BackendDict -from moto.kinesisvideo import kinesisvideo_backends +from moto.kinesisvideo.models import kinesisvideo_backends, KinesisVideoBackend from moto.sts.utils import random_session_token class KinesisVideoArchivedMediaBackend(BaseBackend): @property - def backend(self): + def backend(self) -> KinesisVideoBackend: return kinesisvideo_backends[self.account_id][self.region_name] - def _get_streaming_url(self, stream_name, stream_arn, api_name): + def _get_streaming_url( + self, stream_name: str, stream_arn: str, api_name: str + ) -> str: stream = self.backend._get_stream(stream_name, stream_arn) data_endpoint = stream.get_data_endpoint(api_name) session_token = random_session_token() @@ -19,19 +22,17 @@ class KinesisVideoArchivedMediaBackend(BaseBackend): relative_path = api_to_relative_path[api_name] return f"{data_endpoint}{relative_path}?SessionToken={session_token}" - def get_hls_streaming_session_url(self, stream_name, stream_arn): - # Ignore option paramters as the format of hls_url does't depends on them + def get_hls_streaming_session_url(self, stream_name: str, stream_arn: str) -> str: + # Ignore option paramters as the format of hls_url doesn't depend on them api_name = "GET_HLS_STREAMING_SESSION_URL" - url = self._get_streaming_url(stream_name, stream_arn, api_name) - return url + return self._get_streaming_url(stream_name, stream_arn, api_name) - def get_dash_streaming_session_url(self, stream_name, stream_arn): - # Ignore option paramters as the format of hls_url does't depends on them + def get_dash_streaming_session_url(self, stream_name: str, stream_arn: str) -> str: + # Ignore option paramters as the format of hls_url doesn't depend on them api_name = "GET_DASH_STREAMING_SESSION_URL" - url = self._get_streaming_url(stream_name, stream_arn, api_name) - return url + return self._get_streaming_url(stream_name, stream_arn, api_name) - def get_clip(self, stream_name, stream_arn): + def get_clip(self, stream_name: str, stream_arn: str) -> Tuple[str, bytes]: self.backend._get_stream(stream_name, stream_arn) content_type = "video/mp4" # Fixed content_type as it depends on input stream payload = b"sample-mp4-video" diff --git a/moto/kinesisvideoarchivedmedia/responses.py b/moto/kinesisvideoarchivedmedia/responses.py index e86824b46..22db6fe48 100644 --- a/moto/kinesisvideoarchivedmedia/responses.py +++ b/moto/kinesisvideoarchivedmedia/responses.py @@ -1,17 +1,18 @@ +from typing import Dict, Tuple from moto.core.responses import BaseResponse -from .models import kinesisvideoarchivedmedia_backends +from .models import kinesisvideoarchivedmedia_backends, KinesisVideoArchivedMediaBackend import json class KinesisVideoArchivedMediaResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="kinesis-video-archived-media") @property - def kinesisvideoarchivedmedia_backend(self): + def kinesisvideoarchivedmedia_backend(self) -> KinesisVideoArchivedMediaBackend: return kinesisvideoarchivedmedia_backends[self.current_account][self.region] - def get_hls_streaming_session_url(self): + def get_hls_streaming_session_url(self) -> str: stream_name = self._get_param("StreamName") stream_arn = self._get_param("StreamARN") hls_streaming_session_url = ( @@ -21,7 +22,7 @@ class KinesisVideoArchivedMediaResponse(BaseResponse): ) return json.dumps(dict(HLSStreamingSessionURL=hls_streaming_session_url)) - def get_dash_streaming_session_url(self): + def get_dash_streaming_session_url(self) -> str: stream_name = self._get_param("StreamName") stream_arn = self._get_param("StreamARN") dash_streaming_session_url = ( @@ -31,7 +32,7 @@ class KinesisVideoArchivedMediaResponse(BaseResponse): ) return json.dumps(dict(DASHStreamingSessionURL=dash_streaming_session_url)) - def get_clip(self): + def get_clip(self) -> Tuple[bytes, Dict[str, str]]: stream_name = self._get_param("StreamName") stream_arn = self._get_param("StreamARN") content_type, payload = self.kinesisvideoarchivedmedia_backend.get_clip( diff --git a/moto/kms/exceptions.py b/moto/kms/exceptions.py index 8808693a7..1ebf0d123 100644 --- a/moto/kms/exceptions.py +++ b/moto/kms/exceptions.py @@ -4,29 +4,29 @@ from moto.core.exceptions import JsonRESTError class NotFoundException(JsonRESTError): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__("NotFoundException", message) class ValidationException(JsonRESTError): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__("ValidationException", message) class AlreadyExistsException(JsonRESTError): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__("AlreadyExistsException", message) class NotAuthorizedException(JsonRESTError): code = 400 - def __init__(self): - super().__init__("NotAuthorizedException", None) + def __init__(self) -> None: + super().__init__("NotAuthorizedException", "") self.description = '{"__type":"NotAuthorizedException"}' @@ -34,7 +34,7 @@ class NotAuthorizedException(JsonRESTError): class AccessDeniedException(JsonRESTError): code = 400 - def __init__(self, message): + def __init__(self, message: str): super().__init__("AccessDeniedException", message) self.description = '{"__type":"AccessDeniedException"}' @@ -43,7 +43,7 @@ class AccessDeniedException(JsonRESTError): class InvalidCiphertextException(JsonRESTError): code = 400 - def __init__(self): - super().__init__("InvalidCiphertextException", None) + def __init__(self) -> None: + super().__init__("InvalidCiphertextException", "") self.description = '{"__type":"InvalidCiphertextException"}' diff --git a/moto/kms/models.py b/moto/kms/models.py index 0ce417126..0b1f4a2b2 100644 --- a/moto/kms/models.py +++ b/moto/kms/models.py @@ -5,8 +5,8 @@ from copy import copy from datetime import datetime, timedelta from cryptography.exceptions import InvalidSignature from cryptography.hazmat.primitives import hashes - from cryptography.hazmat.primitives.asymmetric import padding +from typing import Any, Dict, List, Tuple, Optional, Iterable, Set from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel from moto.core.utils import unix_time @@ -28,12 +28,12 @@ from .utils import ( class Grant(BaseModel): def __init__( self, - key_id, - name, - grantee_principal, - operations, - constraints, - retiring_principal, + key_id: str, + name: str, + grantee_principal: str, + operations: List[str], + constraints: Dict[str, Any], + retiring_principal: str, ): self.key_id = key_id self.name = name @@ -44,7 +44,7 @@ class Grant(BaseModel): self.id = mock_random.get_random_hex() self.token = mock_random.get_random_hex() - def to_json(self): + def to_json(self) -> Dict[str, Any]: return { "KeyId": self.key_id, "GrantId": self.id, @@ -59,13 +59,13 @@ class Grant(BaseModel): class Key(CloudFormationModel): def __init__( self, - policy, - key_usage, - key_spec, - description, - account_id, - region, - multi_region=False, + policy: Optional[str], + key_usage: str, + key_spec: str, + description: str, + account_id: str, + region: str, + multi_region: bool = False, ): self.id = generate_key_id(multi_region) self.creation_date = unix_time() @@ -78,7 +78,7 @@ class Key(CloudFormationModel): self.region = region self.multi_region = multi_region self.key_rotation_status = False - self.deletion_date = None + self.deletion_date: Optional[datetime] = None self.key_material = generate_master_key() self.private_key = generate_private_key() self.origin = "AWS_KMS" @@ -86,10 +86,15 @@ class Key(CloudFormationModel): self.key_spec = key_spec or "SYMMETRIC_DEFAULT" self.arn = f"arn:aws:kms:{region}:{account_id}:key/{self.id}" - self.grants = dict() + self.grants: Dict[str, Grant] = dict() def add_grant( - self, name, grantee_principal, operations, constraints, retiring_principal + self, + name: str, + grantee_principal: str, + operations: List[str], + constraints: Dict[str, Any], + retiring_principal: str, ) -> Grant: grant = Grant( self.id, @@ -102,32 +107,32 @@ class Key(CloudFormationModel): self.grants[grant.id] = grant return grant - def list_grants(self, grant_id) -> [Grant]: + def list_grants(self, grant_id: str) -> List[Grant]: grant_ids = [grant_id] if grant_id else self.grants.keys() return [grant for _id, grant in self.grants.items() if _id in grant_ids] - def list_retirable_grants(self, retiring_principal) -> [Grant]: + def list_retirable_grants(self, retiring_principal: str) -> List[Grant]: return [ grant for grant in self.grants.values() if grant.retiring_principal == retiring_principal ] - def revoke_grant(self, grant_id) -> None: + def revoke_grant(self, grant_id: str) -> None: if not self.grants.pop(grant_id, None): raise JsonRESTError("NotFoundException", f"Grant ID {grant_id} not found") - def retire_grant(self, grant_id) -> None: + def retire_grant(self, grant_id: str) -> None: self.grants.pop(grant_id, None) - def retire_grant_by_token(self, grant_token) -> None: + def retire_grant_by_token(self, grant_token: str) -> None: self.grants = { _id: grant for _id, grant in self.grants.items() if grant.token != grant_token } - def generate_default_policy(self): + def generate_default_policy(self) -> str: return json.dumps( { "Version": "2012-10-17", @@ -145,11 +150,11 @@ class Key(CloudFormationModel): ) @property - def physical_resource_id(self): + def physical_resource_id(self) -> str: return self.id @property - def encryption_algorithms(self): + def encryption_algorithms(self) -> Optional[List[str]]: if self.key_usage == "SIGN_VERIFY": return None elif self.key_spec == "SYMMETRIC_DEFAULT": @@ -158,9 +163,9 @@ class Key(CloudFormationModel): return ["RSAES_OAEP_SHA_1", "RSAES_OAEP_SHA_256"] @property - def signing_algorithms(self): + def signing_algorithms(self) -> List[str]: if self.key_usage == "ENCRYPT_DECRYPT": - return None + return None # type: ignore[return-value] elif self.key_spec in ["ECC_NIST_P256", "ECC_SECG_P256K1"]: return ["ECDSA_SHA_256"] elif self.key_spec == "ECC_NIST_P384": @@ -177,7 +182,7 @@ class Key(CloudFormationModel): "RSASSA_PSS_SHA_512", ] - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: key_dict = { "KeyMetadata": { "AWSAccountId": self.account_id, @@ -201,22 +206,27 @@ class Key(CloudFormationModel): key_dict["KeyMetadata"]["DeletionDate"] = unix_time(self.deletion_date) return key_dict - def delete(self, account_id, region_name): + def delete(self, account_id: str, region_name: str) -> None: kms_backends[account_id][region_name].delete_key(self.id) @staticmethod - def cloudformation_name_type(): - return None + def cloudformation_name_type() -> str: + return "" @staticmethod - def cloudformation_type(): + def cloudformation_type() -> str: # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-kms-key.html return "AWS::KMS::Key" @classmethod - def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, account_id, region_name, **kwargs - ): + def create_from_cloudformation_json( # type: ignore[misc] + cls, + resource_name: str, + cloudformation_json: Any, + account_id: str, + region_name: str, + **kwargs: Any, + ) -> "Key": kms_backend = kms_backends[account_id][region_name] properties = cloudformation_json["Properties"] @@ -233,10 +243,10 @@ class Key(CloudFormationModel): return key @classmethod - def has_cfn_attr(cls, attr): + def has_cfn_attr(cls, attr: str) -> bool: return attr in ["Arn"] - def get_cfn_attribute(self, attribute_name): + def get_cfn_attribute(self, attribute_name: str) -> str: from moto.cloudformation.exceptions import UnformattedGetAttTemplateException if attribute_name == "Arn": @@ -245,20 +255,22 @@ class Key(CloudFormationModel): class KmsBackend(BaseBackend): - def __init__(self, region_name, account_id=None): - super().__init__(region_name=region_name, account_id=account_id) - self.keys = {} - self.key_to_aliases = defaultdict(set) + def __init__(self, region_name: str, account_id: Optional[str] = None): + super().__init__(region_name=region_name, account_id=account_id) # type: ignore + self.keys: Dict[str, Key] = {} + self.key_to_aliases: Dict[str, Set[str]] = defaultdict(set) self.tagger = TaggingService(key_name="TagKey", value_name="TagValue") @staticmethod - def default_vpc_endpoint_service(service_region, zones): + def default_vpc_endpoint_service( + service_region: str, zones: List[str] + ) -> List[Dict[str, str]]: """Default VPC endpoint service.""" return BaseBackend.default_vpc_endpoint_service_factory( service_region, zones, "kms" ) - def _generate_default_keys(self, alias_name): + def _generate_default_keys(self, alias_name: str) -> Optional[str]: """Creates default kms keys""" if alias_name in RESERVED_ALIASES: key = self.create_key( @@ -270,10 +282,17 @@ class KmsBackend(BaseBackend): ) self.add_alias(key.id, alias_name) return key.id + return None def create_key( - self, policy, key_usage, key_spec, description, tags, multi_region=False - ): + self, + policy: Optional[str], + key_usage: str, + key_spec: str, + description: str, + tags: Optional[List[Dict[str, str]]], + multi_region: bool = False, + ) -> Key: """ The provided Policy currently does not need to be valid. If it is valid, Moto will perform authorization checks on key-related operations, just like AWS does. @@ -303,7 +322,7 @@ class KmsBackend(BaseBackend): # # In our implementation with just create a copy of all the properties once without any protection from change, # as the exact implementation is currently infeasible. - def replicate_key(self, key_id, replica_region): + def replicate_key(self, key_id: str, replica_region: str) -> None: # Using copy() instead of deepcopy(), as the latter results in exception: # TypeError: cannot pickle '_cffi_backend.FFI' object # Since we only update top level properties, copy() should suffice. @@ -312,31 +331,31 @@ class KmsBackend(BaseBackend): to_region_backend = kms_backends[self.account_id][replica_region] to_region_backend.keys[replica_key.id] = replica_key - def update_key_description(self, key_id, description): + def update_key_description(self, key_id: str, description: str) -> None: key = self.keys[self.get_key_id(key_id)] key.description = description - def delete_key(self, key_id): + def delete_key(self, key_id: str) -> None: if key_id in self.keys: if key_id in self.key_to_aliases: self.key_to_aliases.pop(key_id) self.tagger.delete_all_tags_for_resource(key_id) - return self.keys.pop(key_id) + self.keys.pop(key_id) - def describe_key(self, key_id) -> Key: + def describe_key(self, key_id: str) -> Key: # allow the different methods (alias, ARN :key/, keyId, ARN alias) to # describe key not just KeyId key_id = self.get_key_id(key_id) if r"alias/" in str(key_id).lower(): - key_id = self.get_key_id_from_alias(key_id) + key_id = self.get_key_id_from_alias(key_id) # type: ignore[assignment] return self.keys[self.get_key_id(key_id)] - def list_keys(self): + def list_keys(self) -> Iterable[Key]: return self.keys.values() @staticmethod - def get_key_id(key_id): + def get_key_id(key_id: str) -> str: # Allow use of ARN as well as pure KeyId if key_id.startswith("arn:") and ":key/" in key_id: return key_id.split(":key/")[1] @@ -344,14 +363,14 @@ class KmsBackend(BaseBackend): return key_id @staticmethod - def get_alias_name(alias_name): + def get_alias_name(alias_name: str) -> str: # Allow use of ARN as well as alias name if alias_name.startswith("arn:") and ":alias/" in alias_name: return "alias/" + alias_name.split(":alias/")[1] return alias_name - def any_id_to_key_id(self, key_id): + def any_id_to_key_id(self, key_id: str) -> str: """Go from any valid key ID to the raw key ID. Acceptable inputs: @@ -363,66 +382,65 @@ class KmsBackend(BaseBackend): key_id = self.get_alias_name(key_id) key_id = self.get_key_id(key_id) if key_id.startswith("alias/"): - key_id = self.get_key_id(self.get_key_id_from_alias(key_id)) + key_id = self.get_key_id(self.get_key_id_from_alias(key_id)) # type: ignore[arg-type] return key_id - def alias_exists(self, alias_name): + def alias_exists(self, alias_name: str) -> bool: for aliases in self.key_to_aliases.values(): if alias_name in aliases: return True return False - def add_alias(self, target_key_id, alias_name): + def add_alias(self, target_key_id: str, alias_name: str) -> None: raw_key_id = self.get_key_id(target_key_id) self.key_to_aliases[raw_key_id].add(alias_name) - def delete_alias(self, alias_name): + def delete_alias(self, alias_name: str) -> None: """Delete the alias.""" for aliases in self.key_to_aliases.values(): if alias_name in aliases: aliases.remove(alias_name) - def get_all_aliases(self): + def get_all_aliases(self) -> Dict[str, Set[str]]: return self.key_to_aliases - def get_key_id_from_alias(self, alias_name): + def get_key_id_from_alias(self, alias_name: str) -> Optional[str]: for key_id, aliases in dict(self.key_to_aliases).items(): if alias_name in ",".join(aliases): return key_id if alias_name in RESERVED_ALIASES: - key_id = self._generate_default_keys(alias_name) - return key_id + return self._generate_default_keys(alias_name) return None - def enable_key_rotation(self, key_id): + def enable_key_rotation(self, key_id: str) -> None: self.keys[self.get_key_id(key_id)].key_rotation_status = True - def disable_key_rotation(self, key_id): + def disable_key_rotation(self, key_id: str) -> None: self.keys[self.get_key_id(key_id)].key_rotation_status = False - def get_key_rotation_status(self, key_id): + def get_key_rotation_status(self, key_id: str) -> bool: return self.keys[self.get_key_id(key_id)].key_rotation_status - def put_key_policy(self, key_id, policy): + def put_key_policy(self, key_id: str, policy: str) -> None: self.keys[self.get_key_id(key_id)].policy = policy - def get_key_policy(self, key_id): + def get_key_policy(self, key_id: str) -> str: return self.keys[self.get_key_id(key_id)].policy - def disable_key(self, key_id): + def disable_key(self, key_id: str) -> None: self.keys[key_id].enabled = False self.keys[key_id].key_state = "Disabled" - def enable_key(self, key_id): + def enable_key(self, key_id: str) -> None: self.keys[key_id].enabled = True self.keys[key_id].key_state = "Enabled" - def cancel_key_deletion(self, key_id): + def cancel_key_deletion(self, key_id: str) -> None: self.keys[key_id].key_state = "Disabled" self.keys[key_id].deletion_date = None - def schedule_key_deletion(self, key_id, pending_window_in_days): + def schedule_key_deletion(self, key_id: str, pending_window_in_days: int) -> float: # type: ignore[return] if 7 <= pending_window_in_days <= 30: self.keys[key_id].enabled = False self.keys[key_id].key_state = "PendingDeletion" @@ -431,7 +449,9 @@ class KmsBackend(BaseBackend): ) return unix_time(self.keys[key_id].deletion_date) - def encrypt(self, key_id, plaintext, encryption_context): + def encrypt( + self, key_id: str, plaintext: bytes, encryption_context: Dict[str, str] + ) -> Tuple[bytes, str]: key_id = self.any_id_to_key_id(key_id) ciphertext_blob = encrypt( @@ -443,7 +463,9 @@ class KmsBackend(BaseBackend): arn = self.keys[key_id].arn return ciphertext_blob, arn - def decrypt(self, ciphertext_blob, encryption_context): + def decrypt( + self, ciphertext_blob: bytes, encryption_context: Dict[str, str] + ) -> Tuple[bytes, str]: plaintext, key_id = decrypt( master_keys=self.keys, ciphertext_blob=ciphertext_blob, @@ -454,11 +476,11 @@ class KmsBackend(BaseBackend): def re_encrypt( self, - ciphertext_blob, - source_encryption_context, - destination_key_id, - destination_encryption_context, - ): + ciphertext_blob: bytes, + source_encryption_context: Dict[str, str], + destination_key_id: str, + destination_encryption_context: Dict[str, str], + ) -> Tuple[bytes, str, str]: destination_key_id = self.any_id_to_key_id(destination_key_id) plaintext, decrypting_arn = self.decrypt( @@ -472,7 +494,13 @@ class KmsBackend(BaseBackend): ) return new_ciphertext_blob, decrypting_arn, encrypting_arn - def generate_data_key(self, key_id, encryption_context, number_of_bytes, key_spec): + def generate_data_key( + self, + key_id: str, + encryption_context: Dict[str, str], + number_of_bytes: int, + key_spec: str, + ) -> Tuple[bytes, bytes, str]: key_id = self.any_id_to_key_id(key_id) if key_spec: @@ -492,7 +520,7 @@ class KmsBackend(BaseBackend): return plaintext, ciphertext_blob, arn - def list_resource_tags(self, key_id_or_arn): + def list_resource_tags(self, key_id_or_arn: str) -> Dict[str, List[Dict[str, str]]]: key_id = self.get_key_id(key_id_or_arn) if key_id in self.keys: return self.tagger.list_tags_for_resource(key_id) @@ -501,21 +529,21 @@ class KmsBackend(BaseBackend): "The request was rejected because the specified entity or resource could not be found.", ) - def tag_resource(self, key_id_or_arn, tags): + def tag_resource(self, key_id_or_arn: str, tags: List[Dict[str, str]]) -> None: key_id = self.get_key_id(key_id_or_arn) if key_id in self.keys: self.tagger.tag_resource(key_id, tags) - return {} + return raise JsonRESTError( "NotFoundException", "The request was rejected because the specified entity or resource could not be found.", ) - def untag_resource(self, key_id_or_arn, tag_names): + def untag_resource(self, key_id_or_arn: str, tag_names: List[str]) -> None: key_id = self.get_key_id(key_id_or_arn) if key_id in self.keys: self.tagger.untag_resource_using_names(key_id, tag_names) - return {} + return raise JsonRESTError( "NotFoundException", "The request was rejected because the specified entity or resource could not be found.", @@ -523,13 +551,13 @@ class KmsBackend(BaseBackend): def create_grant( self, - key_id, - grantee_principal, - operations, - name, - constraints, - retiring_principal, - ): + key_id: str, + grantee_principal: str, + operations: List[str], + name: str, + constraints: Dict[str, Any], + retiring_principal: str, + ) -> Tuple[str, str]: key = self.describe_key(key_id) grant = key.add_grant( name, @@ -540,21 +568,21 @@ class KmsBackend(BaseBackend): ) return grant.id, grant.token - def list_grants(self, key_id, grant_id) -> [Grant]: + def list_grants(self, key_id: str, grant_id: str) -> List[Grant]: key = self.describe_key(key_id) return key.list_grants(grant_id) - def list_retirable_grants(self, retiring_principal): + def list_retirable_grants(self, retiring_principal: str) -> List[Grant]: grants = [] for key in self.keys.values(): grants.extend(key.list_retirable_grants(retiring_principal)) return grants - def revoke_grant(self, key_id, grant_id) -> None: + def revoke_grant(self, key_id: str, grant_id: str) -> None: key = self.describe_key(key_id) key.revoke_grant(grant_id) - def retire_grant(self, key_id, grant_id, grant_token) -> None: + def retire_grant(self, key_id: str, grant_id: str, grant_token: str) -> None: if grant_token: for key in self.keys.values(): key.retire_grant_by_token(grant_token) @@ -562,7 +590,7 @@ class KmsBackend(BaseBackend): key = self.describe_key(key_id) key.retire_grant(grant_id) - def __ensure_valid_sign_and_verify_key(self, key: Key): + def __ensure_valid_sign_and_verify_key(self, key: Key) -> None: if key.key_usage != "SIGN_VERIFY": raise ValidationException( ( @@ -571,7 +599,9 @@ class KmsBackend(BaseBackend): ).format(key_id=key.id) ) - def __ensure_valid_signing_augorithm(self, key: Key, signing_algorithm): + def __ensure_valid_signing_augorithm( + self, key: Key, signing_algorithm: str + ) -> None: if signing_algorithm not in key.signing_algorithms: raise ValidationException( ( @@ -584,7 +614,9 @@ class KmsBackend(BaseBackend): ) ) - def sign(self, key_id, message, signing_algorithm): + def sign( + self, key_id: str, message: bytes, signing_algorithm: str + ) -> Tuple[str, bytes, str]: """Sign message using generated private key. - signing_algorithm is ignored and hardcoded to RSASSA_PSS_SHA_256 @@ -607,7 +639,9 @@ class KmsBackend(BaseBackend): return key.arn, signature, signing_algorithm - def verify(self, key_id, message, signature, signing_algorithm): + def verify( + self, key_id: str, message: bytes, signature: bytes, signing_algorithm: str + ) -> Tuple[str, bool, str]: """Verify message using public key from generated private key. - signing_algorithm is ignored and hardcoded to RSASSA_PSS_SHA_256 diff --git a/moto/kms/policy_validator.py b/moto/kms/policy_validator.py index f278dbc9b..72392e8fc 100644 --- a/moto/kms/policy_validator.py +++ b/moto/kms/policy_validator.py @@ -1,4 +1,5 @@ from collections import defaultdict +from typing import Any, Dict, List import json from .models import Key from .exceptions import AccessDeniedException @@ -8,7 +9,7 @@ ALTERNATIVE_ACTIONS = defaultdict(list) ALTERNATIVE_ACTIONS["kms:DescribeKey"] = ["kms:*", "kms:Describe*", "kms:DescribeKey"] -def validate_policy(key: Key, action: str): +def validate_policy(key: Key, action: str) -> None: """ Relevant docs: - https://docs.aws.amazon.com/kms/latest/developerguide/key-policy-default.html @@ -29,20 +30,22 @@ def validate_policy(key: Key, action: str): ) -def check_statement(statement, resource, action): +def check_statement(statement: Dict[str, Any], resource: str, action: str) -> bool: return action_matches(statement.get("Action", []), action) and resource_matches( statement.get("Resource", ""), resource ) -def action_matches(applicable_actions, action): +def action_matches(applicable_actions: List[str], action: str) -> bool: alternatives = ALTERNATIVE_ACTIONS[action] if any(alt in applicable_actions for alt in alternatives): return True return False -def resource_matches(applicable_resources, resource): # pylint: disable=unused-argument +def resource_matches( + applicable_resources: str, resource: str # pylint: disable=unused-argument +) -> bool: if applicable_resources == "*": return True return False diff --git a/moto/kms/responses.py b/moto/kms/responses.py index b02c5cb27..e5a359d00 100644 --- a/moto/kms/responses.py +++ b/moto/kms/responses.py @@ -3,6 +3,7 @@ import json import os import re import warnings +from typing import Any, Dict from moto.core.responses import BaseResponse from moto.kms.utils import RESERVED_ALIASES, RESERVED_ALIASE_TARGET_KEY_IDS @@ -17,24 +18,23 @@ from .exceptions import ( class KmsResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="kms") - @property - def parameters(self): + def _get_param(self, param_name: str, if_none: Any = None) -> Any: # type: ignore params = json.loads(self.body) for key in ("Plaintext", "CiphertextBlob"): if key in params: params[key] = base64.b64decode(params[key].encode("utf-8")) - return params + return params.get(param_name, if_none) @property def kms_backend(self) -> KmsBackend: return kms_backends[self.current_account][self.region] - def _display_arn(self, key_id): + def _display_arn(self, key_id: str) -> str: if key_id.startswith("arn:"): return key_id @@ -45,7 +45,7 @@ class KmsResponse(BaseResponse): return f"arn:aws:kms:{self.region}:{self.current_account}:{id_type}{key_id}" - def _validate_cmk_id(self, key_id): + def _validate_cmk_id(self, key_id: str) -> None: """Determine whether a CMK ID exists. - raw key ID @@ -69,7 +69,7 @@ class KmsResponse(BaseResponse): if cmk_id not in self.kms_backend.keys: raise NotFoundException(f"Key '{self._display_arn(key_id)}' does not exist") - def _validate_alias(self, key_id): + def _validate_alias(self, key_id: str) -> None: """Determine whether an alias exists. - alias name @@ -88,8 +88,8 @@ class KmsResponse(BaseResponse): if cmk_id is None: raise error - def _validate_key_id(self, key_id): - """Determine whether or not a key ID exists. + def _validate_key_id(self, key_id: str) -> None: + """Determine whether a key ID exists. - raw key ID - key ARN @@ -105,77 +105,77 @@ class KmsResponse(BaseResponse): self._validate_cmk_id(key_id) - def _validate_key_policy(self, key_id, action): + def _validate_key_policy(self, key_id: str, action: str) -> None: """ Validate whether the specified action is allowed, given the key policy """ key = self.kms_backend.describe_key(self.kms_backend.get_key_id(key_id)) validate_policy(key, action) - def create_key(self): + def create_key(self) -> str: """https://docs.aws.amazon.com/kms/latest/APIReference/API_CreateKey.html""" - policy = self.parameters.get("Policy") - key_usage = self.parameters.get("KeyUsage") - key_spec = self.parameters.get("KeySpec") or self.parameters.get( + policy = self._get_param("Policy") + key_usage = self._get_param("KeyUsage") + key_spec = self._get_param("KeySpec") or self._get_param( "CustomerMasterKeySpec" ) - description = self.parameters.get("Description") - tags = self.parameters.get("Tags") - multi_region = self.parameters.get("MultiRegion") + description = self._get_param("Description") + tags = self._get_param("Tags") + multi_region = self._get_param("MultiRegion") key = self.kms_backend.create_key( policy, key_usage, key_spec, description, tags, multi_region ) return json.dumps(key.to_dict()) - def replicate_key(self): - key_id = self.parameters.get("KeyId") + def replicate_key(self) -> None: + key_id = self._get_param("KeyId") self._validate_key_id(key_id) - replica_region = self.parameters.get("ReplicaRegion") + replica_region = self._get_param("ReplicaRegion") self.kms_backend.replicate_key(key_id, replica_region) - def update_key_description(self): + def update_key_description(self) -> str: """https://docs.aws.amazon.com/kms/latest/APIReference/API_UpdateKeyDescription.html""" - key_id = self.parameters.get("KeyId") - description = self.parameters.get("Description") + key_id = self._get_param("KeyId") + description = self._get_param("Description") self._validate_cmk_id(key_id) self.kms_backend.update_key_description(key_id, description) return json.dumps(None) - def tag_resource(self): + def tag_resource(self) -> str: """https://docs.aws.amazon.com/kms/latest/APIReference/API_TagResource.html""" - key_id = self.parameters.get("KeyId") - tags = self.parameters.get("Tags") + key_id = self._get_param("KeyId") + tags = self._get_param("Tags") self._validate_cmk_id(key_id) - result = self.kms_backend.tag_resource(key_id, tags) - return json.dumps(result) + self.kms_backend.tag_resource(key_id, tags) + return "{}" - def untag_resource(self): + def untag_resource(self) -> str: """https://docs.aws.amazon.com/kms/latest/APIReference/API_UntagResource.html""" - key_id = self.parameters.get("KeyId") - tag_names = self.parameters.get("TagKeys") + key_id = self._get_param("KeyId") + tag_names = self._get_param("TagKeys") self._validate_cmk_id(key_id) - result = self.kms_backend.untag_resource(key_id, tag_names) - return json.dumps(result) + self.kms_backend.untag_resource(key_id, tag_names) + return "{}" - def list_resource_tags(self): + def list_resource_tags(self) -> str: """https://docs.aws.amazon.com/kms/latest/APIReference/API_ListResourceTags.html""" - key_id = self.parameters.get("KeyId") + key_id = self._get_param("KeyId") self._validate_cmk_id(key_id) - tags = self.kms_backend.list_resource_tags(key_id) + tags: Dict[str, Any] = self.kms_backend.list_resource_tags(key_id) tags.update({"NextMarker": None, "Truncated": False}) return json.dumps(tags) - def describe_key(self): + def describe_key(self) -> str: """https://docs.aws.amazon.com/kms/latest/APIReference/API_DescribeKey.html""" - key_id = self.parameters.get("KeyId") + key_id = self._get_param("KeyId") self._validate_key_id(key_id) self._validate_key_policy(key_id, "kms:DescribeKey") @@ -184,7 +184,7 @@ class KmsResponse(BaseResponse): return json.dumps(key.to_dict()) - def list_keys(self): + def list_keys(self) -> str: """https://docs.aws.amazon.com/kms/latest/APIReference/API_ListKeys.html""" keys = self.kms_backend.list_keys() @@ -196,17 +196,17 @@ class KmsResponse(BaseResponse): } ) - def create_alias(self): + def create_alias(self) -> str: """https://docs.aws.amazon.com/kms/latest/APIReference/API_CreateAlias.html""" return self._set_alias() - def update_alias(self): + def update_alias(self) -> str: """https://docs.aws.amazon.com/kms/latest/APIReference/API_UpdateAlias.html""" return self._set_alias(update=True) - def _set_alias(self, update=False): - alias_name = self.parameters["AliasName"] - target_key_id = self.parameters["TargetKeyId"] + def _set_alias(self, update: bool = False) -> str: + alias_name = self._get_param("AliasName") + target_key_id = self._get_param("TargetKeyId") if not alias_name.startswith("alias/"): raise ValidationException("Invalid identifier") @@ -243,9 +243,9 @@ class KmsResponse(BaseResponse): return json.dumps(None) - def delete_alias(self): + def delete_alias(self) -> str: """https://docs.aws.amazon.com/kms/latest/APIReference/API_DeleteAlias.html""" - alias_name = self.parameters["AliasName"] + alias_name = self._get_param("AliasName") if not alias_name.startswith("alias/"): raise ValidationException("Invalid identifier") @@ -256,10 +256,10 @@ class KmsResponse(BaseResponse): return json.dumps(None) - def list_aliases(self): + def list_aliases(self) -> str: """https://docs.aws.amazon.com/kms/latest/APIReference/API_ListAliases.html""" region = self.region - key_id = self.parameters.get("KeyId") + key_id = self._get_param("KeyId") if key_id is not None: self._validate_key_id(key_id) key_id = self.kms_backend.get_key_id(key_id) @@ -298,13 +298,13 @@ class KmsResponse(BaseResponse): return json.dumps({"Truncated": False, "Aliases": response_aliases}) - def create_grant(self): - key_id = self.parameters.get("KeyId") - grantee_principal = self.parameters.get("GranteePrincipal") - retiring_principal = self.parameters.get("RetiringPrincipal") - operations = self.parameters.get("Operations") - name = self.parameters.get("Name") - constraints = self.parameters.get("Constraints") + def create_grant(self) -> str: + key_id = self._get_param("KeyId") + grantee_principal = self._get_param("GranteePrincipal") + retiring_principal = self._get_param("RetiringPrincipal") + operations = self._get_param("Operations") + name = self._get_param("Name") + constraints = self._get_param("Constraints") grant_id, grant_token = self.kms_backend.create_grant( key_id, @@ -316,9 +316,9 @@ class KmsResponse(BaseResponse): ) return json.dumps({"GrantId": grant_id, "GrantToken": grant_token}) - def list_grants(self): - key_id = self.parameters.get("KeyId") - grant_id = self.parameters.get("GrantId") + def list_grants(self) -> str: + key_id = self._get_param("KeyId") + grant_id = self._get_param("GrantId") grants = self.kms_backend.list_grants(key_id=key_id, grant_id=grant_id) return json.dumps( @@ -329,8 +329,8 @@ class KmsResponse(BaseResponse): } ) - def list_retirable_grants(self): - retiring_principal = self.parameters.get("RetiringPrincipal") + def list_retirable_grants(self) -> str: + retiring_principal = self._get_param("RetiringPrincipal") grants = self.kms_backend.list_retirable_grants(retiring_principal) return json.dumps( @@ -341,24 +341,24 @@ class KmsResponse(BaseResponse): } ) - def revoke_grant(self): - key_id = self.parameters.get("KeyId") - grant_id = self.parameters.get("GrantId") + def revoke_grant(self) -> str: + key_id = self._get_param("KeyId") + grant_id = self._get_param("GrantId") self.kms_backend.revoke_grant(key_id, grant_id) return "{}" - def retire_grant(self): - key_id = self.parameters.get("KeyId") - grant_id = self.parameters.get("GrantId") - grant_token = self.parameters.get("GrantToken") + def retire_grant(self) -> str: + key_id = self._get_param("KeyId") + grant_id = self._get_param("GrantId") + grant_token = self._get_param("GrantToken") self.kms_backend.retire_grant(key_id, grant_id, grant_token) return "{}" - def enable_key_rotation(self): + def enable_key_rotation(self) -> str: """https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKeyRotation.html""" - key_id = self.parameters.get("KeyId") + key_id = self._get_param("KeyId") self._validate_cmk_id(key_id) @@ -366,9 +366,9 @@ class KmsResponse(BaseResponse): return json.dumps(None) - def disable_key_rotation(self): + def disable_key_rotation(self) -> str: """https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKeyRotation.html""" - key_id = self.parameters.get("KeyId") + key_id = self._get_param("KeyId") self._validate_cmk_id(key_id) @@ -376,9 +376,9 @@ class KmsResponse(BaseResponse): return json.dumps(None) - def get_key_rotation_status(self): + def get_key_rotation_status(self) -> str: """https://docs.aws.amazon.com/kms/latest/APIReference/API_GetKeyRotationStatus.html""" - key_id = self.parameters.get("KeyId") + key_id = self._get_param("KeyId") self._validate_cmk_id(key_id) @@ -386,11 +386,11 @@ class KmsResponse(BaseResponse): return json.dumps({"KeyRotationEnabled": rotation_enabled}) - def put_key_policy(self): + def put_key_policy(self) -> str: """https://docs.aws.amazon.com/kms/latest/APIReference/API_PutKeyPolicy.html""" - key_id = self.parameters.get("KeyId") - policy_name = self.parameters.get("PolicyName") - policy = self.parameters.get("Policy") + key_id = self._get_param("KeyId") + policy_name = self._get_param("PolicyName") + policy = self._get_param("Policy") _assert_default_policy(policy_name) self._validate_cmk_id(key_id) @@ -399,10 +399,10 @@ class KmsResponse(BaseResponse): return json.dumps(None) - def get_key_policy(self): + def get_key_policy(self) -> str: """https://docs.aws.amazon.com/kms/latest/APIReference/API_GetKeyPolicy.html""" - key_id = self.parameters.get("KeyId") - policy_name = self.parameters.get("PolicyName") + key_id = self._get_param("KeyId") + policy_name = self._get_param("PolicyName") _assert_default_policy(policy_name) self._validate_cmk_id(key_id) @@ -410,9 +410,9 @@ class KmsResponse(BaseResponse): policy = self.kms_backend.get_key_policy(key_id) or "{}" return json.dumps({"Policy": policy}) - def list_key_policies(self): + def list_key_policies(self) -> str: """https://docs.aws.amazon.com/kms/latest/APIReference/API_ListKeyPolicies.html""" - key_id = self.parameters.get("KeyId") + key_id = self._get_param("KeyId") self._validate_cmk_id(key_id) @@ -420,11 +420,11 @@ class KmsResponse(BaseResponse): return json.dumps({"Truncated": False, "PolicyNames": ["default"]}) - def encrypt(self): + def encrypt(self) -> str: """https://docs.aws.amazon.com/kms/latest/APIReference/API_Encrypt.html""" - key_id = self.parameters.get("KeyId") - encryption_context = self.parameters.get("EncryptionContext", {}) - plaintext = self.parameters.get("Plaintext") + key_id = self._get_param("KeyId") + encryption_context = self._get_param("EncryptionContext", {}) + plaintext = self._get_param("Plaintext") self._validate_key_id(key_id) @@ -438,10 +438,10 @@ class KmsResponse(BaseResponse): return json.dumps({"CiphertextBlob": ciphertext_blob_response, "KeyId": arn}) - def decrypt(self): + def decrypt(self) -> str: """https://docs.aws.amazon.com/kms/latest/APIReference/API_Decrypt.html""" - ciphertext_blob = self.parameters.get("CiphertextBlob") - encryption_context = self.parameters.get("EncryptionContext", {}) + ciphertext_blob = self._get_param("CiphertextBlob") + encryption_context = self._get_param("EncryptionContext", {}) plaintext, arn = self.kms_backend.decrypt( ciphertext_blob=ciphertext_blob, encryption_context=encryption_context @@ -451,12 +451,12 @@ class KmsResponse(BaseResponse): return json.dumps({"Plaintext": plaintext_response, "KeyId": arn}) - def re_encrypt(self): + def re_encrypt(self) -> str: """https://docs.aws.amazon.com/kms/latest/APIReference/API_ReEncrypt.html""" - ciphertext_blob = self.parameters.get("CiphertextBlob") - source_encryption_context = self.parameters.get("SourceEncryptionContext", {}) - destination_key_id = self.parameters.get("DestinationKeyId") - destination_encryption_context = self.parameters.get( + ciphertext_blob = self._get_param("CiphertextBlob") + source_encryption_context = self._get_param("SourceEncryptionContext", {}) + destination_key_id = self._get_param("DestinationKeyId") + destination_encryption_context = self._get_param( "DestinationEncryptionContext", {} ) @@ -483,9 +483,9 @@ class KmsResponse(BaseResponse): } ) - def disable_key(self): + def disable_key(self) -> str: """https://docs.aws.amazon.com/kms/latest/APIReference/API_DisableKey.html""" - key_id = self.parameters.get("KeyId") + key_id = self._get_param("KeyId") self._validate_cmk_id(key_id) @@ -493,9 +493,9 @@ class KmsResponse(BaseResponse): return json.dumps(None) - def enable_key(self): + def enable_key(self) -> str: """https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKey.html""" - key_id = self.parameters.get("KeyId") + key_id = self._get_param("KeyId") self._validate_cmk_id(key_id) @@ -503,9 +503,9 @@ class KmsResponse(BaseResponse): return json.dumps(None) - def cancel_key_deletion(self): + def cancel_key_deletion(self) -> str: """https://docs.aws.amazon.com/kms/latest/APIReference/API_CancelKeyDeletion.html""" - key_id = self.parameters.get("KeyId") + key_id = self._get_param("KeyId") self._validate_cmk_id(key_id) @@ -513,13 +513,13 @@ class KmsResponse(BaseResponse): return json.dumps({"KeyId": key_id}) - def schedule_key_deletion(self): + def schedule_key_deletion(self) -> str: """https://docs.aws.amazon.com/kms/latest/APIReference/API_ScheduleKeyDeletion.html""" - key_id = self.parameters.get("KeyId") - if self.parameters.get("PendingWindowInDays") is None: + key_id = self._get_param("KeyId") + if self._get_param("PendingWindowInDays") is None: pending_window_in_days = 30 else: - pending_window_in_days = self.parameters.get("PendingWindowInDays") + pending_window_in_days = self._get_param("PendingWindowInDays") self._validate_cmk_id(key_id) @@ -532,12 +532,12 @@ class KmsResponse(BaseResponse): } ) - def generate_data_key(self): + def generate_data_key(self) -> str: """https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateDataKey.html""" - key_id = self.parameters.get("KeyId") - encryption_context = self.parameters.get("EncryptionContext", {}) - number_of_bytes = self.parameters.get("NumberOfBytes") - key_spec = self.parameters.get("KeySpec") + key_id = self._get_param("KeyId") + encryption_context = self._get_param("EncryptionContext", {}) + number_of_bytes = self._get_param("NumberOfBytes") + key_spec = self._get_param("KeySpec") # Param validation self._validate_key_id(key_id) @@ -587,16 +587,16 @@ class KmsResponse(BaseResponse): } ) - def generate_data_key_without_plaintext(self): + def generate_data_key_without_plaintext(self) -> str: """https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateDataKeyWithoutPlaintext.html""" result = json.loads(self.generate_data_key()) del result["Plaintext"] return json.dumps(result) - def generate_random(self): + def generate_random(self) -> str: """https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateRandom.html""" - number_of_bytes = self.parameters.get("NumberOfBytes") + number_of_bytes = self._get_param("NumberOfBytes") if number_of_bytes and (number_of_bytes > 1024 or number_of_bytes < 1): raise ValidationException( @@ -613,13 +613,13 @@ class KmsResponse(BaseResponse): return json.dumps({"Plaintext": response_entropy}) - def sign(self): + def sign(self) -> str: """https://docs.aws.amazon.com/kms/latest/APIReference/API_Sign.html""" - key_id = self.parameters.get("KeyId") - message = self.parameters.get("Message") - message_type = self.parameters.get("MessageType") - grant_tokens = self.parameters.get("GrantTokens") - signing_algorithm = self.parameters.get("SigningAlgorithm") + key_id = self._get_param("KeyId") + message = self._get_param("Message") + message_type = self._get_param("MessageType") + grant_tokens = self._get_param("GrantTokens") + signing_algorithm = self._get_param("SigningAlgorithm") self._validate_key_id(key_id) @@ -660,14 +660,14 @@ class KmsResponse(BaseResponse): } ) - def verify(self): + def verify(self) -> str: """https://docs.aws.amazon.com/kms/latest/APIReference/API_Verify.html""" - key_id = self.parameters.get("KeyId") - message = self.parameters.get("Message") - message_type = self.parameters.get("MessageType") - signature = self.parameters.get("Signature") - signing_algorithm = self.parameters.get("SigningAlgorithm") - grant_tokens = self.parameters.get("GrantTokens") + key_id = self._get_param("KeyId") + message = self._get_param("Message") + message_type = self._get_param("MessageType") + signature = self._get_param("Signature") + signing_algorithm = self._get_param("SigningAlgorithm") + grant_tokens = self._get_param("GrantTokens") self._validate_key_id(key_id) @@ -722,6 +722,6 @@ class KmsResponse(BaseResponse): ) -def _assert_default_policy(policy_name): +def _assert_default_policy(policy_name: str) -> None: if policy_name != "default": raise NotFoundException("No such policy exists") diff --git a/moto/kms/utils.py b/moto/kms/utils.py index 15da5d9e9..2639f5ca4 100644 --- a/moto/kms/utils.py +++ b/moto/kms/utils.py @@ -1,4 +1,5 @@ from collections import namedtuple +from typing import Any, Dict, Tuple import io import os import struct @@ -47,7 +48,7 @@ RESERVED_ALIASE_TARGET_KEY_IDS = { RESERVED_ALIASES = list(RESERVED_ALIASE_TARGET_KEY_IDS.keys()) -def generate_key_id(multi_region=False): +def generate_key_id(multi_region: bool = False) -> str: key = str(mock_random.uuid4()) # https://docs.aws.amazon.com/kms/latest/developerguide/multi-region-keys-overview.html # "Notice that multi-Region keys have a distinctive key ID that begins with mrk-. You can use the mrk- prefix to @@ -58,17 +59,17 @@ def generate_key_id(multi_region=False): return key -def generate_data_key(number_of_bytes): +def generate_data_key(number_of_bytes: int) -> bytes: """Generate a data key.""" return os.urandom(number_of_bytes) -def generate_master_key(): +def generate_master_key() -> bytes: """Generate a master key.""" return generate_data_key(MASTER_KEY_LEN) -def generate_private_key(): +def generate_private_key() -> rsa.RSAPrivateKey: """Generate a private key to be used on asymmetric sign/verify. NOTE: KeySpec is not taken into consideration and the key is always RSA_2048 @@ -80,7 +81,7 @@ def generate_private_key(): ) -def _serialize_ciphertext_blob(ciphertext): +def _serialize_ciphertext_blob(ciphertext: Ciphertext) -> bytes: """Serialize Ciphertext object into a ciphertext blob. NOTE: This is just a simple binary format. It is not what KMS actually does. @@ -94,7 +95,7 @@ def _serialize_ciphertext_blob(ciphertext): return header + ciphertext.ciphertext -def _deserialize_ciphertext_blob(ciphertext_blob): +def _deserialize_ciphertext_blob(ciphertext_blob: bytes) -> Ciphertext: """Deserialize ciphertext blob into a Ciphertext object. NOTE: This is just a simple binary format. It is not what KMS actually does. @@ -107,7 +108,7 @@ def _deserialize_ciphertext_blob(ciphertext_blob): ) -def _serialize_encryption_context(encryption_context): +def _serialize_encryption_context(encryption_context: Dict[str, str]) -> bytes: """Serialize encryption context for use a AAD. NOTE: This is not necessarily what KMS does, but it retains the same properties. @@ -119,7 +120,12 @@ def _serialize_encryption_context(encryption_context): return aad.getvalue() -def encrypt(master_keys, key_id, plaintext, encryption_context): +def encrypt( + master_keys: Dict[str, Any], + key_id: str, + plaintext: bytes, + encryption_context: Dict[str, str], +) -> bytes: """Encrypt data using a master key material. NOTE: This is not necessarily what KMS does, but it retains the same properties. @@ -159,7 +165,11 @@ def encrypt(master_keys, key_id, plaintext, encryption_context): ) -def decrypt(master_keys, ciphertext_blob, encryption_context): +def decrypt( + master_keys: Dict[str, Any], + ciphertext_blob: bytes, + encryption_context: Dict[str, str], +) -> Tuple[bytes, str]: """Decrypt a ciphertext blob using a master key material. NOTE: This is not necessarily what KMS does, but it retains the same properties. diff --git a/moto/sts/utils.py b/moto/sts/utils.py index eda8ed851..afd095796 100644 --- a/moto/sts/utils.py +++ b/moto/sts/utils.py @@ -9,7 +9,7 @@ SESSION_TOKEN_PREFIX = "FQoGZXIvYXdzEBYaD" DEFAULT_STS_SESSION_DURATION = 3600 -def random_session_token(): +def random_session_token() -> str: return ( SESSION_TOKEN_PREFIX + base64.b64encode(os.urandom(266))[len(SESSION_TOKEN_PREFIX) :].decode() diff --git a/setup.cfg b/setup.cfg index 0b75c56c0..3adab73d7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -235,7 +235,7 @@ disable = W,C,R,E enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import [mypy] -files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/moto_api,moto/neptune +files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/moto_api,moto/neptune show_column_numbers=True show_error_codes = True disable_error_code=abstract