Techdebt: MyPy K (#6111)

This commit is contained in:
Bert Blommers 2023-03-23 09:17:40 -01:00 committed by GitHub
parent 64abff588f
commit c3460b8a1a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 680 additions and 565 deletions

View File

@ -1,46 +1,47 @@
from moto.core.exceptions import JsonRESTError from moto.core.exceptions import JsonRESTError
from typing import Optional
class ResourceNotFoundError(JsonRESTError): class ResourceNotFoundError(JsonRESTError):
def __init__(self, message): def __init__(self, message: str):
super().__init__(error_type="ResourceNotFoundException", message=message) super().__init__(error_type="ResourceNotFoundException", message=message)
class ResourceInUseError(JsonRESTError): class ResourceInUseError(JsonRESTError):
def __init__(self, message): def __init__(self, message: str):
super().__init__(error_type="ResourceInUseException", message=message) super().__init__(error_type="ResourceInUseException", message=message)
class StreamNotFoundError(ResourceNotFoundError): class StreamNotFoundError(ResourceNotFoundError):
def __init__(self, stream_name, account_id): def __init__(self, stream_name: str, account_id: str):
super().__init__(f"Stream {stream_name} under account {account_id} not found.") super().__init__(f"Stream {stream_name} under account {account_id} not found.")
class StreamCannotBeUpdatedError(JsonRESTError): class StreamCannotBeUpdatedError(JsonRESTError):
def __init__(self, stream_name, account_id): def __init__(self, stream_name: str, account_id: str):
message = f"Request is invalid. Stream {stream_name} under account {account_id} is in On-Demand mode." message = f"Request is invalid. Stream {stream_name} under account {account_id} is in On-Demand mode."
super().__init__(error_type="ValidationException", message=message) super().__init__(error_type="ValidationException", message=message)
class ShardNotFoundError(ResourceNotFoundError): class ShardNotFoundError(ResourceNotFoundError):
def __init__(self, shard_id, stream, account_id): def __init__(self, shard_id: str, stream: str, account_id: str):
super().__init__( super().__init__(
f"Could not find shard {shard_id} in stream {stream} under account {account_id}." f"Could not find shard {shard_id} in stream {stream} under account {account_id}."
) )
class ConsumerNotFound(ResourceNotFoundError): class ConsumerNotFound(ResourceNotFoundError):
def __init__(self, consumer, account_id): def __init__(self, consumer: str, account_id: str):
super().__init__(f"Consumer {consumer}, account {account_id} not found.") super().__init__(f"Consumer {consumer}, account {account_id} not found.")
class InvalidArgumentError(JsonRESTError): class InvalidArgumentError(JsonRESTError):
def __init__(self, message): def __init__(self, message: str):
super().__init__(error_type="InvalidArgumentException", message=message) super().__init__(error_type="InvalidArgumentException", message=message)
class InvalidRetentionPeriod(InvalidArgumentError): class InvalidRetentionPeriod(InvalidArgumentError):
def __init__(self, hours, too_short): def __init__(self, hours: int, too_short: bool):
if too_short: if too_short:
msg = f"Minimum allowed retention period is 24 hours. Requested retention period ({hours} hours) is too short." msg = f"Minimum allowed retention period is 24 hours. Requested retention period ({hours} hours) is too short."
else: else:
@ -49,31 +50,31 @@ class InvalidRetentionPeriod(InvalidArgumentError):
class InvalidDecreaseRetention(InvalidArgumentError): class InvalidDecreaseRetention(InvalidArgumentError):
def __init__(self, name, requested, existing): def __init__(self, name: Optional[str], requested: int, existing: int):
msg = f"Requested retention period ({requested} hours) for stream {name} can not be longer than existing retention period ({existing} hours). Use IncreaseRetentionPeriod API." msg = f"Requested retention period ({requested} hours) for stream {name} can not be longer than existing retention period ({existing} hours). Use IncreaseRetentionPeriod API."
super().__init__(msg) super().__init__(msg)
class InvalidIncreaseRetention(InvalidArgumentError): class InvalidIncreaseRetention(InvalidArgumentError):
def __init__(self, name, requested, existing): def __init__(self, name: Optional[str], requested: int, existing: int):
msg = f"Requested retention period ({requested} hours) for stream {name} can not be shorter than existing retention period ({existing} hours). Use DecreaseRetentionPeriod API." msg = f"Requested retention period ({requested} hours) for stream {name} can not be shorter than existing retention period ({existing} hours). Use DecreaseRetentionPeriod API."
super().__init__(msg) super().__init__(msg)
class ValidationException(JsonRESTError): class ValidationException(JsonRESTError):
def __init__(self, value, position, regex_to_match): def __init__(self, value: str, position: str, regex_to_match: str):
msg = f"1 validation error detected: Value '{value}' at '{position}' failed to satisfy constraint: Member must satisfy regular expression pattern: {regex_to_match}" msg = f"1 validation error detected: Value '{value}' at '{position}' failed to satisfy constraint: Member must satisfy regular expression pattern: {regex_to_match}"
super().__init__(error_type="ValidationException", message=msg) super().__init__(error_type="ValidationException", message=msg)
class RecordSizeExceedsLimit(JsonRESTError): class RecordSizeExceedsLimit(JsonRESTError):
def __init__(self, position): def __init__(self, position: int):
msg = f"1 validation error detected: Value at 'records.{position}.member.data' failed to satisfy constraint: Member must have length less than or equal to 1048576" msg = f"1 validation error detected: Value at 'records.{position}.member.data' failed to satisfy constraint: Member must have length less than or equal to 1048576"
super().__init__(error_type="ValidationException", message=msg) super().__init__(error_type="ValidationException", message=msg)
class TotalRecordsSizeExceedsLimit(JsonRESTError): class TotalRecordsSizeExceedsLimit(JsonRESTError):
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
error_type="InvalidArgumentException", error_type="InvalidArgumentException",
message="Records size exceeds 5 MB limit", message="Records size exceeds 5 MB limit",
@ -81,6 +82,6 @@ class TotalRecordsSizeExceedsLimit(JsonRESTError):
class TooManyRecords(JsonRESTError): class TooManyRecords(JsonRESTError):
def __init__(self): def __init__(self) -> None:
msg = "1 validation error detected: Value at 'records' failed to satisfy constraint: Member must have length less than or equal to 500" msg = "1 validation error detected: Value at 'records' failed to satisfy constraint: Member must have length less than or equal to 500"
super().__init__(error_type="ValidationException", message=msg) super().__init__(error_type="ValidationException", message=msg)

View File

@ -4,7 +4,7 @@ import re
import itertools import itertools
from operator import attrgetter from operator import attrgetter
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple, Iterable
from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel
from moto.core.utils import unix_time from moto.core.utils import unix_time
@ -35,14 +35,16 @@ from .utils import (
class Consumer(BaseModel): class Consumer(BaseModel):
def __init__(self, consumer_name, account_id, region_name, stream_arn): def __init__(
self, consumer_name: str, account_id: str, region_name: str, stream_arn: str
):
self.consumer_name = consumer_name self.consumer_name = consumer_name
self.created = unix_time() self.created = unix_time()
self.stream_arn = stream_arn self.stream_arn = stream_arn
stream_name = stream_arn.split("/")[-1] stream_name = stream_arn.split("/")[-1]
self.consumer_arn = f"arn:aws:kinesis:{region_name}:{account_id}:stream/{stream_name}/consumer/{consumer_name}" self.consumer_arn = f"arn:aws:kinesis:{region_name}:{account_id}:stream/{stream_name}/consumer/{consumer_name}"
def to_json(self, include_stream_arn=False): def to_json(self, include_stream_arn: bool = False) -> Dict[str, Any]:
resp = { resp = {
"ConsumerName": self.consumer_name, "ConsumerName": self.consumer_name,
"ConsumerARN": self.consumer_arn, "ConsumerARN": self.consumer_arn,
@ -55,7 +57,13 @@ class Consumer(BaseModel):
class Record(BaseModel): class Record(BaseModel):
def __init__(self, partition_key, data, sequence_number, explicit_hash_key): def __init__(
self,
partition_key: str,
data: str,
sequence_number: int,
explicit_hash_key: str,
):
self.partition_key = partition_key self.partition_key = partition_key
self.data = data self.data = data
self.sequence_number = sequence_number self.sequence_number = sequence_number
@ -63,7 +71,7 @@ class Record(BaseModel):
self.created_at_datetime = datetime.datetime.utcnow() self.created_at_datetime = datetime.datetime.utcnow()
self.created_at = unix_time(self.created_at_datetime) self.created_at = unix_time(self.created_at_datetime)
def to_json(self): def to_json(self) -> Dict[str, Any]:
return { return {
"Data": self.data, "Data": self.data,
"PartitionKey": self.partition_key, "PartitionKey": self.partition_key,
@ -74,29 +82,36 @@ class Record(BaseModel):
class Shard(BaseModel): class Shard(BaseModel):
def __init__( def __init__(
self, shard_id, starting_hash, ending_hash, parent=None, adjacent_parent=None self,
shard_id: int,
starting_hash: int,
ending_hash: int,
parent: Optional[str] = None,
adjacent_parent: Optional[str] = None,
): ):
self._shard_id = shard_id self._shard_id = shard_id
self.starting_hash = starting_hash self.starting_hash = starting_hash
self.ending_hash = ending_hash self.ending_hash = ending_hash
self.records = OrderedDict() self.records: Dict[int, Record] = OrderedDict()
self.is_open = True self.is_open = True
self.parent = parent self.parent = parent
self.adjacent_parent = adjacent_parent self.adjacent_parent = adjacent_parent
@property @property
def shard_id(self): def shard_id(self) -> str:
return f"shardId-{str(self._shard_id).zfill(12)}" return f"shardId-{str(self._shard_id).zfill(12)}"
def get_records(self, last_sequence_id, limit): def get_records(
last_sequence_id = int(last_sequence_id) self, last_sequence_id: str, limit: Optional[int]
) -> Tuple[List[Record], int, int]:
last_sequence_int = int(last_sequence_id)
results = [] results = []
secs_behind_latest = 0 secs_behind_latest = 0.0
for sequence_number, record in self.records.items(): for sequence_number, record in self.records.items():
if sequence_number > last_sequence_id: if sequence_number > last_sequence_int:
results.append(record) results.append(record)
last_sequence_id = sequence_number last_sequence_int = sequence_number
very_last_record = self.records[next(reversed(self.records))] very_last_record = self.records[next(reversed(self.records))]
secs_behind_latest = very_last_record.created_at - record.created_at secs_behind_latest = very_last_record.created_at - record.created_at
@ -105,9 +120,9 @@ class Shard(BaseModel):
break break
millis_behind_latest = int(secs_behind_latest * 1000) millis_behind_latest = int(secs_behind_latest * 1000)
return results, last_sequence_id, millis_behind_latest return results, last_sequence_int, millis_behind_latest
def put_record(self, partition_key, data, explicit_hash_key): def put_record(self, partition_key: str, data: str, explicit_hash_key: str) -> str:
# Note: this function is not safe for concurrency # Note: this function is not safe for concurrency
if self.records: if self.records:
last_sequence_number = self.get_max_sequence_number() last_sequence_number = self.get_max_sequence_number()
@ -119,17 +134,17 @@ class Shard(BaseModel):
) )
return str(sequence_number) return str(sequence_number)
def get_min_sequence_number(self): def get_min_sequence_number(self) -> int:
if self.records: if self.records:
return list(self.records.keys())[0] return list(self.records.keys())[0]
return 0 return 0
def get_max_sequence_number(self): def get_max_sequence_number(self) -> int:
if self.records: if self.records:
return list(self.records.keys())[-1] return list(self.records.keys())[-1]
return 0 return 0
def get_sequence_number_at(self, at_timestamp): def get_sequence_number_at(self, at_timestamp: float) -> int:
if not self.records or at_timestamp < list(self.records.values())[0].created_at: if not self.records or at_timestamp < list(self.records.values())[0].created_at:
return 0 return 0
else: else:
@ -143,10 +158,10 @@ class Shard(BaseModel):
), ),
None, None,
) )
return r.sequence_number return r.sequence_number # type: ignore
def to_json(self): def to_json(self) -> Dict[str, Any]:
response = { response: Dict[str, Any] = {
"HashKeyRange": { "HashKeyRange": {
"EndingHashKey": str(self.ending_hash), "EndingHashKey": str(self.ending_hash),
"StartingHashKey": str(self.starting_hash), "StartingHashKey": str(self.starting_hash),
@ -170,12 +185,12 @@ class Shard(BaseModel):
class Stream(CloudFormationModel): class Stream(CloudFormationModel):
def __init__( def __init__(
self, self,
stream_name, stream_name: str,
shard_count, shard_count: int,
stream_mode, stream_mode: Optional[Dict[str, str]],
retention_period_hours, retention_period_hours: Optional[int],
account_id, account_id: str,
region_name, region_name: str,
): ):
self.stream_name = stream_name self.stream_name = stream_name
self.creation_datetime = datetime.datetime.now().strftime( self.creation_datetime = datetime.datetime.now().strftime(
@ -184,27 +199,27 @@ class Stream(CloudFormationModel):
self.region = region_name self.region = region_name
self.account_id = account_id self.account_id = account_id
self.arn = f"arn:aws:kinesis:{region_name}:{account_id}:stream/{stream_name}" self.arn = f"arn:aws:kinesis:{region_name}:{account_id}:stream/{stream_name}"
self.shards = {} self.shards: Dict[str, Shard] = {}
self.tags = {} self.tags: Dict[str, str] = {}
self.status = "ACTIVE" self.status = "ACTIVE"
self.shard_count = None self.shard_count: Optional[int] = None
self.stream_mode = stream_mode or {"StreamMode": "PROVISIONED"} self.stream_mode = stream_mode or {"StreamMode": "PROVISIONED"}
if self.stream_mode.get("StreamMode", "") == "ON_DEMAND": if self.stream_mode.get("StreamMode", "") == "ON_DEMAND":
shard_count = 4 shard_count = 4
self.init_shards(shard_count) self.init_shards(shard_count)
self.retention_period_hours = retention_period_hours or 24 self.retention_period_hours = retention_period_hours or 24
self.shard_level_metrics = [] self.shard_level_metrics: List[str] = []
self.encryption_type = "NONE" self.encryption_type = "NONE"
self.key_id = None self.key_id: Optional[str] = None
self.consumers = [] self.consumers: List[Consumer] = []
def delete_consumer(self, consumer_arn): def delete_consumer(self, consumer_arn: str) -> None:
self.consumers = [c for c in self.consumers if c.consumer_arn != consumer_arn] self.consumers = [c for c in self.consumers if c.consumer_arn != consumer_arn]
def get_consumer_by_arn(self, consumer_arn): def get_consumer_by_arn(self, consumer_arn: str) -> Optional[Consumer]:
return next((c for c in self.consumers if c.consumer_arn == consumer_arn), None) return next((c for c in self.consumers if c.consumer_arn == consumer_arn), None)
def init_shards(self, shard_count): def init_shards(self, shard_count: int) -> None:
self.shard_count = shard_count self.shard_count = shard_count
step = 2**128 // shard_count step = 2**128 // shard_count
@ -216,16 +231,16 @@ class Stream(CloudFormationModel):
shard = Shard(index, start, end) shard = Shard(index, start, end)
self.shards[shard.shard_id] = shard self.shards[shard.shard_id] = shard
def split_shard(self, shard_to_split, new_starting_hash_key): def split_shard(self, shard_to_split: str, new_starting_hash_key: str) -> None:
new_starting_hash_key = int(new_starting_hash_key) new_starting_hash_int = int(new_starting_hash_key)
shard = self.shards[shard_to_split] shard = self.shards[shard_to_split]
if shard.starting_hash < new_starting_hash_key < shard.ending_hash: if shard.starting_hash < new_starting_hash_int < shard.ending_hash:
pass pass
else: else:
raise InvalidArgumentError( raise InvalidArgumentError(
message=f"NewStartingHashKey {new_starting_hash_key} used in SplitShard() on shard {shard_to_split} in stream {self.stream_name} under account {self.account_id} is not both greater than one plus the shard's StartingHashKey {shard.starting_hash} and less than the shard's EndingHashKey {(shard.ending_hash - 1)}." message=f"NewStartingHashKey {new_starting_hash_int} used in SplitShard() on shard {shard_to_split} in stream {self.stream_name} under account {self.account_id} is not both greater than one plus the shard's StartingHashKey {shard.starting_hash} and less than the shard's EndingHashKey {(shard.ending_hash - 1)}."
) )
if not shard.is_open: if not shard.is_open:
@ -241,12 +256,12 @@ class Stream(CloudFormationModel):
new_shard_1 = Shard( new_shard_1 = Shard(
last_id + 1, last_id + 1,
starting_hash=shard.starting_hash, starting_hash=shard.starting_hash,
ending_hash=new_starting_hash_key - 1, ending_hash=new_starting_hash_int - 1,
parent=shard.shard_id, parent=shard.shard_id,
) )
new_shard_2 = Shard( new_shard_2 = Shard(
last_id + 2, last_id + 2,
starting_hash=new_starting_hash_key, starting_hash=new_starting_hash_int,
ending_hash=shard.ending_hash, ending_hash=shard.ending_hash,
parent=shard.shard_id, parent=shard.shard_id,
) )
@ -261,7 +276,7 @@ class Stream(CloudFormationModel):
record = records[index] record = records[index]
self.put_record(record.partition_key, record.explicit_hash_key, record.data) self.put_record(record.partition_key, record.explicit_hash_key, record.data)
def merge_shards(self, shard_to_merge, adjacent_shard_to_merge): def merge_shards(self, shard_to_merge: str, adjacent_shard_to_merge: str) -> None:
shard1 = self.shards[shard_to_merge] shard1 = self.shards[shard_to_merge]
shard2 = self.shards[adjacent_shard_to_merge] shard2 = self.shards[adjacent_shard_to_merge]
@ -300,8 +315,8 @@ class Stream(CloudFormationModel):
record.partition_key, record.data, record.explicit_hash_key record.partition_key, record.data, record.explicit_hash_key
) )
def update_shard_count(self, target_shard_count): def update_shard_count(self, target_shard_count: int) -> None:
if self.stream_mode.get("StreamMode", "") == "ON_DEMAND": if self.stream_mode.get("StreamMode", "") == "ON_DEMAND": # type: ignore
raise StreamCannotBeUpdatedError( raise StreamCannotBeUpdatedError(
stream_name=self.stream_name, account_id=self.account_id stream_name=self.stream_name, account_id=self.account_id
) )
@ -351,13 +366,15 @@ class Stream(CloudFormationModel):
self.shard_count = target_shard_count self.shard_count = target_shard_count
def get_shard(self, shard_id): def get_shard(self, shard_id: str) -> Shard:
if shard_id in self.shards: if shard_id in self.shards:
return self.shards[shard_id] return self.shards[shard_id]
else: else:
raise ShardNotFoundError(shard_id, stream="", account_id=self.account_id) raise ShardNotFoundError(shard_id, stream="", account_id=self.account_id)
def get_shard_for_key(self, partition_key, explicit_hash_key): def get_shard_for_key(
self, partition_key: str, explicit_hash_key: str
) -> Optional[Shard]:
if not isinstance(partition_key, str): if not isinstance(partition_key, str):
raise InvalidArgumentError("partition_key") raise InvalidArgumentError("partition_key")
if len(partition_key) > 256: if len(partition_key) > 256:
@ -367,25 +384,28 @@ class Stream(CloudFormationModel):
if not isinstance(explicit_hash_key, str): if not isinstance(explicit_hash_key, str):
raise InvalidArgumentError("explicit_hash_key") raise InvalidArgumentError("explicit_hash_key")
key = int(explicit_hash_key) int_key = int(explicit_hash_key)
if key >= 2**128: if int_key >= 2**128:
raise InvalidArgumentError("explicit_hash_key") raise InvalidArgumentError("explicit_hash_key")
else: else:
key = int(md5_hash(partition_key.encode("utf-8")).hexdigest(), 16) int_key = int(md5_hash(partition_key.encode("utf-8")).hexdigest(), 16)
for shard in self.shards.values(): for shard in self.shards.values():
if shard.starting_hash <= key < shard.ending_hash: if shard.starting_hash <= int_key < shard.ending_hash:
return shard return shard
return None
def put_record(self, partition_key, explicit_hash_key, data): def put_record(
self, partition_key: str, explicit_hash_key: str, data: str
) -> Tuple[str, str]:
shard = self.get_shard_for_key(partition_key, explicit_hash_key) shard = self.get_shard_for_key(partition_key, explicit_hash_key)
sequence_number = shard.put_record(partition_key, data, explicit_hash_key) sequence_number = shard.put_record(partition_key, data, explicit_hash_key) # type: ignore
return sequence_number, shard.shard_id return sequence_number, shard.shard_id # type: ignore
def to_json(self, shard_limit=None): def to_json(self, shard_limit: Optional[int] = None) -> Dict[str, Any]:
all_shards = list(self.shards.values()) all_shards = list(self.shards.values())
requested_shards = all_shards[0 : shard_limit or len(all_shards)] requested_shards = all_shards[0 : shard_limit or len(all_shards)]
return { return {
@ -403,7 +423,7 @@ class Stream(CloudFormationModel):
} }
} }
def to_json_summary(self): def to_json_summary(self) -> Dict[str, Any]:
return { return {
"StreamDescriptionSummary": { "StreamDescriptionSummary": {
"StreamARN": self.arn, "StreamARN": self.arn,
@ -420,18 +440,23 @@ class Stream(CloudFormationModel):
} }
@staticmethod @staticmethod
def cloudformation_name_type(): def cloudformation_name_type() -> str:
return "Name" return "Name"
@staticmethod @staticmethod
def cloudformation_type(): def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-kinesis-stream.html # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-kinesis-stream.html
return "AWS::Kinesis::Stream" return "AWS::Kinesis::Stream"
@classmethod @classmethod
def create_from_cloudformation_json( def create_from_cloudformation_json( # type: ignore[misc]
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs cls,
): resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
**kwargs: Any,
) -> "Stream":
properties = cloudformation_json.get("Properties", {}) properties = cloudformation_json.get("Properties", {})
shard_count = properties.get("ShardCount", 1) shard_count = properties.get("ShardCount", 1)
retention_period_hours = properties.get("RetentionPeriodHours", resource_name) retention_period_hours = properties.get("RetentionPeriodHours", resource_name)
@ -451,14 +476,14 @@ class Stream(CloudFormationModel):
return stream return stream
@classmethod @classmethod
def update_from_cloudformation_json( def update_from_cloudformation_json( # type: ignore[misc]
cls, cls,
original_resource, original_resource: Any,
new_resource_name, new_resource_name: str,
cloudformation_json, cloudformation_json: Any,
account_id, account_id: str,
region_name, region_name: str,
): ) -> "Stream":
properties = cloudformation_json["Properties"] properties = cloudformation_json["Properties"]
if Stream.is_replacement_update(properties): if Stream.is_replacement_update(properties):
@ -466,11 +491,17 @@ class Stream(CloudFormationModel):
if resource_name_property not in properties: if resource_name_property not in properties:
properties[resource_name_property] = new_resource_name properties[resource_name_property] = new_resource_name
new_resource = cls.create_from_cloudformation_json( new_resource = cls.create_from_cloudformation_json(
properties[resource_name_property], cloudformation_json, region_name resource_name=properties[resource_name_property],
cloudformation_json=cloudformation_json,
account_id=account_id,
region_name=region_name,
) )
properties[resource_name_property] = original_resource.name properties[resource_name_property] = original_resource.name
cls.delete_from_cloudformation_json( cls.delete_from_cloudformation_json(
original_resource.name, cloudformation_json, region_name resource_name=original_resource.name,
cloudformation_json=cloudformation_json,
account_id=account_id,
region_name=region_name,
) )
return new_resource return new_resource
@ -489,14 +520,18 @@ class Stream(CloudFormationModel):
return original_resource return original_resource
@classmethod @classmethod
def delete_from_cloudformation_json( def delete_from_cloudformation_json( # type: ignore[misc]
cls, resource_name, cloudformation_json, account_id, region_name cls,
): resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
) -> None:
backend: KinesisBackend = kinesis_backends[account_id][region_name] backend: KinesisBackend = kinesis_backends[account_id][region_name]
backend.delete_stream(stream_arn=None, stream_name=resource_name) backend.delete_stream(stream_arn=None, stream_name=resource_name)
@staticmethod @staticmethod
def is_replacement_update(properties): def is_replacement_update(properties: List[str]) -> bool:
properties_requiring_replacement_update = ["BucketName", "ObjectLockEnabled"] properties_requiring_replacement_update = ["BucketName", "ObjectLockEnabled"]
return any( return any(
[ [
@ -506,10 +541,10 @@ class Stream(CloudFormationModel):
) )
@classmethod @classmethod
def has_cfn_attr(cls, attr): def has_cfn_attr(cls, attr: str) -> bool:
return attr in ["Arn"] return attr in ["Arn"]
def get_cfn_attribute(self, attribute_name): def get_cfn_attribute(self, attribute_name: str) -> str:
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == "Arn": if attribute_name == "Arn":
@ -517,25 +552,31 @@ class Stream(CloudFormationModel):
raise UnformattedGetAttTemplateException() raise UnformattedGetAttTemplateException()
@property @property
def physical_resource_id(self): def physical_resource_id(self) -> str:
return self.stream_name return self.stream_name
class KinesisBackend(BaseBackend): class KinesisBackend(BaseBackend):
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.streams: Dict[str, Stream] = OrderedDict() self.streams: Dict[str, Stream] = OrderedDict()
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(
service_region: str, zones: List[str]
) -> List[Dict[str, str]]:
"""Default VPC endpoint service.""" """Default VPC endpoint service."""
return BaseBackend.default_vpc_endpoint_service_factory( return BaseBackend.default_vpc_endpoint_service_factory(
service_region, zones, "kinesis", special_service_name="kinesis-streams" service_region, zones, "kinesis", special_service_name="kinesis-streams"
) )
def create_stream( def create_stream(
self, stream_name, shard_count, stream_mode=None, retention_period_hours=None self,
): stream_name: str,
shard_count: int,
stream_mode: Optional[Dict[str, str]] = None,
retention_period_hours: Optional[int] = None,
) -> Stream:
if stream_name in self.streams: if stream_name in self.streams:
raise ResourceInUseError(stream_name) raise ResourceInUseError(stream_name)
stream = Stream( stream = Stream(
@ -560,14 +601,14 @@ class KinesisBackend(BaseBackend):
return stream return stream
if stream_arn: if stream_arn:
stream_name = stream_arn.split("/")[1] stream_name = stream_arn.split("/")[1]
raise StreamNotFoundError(stream_name, self.account_id) raise StreamNotFoundError(stream_name, self.account_id) # type: ignore
def describe_stream_summary( def describe_stream_summary(
self, stream_arn: Optional[str], stream_name: Optional[str] self, stream_arn: Optional[str], stream_name: Optional[str]
) -> Stream: ) -> Stream:
return self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) return self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
def list_streams(self): def list_streams(self) -> Iterable[Stream]:
return self.streams.values() return self.streams.values()
def delete_stream( def delete_stream(
@ -582,9 +623,9 @@ class KinesisBackend(BaseBackend):
stream_name: Optional[str], stream_name: Optional[str],
shard_id: str, shard_id: str,
shard_iterator_type: str, shard_iterator_type: str,
starting_sequence_number: str, starting_sequence_number: int,
at_timestamp: str, at_timestamp: datetime.datetime,
): ) -> str:
# Validate params # Validate params
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
try: try:
@ -605,31 +646,31 @@ class KinesisBackend(BaseBackend):
def get_records( def get_records(
self, stream_arn: Optional[str], shard_iterator: str, limit: Optional[int] self, stream_arn: Optional[str], shard_iterator: str, limit: Optional[int]
): ) -> Tuple[str, List[Record], int]:
decomposed = decompose_shard_iterator(shard_iterator) decomposed = decompose_shard_iterator(shard_iterator)
stream_name, shard_id, last_sequence_id = decomposed stream_name, shard_id, last_sequence_id = decomposed
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
shard = stream.get_shard(shard_id) shard = stream.get_shard(shard_id)
records, last_sequence_id, millis_behind_latest = shard.get_records( records, last_sequence_id, millis_behind_latest = shard.get_records( # type: ignore
last_sequence_id, limit last_sequence_id, limit
) )
next_shard_iterator = compose_shard_iterator( next_shard_iterator = compose_shard_iterator(
stream_name, shard, last_sequence_id stream_name, shard, last_sequence_id # type: ignore
) )
return next_shard_iterator, records, millis_behind_latest return next_shard_iterator, records, millis_behind_latest
def put_record( def put_record(
self, self,
stream_arn, stream_arn: str,
stream_name, stream_name: str,
partition_key, partition_key: str,
explicit_hash_key, explicit_hash_key: str,
data, data: str,
): ) -> Tuple[str, str]:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
sequence_number, shard_id = stream.put_record( sequence_number, shard_id = stream.put_record(
@ -638,10 +679,12 @@ class KinesisBackend(BaseBackend):
return sequence_number, shard_id return sequence_number, shard_id
def put_records(self, stream_arn, stream_name, records): def put_records(
self, stream_arn: str, stream_name: str, records: List[Dict[str, Any]]
) -> Dict[str, Any]:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
response = {"FailedRecordCount": 0, "Records": []} response: Dict[str, Any] = {"FailedRecordCount": 0, "Records": []}
if len(records) > 500: if len(records) > 500:
raise TooManyRecords raise TooManyRecords
@ -660,7 +703,7 @@ class KinesisBackend(BaseBackend):
data = record.get("Data") data = record.get("Data")
sequence_number, shard_id = stream.put_record( sequence_number, shard_id = stream.put_record(
partition_key, explicit_hash_key, data partition_key, explicit_hash_key, data # type: ignore[arg-type]
) )
response["Records"].append( response["Records"].append(
{"SequenceNumber": sequence_number, "ShardId": shard_id} {"SequenceNumber": sequence_number, "ShardId": shard_id}
@ -669,8 +712,12 @@ class KinesisBackend(BaseBackend):
return response return response
def split_shard( def split_shard(
self, stream_arn, stream_name, shard_to_split, new_starting_hash_key self,
): stream_arn: str,
stream_name: str,
shard_to_split: str,
new_starting_hash_key: str,
) -> None:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
if not re.match("[a-zA-Z0-9_.-]+", shard_to_split): if not re.match("[a-zA-Z0-9_.-]+", shard_to_split):
@ -695,8 +742,12 @@ class KinesisBackend(BaseBackend):
stream.split_shard(shard_to_split, new_starting_hash_key) stream.split_shard(shard_to_split, new_starting_hash_key)
def merge_shards( def merge_shards(
self, stream_arn, stream_name, shard_to_merge, adjacent_shard_to_merge self,
): stream_arn: str,
stream_name: str,
shard_to_merge: str,
adjacent_shard_to_merge: str,
) -> None:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
if shard_to_merge not in stream.shards: if shard_to_merge not in stream.shards:
@ -713,7 +764,9 @@ class KinesisBackend(BaseBackend):
stream.merge_shards(shard_to_merge, adjacent_shard_to_merge) stream.merge_shards(shard_to_merge, adjacent_shard_to_merge)
def update_shard_count(self, stream_arn, stream_name, target_shard_count): def update_shard_count(
self, stream_arn: str, stream_name: str, target_shard_count: int
) -> int:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
current_shard_count = len([s for s in stream.shards.values() if s.is_open]) current_shard_count = len([s for s in stream.shards.values() if s.is_open])
@ -722,7 +775,7 @@ class KinesisBackend(BaseBackend):
return current_shard_count return current_shard_count
@paginate(pagination_model=PAGINATION_MODEL) @paginate(pagination_model=PAGINATION_MODEL)
def list_shards(self, stream_arn: Optional[str], stream_name: Optional[str]): def list_shards(self, stream_arn: Optional[str], stream_name: Optional[str]) -> List[Dict[str, Any]]: # type: ignore
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
shards = sorted(stream.shards.values(), key=lambda x: x.shard_id) shards = sorted(stream.shards.values(), key=lambda x: x.shard_id)
return [shard.to_json() for shard in shards] return [shard.to_json() for shard in shards]
@ -766,12 +819,16 @@ class KinesisBackend(BaseBackend):
stream.retention_period_hours = retention_period_hours stream.retention_period_hours = retention_period_hours
def list_tags_for_stream( def list_tags_for_stream(
self, stream_arn, stream_name, exclusive_start_tag_key=None, limit=None self,
): stream_arn: str,
stream_name: str,
exclusive_start_tag_key: Optional[str] = None,
limit: Optional[int] = None,
) -> Dict[str, Any]:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
tags = [] tags: List[Dict[str, str]] = []
result = {"HasMoreTags": False, "Tags": tags} result: Dict[str, Any] = {"HasMoreTags": False, "Tags": tags}
for key, val in sorted(stream.tags.items(), key=lambda x: x[0]): for key, val in sorted(stream.tags.items(), key=lambda x: x[0]):
if limit and len(tags) >= limit: if limit and len(tags) >= limit:
result["HasMoreTags"] = True result["HasMoreTags"] = True
@ -805,7 +862,7 @@ class KinesisBackend(BaseBackend):
stream_arn: Optional[str], stream_arn: Optional[str],
stream_name: Optional[str], stream_name: Optional[str],
shard_level_metrics: List[str], shard_level_metrics: List[str],
) -> Tuple[str, str, Dict[str, Any], Dict[str, Any]]: ) -> Tuple[str, str, List[str], List[str]]:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
current_shard_level_metrics = stream.shard_level_metrics current_shard_level_metrics = stream.shard_level_metrics
desired_metrics = list(set(current_shard_level_metrics + shard_level_metrics)) desired_metrics = list(set(current_shard_level_metrics + shard_level_metrics))
@ -822,7 +879,7 @@ class KinesisBackend(BaseBackend):
stream_arn: Optional[str], stream_arn: Optional[str],
stream_name: Optional[str], stream_name: Optional[str],
to_be_disabled: List[str], to_be_disabled: List[str],
) -> Tuple[str, str, Dict[str, Any], Dict[str, Any]]: ) -> Tuple[str, str, List[str], List[str]]:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
current_metrics = stream.shard_level_metrics current_metrics = stream.shard_level_metrics
if "ALL" in to_be_disabled: if "ALL" in to_be_disabled:
@ -834,19 +891,19 @@ class KinesisBackend(BaseBackend):
stream.shard_level_metrics = desired_metrics stream.shard_level_metrics = desired_metrics
return stream.arn, stream.stream_name, current_metrics, desired_metrics return stream.arn, stream.stream_name, current_metrics, desired_metrics
def _find_stream_by_arn(self, stream_arn): def _find_stream_by_arn(self, stream_arn: str) -> Stream: # type: ignore[return]
for stream in self.streams.values(): for stream in self.streams.values():
if stream.arn == stream_arn: if stream.arn == stream_arn:
return stream return stream
def list_stream_consumers(self, stream_arn): def list_stream_consumers(self, stream_arn: str) -> List[Consumer]:
""" """
Pagination is not yet implemented Pagination is not yet implemented
""" """
stream = self._find_stream_by_arn(stream_arn) stream = self._find_stream_by_arn(stream_arn)
return stream.consumers return stream.consumers
def register_stream_consumer(self, stream_arn, consumer_name): def register_stream_consumer(self, stream_arn: str, consumer_name: str) -> Consumer:
consumer = Consumer( consumer = Consumer(
consumer_name, self.account_id, self.region_name, stream_arn consumer_name, self.account_id, self.region_name, stream_arn
) )
@ -854,7 +911,9 @@ class KinesisBackend(BaseBackend):
stream.consumers.append(consumer) stream.consumers.append(consumer)
return consumer return consumer
def describe_stream_consumer(self, stream_arn, consumer_name, consumer_arn): def describe_stream_consumer(
self, stream_arn: str, consumer_name: str, consumer_arn: str
) -> Consumer:
if stream_arn: if stream_arn:
stream = self._find_stream_by_arn(stream_arn) stream = self._find_stream_by_arn(stream_arn)
for consumer in stream.consumers: for consumer in stream.consumers:
@ -862,14 +921,16 @@ class KinesisBackend(BaseBackend):
return consumer return consumer
if consumer_arn: if consumer_arn:
for stream in self.streams.values(): for stream in self.streams.values():
consumer = stream.get_consumer_by_arn(consumer_arn) _consumer = stream.get_consumer_by_arn(consumer_arn)
if consumer: if _consumer:
return consumer return _consumer
raise ConsumerNotFound( raise ConsumerNotFound(
consumer=consumer_name or consumer_arn, account_id=self.account_id consumer=consumer_name or consumer_arn, account_id=self.account_id
) )
def deregister_stream_consumer(self, stream_arn, consumer_name, consumer_arn): def deregister_stream_consumer(
self, stream_arn: str, consumer_name: str, consumer_arn: str
) -> None:
if stream_arn: if stream_arn:
stream = self._find_stream_by_arn(stream_arn) stream = self._find_stream_by_arn(stream_arn)
stream.consumers = [ stream.consumers = [
@ -881,17 +942,19 @@ class KinesisBackend(BaseBackend):
# It will be a noop for other streams # It will be a noop for other streams
stream.delete_consumer(consumer_arn) stream.delete_consumer(consumer_arn)
def start_stream_encryption(self, stream_arn, stream_name, encryption_type, key_id): def start_stream_encryption(
self, stream_arn: str, stream_name: str, encryption_type: str, key_id: str
) -> None:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
stream.encryption_type = encryption_type stream.encryption_type = encryption_type
stream.key_id = key_id stream.key_id = key_id
def stop_stream_encryption(self, stream_arn, stream_name): def stop_stream_encryption(self, stream_arn: str, stream_name: str) -> None:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name) stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
stream.encryption_type = "NONE" stream.encryption_type = "NONE"
stream.key_id = None stream.key_id = None
def update_stream_mode(self, stream_arn, stream_mode): def update_stream_mode(self, stream_arn: str, stream_mode: Dict[str, str]) -> None:
stream = self._find_stream_by_arn(stream_arn) stream = self._find_stream_by_arn(stream_arn)
stream.stream_mode = stream_mode stream.stream_mode = stream_mode

View File

@ -5,45 +5,41 @@ from .models import kinesis_backends, KinesisBackend
class KinesisResponse(BaseResponse): class KinesisResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="kinesis") super().__init__(service_name="kinesis")
@property
def parameters(self):
return json.loads(self.body)
@property @property
def kinesis_backend(self) -> KinesisBackend: def kinesis_backend(self) -> KinesisBackend:
return kinesis_backends[self.current_account][self.region] return kinesis_backends[self.current_account][self.region]
def create_stream(self): def create_stream(self) -> str:
stream_name = self.parameters.get("StreamName") stream_name = self._get_param("StreamName")
shard_count = self.parameters.get("ShardCount") shard_count = self._get_param("ShardCount")
stream_mode = self.parameters.get("StreamModeDetails") stream_mode = self._get_param("StreamModeDetails")
self.kinesis_backend.create_stream( self.kinesis_backend.create_stream(
stream_name, shard_count, stream_mode=stream_mode stream_name, shard_count, stream_mode=stream_mode
) )
return "" return ""
def describe_stream(self): def describe_stream(self) -> str:
stream_name = self.parameters.get("StreamName") stream_name = self._get_param("StreamName")
stream_arn = self.parameters.get("StreamARN") stream_arn = self._get_param("StreamARN")
limit = self.parameters.get("Limit") limit = self._get_param("Limit")
stream = self.kinesis_backend.describe_stream(stream_arn, stream_name) stream = self.kinesis_backend.describe_stream(stream_arn, stream_name)
return json.dumps(stream.to_json(shard_limit=limit)) return json.dumps(stream.to_json(shard_limit=limit))
def describe_stream_summary(self): def describe_stream_summary(self) -> str:
stream_arn = self.parameters.get("StreamARN") stream_arn = self._get_param("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self._get_param("StreamName")
stream = self.kinesis_backend.describe_stream_summary(stream_arn, stream_name) stream = self.kinesis_backend.describe_stream_summary(stream_arn, stream_name)
return json.dumps(stream.to_json_summary()) return json.dumps(stream.to_json_summary())
def list_streams(self): def list_streams(self) -> str:
streams = self.kinesis_backend.list_streams() streams = self.kinesis_backend.list_streams()
stream_names = [stream.stream_name for stream in streams] stream_names = [stream.stream_name for stream in streams]
max_streams = self._get_param("Limit", 10) max_streams = self._get_param("Limit", 10)
try: try:
token = self.parameters.get("ExclusiveStartStreamName") token = self._get_param("ExclusiveStartStreamName")
except ValueError: except ValueError:
token = self._get_param("ExclusiveStartStreamName") token = self._get_param("ExclusiveStartStreamName")
if token: if token:
@ -59,19 +55,19 @@ class KinesisResponse(BaseResponse):
{"HasMoreStreams": has_more_streams, "StreamNames": streams_resp} {"HasMoreStreams": has_more_streams, "StreamNames": streams_resp}
) )
def delete_stream(self): def delete_stream(self) -> str:
stream_arn = self.parameters.get("StreamARN") stream_arn = self._get_param("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self._get_param("StreamName")
self.kinesis_backend.delete_stream(stream_arn, stream_name) self.kinesis_backend.delete_stream(stream_arn, stream_name)
return "" return ""
def get_shard_iterator(self): def get_shard_iterator(self) -> str:
stream_arn = self.parameters.get("StreamARN") stream_arn = self._get_param("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self._get_param("StreamName")
shard_id = self.parameters.get("ShardId") shard_id = self._get_param("ShardId")
shard_iterator_type = self.parameters.get("ShardIteratorType") shard_iterator_type = self._get_param("ShardIteratorType")
starting_sequence_number = self.parameters.get("StartingSequenceNumber") starting_sequence_number = self._get_param("StartingSequenceNumber")
at_timestamp = self.parameters.get("Timestamp") at_timestamp = self._get_param("Timestamp")
shard_iterator = self.kinesis_backend.get_shard_iterator( shard_iterator = self.kinesis_backend.get_shard_iterator(
stream_arn, stream_arn,
@ -84,10 +80,10 @@ class KinesisResponse(BaseResponse):
return json.dumps({"ShardIterator": shard_iterator}) return json.dumps({"ShardIterator": shard_iterator})
def get_records(self): def get_records(self) -> str:
stream_arn = self.parameters.get("StreamARN") stream_arn = self._get_param("StreamARN")
shard_iterator = self.parameters.get("ShardIterator") shard_iterator = self._get_param("ShardIterator")
limit = self.parameters.get("Limit") limit = self._get_param("Limit")
( (
next_shard_iterator, next_shard_iterator,
@ -103,12 +99,12 @@ class KinesisResponse(BaseResponse):
} }
) )
def put_record(self): def put_record(self) -> str:
stream_arn = self.parameters.get("StreamARN") stream_arn = self._get_param("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self._get_param("StreamName")
partition_key = self.parameters.get("PartitionKey") partition_key = self._get_param("PartitionKey")
explicit_hash_key = self.parameters.get("ExplicitHashKey") explicit_hash_key = self._get_param("ExplicitHashKey")
data = self.parameters.get("Data") data = self._get_param("Data")
sequence_number, shard_id = self.kinesis_backend.put_record( sequence_number, shard_id = self.kinesis_backend.put_record(
stream_arn, stream_arn,
@ -120,40 +116,40 @@ class KinesisResponse(BaseResponse):
return json.dumps({"SequenceNumber": sequence_number, "ShardId": shard_id}) return json.dumps({"SequenceNumber": sequence_number, "ShardId": shard_id})
def put_records(self): def put_records(self) -> str:
stream_arn = self.parameters.get("StreamARN") stream_arn = self._get_param("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self._get_param("StreamName")
records = self.parameters.get("Records") records = self._get_param("Records")
response = self.kinesis_backend.put_records(stream_arn, stream_name, records) response = self.kinesis_backend.put_records(stream_arn, stream_name, records)
return json.dumps(response) return json.dumps(response)
def split_shard(self): def split_shard(self) -> str:
stream_arn = self.parameters.get("StreamARN") stream_arn = self._get_param("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self._get_param("StreamName")
shard_to_split = self.parameters.get("ShardToSplit") shard_to_split = self._get_param("ShardToSplit")
new_starting_hash_key = self.parameters.get("NewStartingHashKey") new_starting_hash_key = self._get_param("NewStartingHashKey")
self.kinesis_backend.split_shard( self.kinesis_backend.split_shard(
stream_arn, stream_name, shard_to_split, new_starting_hash_key stream_arn, stream_name, shard_to_split, new_starting_hash_key
) )
return "" return ""
def merge_shards(self): def merge_shards(self) -> str:
stream_arn = self.parameters.get("StreamARN") stream_arn = self._get_param("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self._get_param("StreamName")
shard_to_merge = self.parameters.get("ShardToMerge") shard_to_merge = self._get_param("ShardToMerge")
adjacent_shard_to_merge = self.parameters.get("AdjacentShardToMerge") adjacent_shard_to_merge = self._get_param("AdjacentShardToMerge")
self.kinesis_backend.merge_shards( self.kinesis_backend.merge_shards(
stream_arn, stream_name, shard_to_merge, adjacent_shard_to_merge stream_arn, stream_name, shard_to_merge, adjacent_shard_to_merge
) )
return "" return ""
def list_shards(self): def list_shards(self) -> str:
stream_arn = self.parameters.get("StreamARN") stream_arn = self._get_param("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self._get_param("StreamName")
next_token = self.parameters.get("NextToken") next_token = self._get_param("NextToken")
max_results = self.parameters.get("MaxResults", 10000) max_results = self._get_param("MaxResults", 10000)
shards, token = self.kinesis_backend.list_shards( shards, token = self.kinesis_backend.list_shards(
stream_arn=stream_arn, stream_arn=stream_arn,
stream_name=stream_name, stream_name=stream_name,
@ -165,10 +161,10 @@ class KinesisResponse(BaseResponse):
res["NextToken"] = token res["NextToken"] = token
return json.dumps(res) return json.dumps(res)
def update_shard_count(self): def update_shard_count(self) -> str:
stream_arn = self.parameters.get("StreamARN") stream_arn = self._get_param("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self._get_param("StreamName")
target_shard_count = self.parameters.get("TargetShardCount") target_shard_count = self._get_param("TargetShardCount")
current_shard_count = self.kinesis_backend.update_shard_count( current_shard_count = self.kinesis_backend.update_shard_count(
stream_arn=stream_arn, stream_arn=stream_arn,
stream_name=stream_name, stream_name=stream_name,
@ -182,52 +178,52 @@ class KinesisResponse(BaseResponse):
) )
) )
def increase_stream_retention_period(self): def increase_stream_retention_period(self) -> str:
stream_arn = self.parameters.get("StreamARN") stream_arn = self._get_param("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self._get_param("StreamName")
retention_period_hours = self.parameters.get("RetentionPeriodHours") retention_period_hours = self._get_param("RetentionPeriodHours")
self.kinesis_backend.increase_stream_retention_period( self.kinesis_backend.increase_stream_retention_period(
stream_arn, stream_name, retention_period_hours stream_arn, stream_name, retention_period_hours
) )
return "" return ""
def decrease_stream_retention_period(self): def decrease_stream_retention_period(self) -> str:
stream_arn = self.parameters.get("StreamARN") stream_arn = self._get_param("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self._get_param("StreamName")
retention_period_hours = self.parameters.get("RetentionPeriodHours") retention_period_hours = self._get_param("RetentionPeriodHours")
self.kinesis_backend.decrease_stream_retention_period( self.kinesis_backend.decrease_stream_retention_period(
stream_arn, stream_name, retention_period_hours stream_arn, stream_name, retention_period_hours
) )
return "" return ""
def add_tags_to_stream(self): def add_tags_to_stream(self) -> str:
stream_arn = self.parameters.get("StreamARN") stream_arn = self._get_param("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self._get_param("StreamName")
tags = self.parameters.get("Tags") tags = self._get_param("Tags")
self.kinesis_backend.add_tags_to_stream(stream_arn, stream_name, tags) self.kinesis_backend.add_tags_to_stream(stream_arn, stream_name, tags)
return json.dumps({}) return json.dumps({})
def list_tags_for_stream(self): def list_tags_for_stream(self) -> str:
stream_arn = self.parameters.get("StreamARN") stream_arn = self._get_param("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self._get_param("StreamName")
exclusive_start_tag_key = self.parameters.get("ExclusiveStartTagKey") exclusive_start_tag_key = self._get_param("ExclusiveStartTagKey")
limit = self.parameters.get("Limit") limit = self._get_param("Limit")
response = self.kinesis_backend.list_tags_for_stream( response = self.kinesis_backend.list_tags_for_stream(
stream_arn, stream_name, exclusive_start_tag_key, limit stream_arn, stream_name, exclusive_start_tag_key, limit
) )
return json.dumps(response) return json.dumps(response)
def remove_tags_from_stream(self): def remove_tags_from_stream(self) -> str:
stream_arn = self.parameters.get("StreamARN") stream_arn = self._get_param("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self._get_param("StreamName")
tag_keys = self.parameters.get("TagKeys") tag_keys = self._get_param("TagKeys")
self.kinesis_backend.remove_tags_from_stream(stream_arn, stream_name, tag_keys) self.kinesis_backend.remove_tags_from_stream(stream_arn, stream_name, tag_keys)
return json.dumps({}) return json.dumps({})
def enable_enhanced_monitoring(self): def enable_enhanced_monitoring(self) -> str:
stream_arn = self.parameters.get("StreamARN") stream_arn = self._get_param("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self._get_param("StreamName")
shard_level_metrics = self.parameters.get("ShardLevelMetrics") shard_level_metrics = self._get_param("ShardLevelMetrics")
arn, name, current, desired = self.kinesis_backend.enable_enhanced_monitoring( arn, name, current, desired = self.kinesis_backend.enable_enhanced_monitoring(
stream_arn=stream_arn, stream_arn=stream_arn,
stream_name=stream_name, stream_name=stream_name,
@ -242,10 +238,10 @@ class KinesisResponse(BaseResponse):
) )
) )
def disable_enhanced_monitoring(self): def disable_enhanced_monitoring(self) -> str:
stream_arn = self.parameters.get("StreamARN") stream_arn = self._get_param("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self._get_param("StreamName")
shard_level_metrics = self.parameters.get("ShardLevelMetrics") shard_level_metrics = self._get_param("ShardLevelMetrics")
arn, name, current, desired = self.kinesis_backend.disable_enhanced_monitoring( arn, name, current, desired = self.kinesis_backend.disable_enhanced_monitoring(
stream_arn=stream_arn, stream_arn=stream_arn,
stream_name=stream_name, stream_name=stream_name,
@ -260,23 +256,23 @@ class KinesisResponse(BaseResponse):
) )
) )
def list_stream_consumers(self): def list_stream_consumers(self) -> str:
stream_arn = self.parameters.get("StreamARN") stream_arn = self._get_param("StreamARN")
consumers = self.kinesis_backend.list_stream_consumers(stream_arn=stream_arn) consumers = self.kinesis_backend.list_stream_consumers(stream_arn=stream_arn)
return json.dumps(dict(Consumers=[c.to_json() for c in consumers])) return json.dumps(dict(Consumers=[c.to_json() for c in consumers]))
def register_stream_consumer(self): def register_stream_consumer(self) -> str:
stream_arn = self.parameters.get("StreamARN") stream_arn = self._get_param("StreamARN")
consumer_name = self.parameters.get("ConsumerName") consumer_name = self._get_param("ConsumerName")
consumer = self.kinesis_backend.register_stream_consumer( consumer = self.kinesis_backend.register_stream_consumer(
stream_arn=stream_arn, consumer_name=consumer_name stream_arn=stream_arn, consumer_name=consumer_name
) )
return json.dumps(dict(Consumer=consumer.to_json())) return json.dumps(dict(Consumer=consumer.to_json()))
def describe_stream_consumer(self): def describe_stream_consumer(self) -> str:
stream_arn = self.parameters.get("StreamARN") stream_arn = self._get_param("StreamARN")
consumer_name = self.parameters.get("ConsumerName") consumer_name = self._get_param("ConsumerName")
consumer_arn = self.parameters.get("ConsumerARN") consumer_arn = self._get_param("ConsumerARN")
consumer = self.kinesis_backend.describe_stream_consumer( consumer = self.kinesis_backend.describe_stream_consumer(
stream_arn=stream_arn, stream_arn=stream_arn,
consumer_name=consumer_name, consumer_name=consumer_name,
@ -286,10 +282,10 @@ class KinesisResponse(BaseResponse):
dict(ConsumerDescription=consumer.to_json(include_stream_arn=True)) dict(ConsumerDescription=consumer.to_json(include_stream_arn=True))
) )
def deregister_stream_consumer(self): def deregister_stream_consumer(self) -> str:
stream_arn = self.parameters.get("StreamARN") stream_arn = self._get_param("StreamARN")
consumer_name = self.parameters.get("ConsumerName") consumer_name = self._get_param("ConsumerName")
consumer_arn = self.parameters.get("ConsumerARN") consumer_arn = self._get_param("ConsumerARN")
self.kinesis_backend.deregister_stream_consumer( self.kinesis_backend.deregister_stream_consumer(
stream_arn=stream_arn, stream_arn=stream_arn,
consumer_name=consumer_name, consumer_name=consumer_name,
@ -297,11 +293,11 @@ class KinesisResponse(BaseResponse):
) )
return json.dumps(dict()) return json.dumps(dict())
def start_stream_encryption(self): def start_stream_encryption(self) -> str:
stream_arn = self.parameters.get("StreamARN") stream_arn = self._get_param("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self._get_param("StreamName")
encryption_type = self.parameters.get("EncryptionType") encryption_type = self._get_param("EncryptionType")
key_id = self.parameters.get("KeyId") key_id = self._get_param("KeyId")
self.kinesis_backend.start_stream_encryption( self.kinesis_backend.start_stream_encryption(
stream_arn=stream_arn, stream_arn=stream_arn,
stream_name=stream_name, stream_name=stream_name,
@ -310,16 +306,16 @@ class KinesisResponse(BaseResponse):
) )
return json.dumps(dict()) return json.dumps(dict())
def stop_stream_encryption(self): def stop_stream_encryption(self) -> str:
stream_arn = self.parameters.get("StreamARN") stream_arn = self._get_param("StreamARN")
stream_name = self.parameters.get("StreamName") stream_name = self._get_param("StreamName")
self.kinesis_backend.stop_stream_encryption( self.kinesis_backend.stop_stream_encryption(
stream_arn=stream_arn, stream_name=stream_name stream_arn=stream_arn, stream_name=stream_name
) )
return json.dumps(dict()) return json.dumps(dict())
def update_stream_mode(self): def update_stream_mode(self) -> str:
stream_arn = self.parameters.get("StreamARN") stream_arn = self._get_param("StreamARN")
stream_mode = self.parameters.get("StreamModeDetails") stream_mode = self._get_param("StreamModeDetails")
self.kinesis_backend.update_stream_mode(stream_arn, stream_mode) self.kinesis_backend.update_stream_mode(stream_arn, stream_mode)
return "{}" return "{}"

View File

@ -1,4 +1,6 @@
import base64 import base64
from datetime import datetime
from typing import Any, Optional, List
from .exceptions import InvalidArgumentError from .exceptions import InvalidArgumentError
@ -30,8 +32,12 @@ PAGINATION_MODEL = {
def compose_new_shard_iterator( def compose_new_shard_iterator(
stream_name, shard, shard_iterator_type, starting_sequence_number, at_timestamp stream_name: Optional[str],
): shard: Any,
shard_iterator_type: str,
starting_sequence_number: int,
at_timestamp: datetime,
) -> str:
if shard_iterator_type == "AT_SEQUENCE_NUMBER": if shard_iterator_type == "AT_SEQUENCE_NUMBER":
last_sequence_id = int(starting_sequence_number) - 1 last_sequence_id = int(starting_sequence_number) - 1
elif shard_iterator_type == "AFTER_SEQUENCE_NUMBER": elif shard_iterator_type == "AFTER_SEQUENCE_NUMBER":
@ -47,11 +53,13 @@ def compose_new_shard_iterator(
return compose_shard_iterator(stream_name, shard, last_sequence_id) return compose_shard_iterator(stream_name, shard, last_sequence_id)
def compose_shard_iterator(stream_name, shard, last_sequence_id): def compose_shard_iterator(
stream_name: Optional[str], shard: Any, last_sequence_id: int
) -> str:
return encode_method( return encode_method(
f"{stream_name}:{shard.shard_id}:{last_sequence_id}".encode("utf-8") f"{stream_name}:{shard.shard_id}:{last_sequence_id}".encode("utf-8")
).decode("utf-8") ).decode("utf-8")
def decompose_shard_iterator(shard_iterator): def decompose_shard_iterator(shard_iterator: str) -> List[str]:
return decode_method(shard_iterator.encode("utf-8")).decode("utf-8").split(":") return decode_method(shard_iterator.encode("utf-8")).decode("utf-8").split(":")

View File

@ -6,7 +6,7 @@ class KinesisvideoClientError(RESTError):
class ResourceNotFoundException(KinesisvideoClientError): class ResourceNotFoundException(KinesisvideoClientError):
def __init__(self): def __init__(self) -> None:
self.code = 404 self.code = 404
super().__init__( super().__init__(
"ResourceNotFoundException", "ResourceNotFoundException",
@ -15,6 +15,6 @@ class ResourceNotFoundException(KinesisvideoClientError):
class ResourceInUseException(KinesisvideoClientError): class ResourceInUseException(KinesisvideoClientError):
def __init__(self, message): def __init__(self, message: str):
self.code = 400 self.code = 400
super().__init__("ResourceInUseException", message) super().__init__("ResourceInUseException", message)

View File

@ -1,5 +1,6 @@
from moto.core import BaseBackend, BackendDict, BaseModel
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List
from moto.core import BaseBackend, BackendDict, BaseModel
from .exceptions import ResourceNotFoundException, ResourceInUseException from .exceptions import ResourceNotFoundException, ResourceInUseException
from moto.moto_api._internal import mock_random as random from moto.moto_api._internal import mock_random as random
@ -7,14 +8,14 @@ from moto.moto_api._internal import mock_random as random
class Stream(BaseModel): class Stream(BaseModel):
def __init__( def __init__(
self, self,
account_id, account_id: str,
region_name, region_name: str,
device_name, device_name: str,
stream_name, stream_name: str,
media_type, media_type: str,
kms_key_id, kms_key_id: str,
data_retention_in_hours, data_retention_in_hours: int,
tags, tags: Dict[str, str],
): ):
self.region_name = region_name self.region_name = region_name
self.stream_name = stream_name self.stream_name = stream_name
@ -30,11 +31,11 @@ class Stream(BaseModel):
self.data_endpoint_number = random.get_random_hex() self.data_endpoint_number = random.get_random_hex()
self.arn = stream_arn self.arn = stream_arn
def get_data_endpoint(self, api_name): def get_data_endpoint(self, api_name: str) -> str:
data_endpoint_prefix = "s-" if api_name in ("PUT_MEDIA", "GET_MEDIA") else "b-" data_endpoint_prefix = "s-" if api_name in ("PUT_MEDIA", "GET_MEDIA") else "b-"
return f"https://{data_endpoint_prefix}{self.data_endpoint_number}.kinesisvideo.{self.region_name}.amazonaws.com" return f"https://{data_endpoint_prefix}{self.data_endpoint_number}.kinesisvideo.{self.region_name}.amazonaws.com"
def to_dict(self): def to_dict(self) -> Dict[str, Any]:
return { return {
"DeviceName": self.device_name, "DeviceName": self.device_name,
"StreamName": self.stream_name, "StreamName": self.stream_name,
@ -49,19 +50,19 @@ class Stream(BaseModel):
class KinesisVideoBackend(BaseBackend): class KinesisVideoBackend(BaseBackend):
def __init__(self, region_name, account_id): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.streams = {} self.streams: Dict[str, Stream] = {}
def create_stream( def create_stream(
self, self,
device_name, device_name: str,
stream_name, stream_name: str,
media_type, media_type: str,
kms_key_id, kms_key_id: str,
data_retention_in_hours, data_retention_in_hours: int,
tags, tags: Dict[str, str],
): ) -> str:
streams = [_ for _ in self.streams.values() if _.stream_name == stream_name] streams = [_ for _ in self.streams.values() if _.stream_name == stream_name]
if len(streams) > 0: if len(streams) > 0:
raise ResourceInUseException(f"The stream {stream_name} already exists.") raise ResourceInUseException(f"The stream {stream_name} already exists.")
@ -78,7 +79,7 @@ class KinesisVideoBackend(BaseBackend):
self.streams[stream.arn] = stream self.streams[stream.arn] = stream
return stream.arn return stream.arn
def _get_stream(self, stream_name, stream_arn): def _get_stream(self, stream_name: str, stream_arn: str) -> Stream:
if stream_name: if stream_name:
streams = [_ for _ in self.streams.values() if _.stream_name == stream_name] streams = [_ for _ in self.streams.values() if _.stream_name == stream_name]
if len(streams) == 0: if len(streams) == 0:
@ -90,20 +91,17 @@ class KinesisVideoBackend(BaseBackend):
raise ResourceNotFoundException() raise ResourceNotFoundException()
return stream return stream
def describe_stream(self, stream_name, stream_arn): def describe_stream(self, stream_name: str, stream_arn: str) -> Dict[str, Any]:
stream = self._get_stream(stream_name, stream_arn) stream = self._get_stream(stream_name, stream_arn)
stream_info = stream.to_dict() return stream.to_dict()
return stream_info
def list_streams(self): def list_streams(self) -> List[Dict[str, Any]]:
""" """
Pagination and the StreamNameCondition-parameter are not yet implemented Pagination and the StreamNameCondition-parameter are not yet implemented
""" """
stream_info_list = [_.to_dict() for _ in self.streams.values()] return [_.to_dict() for _ in self.streams.values()]
next_token = None
return stream_info_list, next_token
def delete_stream(self, stream_arn): def delete_stream(self, stream_arn: str) -> None:
""" """
The CurrentVersion-parameter is not yet implemented The CurrentVersion-parameter is not yet implemented
""" """
@ -112,11 +110,11 @@ class KinesisVideoBackend(BaseBackend):
raise ResourceNotFoundException() raise ResourceNotFoundException()
del self.streams[stream_arn] del self.streams[stream_arn]
def get_data_endpoint(self, stream_name, stream_arn, api_name): def get_data_endpoint(
self, stream_name: str, stream_arn: str, api_name: str
) -> str:
stream = self._get_stream(stream_name, stream_arn) stream = self._get_stream(stream_name, stream_arn)
return stream.get_data_endpoint(api_name) return stream.get_data_endpoint(api_name)
# add methods from here
kinesisvideo_backends = BackendDict(KinesisVideoBackend, "kinesisvideo") kinesisvideo_backends = BackendDict(KinesisVideoBackend, "kinesisvideo")

View File

@ -1,17 +1,17 @@
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import kinesisvideo_backends from .models import kinesisvideo_backends, KinesisVideoBackend
import json import json
class KinesisVideoResponse(BaseResponse): class KinesisVideoResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="kinesisvideo") super().__init__(service_name="kinesisvideo")
@property @property
def kinesisvideo_backend(self): def kinesisvideo_backend(self) -> KinesisVideoBackend:
return kinesisvideo_backends[self.current_account][self.region] return kinesisvideo_backends[self.current_account][self.region]
def create_stream(self): def create_stream(self) -> str:
device_name = self._get_param("DeviceName") device_name = self._get_param("DeviceName")
stream_name = self._get_param("StreamName") stream_name = self._get_param("StreamName")
media_type = self._get_param("MediaType") media_type = self._get_param("MediaType")
@ -28,7 +28,7 @@ class KinesisVideoResponse(BaseResponse):
) )
return json.dumps(dict(StreamARN=stream_arn)) return json.dumps(dict(StreamARN=stream_arn))
def describe_stream(self): def describe_stream(self) -> str:
stream_name = self._get_param("StreamName") stream_name = self._get_param("StreamName")
stream_arn = self._get_param("StreamARN") stream_arn = self._get_param("StreamARN")
stream_info = self.kinesisvideo_backend.describe_stream( stream_info = self.kinesisvideo_backend.describe_stream(
@ -36,16 +36,16 @@ class KinesisVideoResponse(BaseResponse):
) )
return json.dumps(dict(StreamInfo=stream_info)) return json.dumps(dict(StreamInfo=stream_info))
def list_streams(self): def list_streams(self) -> str:
stream_info_list, next_token = self.kinesisvideo_backend.list_streams() stream_info_list = self.kinesisvideo_backend.list_streams()
return json.dumps(dict(StreamInfoList=stream_info_list, NextToken=next_token)) return json.dumps(dict(StreamInfoList=stream_info_list, NextToken=None))
def delete_stream(self): def delete_stream(self) -> str:
stream_arn = self._get_param("StreamARN") stream_arn = self._get_param("StreamARN")
self.kinesisvideo_backend.delete_stream(stream_arn=stream_arn) self.kinesisvideo_backend.delete_stream(stream_arn=stream_arn)
return json.dumps(dict()) return json.dumps(dict())
def get_data_endpoint(self): def get_data_endpoint(self) -> str:
stream_name = self._get_param("StreamName") stream_name = self._get_param("StreamName")
stream_arn = self._get_param("StreamARN") stream_arn = self._get_param("StreamARN")
api_name = self._get_param("APIName") api_name = self._get_param("APIName")

View File

@ -1,14 +1,17 @@
from typing import Tuple
from moto.core import BaseBackend, BackendDict from moto.core import BaseBackend, BackendDict
from moto.kinesisvideo import kinesisvideo_backends from moto.kinesisvideo.models import kinesisvideo_backends, KinesisVideoBackend
from moto.sts.utils import random_session_token from moto.sts.utils import random_session_token
class KinesisVideoArchivedMediaBackend(BaseBackend): class KinesisVideoArchivedMediaBackend(BaseBackend):
@property @property
def backend(self): def backend(self) -> KinesisVideoBackend:
return kinesisvideo_backends[self.account_id][self.region_name] return kinesisvideo_backends[self.account_id][self.region_name]
def _get_streaming_url(self, stream_name, stream_arn, api_name): def _get_streaming_url(
self, stream_name: str, stream_arn: str, api_name: str
) -> str:
stream = self.backend._get_stream(stream_name, stream_arn) stream = self.backend._get_stream(stream_name, stream_arn)
data_endpoint = stream.get_data_endpoint(api_name) data_endpoint = stream.get_data_endpoint(api_name)
session_token = random_session_token() session_token = random_session_token()
@ -19,19 +22,17 @@ class KinesisVideoArchivedMediaBackend(BaseBackend):
relative_path = api_to_relative_path[api_name] relative_path = api_to_relative_path[api_name]
return f"{data_endpoint}{relative_path}?SessionToken={session_token}" return f"{data_endpoint}{relative_path}?SessionToken={session_token}"
def get_hls_streaming_session_url(self, stream_name, stream_arn): def get_hls_streaming_session_url(self, stream_name: str, stream_arn: str) -> str:
# Ignore option paramters as the format of hls_url does't depends on them # Ignore option paramters as the format of hls_url doesn't depend on them
api_name = "GET_HLS_STREAMING_SESSION_URL" api_name = "GET_HLS_STREAMING_SESSION_URL"
url = self._get_streaming_url(stream_name, stream_arn, api_name) return self._get_streaming_url(stream_name, stream_arn, api_name)
return url
def get_dash_streaming_session_url(self, stream_name, stream_arn): def get_dash_streaming_session_url(self, stream_name: str, stream_arn: str) -> str:
# Ignore option paramters as the format of hls_url does't depends on them # Ignore option paramters as the format of hls_url doesn't depend on them
api_name = "GET_DASH_STREAMING_SESSION_URL" api_name = "GET_DASH_STREAMING_SESSION_URL"
url = self._get_streaming_url(stream_name, stream_arn, api_name) return self._get_streaming_url(stream_name, stream_arn, api_name)
return url
def get_clip(self, stream_name, stream_arn): def get_clip(self, stream_name: str, stream_arn: str) -> Tuple[str, bytes]:
self.backend._get_stream(stream_name, stream_arn) self.backend._get_stream(stream_name, stream_arn)
content_type = "video/mp4" # Fixed content_type as it depends on input stream content_type = "video/mp4" # Fixed content_type as it depends on input stream
payload = b"sample-mp4-video" payload = b"sample-mp4-video"

View File

@ -1,17 +1,18 @@
from typing import Dict, Tuple
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import kinesisvideoarchivedmedia_backends from .models import kinesisvideoarchivedmedia_backends, KinesisVideoArchivedMediaBackend
import json import json
class KinesisVideoArchivedMediaResponse(BaseResponse): class KinesisVideoArchivedMediaResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="kinesis-video-archived-media") super().__init__(service_name="kinesis-video-archived-media")
@property @property
def kinesisvideoarchivedmedia_backend(self): def kinesisvideoarchivedmedia_backend(self) -> KinesisVideoArchivedMediaBackend:
return kinesisvideoarchivedmedia_backends[self.current_account][self.region] return kinesisvideoarchivedmedia_backends[self.current_account][self.region]
def get_hls_streaming_session_url(self): def get_hls_streaming_session_url(self) -> str:
stream_name = self._get_param("StreamName") stream_name = self._get_param("StreamName")
stream_arn = self._get_param("StreamARN") stream_arn = self._get_param("StreamARN")
hls_streaming_session_url = ( hls_streaming_session_url = (
@ -21,7 +22,7 @@ class KinesisVideoArchivedMediaResponse(BaseResponse):
) )
return json.dumps(dict(HLSStreamingSessionURL=hls_streaming_session_url)) return json.dumps(dict(HLSStreamingSessionURL=hls_streaming_session_url))
def get_dash_streaming_session_url(self): def get_dash_streaming_session_url(self) -> str:
stream_name = self._get_param("StreamName") stream_name = self._get_param("StreamName")
stream_arn = self._get_param("StreamARN") stream_arn = self._get_param("StreamARN")
dash_streaming_session_url = ( dash_streaming_session_url = (
@ -31,7 +32,7 @@ class KinesisVideoArchivedMediaResponse(BaseResponse):
) )
return json.dumps(dict(DASHStreamingSessionURL=dash_streaming_session_url)) return json.dumps(dict(DASHStreamingSessionURL=dash_streaming_session_url))
def get_clip(self): def get_clip(self) -> Tuple[bytes, Dict[str, str]]:
stream_name = self._get_param("StreamName") stream_name = self._get_param("StreamName")
stream_arn = self._get_param("StreamARN") stream_arn = self._get_param("StreamARN")
content_type, payload = self.kinesisvideoarchivedmedia_backend.get_clip( content_type, payload = self.kinesisvideoarchivedmedia_backend.get_clip(

View File

@ -4,29 +4,29 @@ from moto.core.exceptions import JsonRESTError
class NotFoundException(JsonRESTError): class NotFoundException(JsonRESTError):
code = 400 code = 400
def __init__(self, message): def __init__(self, message: str):
super().__init__("NotFoundException", message) super().__init__("NotFoundException", message)
class ValidationException(JsonRESTError): class ValidationException(JsonRESTError):
code = 400 code = 400
def __init__(self, message): def __init__(self, message: str):
super().__init__("ValidationException", message) super().__init__("ValidationException", message)
class AlreadyExistsException(JsonRESTError): class AlreadyExistsException(JsonRESTError):
code = 400 code = 400
def __init__(self, message): def __init__(self, message: str):
super().__init__("AlreadyExistsException", message) super().__init__("AlreadyExistsException", message)
class NotAuthorizedException(JsonRESTError): class NotAuthorizedException(JsonRESTError):
code = 400 code = 400
def __init__(self): def __init__(self) -> None:
super().__init__("NotAuthorizedException", None) super().__init__("NotAuthorizedException", "")
self.description = '{"__type":"NotAuthorizedException"}' self.description = '{"__type":"NotAuthorizedException"}'
@ -34,7 +34,7 @@ class NotAuthorizedException(JsonRESTError):
class AccessDeniedException(JsonRESTError): class AccessDeniedException(JsonRESTError):
code = 400 code = 400
def __init__(self, message): def __init__(self, message: str):
super().__init__("AccessDeniedException", message) super().__init__("AccessDeniedException", message)
self.description = '{"__type":"AccessDeniedException"}' self.description = '{"__type":"AccessDeniedException"}'
@ -43,7 +43,7 @@ class AccessDeniedException(JsonRESTError):
class InvalidCiphertextException(JsonRESTError): class InvalidCiphertextException(JsonRESTError):
code = 400 code = 400
def __init__(self): def __init__(self) -> None:
super().__init__("InvalidCiphertextException", None) super().__init__("InvalidCiphertextException", "")
self.description = '{"__type":"InvalidCiphertextException"}' self.description = '{"__type":"InvalidCiphertextException"}'

View File

@ -5,8 +5,8 @@ from copy import copy
from datetime import datetime, timedelta from datetime import datetime, timedelta
from cryptography.exceptions import InvalidSignature from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding from cryptography.hazmat.primitives.asymmetric import padding
from typing import Any, Dict, List, Tuple, Optional, Iterable, Set
from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel
from moto.core.utils import unix_time from moto.core.utils import unix_time
@ -28,12 +28,12 @@ from .utils import (
class Grant(BaseModel): class Grant(BaseModel):
def __init__( def __init__(
self, self,
key_id, key_id: str,
name, name: str,
grantee_principal, grantee_principal: str,
operations, operations: List[str],
constraints, constraints: Dict[str, Any],
retiring_principal, retiring_principal: str,
): ):
self.key_id = key_id self.key_id = key_id
self.name = name self.name = name
@ -44,7 +44,7 @@ class Grant(BaseModel):
self.id = mock_random.get_random_hex() self.id = mock_random.get_random_hex()
self.token = mock_random.get_random_hex() self.token = mock_random.get_random_hex()
def to_json(self): def to_json(self) -> Dict[str, Any]:
return { return {
"KeyId": self.key_id, "KeyId": self.key_id,
"GrantId": self.id, "GrantId": self.id,
@ -59,13 +59,13 @@ class Grant(BaseModel):
class Key(CloudFormationModel): class Key(CloudFormationModel):
def __init__( def __init__(
self, self,
policy, policy: Optional[str],
key_usage, key_usage: str,
key_spec, key_spec: str,
description, description: str,
account_id, account_id: str,
region, region: str,
multi_region=False, multi_region: bool = False,
): ):
self.id = generate_key_id(multi_region) self.id = generate_key_id(multi_region)
self.creation_date = unix_time() self.creation_date = unix_time()
@ -78,7 +78,7 @@ class Key(CloudFormationModel):
self.region = region self.region = region
self.multi_region = multi_region self.multi_region = multi_region
self.key_rotation_status = False self.key_rotation_status = False
self.deletion_date = None self.deletion_date: Optional[datetime] = None
self.key_material = generate_master_key() self.key_material = generate_master_key()
self.private_key = generate_private_key() self.private_key = generate_private_key()
self.origin = "AWS_KMS" self.origin = "AWS_KMS"
@ -86,10 +86,15 @@ class Key(CloudFormationModel):
self.key_spec = key_spec or "SYMMETRIC_DEFAULT" self.key_spec = key_spec or "SYMMETRIC_DEFAULT"
self.arn = f"arn:aws:kms:{region}:{account_id}:key/{self.id}" self.arn = f"arn:aws:kms:{region}:{account_id}:key/{self.id}"
self.grants = dict() self.grants: Dict[str, Grant] = dict()
def add_grant( def add_grant(
self, name, grantee_principal, operations, constraints, retiring_principal self,
name: str,
grantee_principal: str,
operations: List[str],
constraints: Dict[str, Any],
retiring_principal: str,
) -> Grant: ) -> Grant:
grant = Grant( grant = Grant(
self.id, self.id,
@ -102,32 +107,32 @@ class Key(CloudFormationModel):
self.grants[grant.id] = grant self.grants[grant.id] = grant
return grant return grant
def list_grants(self, grant_id) -> [Grant]: def list_grants(self, grant_id: str) -> List[Grant]:
grant_ids = [grant_id] if grant_id else self.grants.keys() grant_ids = [grant_id] if grant_id else self.grants.keys()
return [grant for _id, grant in self.grants.items() if _id in grant_ids] return [grant for _id, grant in self.grants.items() if _id in grant_ids]
def list_retirable_grants(self, retiring_principal) -> [Grant]: def list_retirable_grants(self, retiring_principal: str) -> List[Grant]:
return [ return [
grant grant
for grant in self.grants.values() for grant in self.grants.values()
if grant.retiring_principal == retiring_principal if grant.retiring_principal == retiring_principal
] ]
def revoke_grant(self, grant_id) -> None: def revoke_grant(self, grant_id: str) -> None:
if not self.grants.pop(grant_id, None): if not self.grants.pop(grant_id, None):
raise JsonRESTError("NotFoundException", f"Grant ID {grant_id} not found") raise JsonRESTError("NotFoundException", f"Grant ID {grant_id} not found")
def retire_grant(self, grant_id) -> None: def retire_grant(self, grant_id: str) -> None:
self.grants.pop(grant_id, None) self.grants.pop(grant_id, None)
def retire_grant_by_token(self, grant_token) -> None: def retire_grant_by_token(self, grant_token: str) -> None:
self.grants = { self.grants = {
_id: grant _id: grant
for _id, grant in self.grants.items() for _id, grant in self.grants.items()
if grant.token != grant_token if grant.token != grant_token
} }
def generate_default_policy(self): def generate_default_policy(self) -> str:
return json.dumps( return json.dumps(
{ {
"Version": "2012-10-17", "Version": "2012-10-17",
@ -145,11 +150,11 @@ class Key(CloudFormationModel):
) )
@property @property
def physical_resource_id(self): def physical_resource_id(self) -> str:
return self.id return self.id
@property @property
def encryption_algorithms(self): def encryption_algorithms(self) -> Optional[List[str]]:
if self.key_usage == "SIGN_VERIFY": if self.key_usage == "SIGN_VERIFY":
return None return None
elif self.key_spec == "SYMMETRIC_DEFAULT": elif self.key_spec == "SYMMETRIC_DEFAULT":
@ -158,9 +163,9 @@ class Key(CloudFormationModel):
return ["RSAES_OAEP_SHA_1", "RSAES_OAEP_SHA_256"] return ["RSAES_OAEP_SHA_1", "RSAES_OAEP_SHA_256"]
@property @property
def signing_algorithms(self): def signing_algorithms(self) -> List[str]:
if self.key_usage == "ENCRYPT_DECRYPT": if self.key_usage == "ENCRYPT_DECRYPT":
return None return None # type: ignore[return-value]
elif self.key_spec in ["ECC_NIST_P256", "ECC_SECG_P256K1"]: elif self.key_spec in ["ECC_NIST_P256", "ECC_SECG_P256K1"]:
return ["ECDSA_SHA_256"] return ["ECDSA_SHA_256"]
elif self.key_spec == "ECC_NIST_P384": elif self.key_spec == "ECC_NIST_P384":
@ -177,7 +182,7 @@ class Key(CloudFormationModel):
"RSASSA_PSS_SHA_512", "RSASSA_PSS_SHA_512",
] ]
def to_dict(self): def to_dict(self) -> Dict[str, Any]:
key_dict = { key_dict = {
"KeyMetadata": { "KeyMetadata": {
"AWSAccountId": self.account_id, "AWSAccountId": self.account_id,
@ -201,22 +206,27 @@ class Key(CloudFormationModel):
key_dict["KeyMetadata"]["DeletionDate"] = unix_time(self.deletion_date) key_dict["KeyMetadata"]["DeletionDate"] = unix_time(self.deletion_date)
return key_dict return key_dict
def delete(self, account_id, region_name): def delete(self, account_id: str, region_name: str) -> None:
kms_backends[account_id][region_name].delete_key(self.id) kms_backends[account_id][region_name].delete_key(self.id)
@staticmethod @staticmethod
def cloudformation_name_type(): def cloudformation_name_type() -> str:
return None return ""
@staticmethod @staticmethod
def cloudformation_type(): def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-kms-key.html # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-kms-key.html
return "AWS::KMS::Key" return "AWS::KMS::Key"
@classmethod @classmethod
def create_from_cloudformation_json( def create_from_cloudformation_json( # type: ignore[misc]
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs cls,
): resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
**kwargs: Any,
) -> "Key":
kms_backend = kms_backends[account_id][region_name] kms_backend = kms_backends[account_id][region_name]
properties = cloudformation_json["Properties"] properties = cloudformation_json["Properties"]
@ -233,10 +243,10 @@ class Key(CloudFormationModel):
return key return key
@classmethod @classmethod
def has_cfn_attr(cls, attr): def has_cfn_attr(cls, attr: str) -> bool:
return attr in ["Arn"] return attr in ["Arn"]
def get_cfn_attribute(self, attribute_name): def get_cfn_attribute(self, attribute_name: str) -> str:
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == "Arn": if attribute_name == "Arn":
@ -245,20 +255,22 @@ class Key(CloudFormationModel):
class KmsBackend(BaseBackend): class KmsBackend(BaseBackend):
def __init__(self, region_name, account_id=None): def __init__(self, region_name: str, account_id: Optional[str] = None):
super().__init__(region_name=region_name, account_id=account_id) super().__init__(region_name=region_name, account_id=account_id) # type: ignore
self.keys = {} self.keys: Dict[str, Key] = {}
self.key_to_aliases = defaultdict(set) self.key_to_aliases: Dict[str, Set[str]] = defaultdict(set)
self.tagger = TaggingService(key_name="TagKey", value_name="TagValue") self.tagger = TaggingService(key_name="TagKey", value_name="TagValue")
@staticmethod @staticmethod
def default_vpc_endpoint_service(service_region, zones): def default_vpc_endpoint_service(
service_region: str, zones: List[str]
) -> List[Dict[str, str]]:
"""Default VPC endpoint service.""" """Default VPC endpoint service."""
return BaseBackend.default_vpc_endpoint_service_factory( return BaseBackend.default_vpc_endpoint_service_factory(
service_region, zones, "kms" service_region, zones, "kms"
) )
def _generate_default_keys(self, alias_name): def _generate_default_keys(self, alias_name: str) -> Optional[str]:
"""Creates default kms keys""" """Creates default kms keys"""
if alias_name in RESERVED_ALIASES: if alias_name in RESERVED_ALIASES:
key = self.create_key( key = self.create_key(
@ -270,10 +282,17 @@ class KmsBackend(BaseBackend):
) )
self.add_alias(key.id, alias_name) self.add_alias(key.id, alias_name)
return key.id return key.id
return None
def create_key( def create_key(
self, policy, key_usage, key_spec, description, tags, multi_region=False self,
): policy: Optional[str],
key_usage: str,
key_spec: str,
description: str,
tags: Optional[List[Dict[str, str]]],
multi_region: bool = False,
) -> Key:
""" """
The provided Policy currently does not need to be valid. If it is valid, Moto will perform authorization checks on key-related operations, just like AWS does. The provided Policy currently does not need to be valid. If it is valid, Moto will perform authorization checks on key-related operations, just like AWS does.
@ -303,7 +322,7 @@ class KmsBackend(BaseBackend):
# #
# In our implementation with just create a copy of all the properties once without any protection from change, # In our implementation with just create a copy of all the properties once without any protection from change,
# as the exact implementation is currently infeasible. # as the exact implementation is currently infeasible.
def replicate_key(self, key_id, replica_region): def replicate_key(self, key_id: str, replica_region: str) -> None:
# Using copy() instead of deepcopy(), as the latter results in exception: # Using copy() instead of deepcopy(), as the latter results in exception:
# TypeError: cannot pickle '_cffi_backend.FFI' object # TypeError: cannot pickle '_cffi_backend.FFI' object
# Since we only update top level properties, copy() should suffice. # Since we only update top level properties, copy() should suffice.
@ -312,31 +331,31 @@ class KmsBackend(BaseBackend):
to_region_backend = kms_backends[self.account_id][replica_region] to_region_backend = kms_backends[self.account_id][replica_region]
to_region_backend.keys[replica_key.id] = replica_key to_region_backend.keys[replica_key.id] = replica_key
def update_key_description(self, key_id, description): def update_key_description(self, key_id: str, description: str) -> None:
key = self.keys[self.get_key_id(key_id)] key = self.keys[self.get_key_id(key_id)]
key.description = description key.description = description
def delete_key(self, key_id): def delete_key(self, key_id: str) -> None:
if key_id in self.keys: if key_id in self.keys:
if key_id in self.key_to_aliases: if key_id in self.key_to_aliases:
self.key_to_aliases.pop(key_id) self.key_to_aliases.pop(key_id)
self.tagger.delete_all_tags_for_resource(key_id) self.tagger.delete_all_tags_for_resource(key_id)
return self.keys.pop(key_id) self.keys.pop(key_id)
def describe_key(self, key_id) -> Key: def describe_key(self, key_id: str) -> Key:
# allow the different methods (alias, ARN :key/, keyId, ARN alias) to # allow the different methods (alias, ARN :key/, keyId, ARN alias) to
# describe key not just KeyId # describe key not just KeyId
key_id = self.get_key_id(key_id) key_id = self.get_key_id(key_id)
if r"alias/" in str(key_id).lower(): if r"alias/" in str(key_id).lower():
key_id = self.get_key_id_from_alias(key_id) key_id = self.get_key_id_from_alias(key_id) # type: ignore[assignment]
return self.keys[self.get_key_id(key_id)] return self.keys[self.get_key_id(key_id)]
def list_keys(self): def list_keys(self) -> Iterable[Key]:
return self.keys.values() return self.keys.values()
@staticmethod @staticmethod
def get_key_id(key_id): def get_key_id(key_id: str) -> str:
# Allow use of ARN as well as pure KeyId # Allow use of ARN as well as pure KeyId
if key_id.startswith("arn:") and ":key/" in key_id: if key_id.startswith("arn:") and ":key/" in key_id:
return key_id.split(":key/")[1] return key_id.split(":key/")[1]
@ -344,14 +363,14 @@ class KmsBackend(BaseBackend):
return key_id return key_id
@staticmethod @staticmethod
def get_alias_name(alias_name): def get_alias_name(alias_name: str) -> str:
# Allow use of ARN as well as alias name # Allow use of ARN as well as alias name
if alias_name.startswith("arn:") and ":alias/" in alias_name: if alias_name.startswith("arn:") and ":alias/" in alias_name:
return "alias/" + alias_name.split(":alias/")[1] return "alias/" + alias_name.split(":alias/")[1]
return alias_name return alias_name
def any_id_to_key_id(self, key_id): def any_id_to_key_id(self, key_id: str) -> str:
"""Go from any valid key ID to the raw key ID. """Go from any valid key ID to the raw key ID.
Acceptable inputs: Acceptable inputs:
@ -363,66 +382,65 @@ class KmsBackend(BaseBackend):
key_id = self.get_alias_name(key_id) key_id = self.get_alias_name(key_id)
key_id = self.get_key_id(key_id) key_id = self.get_key_id(key_id)
if key_id.startswith("alias/"): if key_id.startswith("alias/"):
key_id = self.get_key_id(self.get_key_id_from_alias(key_id)) key_id = self.get_key_id(self.get_key_id_from_alias(key_id)) # type: ignore[arg-type]
return key_id return key_id
def alias_exists(self, alias_name): def alias_exists(self, alias_name: str) -> bool:
for aliases in self.key_to_aliases.values(): for aliases in self.key_to_aliases.values():
if alias_name in aliases: if alias_name in aliases:
return True return True
return False return False
def add_alias(self, target_key_id, alias_name): def add_alias(self, target_key_id: str, alias_name: str) -> None:
raw_key_id = self.get_key_id(target_key_id) raw_key_id = self.get_key_id(target_key_id)
self.key_to_aliases[raw_key_id].add(alias_name) self.key_to_aliases[raw_key_id].add(alias_name)
def delete_alias(self, alias_name): def delete_alias(self, alias_name: str) -> None:
"""Delete the alias.""" """Delete the alias."""
for aliases in self.key_to_aliases.values(): for aliases in self.key_to_aliases.values():
if alias_name in aliases: if alias_name in aliases:
aliases.remove(alias_name) aliases.remove(alias_name)
def get_all_aliases(self): def get_all_aliases(self) -> Dict[str, Set[str]]:
return self.key_to_aliases return self.key_to_aliases
def get_key_id_from_alias(self, alias_name): def get_key_id_from_alias(self, alias_name: str) -> Optional[str]:
for key_id, aliases in dict(self.key_to_aliases).items(): for key_id, aliases in dict(self.key_to_aliases).items():
if alias_name in ",".join(aliases): if alias_name in ",".join(aliases):
return key_id return key_id
if alias_name in RESERVED_ALIASES: if alias_name in RESERVED_ALIASES:
key_id = self._generate_default_keys(alias_name) return self._generate_default_keys(alias_name)
return key_id
return None return None
def enable_key_rotation(self, key_id): def enable_key_rotation(self, key_id: str) -> None:
self.keys[self.get_key_id(key_id)].key_rotation_status = True self.keys[self.get_key_id(key_id)].key_rotation_status = True
def disable_key_rotation(self, key_id): def disable_key_rotation(self, key_id: str) -> None:
self.keys[self.get_key_id(key_id)].key_rotation_status = False self.keys[self.get_key_id(key_id)].key_rotation_status = False
def get_key_rotation_status(self, key_id): def get_key_rotation_status(self, key_id: str) -> bool:
return self.keys[self.get_key_id(key_id)].key_rotation_status return self.keys[self.get_key_id(key_id)].key_rotation_status
def put_key_policy(self, key_id, policy): def put_key_policy(self, key_id: str, policy: str) -> None:
self.keys[self.get_key_id(key_id)].policy = policy self.keys[self.get_key_id(key_id)].policy = policy
def get_key_policy(self, key_id): def get_key_policy(self, key_id: str) -> str:
return self.keys[self.get_key_id(key_id)].policy return self.keys[self.get_key_id(key_id)].policy
def disable_key(self, key_id): def disable_key(self, key_id: str) -> None:
self.keys[key_id].enabled = False self.keys[key_id].enabled = False
self.keys[key_id].key_state = "Disabled" self.keys[key_id].key_state = "Disabled"
def enable_key(self, key_id): def enable_key(self, key_id: str) -> None:
self.keys[key_id].enabled = True self.keys[key_id].enabled = True
self.keys[key_id].key_state = "Enabled" self.keys[key_id].key_state = "Enabled"
def cancel_key_deletion(self, key_id): def cancel_key_deletion(self, key_id: str) -> None:
self.keys[key_id].key_state = "Disabled" self.keys[key_id].key_state = "Disabled"
self.keys[key_id].deletion_date = None self.keys[key_id].deletion_date = None
def schedule_key_deletion(self, key_id, pending_window_in_days): def schedule_key_deletion(self, key_id: str, pending_window_in_days: int) -> float: # type: ignore[return]
if 7 <= pending_window_in_days <= 30: if 7 <= pending_window_in_days <= 30:
self.keys[key_id].enabled = False self.keys[key_id].enabled = False
self.keys[key_id].key_state = "PendingDeletion" self.keys[key_id].key_state = "PendingDeletion"
@ -431,7 +449,9 @@ class KmsBackend(BaseBackend):
) )
return unix_time(self.keys[key_id].deletion_date) return unix_time(self.keys[key_id].deletion_date)
def encrypt(self, key_id, plaintext, encryption_context): def encrypt(
self, key_id: str, plaintext: bytes, encryption_context: Dict[str, str]
) -> Tuple[bytes, str]:
key_id = self.any_id_to_key_id(key_id) key_id = self.any_id_to_key_id(key_id)
ciphertext_blob = encrypt( ciphertext_blob = encrypt(
@ -443,7 +463,9 @@ class KmsBackend(BaseBackend):
arn = self.keys[key_id].arn arn = self.keys[key_id].arn
return ciphertext_blob, arn return ciphertext_blob, arn
def decrypt(self, ciphertext_blob, encryption_context): def decrypt(
self, ciphertext_blob: bytes, encryption_context: Dict[str, str]
) -> Tuple[bytes, str]:
plaintext, key_id = decrypt( plaintext, key_id = decrypt(
master_keys=self.keys, master_keys=self.keys,
ciphertext_blob=ciphertext_blob, ciphertext_blob=ciphertext_blob,
@ -454,11 +476,11 @@ class KmsBackend(BaseBackend):
def re_encrypt( def re_encrypt(
self, self,
ciphertext_blob, ciphertext_blob: bytes,
source_encryption_context, source_encryption_context: Dict[str, str],
destination_key_id, destination_key_id: str,
destination_encryption_context, destination_encryption_context: Dict[str, str],
): ) -> Tuple[bytes, str, str]:
destination_key_id = self.any_id_to_key_id(destination_key_id) destination_key_id = self.any_id_to_key_id(destination_key_id)
plaintext, decrypting_arn = self.decrypt( plaintext, decrypting_arn = self.decrypt(
@ -472,7 +494,13 @@ class KmsBackend(BaseBackend):
) )
return new_ciphertext_blob, decrypting_arn, encrypting_arn return new_ciphertext_blob, decrypting_arn, encrypting_arn
def generate_data_key(self, key_id, encryption_context, number_of_bytes, key_spec): def generate_data_key(
self,
key_id: str,
encryption_context: Dict[str, str],
number_of_bytes: int,
key_spec: str,
) -> Tuple[bytes, bytes, str]:
key_id = self.any_id_to_key_id(key_id) key_id = self.any_id_to_key_id(key_id)
if key_spec: if key_spec:
@ -492,7 +520,7 @@ class KmsBackend(BaseBackend):
return plaintext, ciphertext_blob, arn return plaintext, ciphertext_blob, arn
def list_resource_tags(self, key_id_or_arn): def list_resource_tags(self, key_id_or_arn: str) -> Dict[str, List[Dict[str, str]]]:
key_id = self.get_key_id(key_id_or_arn) key_id = self.get_key_id(key_id_or_arn)
if key_id in self.keys: if key_id in self.keys:
return self.tagger.list_tags_for_resource(key_id) return self.tagger.list_tags_for_resource(key_id)
@ -501,21 +529,21 @@ class KmsBackend(BaseBackend):
"The request was rejected because the specified entity or resource could not be found.", "The request was rejected because the specified entity or resource could not be found.",
) )
def tag_resource(self, key_id_or_arn, tags): def tag_resource(self, key_id_or_arn: str, tags: List[Dict[str, str]]) -> None:
key_id = self.get_key_id(key_id_or_arn) key_id = self.get_key_id(key_id_or_arn)
if key_id in self.keys: if key_id in self.keys:
self.tagger.tag_resource(key_id, tags) self.tagger.tag_resource(key_id, tags)
return {} return
raise JsonRESTError( raise JsonRESTError(
"NotFoundException", "NotFoundException",
"The request was rejected because the specified entity or resource could not be found.", "The request was rejected because the specified entity or resource could not be found.",
) )
def untag_resource(self, key_id_or_arn, tag_names): def untag_resource(self, key_id_or_arn: str, tag_names: List[str]) -> None:
key_id = self.get_key_id(key_id_or_arn) key_id = self.get_key_id(key_id_or_arn)
if key_id in self.keys: if key_id in self.keys:
self.tagger.untag_resource_using_names(key_id, tag_names) self.tagger.untag_resource_using_names(key_id, tag_names)
return {} return
raise JsonRESTError( raise JsonRESTError(
"NotFoundException", "NotFoundException",
"The request was rejected because the specified entity or resource could not be found.", "The request was rejected because the specified entity or resource could not be found.",
@ -523,13 +551,13 @@ class KmsBackend(BaseBackend):
def create_grant( def create_grant(
self, self,
key_id, key_id: str,
grantee_principal, grantee_principal: str,
operations, operations: List[str],
name, name: str,
constraints, constraints: Dict[str, Any],
retiring_principal, retiring_principal: str,
): ) -> Tuple[str, str]:
key = self.describe_key(key_id) key = self.describe_key(key_id)
grant = key.add_grant( grant = key.add_grant(
name, name,
@ -540,21 +568,21 @@ class KmsBackend(BaseBackend):
) )
return grant.id, grant.token return grant.id, grant.token
def list_grants(self, key_id, grant_id) -> [Grant]: def list_grants(self, key_id: str, grant_id: str) -> List[Grant]:
key = self.describe_key(key_id) key = self.describe_key(key_id)
return key.list_grants(grant_id) return key.list_grants(grant_id)
def list_retirable_grants(self, retiring_principal): def list_retirable_grants(self, retiring_principal: str) -> List[Grant]:
grants = [] grants = []
for key in self.keys.values(): for key in self.keys.values():
grants.extend(key.list_retirable_grants(retiring_principal)) grants.extend(key.list_retirable_grants(retiring_principal))
return grants return grants
def revoke_grant(self, key_id, grant_id) -> None: def revoke_grant(self, key_id: str, grant_id: str) -> None:
key = self.describe_key(key_id) key = self.describe_key(key_id)
key.revoke_grant(grant_id) key.revoke_grant(grant_id)
def retire_grant(self, key_id, grant_id, grant_token) -> None: def retire_grant(self, key_id: str, grant_id: str, grant_token: str) -> None:
if grant_token: if grant_token:
for key in self.keys.values(): for key in self.keys.values():
key.retire_grant_by_token(grant_token) key.retire_grant_by_token(grant_token)
@ -562,7 +590,7 @@ class KmsBackend(BaseBackend):
key = self.describe_key(key_id) key = self.describe_key(key_id)
key.retire_grant(grant_id) key.retire_grant(grant_id)
def __ensure_valid_sign_and_verify_key(self, key: Key): def __ensure_valid_sign_and_verify_key(self, key: Key) -> None:
if key.key_usage != "SIGN_VERIFY": if key.key_usage != "SIGN_VERIFY":
raise ValidationException( raise ValidationException(
( (
@ -571,7 +599,9 @@ class KmsBackend(BaseBackend):
).format(key_id=key.id) ).format(key_id=key.id)
) )
def __ensure_valid_signing_augorithm(self, key: Key, signing_algorithm): def __ensure_valid_signing_augorithm(
self, key: Key, signing_algorithm: str
) -> None:
if signing_algorithm not in key.signing_algorithms: if signing_algorithm not in key.signing_algorithms:
raise ValidationException( raise ValidationException(
( (
@ -584,7 +614,9 @@ class KmsBackend(BaseBackend):
) )
) )
def sign(self, key_id, message, signing_algorithm): def sign(
self, key_id: str, message: bytes, signing_algorithm: str
) -> Tuple[str, bytes, str]:
"""Sign message using generated private key. """Sign message using generated private key.
- signing_algorithm is ignored and hardcoded to RSASSA_PSS_SHA_256 - signing_algorithm is ignored and hardcoded to RSASSA_PSS_SHA_256
@ -607,7 +639,9 @@ class KmsBackend(BaseBackend):
return key.arn, signature, signing_algorithm return key.arn, signature, signing_algorithm
def verify(self, key_id, message, signature, signing_algorithm): def verify(
self, key_id: str, message: bytes, signature: bytes, signing_algorithm: str
) -> Tuple[str, bool, str]:
"""Verify message using public key from generated private key. """Verify message using public key from generated private key.
- signing_algorithm is ignored and hardcoded to RSASSA_PSS_SHA_256 - signing_algorithm is ignored and hardcoded to RSASSA_PSS_SHA_256

View File

@ -1,4 +1,5 @@
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, List
import json import json
from .models import Key from .models import Key
from .exceptions import AccessDeniedException from .exceptions import AccessDeniedException
@ -8,7 +9,7 @@ ALTERNATIVE_ACTIONS = defaultdict(list)
ALTERNATIVE_ACTIONS["kms:DescribeKey"] = ["kms:*", "kms:Describe*", "kms:DescribeKey"] ALTERNATIVE_ACTIONS["kms:DescribeKey"] = ["kms:*", "kms:Describe*", "kms:DescribeKey"]
def validate_policy(key: Key, action: str): def validate_policy(key: Key, action: str) -> None:
""" """
Relevant docs: Relevant docs:
- https://docs.aws.amazon.com/kms/latest/developerguide/key-policy-default.html - https://docs.aws.amazon.com/kms/latest/developerguide/key-policy-default.html
@ -29,20 +30,22 @@ def validate_policy(key: Key, action: str):
) )
def check_statement(statement, resource, action): def check_statement(statement: Dict[str, Any], resource: str, action: str) -> bool:
return action_matches(statement.get("Action", []), action) and resource_matches( return action_matches(statement.get("Action", []), action) and resource_matches(
statement.get("Resource", ""), resource statement.get("Resource", ""), resource
) )
def action_matches(applicable_actions, action): def action_matches(applicable_actions: List[str], action: str) -> bool:
alternatives = ALTERNATIVE_ACTIONS[action] alternatives = ALTERNATIVE_ACTIONS[action]
if any(alt in applicable_actions for alt in alternatives): if any(alt in applicable_actions for alt in alternatives):
return True return True
return False return False
def resource_matches(applicable_resources, resource): # pylint: disable=unused-argument def resource_matches(
applicable_resources: str, resource: str # pylint: disable=unused-argument
) -> bool:
if applicable_resources == "*": if applicable_resources == "*":
return True return True
return False return False

View File

@ -3,6 +3,7 @@ import json
import os import os
import re import re
import warnings import warnings
from typing import Any, Dict
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.kms.utils import RESERVED_ALIASES, RESERVED_ALIASE_TARGET_KEY_IDS from moto.kms.utils import RESERVED_ALIASES, RESERVED_ALIASE_TARGET_KEY_IDS
@ -17,24 +18,23 @@ from .exceptions import (
class KmsResponse(BaseResponse): class KmsResponse(BaseResponse):
def __init__(self): def __init__(self) -> None:
super().__init__(service_name="kms") super().__init__(service_name="kms")
@property def _get_param(self, param_name: str, if_none: Any = None) -> Any: # type: ignore
def parameters(self):
params = json.loads(self.body) params = json.loads(self.body)
for key in ("Plaintext", "CiphertextBlob"): for key in ("Plaintext", "CiphertextBlob"):
if key in params: if key in params:
params[key] = base64.b64decode(params[key].encode("utf-8")) params[key] = base64.b64decode(params[key].encode("utf-8"))
return params return params.get(param_name, if_none)
@property @property
def kms_backend(self) -> KmsBackend: def kms_backend(self) -> KmsBackend:
return kms_backends[self.current_account][self.region] return kms_backends[self.current_account][self.region]
def _display_arn(self, key_id): def _display_arn(self, key_id: str) -> str:
if key_id.startswith("arn:"): if key_id.startswith("arn:"):
return key_id return key_id
@ -45,7 +45,7 @@ class KmsResponse(BaseResponse):
return f"arn:aws:kms:{self.region}:{self.current_account}:{id_type}{key_id}" return f"arn:aws:kms:{self.region}:{self.current_account}:{id_type}{key_id}"
def _validate_cmk_id(self, key_id): def _validate_cmk_id(self, key_id: str) -> None:
"""Determine whether a CMK ID exists. """Determine whether a CMK ID exists.
- raw key ID - raw key ID
@ -69,7 +69,7 @@ class KmsResponse(BaseResponse):
if cmk_id not in self.kms_backend.keys: if cmk_id not in self.kms_backend.keys:
raise NotFoundException(f"Key '{self._display_arn(key_id)}' does not exist") raise NotFoundException(f"Key '{self._display_arn(key_id)}' does not exist")
def _validate_alias(self, key_id): def _validate_alias(self, key_id: str) -> None:
"""Determine whether an alias exists. """Determine whether an alias exists.
- alias name - alias name
@ -88,8 +88,8 @@ class KmsResponse(BaseResponse):
if cmk_id is None: if cmk_id is None:
raise error raise error
def _validate_key_id(self, key_id): def _validate_key_id(self, key_id: str) -> None:
"""Determine whether or not a key ID exists. """Determine whether a key ID exists.
- raw key ID - raw key ID
- key ARN - key ARN
@ -105,77 +105,77 @@ class KmsResponse(BaseResponse):
self._validate_cmk_id(key_id) self._validate_cmk_id(key_id)
def _validate_key_policy(self, key_id, action): def _validate_key_policy(self, key_id: str, action: str) -> None:
""" """
Validate whether the specified action is allowed, given the key policy Validate whether the specified action is allowed, given the key policy
""" """
key = self.kms_backend.describe_key(self.kms_backend.get_key_id(key_id)) key = self.kms_backend.describe_key(self.kms_backend.get_key_id(key_id))
validate_policy(key, action) validate_policy(key, action)
def create_key(self): def create_key(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_CreateKey.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_CreateKey.html"""
policy = self.parameters.get("Policy") policy = self._get_param("Policy")
key_usage = self.parameters.get("KeyUsage") key_usage = self._get_param("KeyUsage")
key_spec = self.parameters.get("KeySpec") or self.parameters.get( key_spec = self._get_param("KeySpec") or self._get_param(
"CustomerMasterKeySpec" "CustomerMasterKeySpec"
) )
description = self.parameters.get("Description") description = self._get_param("Description")
tags = self.parameters.get("Tags") tags = self._get_param("Tags")
multi_region = self.parameters.get("MultiRegion") multi_region = self._get_param("MultiRegion")
key = self.kms_backend.create_key( key = self.kms_backend.create_key(
policy, key_usage, key_spec, description, tags, multi_region policy, key_usage, key_spec, description, tags, multi_region
) )
return json.dumps(key.to_dict()) return json.dumps(key.to_dict())
def replicate_key(self): def replicate_key(self) -> None:
key_id = self.parameters.get("KeyId") key_id = self._get_param("KeyId")
self._validate_key_id(key_id) self._validate_key_id(key_id)
replica_region = self.parameters.get("ReplicaRegion") replica_region = self._get_param("ReplicaRegion")
self.kms_backend.replicate_key(key_id, replica_region) self.kms_backend.replicate_key(key_id, replica_region)
def update_key_description(self): def update_key_description(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_UpdateKeyDescription.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_UpdateKeyDescription.html"""
key_id = self.parameters.get("KeyId") key_id = self._get_param("KeyId")
description = self.parameters.get("Description") description = self._get_param("Description")
self._validate_cmk_id(key_id) self._validate_cmk_id(key_id)
self.kms_backend.update_key_description(key_id, description) self.kms_backend.update_key_description(key_id, description)
return json.dumps(None) return json.dumps(None)
def tag_resource(self): def tag_resource(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_TagResource.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_TagResource.html"""
key_id = self.parameters.get("KeyId") key_id = self._get_param("KeyId")
tags = self.parameters.get("Tags") tags = self._get_param("Tags")
self._validate_cmk_id(key_id) self._validate_cmk_id(key_id)
result = self.kms_backend.tag_resource(key_id, tags) self.kms_backend.tag_resource(key_id, tags)
return json.dumps(result) return "{}"
def untag_resource(self): def untag_resource(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_UntagResource.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_UntagResource.html"""
key_id = self.parameters.get("KeyId") key_id = self._get_param("KeyId")
tag_names = self.parameters.get("TagKeys") tag_names = self._get_param("TagKeys")
self._validate_cmk_id(key_id) self._validate_cmk_id(key_id)
result = self.kms_backend.untag_resource(key_id, tag_names) self.kms_backend.untag_resource(key_id, tag_names)
return json.dumps(result) return "{}"
def list_resource_tags(self): def list_resource_tags(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ListResourceTags.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_ListResourceTags.html"""
key_id = self.parameters.get("KeyId") key_id = self._get_param("KeyId")
self._validate_cmk_id(key_id) self._validate_cmk_id(key_id)
tags = self.kms_backend.list_resource_tags(key_id) tags: Dict[str, Any] = self.kms_backend.list_resource_tags(key_id)
tags.update({"NextMarker": None, "Truncated": False}) tags.update({"NextMarker": None, "Truncated": False})
return json.dumps(tags) return json.dumps(tags)
def describe_key(self): def describe_key(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_DescribeKey.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_DescribeKey.html"""
key_id = self.parameters.get("KeyId") key_id = self._get_param("KeyId")
self._validate_key_id(key_id) self._validate_key_id(key_id)
self._validate_key_policy(key_id, "kms:DescribeKey") self._validate_key_policy(key_id, "kms:DescribeKey")
@ -184,7 +184,7 @@ class KmsResponse(BaseResponse):
return json.dumps(key.to_dict()) return json.dumps(key.to_dict())
def list_keys(self): def list_keys(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ListKeys.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_ListKeys.html"""
keys = self.kms_backend.list_keys() keys = self.kms_backend.list_keys()
@ -196,17 +196,17 @@ class KmsResponse(BaseResponse):
} }
) )
def create_alias(self): def create_alias(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_CreateAlias.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_CreateAlias.html"""
return self._set_alias() return self._set_alias()
def update_alias(self): def update_alias(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_UpdateAlias.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_UpdateAlias.html"""
return self._set_alias(update=True) return self._set_alias(update=True)
def _set_alias(self, update=False): def _set_alias(self, update: bool = False) -> str:
alias_name = self.parameters["AliasName"] alias_name = self._get_param("AliasName")
target_key_id = self.parameters["TargetKeyId"] target_key_id = self._get_param("TargetKeyId")
if not alias_name.startswith("alias/"): if not alias_name.startswith("alias/"):
raise ValidationException("Invalid identifier") raise ValidationException("Invalid identifier")
@ -243,9 +243,9 @@ class KmsResponse(BaseResponse):
return json.dumps(None) return json.dumps(None)
def delete_alias(self): def delete_alias(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_DeleteAlias.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_DeleteAlias.html"""
alias_name = self.parameters["AliasName"] alias_name = self._get_param("AliasName")
if not alias_name.startswith("alias/"): if not alias_name.startswith("alias/"):
raise ValidationException("Invalid identifier") raise ValidationException("Invalid identifier")
@ -256,10 +256,10 @@ class KmsResponse(BaseResponse):
return json.dumps(None) return json.dumps(None)
def list_aliases(self): def list_aliases(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ListAliases.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_ListAliases.html"""
region = self.region region = self.region
key_id = self.parameters.get("KeyId") key_id = self._get_param("KeyId")
if key_id is not None: if key_id is not None:
self._validate_key_id(key_id) self._validate_key_id(key_id)
key_id = self.kms_backend.get_key_id(key_id) key_id = self.kms_backend.get_key_id(key_id)
@ -298,13 +298,13 @@ class KmsResponse(BaseResponse):
return json.dumps({"Truncated": False, "Aliases": response_aliases}) return json.dumps({"Truncated": False, "Aliases": response_aliases})
def create_grant(self): def create_grant(self) -> str:
key_id = self.parameters.get("KeyId") key_id = self._get_param("KeyId")
grantee_principal = self.parameters.get("GranteePrincipal") grantee_principal = self._get_param("GranteePrincipal")
retiring_principal = self.parameters.get("RetiringPrincipal") retiring_principal = self._get_param("RetiringPrincipal")
operations = self.parameters.get("Operations") operations = self._get_param("Operations")
name = self.parameters.get("Name") name = self._get_param("Name")
constraints = self.parameters.get("Constraints") constraints = self._get_param("Constraints")
grant_id, grant_token = self.kms_backend.create_grant( grant_id, grant_token = self.kms_backend.create_grant(
key_id, key_id,
@ -316,9 +316,9 @@ class KmsResponse(BaseResponse):
) )
return json.dumps({"GrantId": grant_id, "GrantToken": grant_token}) return json.dumps({"GrantId": grant_id, "GrantToken": grant_token})
def list_grants(self): def list_grants(self) -> str:
key_id = self.parameters.get("KeyId") key_id = self._get_param("KeyId")
grant_id = self.parameters.get("GrantId") grant_id = self._get_param("GrantId")
grants = self.kms_backend.list_grants(key_id=key_id, grant_id=grant_id) grants = self.kms_backend.list_grants(key_id=key_id, grant_id=grant_id)
return json.dumps( return json.dumps(
@ -329,8 +329,8 @@ class KmsResponse(BaseResponse):
} }
) )
def list_retirable_grants(self): def list_retirable_grants(self) -> str:
retiring_principal = self.parameters.get("RetiringPrincipal") retiring_principal = self._get_param("RetiringPrincipal")
grants = self.kms_backend.list_retirable_grants(retiring_principal) grants = self.kms_backend.list_retirable_grants(retiring_principal)
return json.dumps( return json.dumps(
@ -341,24 +341,24 @@ class KmsResponse(BaseResponse):
} }
) )
def revoke_grant(self): def revoke_grant(self) -> str:
key_id = self.parameters.get("KeyId") key_id = self._get_param("KeyId")
grant_id = self.parameters.get("GrantId") grant_id = self._get_param("GrantId")
self.kms_backend.revoke_grant(key_id, grant_id) self.kms_backend.revoke_grant(key_id, grant_id)
return "{}" return "{}"
def retire_grant(self): def retire_grant(self) -> str:
key_id = self.parameters.get("KeyId") key_id = self._get_param("KeyId")
grant_id = self.parameters.get("GrantId") grant_id = self._get_param("GrantId")
grant_token = self.parameters.get("GrantToken") grant_token = self._get_param("GrantToken")
self.kms_backend.retire_grant(key_id, grant_id, grant_token) self.kms_backend.retire_grant(key_id, grant_id, grant_token)
return "{}" return "{}"
def enable_key_rotation(self): def enable_key_rotation(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKeyRotation.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKeyRotation.html"""
key_id = self.parameters.get("KeyId") key_id = self._get_param("KeyId")
self._validate_cmk_id(key_id) self._validate_cmk_id(key_id)
@ -366,9 +366,9 @@ class KmsResponse(BaseResponse):
return json.dumps(None) return json.dumps(None)
def disable_key_rotation(self): def disable_key_rotation(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKeyRotation.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKeyRotation.html"""
key_id = self.parameters.get("KeyId") key_id = self._get_param("KeyId")
self._validate_cmk_id(key_id) self._validate_cmk_id(key_id)
@ -376,9 +376,9 @@ class KmsResponse(BaseResponse):
return json.dumps(None) return json.dumps(None)
def get_key_rotation_status(self): def get_key_rotation_status(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_GetKeyRotationStatus.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_GetKeyRotationStatus.html"""
key_id = self.parameters.get("KeyId") key_id = self._get_param("KeyId")
self._validate_cmk_id(key_id) self._validate_cmk_id(key_id)
@ -386,11 +386,11 @@ class KmsResponse(BaseResponse):
return json.dumps({"KeyRotationEnabled": rotation_enabled}) return json.dumps({"KeyRotationEnabled": rotation_enabled})
def put_key_policy(self): def put_key_policy(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_PutKeyPolicy.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_PutKeyPolicy.html"""
key_id = self.parameters.get("KeyId") key_id = self._get_param("KeyId")
policy_name = self.parameters.get("PolicyName") policy_name = self._get_param("PolicyName")
policy = self.parameters.get("Policy") policy = self._get_param("Policy")
_assert_default_policy(policy_name) _assert_default_policy(policy_name)
self._validate_cmk_id(key_id) self._validate_cmk_id(key_id)
@ -399,10 +399,10 @@ class KmsResponse(BaseResponse):
return json.dumps(None) return json.dumps(None)
def get_key_policy(self): def get_key_policy(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_GetKeyPolicy.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_GetKeyPolicy.html"""
key_id = self.parameters.get("KeyId") key_id = self._get_param("KeyId")
policy_name = self.parameters.get("PolicyName") policy_name = self._get_param("PolicyName")
_assert_default_policy(policy_name) _assert_default_policy(policy_name)
self._validate_cmk_id(key_id) self._validate_cmk_id(key_id)
@ -410,9 +410,9 @@ class KmsResponse(BaseResponse):
policy = self.kms_backend.get_key_policy(key_id) or "{}" policy = self.kms_backend.get_key_policy(key_id) or "{}"
return json.dumps({"Policy": policy}) return json.dumps({"Policy": policy})
def list_key_policies(self): def list_key_policies(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ListKeyPolicies.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_ListKeyPolicies.html"""
key_id = self.parameters.get("KeyId") key_id = self._get_param("KeyId")
self._validate_cmk_id(key_id) self._validate_cmk_id(key_id)
@ -420,11 +420,11 @@ class KmsResponse(BaseResponse):
return json.dumps({"Truncated": False, "PolicyNames": ["default"]}) return json.dumps({"Truncated": False, "PolicyNames": ["default"]})
def encrypt(self): def encrypt(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_Encrypt.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_Encrypt.html"""
key_id = self.parameters.get("KeyId") key_id = self._get_param("KeyId")
encryption_context = self.parameters.get("EncryptionContext", {}) encryption_context = self._get_param("EncryptionContext", {})
plaintext = self.parameters.get("Plaintext") plaintext = self._get_param("Plaintext")
self._validate_key_id(key_id) self._validate_key_id(key_id)
@ -438,10 +438,10 @@ class KmsResponse(BaseResponse):
return json.dumps({"CiphertextBlob": ciphertext_blob_response, "KeyId": arn}) return json.dumps({"CiphertextBlob": ciphertext_blob_response, "KeyId": arn})
def decrypt(self): def decrypt(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_Decrypt.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_Decrypt.html"""
ciphertext_blob = self.parameters.get("CiphertextBlob") ciphertext_blob = self._get_param("CiphertextBlob")
encryption_context = self.parameters.get("EncryptionContext", {}) encryption_context = self._get_param("EncryptionContext", {})
plaintext, arn = self.kms_backend.decrypt( plaintext, arn = self.kms_backend.decrypt(
ciphertext_blob=ciphertext_blob, encryption_context=encryption_context ciphertext_blob=ciphertext_blob, encryption_context=encryption_context
@ -451,12 +451,12 @@ class KmsResponse(BaseResponse):
return json.dumps({"Plaintext": plaintext_response, "KeyId": arn}) return json.dumps({"Plaintext": plaintext_response, "KeyId": arn})
def re_encrypt(self): def re_encrypt(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ReEncrypt.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_ReEncrypt.html"""
ciphertext_blob = self.parameters.get("CiphertextBlob") ciphertext_blob = self._get_param("CiphertextBlob")
source_encryption_context = self.parameters.get("SourceEncryptionContext", {}) source_encryption_context = self._get_param("SourceEncryptionContext", {})
destination_key_id = self.parameters.get("DestinationKeyId") destination_key_id = self._get_param("DestinationKeyId")
destination_encryption_context = self.parameters.get( destination_encryption_context = self._get_param(
"DestinationEncryptionContext", {} "DestinationEncryptionContext", {}
) )
@ -483,9 +483,9 @@ class KmsResponse(BaseResponse):
} }
) )
def disable_key(self): def disable_key(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_DisableKey.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_DisableKey.html"""
key_id = self.parameters.get("KeyId") key_id = self._get_param("KeyId")
self._validate_cmk_id(key_id) self._validate_cmk_id(key_id)
@ -493,9 +493,9 @@ class KmsResponse(BaseResponse):
return json.dumps(None) return json.dumps(None)
def enable_key(self): def enable_key(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKey.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKey.html"""
key_id = self.parameters.get("KeyId") key_id = self._get_param("KeyId")
self._validate_cmk_id(key_id) self._validate_cmk_id(key_id)
@ -503,9 +503,9 @@ class KmsResponse(BaseResponse):
return json.dumps(None) return json.dumps(None)
def cancel_key_deletion(self): def cancel_key_deletion(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_CancelKeyDeletion.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_CancelKeyDeletion.html"""
key_id = self.parameters.get("KeyId") key_id = self._get_param("KeyId")
self._validate_cmk_id(key_id) self._validate_cmk_id(key_id)
@ -513,13 +513,13 @@ class KmsResponse(BaseResponse):
return json.dumps({"KeyId": key_id}) return json.dumps({"KeyId": key_id})
def schedule_key_deletion(self): def schedule_key_deletion(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ScheduleKeyDeletion.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_ScheduleKeyDeletion.html"""
key_id = self.parameters.get("KeyId") key_id = self._get_param("KeyId")
if self.parameters.get("PendingWindowInDays") is None: if self._get_param("PendingWindowInDays") is None:
pending_window_in_days = 30 pending_window_in_days = 30
else: else:
pending_window_in_days = self.parameters.get("PendingWindowInDays") pending_window_in_days = self._get_param("PendingWindowInDays")
self._validate_cmk_id(key_id) self._validate_cmk_id(key_id)
@ -532,12 +532,12 @@ class KmsResponse(BaseResponse):
} }
) )
def generate_data_key(self): def generate_data_key(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateDataKey.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateDataKey.html"""
key_id = self.parameters.get("KeyId") key_id = self._get_param("KeyId")
encryption_context = self.parameters.get("EncryptionContext", {}) encryption_context = self._get_param("EncryptionContext", {})
number_of_bytes = self.parameters.get("NumberOfBytes") number_of_bytes = self._get_param("NumberOfBytes")
key_spec = self.parameters.get("KeySpec") key_spec = self._get_param("KeySpec")
# Param validation # Param validation
self._validate_key_id(key_id) self._validate_key_id(key_id)
@ -587,16 +587,16 @@ class KmsResponse(BaseResponse):
} }
) )
def generate_data_key_without_plaintext(self): def generate_data_key_without_plaintext(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateDataKeyWithoutPlaintext.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateDataKeyWithoutPlaintext.html"""
result = json.loads(self.generate_data_key()) result = json.loads(self.generate_data_key())
del result["Plaintext"] del result["Plaintext"]
return json.dumps(result) return json.dumps(result)
def generate_random(self): def generate_random(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateRandom.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateRandom.html"""
number_of_bytes = self.parameters.get("NumberOfBytes") number_of_bytes = self._get_param("NumberOfBytes")
if number_of_bytes and (number_of_bytes > 1024 or number_of_bytes < 1): if number_of_bytes and (number_of_bytes > 1024 or number_of_bytes < 1):
raise ValidationException( raise ValidationException(
@ -613,13 +613,13 @@ class KmsResponse(BaseResponse):
return json.dumps({"Plaintext": response_entropy}) return json.dumps({"Plaintext": response_entropy})
def sign(self): def sign(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_Sign.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_Sign.html"""
key_id = self.parameters.get("KeyId") key_id = self._get_param("KeyId")
message = self.parameters.get("Message") message = self._get_param("Message")
message_type = self.parameters.get("MessageType") message_type = self._get_param("MessageType")
grant_tokens = self.parameters.get("GrantTokens") grant_tokens = self._get_param("GrantTokens")
signing_algorithm = self.parameters.get("SigningAlgorithm") signing_algorithm = self._get_param("SigningAlgorithm")
self._validate_key_id(key_id) self._validate_key_id(key_id)
@ -660,14 +660,14 @@ class KmsResponse(BaseResponse):
} }
) )
def verify(self): def verify(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_Verify.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_Verify.html"""
key_id = self.parameters.get("KeyId") key_id = self._get_param("KeyId")
message = self.parameters.get("Message") message = self._get_param("Message")
message_type = self.parameters.get("MessageType") message_type = self._get_param("MessageType")
signature = self.parameters.get("Signature") signature = self._get_param("Signature")
signing_algorithm = self.parameters.get("SigningAlgorithm") signing_algorithm = self._get_param("SigningAlgorithm")
grant_tokens = self.parameters.get("GrantTokens") grant_tokens = self._get_param("GrantTokens")
self._validate_key_id(key_id) self._validate_key_id(key_id)
@ -722,6 +722,6 @@ class KmsResponse(BaseResponse):
) )
def _assert_default_policy(policy_name): def _assert_default_policy(policy_name: str) -> None:
if policy_name != "default": if policy_name != "default":
raise NotFoundException("No such policy exists") raise NotFoundException("No such policy exists")

View File

@ -1,4 +1,5 @@
from collections import namedtuple from collections import namedtuple
from typing import Any, Dict, Tuple
import io import io
import os import os
import struct import struct
@ -47,7 +48,7 @@ RESERVED_ALIASE_TARGET_KEY_IDS = {
RESERVED_ALIASES = list(RESERVED_ALIASE_TARGET_KEY_IDS.keys()) RESERVED_ALIASES = list(RESERVED_ALIASE_TARGET_KEY_IDS.keys())
def generate_key_id(multi_region=False): def generate_key_id(multi_region: bool = False) -> str:
key = str(mock_random.uuid4()) key = str(mock_random.uuid4())
# https://docs.aws.amazon.com/kms/latest/developerguide/multi-region-keys-overview.html # https://docs.aws.amazon.com/kms/latest/developerguide/multi-region-keys-overview.html
# "Notice that multi-Region keys have a distinctive key ID that begins with mrk-. You can use the mrk- prefix to # "Notice that multi-Region keys have a distinctive key ID that begins with mrk-. You can use the mrk- prefix to
@ -58,17 +59,17 @@ def generate_key_id(multi_region=False):
return key return key
def generate_data_key(number_of_bytes): def generate_data_key(number_of_bytes: int) -> bytes:
"""Generate a data key.""" """Generate a data key."""
return os.urandom(number_of_bytes) return os.urandom(number_of_bytes)
def generate_master_key(): def generate_master_key() -> bytes:
"""Generate a master key.""" """Generate a master key."""
return generate_data_key(MASTER_KEY_LEN) return generate_data_key(MASTER_KEY_LEN)
def generate_private_key(): def generate_private_key() -> rsa.RSAPrivateKey:
"""Generate a private key to be used on asymmetric sign/verify. """Generate a private key to be used on asymmetric sign/verify.
NOTE: KeySpec is not taken into consideration and the key is always RSA_2048 NOTE: KeySpec is not taken into consideration and the key is always RSA_2048
@ -80,7 +81,7 @@ def generate_private_key():
) )
def _serialize_ciphertext_blob(ciphertext): def _serialize_ciphertext_blob(ciphertext: Ciphertext) -> bytes:
"""Serialize Ciphertext object into a ciphertext blob. """Serialize Ciphertext object into a ciphertext blob.
NOTE: This is just a simple binary format. It is not what KMS actually does. NOTE: This is just a simple binary format. It is not what KMS actually does.
@ -94,7 +95,7 @@ def _serialize_ciphertext_blob(ciphertext):
return header + ciphertext.ciphertext return header + ciphertext.ciphertext
def _deserialize_ciphertext_blob(ciphertext_blob): def _deserialize_ciphertext_blob(ciphertext_blob: bytes) -> Ciphertext:
"""Deserialize ciphertext blob into a Ciphertext object. """Deserialize ciphertext blob into a Ciphertext object.
NOTE: This is just a simple binary format. It is not what KMS actually does. NOTE: This is just a simple binary format. It is not what KMS actually does.
@ -107,7 +108,7 @@ def _deserialize_ciphertext_blob(ciphertext_blob):
) )
def _serialize_encryption_context(encryption_context): def _serialize_encryption_context(encryption_context: Dict[str, str]) -> bytes:
"""Serialize encryption context for use a AAD. """Serialize encryption context for use a AAD.
NOTE: This is not necessarily what KMS does, but it retains the same properties. NOTE: This is not necessarily what KMS does, but it retains the same properties.
@ -119,7 +120,12 @@ def _serialize_encryption_context(encryption_context):
return aad.getvalue() return aad.getvalue()
def encrypt(master_keys, key_id, plaintext, encryption_context): def encrypt(
master_keys: Dict[str, Any],
key_id: str,
plaintext: bytes,
encryption_context: Dict[str, str],
) -> bytes:
"""Encrypt data using a master key material. """Encrypt data using a master key material.
NOTE: This is not necessarily what KMS does, but it retains the same properties. NOTE: This is not necessarily what KMS does, but it retains the same properties.
@ -159,7 +165,11 @@ def encrypt(master_keys, key_id, plaintext, encryption_context):
) )
def decrypt(master_keys, ciphertext_blob, encryption_context): def decrypt(
master_keys: Dict[str, Any],
ciphertext_blob: bytes,
encryption_context: Dict[str, str],
) -> Tuple[bytes, str]:
"""Decrypt a ciphertext blob using a master key material. """Decrypt a ciphertext blob using a master key material.
NOTE: This is not necessarily what KMS does, but it retains the same properties. NOTE: This is not necessarily what KMS does, but it retains the same properties.

View File

@ -9,7 +9,7 @@ SESSION_TOKEN_PREFIX = "FQoGZXIvYXdzEBYaD"
DEFAULT_STS_SESSION_DURATION = 3600 DEFAULT_STS_SESSION_DURATION = 3600
def random_session_token(): def random_session_token() -> str:
return ( return (
SESSION_TOKEN_PREFIX SESSION_TOKEN_PREFIX
+ base64.b64encode(os.urandom(266))[len(SESSION_TOKEN_PREFIX) :].decode() + base64.b64encode(os.urandom(266))[len(SESSION_TOKEN_PREFIX) :].decode()

View File

@ -235,7 +235,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy] [mypy]
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/moto_api,moto/neptune files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/moto_api,moto/neptune
show_column_numbers=True show_column_numbers=True
show_error_codes = True show_error_codes = True
disable_error_code=abstract disable_error_code=abstract