Techdebt: MyPy K (#6111)

This commit is contained in:
Bert Blommers 2023-03-23 09:17:40 -01:00 committed by GitHub
parent 64abff588f
commit c3460b8a1a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 680 additions and 565 deletions

View File

@ -1,46 +1,47 @@
from moto.core.exceptions import JsonRESTError
from typing import Optional
class ResourceNotFoundError(JsonRESTError):
def __init__(self, message):
def __init__(self, message: str):
super().__init__(error_type="ResourceNotFoundException", message=message)
class ResourceInUseError(JsonRESTError):
def __init__(self, message):
def __init__(self, message: str):
super().__init__(error_type="ResourceInUseException", message=message)
class StreamNotFoundError(ResourceNotFoundError):
def __init__(self, stream_name, account_id):
def __init__(self, stream_name: str, account_id: str):
super().__init__(f"Stream {stream_name} under account {account_id} not found.")
class StreamCannotBeUpdatedError(JsonRESTError):
def __init__(self, stream_name, account_id):
def __init__(self, stream_name: str, account_id: str):
message = f"Request is invalid. Stream {stream_name} under account {account_id} is in On-Demand mode."
super().__init__(error_type="ValidationException", message=message)
class ShardNotFoundError(ResourceNotFoundError):
def __init__(self, shard_id, stream, account_id):
def __init__(self, shard_id: str, stream: str, account_id: str):
super().__init__(
f"Could not find shard {shard_id} in stream {stream} under account {account_id}."
)
class ConsumerNotFound(ResourceNotFoundError):
def __init__(self, consumer, account_id):
def __init__(self, consumer: str, account_id: str):
super().__init__(f"Consumer {consumer}, account {account_id} not found.")
class InvalidArgumentError(JsonRESTError):
def __init__(self, message):
def __init__(self, message: str):
super().__init__(error_type="InvalidArgumentException", message=message)
class InvalidRetentionPeriod(InvalidArgumentError):
def __init__(self, hours, too_short):
def __init__(self, hours: int, too_short: bool):
if too_short:
msg = f"Minimum allowed retention period is 24 hours. Requested retention period ({hours} hours) is too short."
else:
@ -49,31 +50,31 @@ class InvalidRetentionPeriod(InvalidArgumentError):
class InvalidDecreaseRetention(InvalidArgumentError):
def __init__(self, name, requested, existing):
def __init__(self, name: Optional[str], requested: int, existing: int):
msg = f"Requested retention period ({requested} hours) for stream {name} can not be longer than existing retention period ({existing} hours). Use IncreaseRetentionPeriod API."
super().__init__(msg)
class InvalidIncreaseRetention(InvalidArgumentError):
def __init__(self, name, requested, existing):
def __init__(self, name: Optional[str], requested: int, existing: int):
msg = f"Requested retention period ({requested} hours) for stream {name} can not be shorter than existing retention period ({existing} hours). Use DecreaseRetentionPeriod API."
super().__init__(msg)
class ValidationException(JsonRESTError):
def __init__(self, value, position, regex_to_match):
def __init__(self, value: str, position: str, regex_to_match: str):
msg = f"1 validation error detected: Value '{value}' at '{position}' failed to satisfy constraint: Member must satisfy regular expression pattern: {regex_to_match}"
super().__init__(error_type="ValidationException", message=msg)
class RecordSizeExceedsLimit(JsonRESTError):
def __init__(self, position):
def __init__(self, position: int):
msg = f"1 validation error detected: Value at 'records.{position}.member.data' failed to satisfy constraint: Member must have length less than or equal to 1048576"
super().__init__(error_type="ValidationException", message=msg)
class TotalRecordsSizeExceedsLimit(JsonRESTError):
def __init__(self):
def __init__(self) -> None:
super().__init__(
error_type="InvalidArgumentException",
message="Records size exceeds 5 MB limit",
@ -81,6 +82,6 @@ class TotalRecordsSizeExceedsLimit(JsonRESTError):
class TooManyRecords(JsonRESTError):
def __init__(self):
def __init__(self) -> None:
msg = "1 validation error detected: Value at 'records' failed to satisfy constraint: Member must have length less than or equal to 500"
super().__init__(error_type="ValidationException", message=msg)

View File

@ -4,7 +4,7 @@ import re
import itertools
from operator import attrgetter
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Iterable
from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel
from moto.core.utils import unix_time
@ -35,14 +35,16 @@ from .utils import (
class Consumer(BaseModel):
def __init__(self, consumer_name, account_id, region_name, stream_arn):
def __init__(
self, consumer_name: str, account_id: str, region_name: str, stream_arn: str
):
self.consumer_name = consumer_name
self.created = unix_time()
self.stream_arn = stream_arn
stream_name = stream_arn.split("/")[-1]
self.consumer_arn = f"arn:aws:kinesis:{region_name}:{account_id}:stream/{stream_name}/consumer/{consumer_name}"
def to_json(self, include_stream_arn=False):
def to_json(self, include_stream_arn: bool = False) -> Dict[str, Any]:
resp = {
"ConsumerName": self.consumer_name,
"ConsumerARN": self.consumer_arn,
@ -55,7 +57,13 @@ class Consumer(BaseModel):
class Record(BaseModel):
def __init__(self, partition_key, data, sequence_number, explicit_hash_key):
def __init__(
self,
partition_key: str,
data: str,
sequence_number: int,
explicit_hash_key: str,
):
self.partition_key = partition_key
self.data = data
self.sequence_number = sequence_number
@ -63,7 +71,7 @@ class Record(BaseModel):
self.created_at_datetime = datetime.datetime.utcnow()
self.created_at = unix_time(self.created_at_datetime)
def to_json(self):
def to_json(self) -> Dict[str, Any]:
return {
"Data": self.data,
"PartitionKey": self.partition_key,
@ -74,29 +82,36 @@ class Record(BaseModel):
class Shard(BaseModel):
def __init__(
self, shard_id, starting_hash, ending_hash, parent=None, adjacent_parent=None
self,
shard_id: int,
starting_hash: int,
ending_hash: int,
parent: Optional[str] = None,
adjacent_parent: Optional[str] = None,
):
self._shard_id = shard_id
self.starting_hash = starting_hash
self.ending_hash = ending_hash
self.records = OrderedDict()
self.records: Dict[int, Record] = OrderedDict()
self.is_open = True
self.parent = parent
self.adjacent_parent = adjacent_parent
@property
def shard_id(self):
def shard_id(self) -> str:
return f"shardId-{str(self._shard_id).zfill(12)}"
def get_records(self, last_sequence_id, limit):
last_sequence_id = int(last_sequence_id)
def get_records(
self, last_sequence_id: str, limit: Optional[int]
) -> Tuple[List[Record], int, int]:
last_sequence_int = int(last_sequence_id)
results = []
secs_behind_latest = 0
secs_behind_latest = 0.0
for sequence_number, record in self.records.items():
if sequence_number > last_sequence_id:
if sequence_number > last_sequence_int:
results.append(record)
last_sequence_id = sequence_number
last_sequence_int = sequence_number
very_last_record = self.records[next(reversed(self.records))]
secs_behind_latest = very_last_record.created_at - record.created_at
@ -105,9 +120,9 @@ class Shard(BaseModel):
break
millis_behind_latest = int(secs_behind_latest * 1000)
return results, last_sequence_id, millis_behind_latest
return results, last_sequence_int, millis_behind_latest
def put_record(self, partition_key, data, explicit_hash_key):
def put_record(self, partition_key: str, data: str, explicit_hash_key: str) -> str:
# Note: this function is not safe for concurrency
if self.records:
last_sequence_number = self.get_max_sequence_number()
@ -119,17 +134,17 @@ class Shard(BaseModel):
)
return str(sequence_number)
def get_min_sequence_number(self):
def get_min_sequence_number(self) -> int:
if self.records:
return list(self.records.keys())[0]
return 0
def get_max_sequence_number(self):
def get_max_sequence_number(self) -> int:
if self.records:
return list(self.records.keys())[-1]
return 0
def get_sequence_number_at(self, at_timestamp):
def get_sequence_number_at(self, at_timestamp: float) -> int:
if not self.records or at_timestamp < list(self.records.values())[0].created_at:
return 0
else:
@ -143,10 +158,10 @@ class Shard(BaseModel):
),
None,
)
return r.sequence_number
return r.sequence_number # type: ignore
def to_json(self):
response = {
def to_json(self) -> Dict[str, Any]:
response: Dict[str, Any] = {
"HashKeyRange": {
"EndingHashKey": str(self.ending_hash),
"StartingHashKey": str(self.starting_hash),
@ -170,12 +185,12 @@ class Shard(BaseModel):
class Stream(CloudFormationModel):
def __init__(
self,
stream_name,
shard_count,
stream_mode,
retention_period_hours,
account_id,
region_name,
stream_name: str,
shard_count: int,
stream_mode: Optional[Dict[str, str]],
retention_period_hours: Optional[int],
account_id: str,
region_name: str,
):
self.stream_name = stream_name
self.creation_datetime = datetime.datetime.now().strftime(
@ -184,27 +199,27 @@ class Stream(CloudFormationModel):
self.region = region_name
self.account_id = account_id
self.arn = f"arn:aws:kinesis:{region_name}:{account_id}:stream/{stream_name}"
self.shards = {}
self.tags = {}
self.shards: Dict[str, Shard] = {}
self.tags: Dict[str, str] = {}
self.status = "ACTIVE"
self.shard_count = None
self.shard_count: Optional[int] = None
self.stream_mode = stream_mode or {"StreamMode": "PROVISIONED"}
if self.stream_mode.get("StreamMode", "") == "ON_DEMAND":
shard_count = 4
self.init_shards(shard_count)
self.retention_period_hours = retention_period_hours or 24
self.shard_level_metrics = []
self.shard_level_metrics: List[str] = []
self.encryption_type = "NONE"
self.key_id = None
self.consumers = []
self.key_id: Optional[str] = None
self.consumers: List[Consumer] = []
def delete_consumer(self, consumer_arn):
def delete_consumer(self, consumer_arn: str) -> None:
self.consumers = [c for c in self.consumers if c.consumer_arn != consumer_arn]
def get_consumer_by_arn(self, consumer_arn):
def get_consumer_by_arn(self, consumer_arn: str) -> Optional[Consumer]:
return next((c for c in self.consumers if c.consumer_arn == consumer_arn), None)
def init_shards(self, shard_count):
def init_shards(self, shard_count: int) -> None:
self.shard_count = shard_count
step = 2**128 // shard_count
@ -216,16 +231,16 @@ class Stream(CloudFormationModel):
shard = Shard(index, start, end)
self.shards[shard.shard_id] = shard
def split_shard(self, shard_to_split, new_starting_hash_key):
new_starting_hash_key = int(new_starting_hash_key)
def split_shard(self, shard_to_split: str, new_starting_hash_key: str) -> None:
new_starting_hash_int = int(new_starting_hash_key)
shard = self.shards[shard_to_split]
if shard.starting_hash < new_starting_hash_key < shard.ending_hash:
if shard.starting_hash < new_starting_hash_int < shard.ending_hash:
pass
else:
raise InvalidArgumentError(
message=f"NewStartingHashKey {new_starting_hash_key} used in SplitShard() on shard {shard_to_split} in stream {self.stream_name} under account {self.account_id} is not both greater than one plus the shard's StartingHashKey {shard.starting_hash} and less than the shard's EndingHashKey {(shard.ending_hash - 1)}."
message=f"NewStartingHashKey {new_starting_hash_int} used in SplitShard() on shard {shard_to_split} in stream {self.stream_name} under account {self.account_id} is not both greater than one plus the shard's StartingHashKey {shard.starting_hash} and less than the shard's EndingHashKey {(shard.ending_hash - 1)}."
)
if not shard.is_open:
@ -241,12 +256,12 @@ class Stream(CloudFormationModel):
new_shard_1 = Shard(
last_id + 1,
starting_hash=shard.starting_hash,
ending_hash=new_starting_hash_key - 1,
ending_hash=new_starting_hash_int - 1,
parent=shard.shard_id,
)
new_shard_2 = Shard(
last_id + 2,
starting_hash=new_starting_hash_key,
starting_hash=new_starting_hash_int,
ending_hash=shard.ending_hash,
parent=shard.shard_id,
)
@ -261,7 +276,7 @@ class Stream(CloudFormationModel):
record = records[index]
self.put_record(record.partition_key, record.explicit_hash_key, record.data)
def merge_shards(self, shard_to_merge, adjacent_shard_to_merge):
def merge_shards(self, shard_to_merge: str, adjacent_shard_to_merge: str) -> None:
shard1 = self.shards[shard_to_merge]
shard2 = self.shards[adjacent_shard_to_merge]
@ -300,8 +315,8 @@ class Stream(CloudFormationModel):
record.partition_key, record.data, record.explicit_hash_key
)
def update_shard_count(self, target_shard_count):
if self.stream_mode.get("StreamMode", "") == "ON_DEMAND":
def update_shard_count(self, target_shard_count: int) -> None:
if self.stream_mode.get("StreamMode", "") == "ON_DEMAND": # type: ignore
raise StreamCannotBeUpdatedError(
stream_name=self.stream_name, account_id=self.account_id
)
@ -351,13 +366,15 @@ class Stream(CloudFormationModel):
self.shard_count = target_shard_count
def get_shard(self, shard_id):
def get_shard(self, shard_id: str) -> Shard:
if shard_id in self.shards:
return self.shards[shard_id]
else:
raise ShardNotFoundError(shard_id, stream="", account_id=self.account_id)
def get_shard_for_key(self, partition_key, explicit_hash_key):
def get_shard_for_key(
self, partition_key: str, explicit_hash_key: str
) -> Optional[Shard]:
if not isinstance(partition_key, str):
raise InvalidArgumentError("partition_key")
if len(partition_key) > 256:
@ -367,25 +384,28 @@ class Stream(CloudFormationModel):
if not isinstance(explicit_hash_key, str):
raise InvalidArgumentError("explicit_hash_key")
key = int(explicit_hash_key)
int_key = int(explicit_hash_key)
if key >= 2**128:
if int_key >= 2**128:
raise InvalidArgumentError("explicit_hash_key")
else:
key = int(md5_hash(partition_key.encode("utf-8")).hexdigest(), 16)
int_key = int(md5_hash(partition_key.encode("utf-8")).hexdigest(), 16)
for shard in self.shards.values():
if shard.starting_hash <= key < shard.ending_hash:
if shard.starting_hash <= int_key < shard.ending_hash:
return shard
return None
def put_record(self, partition_key, explicit_hash_key, data):
def put_record(
self, partition_key: str, explicit_hash_key: str, data: str
) -> Tuple[str, str]:
shard = self.get_shard_for_key(partition_key, explicit_hash_key)
sequence_number = shard.put_record(partition_key, data, explicit_hash_key)
return sequence_number, shard.shard_id
sequence_number = shard.put_record(partition_key, data, explicit_hash_key) # type: ignore
return sequence_number, shard.shard_id # type: ignore
def to_json(self, shard_limit=None):
def to_json(self, shard_limit: Optional[int] = None) -> Dict[str, Any]:
all_shards = list(self.shards.values())
requested_shards = all_shards[0 : shard_limit or len(all_shards)]
return {
@ -403,7 +423,7 @@ class Stream(CloudFormationModel):
}
}
def to_json_summary(self):
def to_json_summary(self) -> Dict[str, Any]:
return {
"StreamDescriptionSummary": {
"StreamARN": self.arn,
@ -420,18 +440,23 @@ class Stream(CloudFormationModel):
}
@staticmethod
def cloudformation_name_type():
def cloudformation_name_type() -> str:
return "Name"
@staticmethod
def cloudformation_type():
def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-kinesis-stream.html
return "AWS::Kinesis::Stream"
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs
):
def create_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
**kwargs: Any,
) -> "Stream":
properties = cloudformation_json.get("Properties", {})
shard_count = properties.get("ShardCount", 1)
retention_period_hours = properties.get("RetentionPeriodHours", resource_name)
@ -451,14 +476,14 @@ class Stream(CloudFormationModel):
return stream
@classmethod
def update_from_cloudformation_json(
def update_from_cloudformation_json( # type: ignore[misc]
cls,
original_resource,
new_resource_name,
cloudformation_json,
account_id,
region_name,
):
original_resource: Any,
new_resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
) -> "Stream":
properties = cloudformation_json["Properties"]
if Stream.is_replacement_update(properties):
@ -466,11 +491,17 @@ class Stream(CloudFormationModel):
if resource_name_property not in properties:
properties[resource_name_property] = new_resource_name
new_resource = cls.create_from_cloudformation_json(
properties[resource_name_property], cloudformation_json, region_name
resource_name=properties[resource_name_property],
cloudformation_json=cloudformation_json,
account_id=account_id,
region_name=region_name,
)
properties[resource_name_property] = original_resource.name
cls.delete_from_cloudformation_json(
original_resource.name, cloudformation_json, region_name
resource_name=original_resource.name,
cloudformation_json=cloudformation_json,
account_id=account_id,
region_name=region_name,
)
return new_resource
@ -489,14 +520,18 @@ class Stream(CloudFormationModel):
return original_resource
@classmethod
def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, account_id, region_name
):
def delete_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
) -> None:
backend: KinesisBackend = kinesis_backends[account_id][region_name]
backend.delete_stream(stream_arn=None, stream_name=resource_name)
@staticmethod
def is_replacement_update(properties):
def is_replacement_update(properties: List[str]) -> bool:
properties_requiring_replacement_update = ["BucketName", "ObjectLockEnabled"]
return any(
[
@ -506,10 +541,10 @@ class Stream(CloudFormationModel):
)
@classmethod
def has_cfn_attr(cls, attr):
def has_cfn_attr(cls, attr: str) -> bool:
return attr in ["Arn"]
def get_cfn_attribute(self, attribute_name):
def get_cfn_attribute(self, attribute_name: str) -> str:
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == "Arn":
@ -517,25 +552,31 @@ class Stream(CloudFormationModel):
raise UnformattedGetAttTemplateException()
@property
def physical_resource_id(self):
def physical_resource_id(self) -> str:
return self.stream_name
class KinesisBackend(BaseBackend):
def __init__(self, region_name, account_id):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.streams: Dict[str, Stream] = OrderedDict()
@staticmethod
def default_vpc_endpoint_service(service_region, zones):
def default_vpc_endpoint_service(
service_region: str, zones: List[str]
) -> List[Dict[str, str]]:
"""Default VPC endpoint service."""
return BaseBackend.default_vpc_endpoint_service_factory(
service_region, zones, "kinesis", special_service_name="kinesis-streams"
)
def create_stream(
self, stream_name, shard_count, stream_mode=None, retention_period_hours=None
):
self,
stream_name: str,
shard_count: int,
stream_mode: Optional[Dict[str, str]] = None,
retention_period_hours: Optional[int] = None,
) -> Stream:
if stream_name in self.streams:
raise ResourceInUseError(stream_name)
stream = Stream(
@ -560,14 +601,14 @@ class KinesisBackend(BaseBackend):
return stream
if stream_arn:
stream_name = stream_arn.split("/")[1]
raise StreamNotFoundError(stream_name, self.account_id)
raise StreamNotFoundError(stream_name, self.account_id) # type: ignore
def describe_stream_summary(
self, stream_arn: Optional[str], stream_name: Optional[str]
) -> Stream:
return self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
def list_streams(self):
def list_streams(self) -> Iterable[Stream]:
return self.streams.values()
def delete_stream(
@ -582,9 +623,9 @@ class KinesisBackend(BaseBackend):
stream_name: Optional[str],
shard_id: str,
shard_iterator_type: str,
starting_sequence_number: str,
at_timestamp: str,
):
starting_sequence_number: int,
at_timestamp: datetime.datetime,
) -> str:
# Validate params
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
try:
@ -605,31 +646,31 @@ class KinesisBackend(BaseBackend):
def get_records(
self, stream_arn: Optional[str], shard_iterator: str, limit: Optional[int]
):
) -> Tuple[str, List[Record], int]:
decomposed = decompose_shard_iterator(shard_iterator)
stream_name, shard_id, last_sequence_id = decomposed
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
shard = stream.get_shard(shard_id)
records, last_sequence_id, millis_behind_latest = shard.get_records(
records, last_sequence_id, millis_behind_latest = shard.get_records( # type: ignore
last_sequence_id, limit
)
next_shard_iterator = compose_shard_iterator(
stream_name, shard, last_sequence_id
stream_name, shard, last_sequence_id # type: ignore
)
return next_shard_iterator, records, millis_behind_latest
def put_record(
self,
stream_arn,
stream_name,
partition_key,
explicit_hash_key,
data,
):
stream_arn: str,
stream_name: str,
partition_key: str,
explicit_hash_key: str,
data: str,
) -> Tuple[str, str]:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
sequence_number, shard_id = stream.put_record(
@ -638,10 +679,12 @@ class KinesisBackend(BaseBackend):
return sequence_number, shard_id
def put_records(self, stream_arn, stream_name, records):
def put_records(
self, stream_arn: str, stream_name: str, records: List[Dict[str, Any]]
) -> Dict[str, Any]:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
response = {"FailedRecordCount": 0, "Records": []}
response: Dict[str, Any] = {"FailedRecordCount": 0, "Records": []}
if len(records) > 500:
raise TooManyRecords
@ -660,7 +703,7 @@ class KinesisBackend(BaseBackend):
data = record.get("Data")
sequence_number, shard_id = stream.put_record(
partition_key, explicit_hash_key, data
partition_key, explicit_hash_key, data # type: ignore[arg-type]
)
response["Records"].append(
{"SequenceNumber": sequence_number, "ShardId": shard_id}
@ -669,8 +712,12 @@ class KinesisBackend(BaseBackend):
return response
def split_shard(
self, stream_arn, stream_name, shard_to_split, new_starting_hash_key
):
self,
stream_arn: str,
stream_name: str,
shard_to_split: str,
new_starting_hash_key: str,
) -> None:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
if not re.match("[a-zA-Z0-9_.-]+", shard_to_split):
@ -695,8 +742,12 @@ class KinesisBackend(BaseBackend):
stream.split_shard(shard_to_split, new_starting_hash_key)
def merge_shards(
self, stream_arn, stream_name, shard_to_merge, adjacent_shard_to_merge
):
self,
stream_arn: str,
stream_name: str,
shard_to_merge: str,
adjacent_shard_to_merge: str,
) -> None:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
if shard_to_merge not in stream.shards:
@ -713,7 +764,9 @@ class KinesisBackend(BaseBackend):
stream.merge_shards(shard_to_merge, adjacent_shard_to_merge)
def update_shard_count(self, stream_arn, stream_name, target_shard_count):
def update_shard_count(
self, stream_arn: str, stream_name: str, target_shard_count: int
) -> int:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
current_shard_count = len([s for s in stream.shards.values() if s.is_open])
@ -722,7 +775,7 @@ class KinesisBackend(BaseBackend):
return current_shard_count
@paginate(pagination_model=PAGINATION_MODEL)
def list_shards(self, stream_arn: Optional[str], stream_name: Optional[str]):
def list_shards(self, stream_arn: Optional[str], stream_name: Optional[str]) -> List[Dict[str, Any]]: # type: ignore
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
shards = sorted(stream.shards.values(), key=lambda x: x.shard_id)
return [shard.to_json() for shard in shards]
@ -766,12 +819,16 @@ class KinesisBackend(BaseBackend):
stream.retention_period_hours = retention_period_hours
def list_tags_for_stream(
self, stream_arn, stream_name, exclusive_start_tag_key=None, limit=None
):
self,
stream_arn: str,
stream_name: str,
exclusive_start_tag_key: Optional[str] = None,
limit: Optional[int] = None,
) -> Dict[str, Any]:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
tags = []
result = {"HasMoreTags": False, "Tags": tags}
tags: List[Dict[str, str]] = []
result: Dict[str, Any] = {"HasMoreTags": False, "Tags": tags}
for key, val in sorted(stream.tags.items(), key=lambda x: x[0]):
if limit and len(tags) >= limit:
result["HasMoreTags"] = True
@ -805,7 +862,7 @@ class KinesisBackend(BaseBackend):
stream_arn: Optional[str],
stream_name: Optional[str],
shard_level_metrics: List[str],
) -> Tuple[str, str, Dict[str, Any], Dict[str, Any]]:
) -> Tuple[str, str, List[str], List[str]]:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
current_shard_level_metrics = stream.shard_level_metrics
desired_metrics = list(set(current_shard_level_metrics + shard_level_metrics))
@ -822,7 +879,7 @@ class KinesisBackend(BaseBackend):
stream_arn: Optional[str],
stream_name: Optional[str],
to_be_disabled: List[str],
) -> Tuple[str, str, Dict[str, Any], Dict[str, Any]]:
) -> Tuple[str, str, List[str], List[str]]:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
current_metrics = stream.shard_level_metrics
if "ALL" in to_be_disabled:
@ -834,19 +891,19 @@ class KinesisBackend(BaseBackend):
stream.shard_level_metrics = desired_metrics
return stream.arn, stream.stream_name, current_metrics, desired_metrics
def _find_stream_by_arn(self, stream_arn):
def _find_stream_by_arn(self, stream_arn: str) -> Stream: # type: ignore[return]
for stream in self.streams.values():
if stream.arn == stream_arn:
return stream
def list_stream_consumers(self, stream_arn):
def list_stream_consumers(self, stream_arn: str) -> List[Consumer]:
"""
Pagination is not yet implemented
"""
stream = self._find_stream_by_arn(stream_arn)
return stream.consumers
def register_stream_consumer(self, stream_arn, consumer_name):
def register_stream_consumer(self, stream_arn: str, consumer_name: str) -> Consumer:
consumer = Consumer(
consumer_name, self.account_id, self.region_name, stream_arn
)
@ -854,7 +911,9 @@ class KinesisBackend(BaseBackend):
stream.consumers.append(consumer)
return consumer
def describe_stream_consumer(self, stream_arn, consumer_name, consumer_arn):
def describe_stream_consumer(
self, stream_arn: str, consumer_name: str, consumer_arn: str
) -> Consumer:
if stream_arn:
stream = self._find_stream_by_arn(stream_arn)
for consumer in stream.consumers:
@ -862,14 +921,16 @@ class KinesisBackend(BaseBackend):
return consumer
if consumer_arn:
for stream in self.streams.values():
consumer = stream.get_consumer_by_arn(consumer_arn)
if consumer:
return consumer
_consumer = stream.get_consumer_by_arn(consumer_arn)
if _consumer:
return _consumer
raise ConsumerNotFound(
consumer=consumer_name or consumer_arn, account_id=self.account_id
)
def deregister_stream_consumer(self, stream_arn, consumer_name, consumer_arn):
def deregister_stream_consumer(
self, stream_arn: str, consumer_name: str, consumer_arn: str
) -> None:
if stream_arn:
stream = self._find_stream_by_arn(stream_arn)
stream.consumers = [
@ -881,17 +942,19 @@ class KinesisBackend(BaseBackend):
# It will be a noop for other streams
stream.delete_consumer(consumer_arn)
def start_stream_encryption(self, stream_arn, stream_name, encryption_type, key_id):
def start_stream_encryption(
self, stream_arn: str, stream_name: str, encryption_type: str, key_id: str
) -> None:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
stream.encryption_type = encryption_type
stream.key_id = key_id
def stop_stream_encryption(self, stream_arn, stream_name):
def stop_stream_encryption(self, stream_arn: str, stream_name: str) -> None:
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
stream.encryption_type = "NONE"
stream.key_id = None
def update_stream_mode(self, stream_arn, stream_mode):
def update_stream_mode(self, stream_arn: str, stream_mode: Dict[str, str]) -> None:
stream = self._find_stream_by_arn(stream_arn)
stream.stream_mode = stream_mode

View File

@ -5,45 +5,41 @@ from .models import kinesis_backends, KinesisBackend
class KinesisResponse(BaseResponse):
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="kinesis")
@property
def parameters(self):
return json.loads(self.body)
@property
def kinesis_backend(self) -> KinesisBackend:
return kinesis_backends[self.current_account][self.region]
def create_stream(self):
stream_name = self.parameters.get("StreamName")
shard_count = self.parameters.get("ShardCount")
stream_mode = self.parameters.get("StreamModeDetails")
def create_stream(self) -> str:
stream_name = self._get_param("StreamName")
shard_count = self._get_param("ShardCount")
stream_mode = self._get_param("StreamModeDetails")
self.kinesis_backend.create_stream(
stream_name, shard_count, stream_mode=stream_mode
)
return ""
def describe_stream(self):
stream_name = self.parameters.get("StreamName")
stream_arn = self.parameters.get("StreamARN")
limit = self.parameters.get("Limit")
def describe_stream(self) -> str:
stream_name = self._get_param("StreamName")
stream_arn = self._get_param("StreamARN")
limit = self._get_param("Limit")
stream = self.kinesis_backend.describe_stream(stream_arn, stream_name)
return json.dumps(stream.to_json(shard_limit=limit))
def describe_stream_summary(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
def describe_stream_summary(self) -> str:
stream_arn = self._get_param("StreamARN")
stream_name = self._get_param("StreamName")
stream = self.kinesis_backend.describe_stream_summary(stream_arn, stream_name)
return json.dumps(stream.to_json_summary())
def list_streams(self):
def list_streams(self) -> str:
streams = self.kinesis_backend.list_streams()
stream_names = [stream.stream_name for stream in streams]
max_streams = self._get_param("Limit", 10)
try:
token = self.parameters.get("ExclusiveStartStreamName")
token = self._get_param("ExclusiveStartStreamName")
except ValueError:
token = self._get_param("ExclusiveStartStreamName")
if token:
@ -59,19 +55,19 @@ class KinesisResponse(BaseResponse):
{"HasMoreStreams": has_more_streams, "StreamNames": streams_resp}
)
def delete_stream(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
def delete_stream(self) -> str:
stream_arn = self._get_param("StreamARN")
stream_name = self._get_param("StreamName")
self.kinesis_backend.delete_stream(stream_arn, stream_name)
return ""
def get_shard_iterator(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
shard_id = self.parameters.get("ShardId")
shard_iterator_type = self.parameters.get("ShardIteratorType")
starting_sequence_number = self.parameters.get("StartingSequenceNumber")
at_timestamp = self.parameters.get("Timestamp")
def get_shard_iterator(self) -> str:
stream_arn = self._get_param("StreamARN")
stream_name = self._get_param("StreamName")
shard_id = self._get_param("ShardId")
shard_iterator_type = self._get_param("ShardIteratorType")
starting_sequence_number = self._get_param("StartingSequenceNumber")
at_timestamp = self._get_param("Timestamp")
shard_iterator = self.kinesis_backend.get_shard_iterator(
stream_arn,
@ -84,10 +80,10 @@ class KinesisResponse(BaseResponse):
return json.dumps({"ShardIterator": shard_iterator})
def get_records(self):
stream_arn = self.parameters.get("StreamARN")
shard_iterator = self.parameters.get("ShardIterator")
limit = self.parameters.get("Limit")
def get_records(self) -> str:
stream_arn = self._get_param("StreamARN")
shard_iterator = self._get_param("ShardIterator")
limit = self._get_param("Limit")
(
next_shard_iterator,
@ -103,12 +99,12 @@ class KinesisResponse(BaseResponse):
}
)
def put_record(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
partition_key = self.parameters.get("PartitionKey")
explicit_hash_key = self.parameters.get("ExplicitHashKey")
data = self.parameters.get("Data")
def put_record(self) -> str:
stream_arn = self._get_param("StreamARN")
stream_name = self._get_param("StreamName")
partition_key = self._get_param("PartitionKey")
explicit_hash_key = self._get_param("ExplicitHashKey")
data = self._get_param("Data")
sequence_number, shard_id = self.kinesis_backend.put_record(
stream_arn,
@ -120,40 +116,40 @@ class KinesisResponse(BaseResponse):
return json.dumps({"SequenceNumber": sequence_number, "ShardId": shard_id})
def put_records(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
records = self.parameters.get("Records")
def put_records(self) -> str:
stream_arn = self._get_param("StreamARN")
stream_name = self._get_param("StreamName")
records = self._get_param("Records")
response = self.kinesis_backend.put_records(stream_arn, stream_name, records)
return json.dumps(response)
def split_shard(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
shard_to_split = self.parameters.get("ShardToSplit")
new_starting_hash_key = self.parameters.get("NewStartingHashKey")
def split_shard(self) -> str:
stream_arn = self._get_param("StreamARN")
stream_name = self._get_param("StreamName")
shard_to_split = self._get_param("ShardToSplit")
new_starting_hash_key = self._get_param("NewStartingHashKey")
self.kinesis_backend.split_shard(
stream_arn, stream_name, shard_to_split, new_starting_hash_key
)
return ""
def merge_shards(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
shard_to_merge = self.parameters.get("ShardToMerge")
adjacent_shard_to_merge = self.parameters.get("AdjacentShardToMerge")
def merge_shards(self) -> str:
stream_arn = self._get_param("StreamARN")
stream_name = self._get_param("StreamName")
shard_to_merge = self._get_param("ShardToMerge")
adjacent_shard_to_merge = self._get_param("AdjacentShardToMerge")
self.kinesis_backend.merge_shards(
stream_arn, stream_name, shard_to_merge, adjacent_shard_to_merge
)
return ""
def list_shards(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
next_token = self.parameters.get("NextToken")
max_results = self.parameters.get("MaxResults", 10000)
def list_shards(self) -> str:
stream_arn = self._get_param("StreamARN")
stream_name = self._get_param("StreamName")
next_token = self._get_param("NextToken")
max_results = self._get_param("MaxResults", 10000)
shards, token = self.kinesis_backend.list_shards(
stream_arn=stream_arn,
stream_name=stream_name,
@ -165,10 +161,10 @@ class KinesisResponse(BaseResponse):
res["NextToken"] = token
return json.dumps(res)
def update_shard_count(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
target_shard_count = self.parameters.get("TargetShardCount")
def update_shard_count(self) -> str:
stream_arn = self._get_param("StreamARN")
stream_name = self._get_param("StreamName")
target_shard_count = self._get_param("TargetShardCount")
current_shard_count = self.kinesis_backend.update_shard_count(
stream_arn=stream_arn,
stream_name=stream_name,
@ -182,52 +178,52 @@ class KinesisResponse(BaseResponse):
)
)
def increase_stream_retention_period(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
retention_period_hours = self.parameters.get("RetentionPeriodHours")
def increase_stream_retention_period(self) -> str:
stream_arn = self._get_param("StreamARN")
stream_name = self._get_param("StreamName")
retention_period_hours = self._get_param("RetentionPeriodHours")
self.kinesis_backend.increase_stream_retention_period(
stream_arn, stream_name, retention_period_hours
)
return ""
def decrease_stream_retention_period(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
retention_period_hours = self.parameters.get("RetentionPeriodHours")
def decrease_stream_retention_period(self) -> str:
stream_arn = self._get_param("StreamARN")
stream_name = self._get_param("StreamName")
retention_period_hours = self._get_param("RetentionPeriodHours")
self.kinesis_backend.decrease_stream_retention_period(
stream_arn, stream_name, retention_period_hours
)
return ""
def add_tags_to_stream(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
tags = self.parameters.get("Tags")
def add_tags_to_stream(self) -> str:
stream_arn = self._get_param("StreamARN")
stream_name = self._get_param("StreamName")
tags = self._get_param("Tags")
self.kinesis_backend.add_tags_to_stream(stream_arn, stream_name, tags)
return json.dumps({})
def list_tags_for_stream(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
exclusive_start_tag_key = self.parameters.get("ExclusiveStartTagKey")
limit = self.parameters.get("Limit")
def list_tags_for_stream(self) -> str:
stream_arn = self._get_param("StreamARN")
stream_name = self._get_param("StreamName")
exclusive_start_tag_key = self._get_param("ExclusiveStartTagKey")
limit = self._get_param("Limit")
response = self.kinesis_backend.list_tags_for_stream(
stream_arn, stream_name, exclusive_start_tag_key, limit
)
return json.dumps(response)
def remove_tags_from_stream(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
tag_keys = self.parameters.get("TagKeys")
def remove_tags_from_stream(self) -> str:
stream_arn = self._get_param("StreamARN")
stream_name = self._get_param("StreamName")
tag_keys = self._get_param("TagKeys")
self.kinesis_backend.remove_tags_from_stream(stream_arn, stream_name, tag_keys)
return json.dumps({})
def enable_enhanced_monitoring(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
shard_level_metrics = self.parameters.get("ShardLevelMetrics")
def enable_enhanced_monitoring(self) -> str:
stream_arn = self._get_param("StreamARN")
stream_name = self._get_param("StreamName")
shard_level_metrics = self._get_param("ShardLevelMetrics")
arn, name, current, desired = self.kinesis_backend.enable_enhanced_monitoring(
stream_arn=stream_arn,
stream_name=stream_name,
@ -242,10 +238,10 @@ class KinesisResponse(BaseResponse):
)
)
def disable_enhanced_monitoring(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
shard_level_metrics = self.parameters.get("ShardLevelMetrics")
def disable_enhanced_monitoring(self) -> str:
stream_arn = self._get_param("StreamARN")
stream_name = self._get_param("StreamName")
shard_level_metrics = self._get_param("ShardLevelMetrics")
arn, name, current, desired = self.kinesis_backend.disable_enhanced_monitoring(
stream_arn=stream_arn,
stream_name=stream_name,
@ -260,23 +256,23 @@ class KinesisResponse(BaseResponse):
)
)
def list_stream_consumers(self):
stream_arn = self.parameters.get("StreamARN")
def list_stream_consumers(self) -> str:
stream_arn = self._get_param("StreamARN")
consumers = self.kinesis_backend.list_stream_consumers(stream_arn=stream_arn)
return json.dumps(dict(Consumers=[c.to_json() for c in consumers]))
def register_stream_consumer(self):
stream_arn = self.parameters.get("StreamARN")
consumer_name = self.parameters.get("ConsumerName")
def register_stream_consumer(self) -> str:
stream_arn = self._get_param("StreamARN")
consumer_name = self._get_param("ConsumerName")
consumer = self.kinesis_backend.register_stream_consumer(
stream_arn=stream_arn, consumer_name=consumer_name
)
return json.dumps(dict(Consumer=consumer.to_json()))
def describe_stream_consumer(self):
stream_arn = self.parameters.get("StreamARN")
consumer_name = self.parameters.get("ConsumerName")
consumer_arn = self.parameters.get("ConsumerARN")
def describe_stream_consumer(self) -> str:
stream_arn = self._get_param("StreamARN")
consumer_name = self._get_param("ConsumerName")
consumer_arn = self._get_param("ConsumerARN")
consumer = self.kinesis_backend.describe_stream_consumer(
stream_arn=stream_arn,
consumer_name=consumer_name,
@ -286,10 +282,10 @@ class KinesisResponse(BaseResponse):
dict(ConsumerDescription=consumer.to_json(include_stream_arn=True))
)
def deregister_stream_consumer(self):
stream_arn = self.parameters.get("StreamARN")
consumer_name = self.parameters.get("ConsumerName")
consumer_arn = self.parameters.get("ConsumerARN")
def deregister_stream_consumer(self) -> str:
stream_arn = self._get_param("StreamARN")
consumer_name = self._get_param("ConsumerName")
consumer_arn = self._get_param("ConsumerARN")
self.kinesis_backend.deregister_stream_consumer(
stream_arn=stream_arn,
consumer_name=consumer_name,
@ -297,11 +293,11 @@ class KinesisResponse(BaseResponse):
)
return json.dumps(dict())
def start_stream_encryption(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
encryption_type = self.parameters.get("EncryptionType")
key_id = self.parameters.get("KeyId")
def start_stream_encryption(self) -> str:
stream_arn = self._get_param("StreamARN")
stream_name = self._get_param("StreamName")
encryption_type = self._get_param("EncryptionType")
key_id = self._get_param("KeyId")
self.kinesis_backend.start_stream_encryption(
stream_arn=stream_arn,
stream_name=stream_name,
@ -310,16 +306,16 @@ class KinesisResponse(BaseResponse):
)
return json.dumps(dict())
def stop_stream_encryption(self):
stream_arn = self.parameters.get("StreamARN")
stream_name = self.parameters.get("StreamName")
def stop_stream_encryption(self) -> str:
stream_arn = self._get_param("StreamARN")
stream_name = self._get_param("StreamName")
self.kinesis_backend.stop_stream_encryption(
stream_arn=stream_arn, stream_name=stream_name
)
return json.dumps(dict())
def update_stream_mode(self):
stream_arn = self.parameters.get("StreamARN")
stream_mode = self.parameters.get("StreamModeDetails")
def update_stream_mode(self) -> str:
stream_arn = self._get_param("StreamARN")
stream_mode = self._get_param("StreamModeDetails")
self.kinesis_backend.update_stream_mode(stream_arn, stream_mode)
return "{}"

View File

@ -1,4 +1,6 @@
import base64
from datetime import datetime
from typing import Any, Optional, List
from .exceptions import InvalidArgumentError
@ -30,8 +32,12 @@ PAGINATION_MODEL = {
def compose_new_shard_iterator(
stream_name, shard, shard_iterator_type, starting_sequence_number, at_timestamp
):
stream_name: Optional[str],
shard: Any,
shard_iterator_type: str,
starting_sequence_number: int,
at_timestamp: datetime,
) -> str:
if shard_iterator_type == "AT_SEQUENCE_NUMBER":
last_sequence_id = int(starting_sequence_number) - 1
elif shard_iterator_type == "AFTER_SEQUENCE_NUMBER":
@ -47,11 +53,13 @@ def compose_new_shard_iterator(
return compose_shard_iterator(stream_name, shard, last_sequence_id)
def compose_shard_iterator(stream_name, shard, last_sequence_id):
def compose_shard_iterator(
stream_name: Optional[str], shard: Any, last_sequence_id: int
) -> str:
return encode_method(
f"{stream_name}:{shard.shard_id}:{last_sequence_id}".encode("utf-8")
).decode("utf-8")
def decompose_shard_iterator(shard_iterator):
def decompose_shard_iterator(shard_iterator: str) -> List[str]:
return decode_method(shard_iterator.encode("utf-8")).decode("utf-8").split(":")

View File

@ -6,7 +6,7 @@ class KinesisvideoClientError(RESTError):
class ResourceNotFoundException(KinesisvideoClientError):
def __init__(self):
def __init__(self) -> None:
self.code = 404
super().__init__(
"ResourceNotFoundException",
@ -15,6 +15,6 @@ class ResourceNotFoundException(KinesisvideoClientError):
class ResourceInUseException(KinesisvideoClientError):
def __init__(self, message):
def __init__(self, message: str):
self.code = 400
super().__init__("ResourceInUseException", message)

View File

@ -1,5 +1,6 @@
from moto.core import BaseBackend, BackendDict, BaseModel
from datetime import datetime
from typing import Any, Dict, List
from moto.core import BaseBackend, BackendDict, BaseModel
from .exceptions import ResourceNotFoundException, ResourceInUseException
from moto.moto_api._internal import mock_random as random
@ -7,14 +8,14 @@ from moto.moto_api._internal import mock_random as random
class Stream(BaseModel):
def __init__(
self,
account_id,
region_name,
device_name,
stream_name,
media_type,
kms_key_id,
data_retention_in_hours,
tags,
account_id: str,
region_name: str,
device_name: str,
stream_name: str,
media_type: str,
kms_key_id: str,
data_retention_in_hours: int,
tags: Dict[str, str],
):
self.region_name = region_name
self.stream_name = stream_name
@ -30,11 +31,11 @@ class Stream(BaseModel):
self.data_endpoint_number = random.get_random_hex()
self.arn = stream_arn
def get_data_endpoint(self, api_name):
def get_data_endpoint(self, api_name: str) -> str:
data_endpoint_prefix = "s-" if api_name in ("PUT_MEDIA", "GET_MEDIA") else "b-"
return f"https://{data_endpoint_prefix}{self.data_endpoint_number}.kinesisvideo.{self.region_name}.amazonaws.com"
def to_dict(self):
def to_dict(self) -> Dict[str, Any]:
return {
"DeviceName": self.device_name,
"StreamName": self.stream_name,
@ -49,19 +50,19 @@ class Stream(BaseModel):
class KinesisVideoBackend(BaseBackend):
def __init__(self, region_name, account_id):
def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.streams = {}
self.streams: Dict[str, Stream] = {}
def create_stream(
self,
device_name,
stream_name,
media_type,
kms_key_id,
data_retention_in_hours,
tags,
):
device_name: str,
stream_name: str,
media_type: str,
kms_key_id: str,
data_retention_in_hours: int,
tags: Dict[str, str],
) -> str:
streams = [_ for _ in self.streams.values() if _.stream_name == stream_name]
if len(streams) > 0:
raise ResourceInUseException(f"The stream {stream_name} already exists.")
@ -78,7 +79,7 @@ class KinesisVideoBackend(BaseBackend):
self.streams[stream.arn] = stream
return stream.arn
def _get_stream(self, stream_name, stream_arn):
def _get_stream(self, stream_name: str, stream_arn: str) -> Stream:
if stream_name:
streams = [_ for _ in self.streams.values() if _.stream_name == stream_name]
if len(streams) == 0:
@ -90,20 +91,17 @@ class KinesisVideoBackend(BaseBackend):
raise ResourceNotFoundException()
return stream
def describe_stream(self, stream_name, stream_arn):
def describe_stream(self, stream_name: str, stream_arn: str) -> Dict[str, Any]:
stream = self._get_stream(stream_name, stream_arn)
stream_info = stream.to_dict()
return stream_info
return stream.to_dict()
def list_streams(self):
def list_streams(self) -> List[Dict[str, Any]]:
"""
Pagination and the StreamNameCondition-parameter are not yet implemented
"""
stream_info_list = [_.to_dict() for _ in self.streams.values()]
next_token = None
return stream_info_list, next_token
return [_.to_dict() for _ in self.streams.values()]
def delete_stream(self, stream_arn):
def delete_stream(self, stream_arn: str) -> None:
"""
The CurrentVersion-parameter is not yet implemented
"""
@ -112,11 +110,11 @@ class KinesisVideoBackend(BaseBackend):
raise ResourceNotFoundException()
del self.streams[stream_arn]
def get_data_endpoint(self, stream_name, stream_arn, api_name):
def get_data_endpoint(
self, stream_name: str, stream_arn: str, api_name: str
) -> str:
stream = self._get_stream(stream_name, stream_arn)
return stream.get_data_endpoint(api_name)
# add methods from here
kinesisvideo_backends = BackendDict(KinesisVideoBackend, "kinesisvideo")

View File

@ -1,17 +1,17 @@
from moto.core.responses import BaseResponse
from .models import kinesisvideo_backends
from .models import kinesisvideo_backends, KinesisVideoBackend
import json
class KinesisVideoResponse(BaseResponse):
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="kinesisvideo")
@property
def kinesisvideo_backend(self):
def kinesisvideo_backend(self) -> KinesisVideoBackend:
return kinesisvideo_backends[self.current_account][self.region]
def create_stream(self):
def create_stream(self) -> str:
device_name = self._get_param("DeviceName")
stream_name = self._get_param("StreamName")
media_type = self._get_param("MediaType")
@ -28,7 +28,7 @@ class KinesisVideoResponse(BaseResponse):
)
return json.dumps(dict(StreamARN=stream_arn))
def describe_stream(self):
def describe_stream(self) -> str:
stream_name = self._get_param("StreamName")
stream_arn = self._get_param("StreamARN")
stream_info = self.kinesisvideo_backend.describe_stream(
@ -36,16 +36,16 @@ class KinesisVideoResponse(BaseResponse):
)
return json.dumps(dict(StreamInfo=stream_info))
def list_streams(self):
stream_info_list, next_token = self.kinesisvideo_backend.list_streams()
return json.dumps(dict(StreamInfoList=stream_info_list, NextToken=next_token))
def list_streams(self) -> str:
stream_info_list = self.kinesisvideo_backend.list_streams()
return json.dumps(dict(StreamInfoList=stream_info_list, NextToken=None))
def delete_stream(self):
def delete_stream(self) -> str:
stream_arn = self._get_param("StreamARN")
self.kinesisvideo_backend.delete_stream(stream_arn=stream_arn)
return json.dumps(dict())
def get_data_endpoint(self):
def get_data_endpoint(self) -> str:
stream_name = self._get_param("StreamName")
stream_arn = self._get_param("StreamARN")
api_name = self._get_param("APIName")

View File

@ -1,14 +1,17 @@
from typing import Tuple
from moto.core import BaseBackend, BackendDict
from moto.kinesisvideo import kinesisvideo_backends
from moto.kinesisvideo.models import kinesisvideo_backends, KinesisVideoBackend
from moto.sts.utils import random_session_token
class KinesisVideoArchivedMediaBackend(BaseBackend):
@property
def backend(self):
def backend(self) -> KinesisVideoBackend:
return kinesisvideo_backends[self.account_id][self.region_name]
def _get_streaming_url(self, stream_name, stream_arn, api_name):
def _get_streaming_url(
self, stream_name: str, stream_arn: str, api_name: str
) -> str:
stream = self.backend._get_stream(stream_name, stream_arn)
data_endpoint = stream.get_data_endpoint(api_name)
session_token = random_session_token()
@ -19,19 +22,17 @@ class KinesisVideoArchivedMediaBackend(BaseBackend):
relative_path = api_to_relative_path[api_name]
return f"{data_endpoint}{relative_path}?SessionToken={session_token}"
def get_hls_streaming_session_url(self, stream_name, stream_arn):
# Ignore option paramters as the format of hls_url does't depends on them
def get_hls_streaming_session_url(self, stream_name: str, stream_arn: str) -> str:
# Ignore option paramters as the format of hls_url doesn't depend on them
api_name = "GET_HLS_STREAMING_SESSION_URL"
url = self._get_streaming_url(stream_name, stream_arn, api_name)
return url
return self._get_streaming_url(stream_name, stream_arn, api_name)
def get_dash_streaming_session_url(self, stream_name, stream_arn):
# Ignore option paramters as the format of hls_url does't depends on them
def get_dash_streaming_session_url(self, stream_name: str, stream_arn: str) -> str:
# Ignore option paramters as the format of hls_url doesn't depend on them
api_name = "GET_DASH_STREAMING_SESSION_URL"
url = self._get_streaming_url(stream_name, stream_arn, api_name)
return url
return self._get_streaming_url(stream_name, stream_arn, api_name)
def get_clip(self, stream_name, stream_arn):
def get_clip(self, stream_name: str, stream_arn: str) -> Tuple[str, bytes]:
self.backend._get_stream(stream_name, stream_arn)
content_type = "video/mp4" # Fixed content_type as it depends on input stream
payload = b"sample-mp4-video"

View File

@ -1,17 +1,18 @@
from typing import Dict, Tuple
from moto.core.responses import BaseResponse
from .models import kinesisvideoarchivedmedia_backends
from .models import kinesisvideoarchivedmedia_backends, KinesisVideoArchivedMediaBackend
import json
class KinesisVideoArchivedMediaResponse(BaseResponse):
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="kinesis-video-archived-media")
@property
def kinesisvideoarchivedmedia_backend(self):
def kinesisvideoarchivedmedia_backend(self) -> KinesisVideoArchivedMediaBackend:
return kinesisvideoarchivedmedia_backends[self.current_account][self.region]
def get_hls_streaming_session_url(self):
def get_hls_streaming_session_url(self) -> str:
stream_name = self._get_param("StreamName")
stream_arn = self._get_param("StreamARN")
hls_streaming_session_url = (
@ -21,7 +22,7 @@ class KinesisVideoArchivedMediaResponse(BaseResponse):
)
return json.dumps(dict(HLSStreamingSessionURL=hls_streaming_session_url))
def get_dash_streaming_session_url(self):
def get_dash_streaming_session_url(self) -> str:
stream_name = self._get_param("StreamName")
stream_arn = self._get_param("StreamARN")
dash_streaming_session_url = (
@ -31,7 +32,7 @@ class KinesisVideoArchivedMediaResponse(BaseResponse):
)
return json.dumps(dict(DASHStreamingSessionURL=dash_streaming_session_url))
def get_clip(self):
def get_clip(self) -> Tuple[bytes, Dict[str, str]]:
stream_name = self._get_param("StreamName")
stream_arn = self._get_param("StreamARN")
content_type, payload = self.kinesisvideoarchivedmedia_backend.get_clip(

View File

@ -4,29 +4,29 @@ from moto.core.exceptions import JsonRESTError
class NotFoundException(JsonRESTError):
code = 400
def __init__(self, message):
def __init__(self, message: str):
super().__init__("NotFoundException", message)
class ValidationException(JsonRESTError):
code = 400
def __init__(self, message):
def __init__(self, message: str):
super().__init__("ValidationException", message)
class AlreadyExistsException(JsonRESTError):
code = 400
def __init__(self, message):
def __init__(self, message: str):
super().__init__("AlreadyExistsException", message)
class NotAuthorizedException(JsonRESTError):
code = 400
def __init__(self):
super().__init__("NotAuthorizedException", None)
def __init__(self) -> None:
super().__init__("NotAuthorizedException", "")
self.description = '{"__type":"NotAuthorizedException"}'
@ -34,7 +34,7 @@ class NotAuthorizedException(JsonRESTError):
class AccessDeniedException(JsonRESTError):
code = 400
def __init__(self, message):
def __init__(self, message: str):
super().__init__("AccessDeniedException", message)
self.description = '{"__type":"AccessDeniedException"}'
@ -43,7 +43,7 @@ class AccessDeniedException(JsonRESTError):
class InvalidCiphertextException(JsonRESTError):
code = 400
def __init__(self):
super().__init__("InvalidCiphertextException", None)
def __init__(self) -> None:
super().__init__("InvalidCiphertextException", "")
self.description = '{"__type":"InvalidCiphertextException"}'

View File

@ -5,8 +5,8 @@ from copy import copy
from datetime import datetime, timedelta
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding
from typing import Any, Dict, List, Tuple, Optional, Iterable, Set
from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel
from moto.core.utils import unix_time
@ -28,12 +28,12 @@ from .utils import (
class Grant(BaseModel):
def __init__(
self,
key_id,
name,
grantee_principal,
operations,
constraints,
retiring_principal,
key_id: str,
name: str,
grantee_principal: str,
operations: List[str],
constraints: Dict[str, Any],
retiring_principal: str,
):
self.key_id = key_id
self.name = name
@ -44,7 +44,7 @@ class Grant(BaseModel):
self.id = mock_random.get_random_hex()
self.token = mock_random.get_random_hex()
def to_json(self):
def to_json(self) -> Dict[str, Any]:
return {
"KeyId": self.key_id,
"GrantId": self.id,
@ -59,13 +59,13 @@ class Grant(BaseModel):
class Key(CloudFormationModel):
def __init__(
self,
policy,
key_usage,
key_spec,
description,
account_id,
region,
multi_region=False,
policy: Optional[str],
key_usage: str,
key_spec: str,
description: str,
account_id: str,
region: str,
multi_region: bool = False,
):
self.id = generate_key_id(multi_region)
self.creation_date = unix_time()
@ -78,7 +78,7 @@ class Key(CloudFormationModel):
self.region = region
self.multi_region = multi_region
self.key_rotation_status = False
self.deletion_date = None
self.deletion_date: Optional[datetime] = None
self.key_material = generate_master_key()
self.private_key = generate_private_key()
self.origin = "AWS_KMS"
@ -86,10 +86,15 @@ class Key(CloudFormationModel):
self.key_spec = key_spec or "SYMMETRIC_DEFAULT"
self.arn = f"arn:aws:kms:{region}:{account_id}:key/{self.id}"
self.grants = dict()
self.grants: Dict[str, Grant] = dict()
def add_grant(
self, name, grantee_principal, operations, constraints, retiring_principal
self,
name: str,
grantee_principal: str,
operations: List[str],
constraints: Dict[str, Any],
retiring_principal: str,
) -> Grant:
grant = Grant(
self.id,
@ -102,32 +107,32 @@ class Key(CloudFormationModel):
self.grants[grant.id] = grant
return grant
def list_grants(self, grant_id) -> [Grant]:
def list_grants(self, grant_id: str) -> List[Grant]:
grant_ids = [grant_id] if grant_id else self.grants.keys()
return [grant for _id, grant in self.grants.items() if _id in grant_ids]
def list_retirable_grants(self, retiring_principal) -> [Grant]:
def list_retirable_grants(self, retiring_principal: str) -> List[Grant]:
return [
grant
for grant in self.grants.values()
if grant.retiring_principal == retiring_principal
]
def revoke_grant(self, grant_id) -> None:
def revoke_grant(self, grant_id: str) -> None:
if not self.grants.pop(grant_id, None):
raise JsonRESTError("NotFoundException", f"Grant ID {grant_id} not found")
def retire_grant(self, grant_id) -> None:
def retire_grant(self, grant_id: str) -> None:
self.grants.pop(grant_id, None)
def retire_grant_by_token(self, grant_token) -> None:
def retire_grant_by_token(self, grant_token: str) -> None:
self.grants = {
_id: grant
for _id, grant in self.grants.items()
if grant.token != grant_token
}
def generate_default_policy(self):
def generate_default_policy(self) -> str:
return json.dumps(
{
"Version": "2012-10-17",
@ -145,11 +150,11 @@ class Key(CloudFormationModel):
)
@property
def physical_resource_id(self):
def physical_resource_id(self) -> str:
return self.id
@property
def encryption_algorithms(self):
def encryption_algorithms(self) -> Optional[List[str]]:
if self.key_usage == "SIGN_VERIFY":
return None
elif self.key_spec == "SYMMETRIC_DEFAULT":
@ -158,9 +163,9 @@ class Key(CloudFormationModel):
return ["RSAES_OAEP_SHA_1", "RSAES_OAEP_SHA_256"]
@property
def signing_algorithms(self):
def signing_algorithms(self) -> List[str]:
if self.key_usage == "ENCRYPT_DECRYPT":
return None
return None # type: ignore[return-value]
elif self.key_spec in ["ECC_NIST_P256", "ECC_SECG_P256K1"]:
return ["ECDSA_SHA_256"]
elif self.key_spec == "ECC_NIST_P384":
@ -177,7 +182,7 @@ class Key(CloudFormationModel):
"RSASSA_PSS_SHA_512",
]
def to_dict(self):
def to_dict(self) -> Dict[str, Any]:
key_dict = {
"KeyMetadata": {
"AWSAccountId": self.account_id,
@ -201,22 +206,27 @@ class Key(CloudFormationModel):
key_dict["KeyMetadata"]["DeletionDate"] = unix_time(self.deletion_date)
return key_dict
def delete(self, account_id, region_name):
def delete(self, account_id: str, region_name: str) -> None:
kms_backends[account_id][region_name].delete_key(self.id)
@staticmethod
def cloudformation_name_type():
return None
def cloudformation_name_type() -> str:
return ""
@staticmethod
def cloudformation_type():
def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-kms-key.html
return "AWS::KMS::Key"
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs
):
def create_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Any,
account_id: str,
region_name: str,
**kwargs: Any,
) -> "Key":
kms_backend = kms_backends[account_id][region_name]
properties = cloudformation_json["Properties"]
@ -233,10 +243,10 @@ class Key(CloudFormationModel):
return key
@classmethod
def has_cfn_attr(cls, attr):
def has_cfn_attr(cls, attr: str) -> bool:
return attr in ["Arn"]
def get_cfn_attribute(self, attribute_name):
def get_cfn_attribute(self, attribute_name: str) -> str:
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == "Arn":
@ -245,20 +255,22 @@ class Key(CloudFormationModel):
class KmsBackend(BaseBackend):
def __init__(self, region_name, account_id=None):
super().__init__(region_name=region_name, account_id=account_id)
self.keys = {}
self.key_to_aliases = defaultdict(set)
def __init__(self, region_name: str, account_id: Optional[str] = None):
super().__init__(region_name=region_name, account_id=account_id) # type: ignore
self.keys: Dict[str, Key] = {}
self.key_to_aliases: Dict[str, Set[str]] = defaultdict(set)
self.tagger = TaggingService(key_name="TagKey", value_name="TagValue")
@staticmethod
def default_vpc_endpoint_service(service_region, zones):
def default_vpc_endpoint_service(
service_region: str, zones: List[str]
) -> List[Dict[str, str]]:
"""Default VPC endpoint service."""
return BaseBackend.default_vpc_endpoint_service_factory(
service_region, zones, "kms"
)
def _generate_default_keys(self, alias_name):
def _generate_default_keys(self, alias_name: str) -> Optional[str]:
"""Creates default kms keys"""
if alias_name in RESERVED_ALIASES:
key = self.create_key(
@ -270,10 +282,17 @@ class KmsBackend(BaseBackend):
)
self.add_alias(key.id, alias_name)
return key.id
return None
def create_key(
self, policy, key_usage, key_spec, description, tags, multi_region=False
):
self,
policy: Optional[str],
key_usage: str,
key_spec: str,
description: str,
tags: Optional[List[Dict[str, str]]],
multi_region: bool = False,
) -> Key:
"""
The provided Policy currently does not need to be valid. If it is valid, Moto will perform authorization checks on key-related operations, just like AWS does.
@ -303,7 +322,7 @@ class KmsBackend(BaseBackend):
#
# In our implementation with just create a copy of all the properties once without any protection from change,
# as the exact implementation is currently infeasible.
def replicate_key(self, key_id, replica_region):
def replicate_key(self, key_id: str, replica_region: str) -> None:
# Using copy() instead of deepcopy(), as the latter results in exception:
# TypeError: cannot pickle '_cffi_backend.FFI' object
# Since we only update top level properties, copy() should suffice.
@ -312,31 +331,31 @@ class KmsBackend(BaseBackend):
to_region_backend = kms_backends[self.account_id][replica_region]
to_region_backend.keys[replica_key.id] = replica_key
def update_key_description(self, key_id, description):
def update_key_description(self, key_id: str, description: str) -> None:
key = self.keys[self.get_key_id(key_id)]
key.description = description
def delete_key(self, key_id):
def delete_key(self, key_id: str) -> None:
if key_id in self.keys:
if key_id in self.key_to_aliases:
self.key_to_aliases.pop(key_id)
self.tagger.delete_all_tags_for_resource(key_id)
return self.keys.pop(key_id)
self.keys.pop(key_id)
def describe_key(self, key_id) -> Key:
def describe_key(self, key_id: str) -> Key:
# allow the different methods (alias, ARN :key/, keyId, ARN alias) to
# describe key not just KeyId
key_id = self.get_key_id(key_id)
if r"alias/" in str(key_id).lower():
key_id = self.get_key_id_from_alias(key_id)
key_id = self.get_key_id_from_alias(key_id) # type: ignore[assignment]
return self.keys[self.get_key_id(key_id)]
def list_keys(self):
def list_keys(self) -> Iterable[Key]:
return self.keys.values()
@staticmethod
def get_key_id(key_id):
def get_key_id(key_id: str) -> str:
# Allow use of ARN as well as pure KeyId
if key_id.startswith("arn:") and ":key/" in key_id:
return key_id.split(":key/")[1]
@ -344,14 +363,14 @@ class KmsBackend(BaseBackend):
return key_id
@staticmethod
def get_alias_name(alias_name):
def get_alias_name(alias_name: str) -> str:
# Allow use of ARN as well as alias name
if alias_name.startswith("arn:") and ":alias/" in alias_name:
return "alias/" + alias_name.split(":alias/")[1]
return alias_name
def any_id_to_key_id(self, key_id):
def any_id_to_key_id(self, key_id: str) -> str:
"""Go from any valid key ID to the raw key ID.
Acceptable inputs:
@ -363,66 +382,65 @@ class KmsBackend(BaseBackend):
key_id = self.get_alias_name(key_id)
key_id = self.get_key_id(key_id)
if key_id.startswith("alias/"):
key_id = self.get_key_id(self.get_key_id_from_alias(key_id))
key_id = self.get_key_id(self.get_key_id_from_alias(key_id)) # type: ignore[arg-type]
return key_id
def alias_exists(self, alias_name):
def alias_exists(self, alias_name: str) -> bool:
for aliases in self.key_to_aliases.values():
if alias_name in aliases:
return True
return False
def add_alias(self, target_key_id, alias_name):
def add_alias(self, target_key_id: str, alias_name: str) -> None:
raw_key_id = self.get_key_id(target_key_id)
self.key_to_aliases[raw_key_id].add(alias_name)
def delete_alias(self, alias_name):
def delete_alias(self, alias_name: str) -> None:
"""Delete the alias."""
for aliases in self.key_to_aliases.values():
if alias_name in aliases:
aliases.remove(alias_name)
def get_all_aliases(self):
def get_all_aliases(self) -> Dict[str, Set[str]]:
return self.key_to_aliases
def get_key_id_from_alias(self, alias_name):
def get_key_id_from_alias(self, alias_name: str) -> Optional[str]:
for key_id, aliases in dict(self.key_to_aliases).items():
if alias_name in ",".join(aliases):
return key_id
if alias_name in RESERVED_ALIASES:
key_id = self._generate_default_keys(alias_name)
return key_id
return self._generate_default_keys(alias_name)
return None
def enable_key_rotation(self, key_id):
def enable_key_rotation(self, key_id: str) -> None:
self.keys[self.get_key_id(key_id)].key_rotation_status = True
def disable_key_rotation(self, key_id):
def disable_key_rotation(self, key_id: str) -> None:
self.keys[self.get_key_id(key_id)].key_rotation_status = False
def get_key_rotation_status(self, key_id):
def get_key_rotation_status(self, key_id: str) -> bool:
return self.keys[self.get_key_id(key_id)].key_rotation_status
def put_key_policy(self, key_id, policy):
def put_key_policy(self, key_id: str, policy: str) -> None:
self.keys[self.get_key_id(key_id)].policy = policy
def get_key_policy(self, key_id):
def get_key_policy(self, key_id: str) -> str:
return self.keys[self.get_key_id(key_id)].policy
def disable_key(self, key_id):
def disable_key(self, key_id: str) -> None:
self.keys[key_id].enabled = False
self.keys[key_id].key_state = "Disabled"
def enable_key(self, key_id):
def enable_key(self, key_id: str) -> None:
self.keys[key_id].enabled = True
self.keys[key_id].key_state = "Enabled"
def cancel_key_deletion(self, key_id):
def cancel_key_deletion(self, key_id: str) -> None:
self.keys[key_id].key_state = "Disabled"
self.keys[key_id].deletion_date = None
def schedule_key_deletion(self, key_id, pending_window_in_days):
def schedule_key_deletion(self, key_id: str, pending_window_in_days: int) -> float: # type: ignore[return]
if 7 <= pending_window_in_days <= 30:
self.keys[key_id].enabled = False
self.keys[key_id].key_state = "PendingDeletion"
@ -431,7 +449,9 @@ class KmsBackend(BaseBackend):
)
return unix_time(self.keys[key_id].deletion_date)
def encrypt(self, key_id, plaintext, encryption_context):
def encrypt(
self, key_id: str, plaintext: bytes, encryption_context: Dict[str, str]
) -> Tuple[bytes, str]:
key_id = self.any_id_to_key_id(key_id)
ciphertext_blob = encrypt(
@ -443,7 +463,9 @@ class KmsBackend(BaseBackend):
arn = self.keys[key_id].arn
return ciphertext_blob, arn
def decrypt(self, ciphertext_blob, encryption_context):
def decrypt(
self, ciphertext_blob: bytes, encryption_context: Dict[str, str]
) -> Tuple[bytes, str]:
plaintext, key_id = decrypt(
master_keys=self.keys,
ciphertext_blob=ciphertext_blob,
@ -454,11 +476,11 @@ class KmsBackend(BaseBackend):
def re_encrypt(
self,
ciphertext_blob,
source_encryption_context,
destination_key_id,
destination_encryption_context,
):
ciphertext_blob: bytes,
source_encryption_context: Dict[str, str],
destination_key_id: str,
destination_encryption_context: Dict[str, str],
) -> Tuple[bytes, str, str]:
destination_key_id = self.any_id_to_key_id(destination_key_id)
plaintext, decrypting_arn = self.decrypt(
@ -472,7 +494,13 @@ class KmsBackend(BaseBackend):
)
return new_ciphertext_blob, decrypting_arn, encrypting_arn
def generate_data_key(self, key_id, encryption_context, number_of_bytes, key_spec):
def generate_data_key(
self,
key_id: str,
encryption_context: Dict[str, str],
number_of_bytes: int,
key_spec: str,
) -> Tuple[bytes, bytes, str]:
key_id = self.any_id_to_key_id(key_id)
if key_spec:
@ -492,7 +520,7 @@ class KmsBackend(BaseBackend):
return plaintext, ciphertext_blob, arn
def list_resource_tags(self, key_id_or_arn):
def list_resource_tags(self, key_id_or_arn: str) -> Dict[str, List[Dict[str, str]]]:
key_id = self.get_key_id(key_id_or_arn)
if key_id in self.keys:
return self.tagger.list_tags_for_resource(key_id)
@ -501,21 +529,21 @@ class KmsBackend(BaseBackend):
"The request was rejected because the specified entity or resource could not be found.",
)
def tag_resource(self, key_id_or_arn, tags):
def tag_resource(self, key_id_or_arn: str, tags: List[Dict[str, str]]) -> None:
key_id = self.get_key_id(key_id_or_arn)
if key_id in self.keys:
self.tagger.tag_resource(key_id, tags)
return {}
return
raise JsonRESTError(
"NotFoundException",
"The request was rejected because the specified entity or resource could not be found.",
)
def untag_resource(self, key_id_or_arn, tag_names):
def untag_resource(self, key_id_or_arn: str, tag_names: List[str]) -> None:
key_id = self.get_key_id(key_id_or_arn)
if key_id in self.keys:
self.tagger.untag_resource_using_names(key_id, tag_names)
return {}
return
raise JsonRESTError(
"NotFoundException",
"The request was rejected because the specified entity or resource could not be found.",
@ -523,13 +551,13 @@ class KmsBackend(BaseBackend):
def create_grant(
self,
key_id,
grantee_principal,
operations,
name,
constraints,
retiring_principal,
):
key_id: str,
grantee_principal: str,
operations: List[str],
name: str,
constraints: Dict[str, Any],
retiring_principal: str,
) -> Tuple[str, str]:
key = self.describe_key(key_id)
grant = key.add_grant(
name,
@ -540,21 +568,21 @@ class KmsBackend(BaseBackend):
)
return grant.id, grant.token
def list_grants(self, key_id, grant_id) -> [Grant]:
def list_grants(self, key_id: str, grant_id: str) -> List[Grant]:
key = self.describe_key(key_id)
return key.list_grants(grant_id)
def list_retirable_grants(self, retiring_principal):
def list_retirable_grants(self, retiring_principal: str) -> List[Grant]:
grants = []
for key in self.keys.values():
grants.extend(key.list_retirable_grants(retiring_principal))
return grants
def revoke_grant(self, key_id, grant_id) -> None:
def revoke_grant(self, key_id: str, grant_id: str) -> None:
key = self.describe_key(key_id)
key.revoke_grant(grant_id)
def retire_grant(self, key_id, grant_id, grant_token) -> None:
def retire_grant(self, key_id: str, grant_id: str, grant_token: str) -> None:
if grant_token:
for key in self.keys.values():
key.retire_grant_by_token(grant_token)
@ -562,7 +590,7 @@ class KmsBackend(BaseBackend):
key = self.describe_key(key_id)
key.retire_grant(grant_id)
def __ensure_valid_sign_and_verify_key(self, key: Key):
def __ensure_valid_sign_and_verify_key(self, key: Key) -> None:
if key.key_usage != "SIGN_VERIFY":
raise ValidationException(
(
@ -571,7 +599,9 @@ class KmsBackend(BaseBackend):
).format(key_id=key.id)
)
def __ensure_valid_signing_augorithm(self, key: Key, signing_algorithm):
def __ensure_valid_signing_augorithm(
self, key: Key, signing_algorithm: str
) -> None:
if signing_algorithm not in key.signing_algorithms:
raise ValidationException(
(
@ -584,7 +614,9 @@ class KmsBackend(BaseBackend):
)
)
def sign(self, key_id, message, signing_algorithm):
def sign(
self, key_id: str, message: bytes, signing_algorithm: str
) -> Tuple[str, bytes, str]:
"""Sign message using generated private key.
- signing_algorithm is ignored and hardcoded to RSASSA_PSS_SHA_256
@ -607,7 +639,9 @@ class KmsBackend(BaseBackend):
return key.arn, signature, signing_algorithm
def verify(self, key_id, message, signature, signing_algorithm):
def verify(
self, key_id: str, message: bytes, signature: bytes, signing_algorithm: str
) -> Tuple[str, bool, str]:
"""Verify message using public key from generated private key.
- signing_algorithm is ignored and hardcoded to RSASSA_PSS_SHA_256

View File

@ -1,4 +1,5 @@
from collections import defaultdict
from typing import Any, Dict, List
import json
from .models import Key
from .exceptions import AccessDeniedException
@ -8,7 +9,7 @@ ALTERNATIVE_ACTIONS = defaultdict(list)
ALTERNATIVE_ACTIONS["kms:DescribeKey"] = ["kms:*", "kms:Describe*", "kms:DescribeKey"]
def validate_policy(key: Key, action: str):
def validate_policy(key: Key, action: str) -> None:
"""
Relevant docs:
- https://docs.aws.amazon.com/kms/latest/developerguide/key-policy-default.html
@ -29,20 +30,22 @@ def validate_policy(key: Key, action: str):
)
def check_statement(statement, resource, action):
def check_statement(statement: Dict[str, Any], resource: str, action: str) -> bool:
return action_matches(statement.get("Action", []), action) and resource_matches(
statement.get("Resource", ""), resource
)
def action_matches(applicable_actions, action):
def action_matches(applicable_actions: List[str], action: str) -> bool:
alternatives = ALTERNATIVE_ACTIONS[action]
if any(alt in applicable_actions for alt in alternatives):
return True
return False
def resource_matches(applicable_resources, resource): # pylint: disable=unused-argument
def resource_matches(
applicable_resources: str, resource: str # pylint: disable=unused-argument
) -> bool:
if applicable_resources == "*":
return True
return False

View File

@ -3,6 +3,7 @@ import json
import os
import re
import warnings
from typing import Any, Dict
from moto.core.responses import BaseResponse
from moto.kms.utils import RESERVED_ALIASES, RESERVED_ALIASE_TARGET_KEY_IDS
@ -17,24 +18,23 @@ from .exceptions import (
class KmsResponse(BaseResponse):
def __init__(self):
def __init__(self) -> None:
super().__init__(service_name="kms")
@property
def parameters(self):
def _get_param(self, param_name: str, if_none: Any = None) -> Any: # type: ignore
params = json.loads(self.body)
for key in ("Plaintext", "CiphertextBlob"):
if key in params:
params[key] = base64.b64decode(params[key].encode("utf-8"))
return params
return params.get(param_name, if_none)
@property
def kms_backend(self) -> KmsBackend:
return kms_backends[self.current_account][self.region]
def _display_arn(self, key_id):
def _display_arn(self, key_id: str) -> str:
if key_id.startswith("arn:"):
return key_id
@ -45,7 +45,7 @@ class KmsResponse(BaseResponse):
return f"arn:aws:kms:{self.region}:{self.current_account}:{id_type}{key_id}"
def _validate_cmk_id(self, key_id):
def _validate_cmk_id(self, key_id: str) -> None:
"""Determine whether a CMK ID exists.
- raw key ID
@ -69,7 +69,7 @@ class KmsResponse(BaseResponse):
if cmk_id not in self.kms_backend.keys:
raise NotFoundException(f"Key '{self._display_arn(key_id)}' does not exist")
def _validate_alias(self, key_id):
def _validate_alias(self, key_id: str) -> None:
"""Determine whether an alias exists.
- alias name
@ -88,8 +88,8 @@ class KmsResponse(BaseResponse):
if cmk_id is None:
raise error
def _validate_key_id(self, key_id):
"""Determine whether or not a key ID exists.
def _validate_key_id(self, key_id: str) -> None:
"""Determine whether a key ID exists.
- raw key ID
- key ARN
@ -105,77 +105,77 @@ class KmsResponse(BaseResponse):
self._validate_cmk_id(key_id)
def _validate_key_policy(self, key_id, action):
def _validate_key_policy(self, key_id: str, action: str) -> None:
"""
Validate whether the specified action is allowed, given the key policy
"""
key = self.kms_backend.describe_key(self.kms_backend.get_key_id(key_id))
validate_policy(key, action)
def create_key(self):
def create_key(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_CreateKey.html"""
policy = self.parameters.get("Policy")
key_usage = self.parameters.get("KeyUsage")
key_spec = self.parameters.get("KeySpec") or self.parameters.get(
policy = self._get_param("Policy")
key_usage = self._get_param("KeyUsage")
key_spec = self._get_param("KeySpec") or self._get_param(
"CustomerMasterKeySpec"
)
description = self.parameters.get("Description")
tags = self.parameters.get("Tags")
multi_region = self.parameters.get("MultiRegion")
description = self._get_param("Description")
tags = self._get_param("Tags")
multi_region = self._get_param("MultiRegion")
key = self.kms_backend.create_key(
policy, key_usage, key_spec, description, tags, multi_region
)
return json.dumps(key.to_dict())
def replicate_key(self):
key_id = self.parameters.get("KeyId")
def replicate_key(self) -> None:
key_id = self._get_param("KeyId")
self._validate_key_id(key_id)
replica_region = self.parameters.get("ReplicaRegion")
replica_region = self._get_param("ReplicaRegion")
self.kms_backend.replicate_key(key_id, replica_region)
def update_key_description(self):
def update_key_description(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_UpdateKeyDescription.html"""
key_id = self.parameters.get("KeyId")
description = self.parameters.get("Description")
key_id = self._get_param("KeyId")
description = self._get_param("Description")
self._validate_cmk_id(key_id)
self.kms_backend.update_key_description(key_id, description)
return json.dumps(None)
def tag_resource(self):
def tag_resource(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_TagResource.html"""
key_id = self.parameters.get("KeyId")
tags = self.parameters.get("Tags")
key_id = self._get_param("KeyId")
tags = self._get_param("Tags")
self._validate_cmk_id(key_id)
result = self.kms_backend.tag_resource(key_id, tags)
return json.dumps(result)
self.kms_backend.tag_resource(key_id, tags)
return "{}"
def untag_resource(self):
def untag_resource(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_UntagResource.html"""
key_id = self.parameters.get("KeyId")
tag_names = self.parameters.get("TagKeys")
key_id = self._get_param("KeyId")
tag_names = self._get_param("TagKeys")
self._validate_cmk_id(key_id)
result = self.kms_backend.untag_resource(key_id, tag_names)
return json.dumps(result)
self.kms_backend.untag_resource(key_id, tag_names)
return "{}"
def list_resource_tags(self):
def list_resource_tags(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ListResourceTags.html"""
key_id = self.parameters.get("KeyId")
key_id = self._get_param("KeyId")
self._validate_cmk_id(key_id)
tags = self.kms_backend.list_resource_tags(key_id)
tags: Dict[str, Any] = self.kms_backend.list_resource_tags(key_id)
tags.update({"NextMarker": None, "Truncated": False})
return json.dumps(tags)
def describe_key(self):
def describe_key(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_DescribeKey.html"""
key_id = self.parameters.get("KeyId")
key_id = self._get_param("KeyId")
self._validate_key_id(key_id)
self._validate_key_policy(key_id, "kms:DescribeKey")
@ -184,7 +184,7 @@ class KmsResponse(BaseResponse):
return json.dumps(key.to_dict())
def list_keys(self):
def list_keys(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ListKeys.html"""
keys = self.kms_backend.list_keys()
@ -196,17 +196,17 @@ class KmsResponse(BaseResponse):
}
)
def create_alias(self):
def create_alias(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_CreateAlias.html"""
return self._set_alias()
def update_alias(self):
def update_alias(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_UpdateAlias.html"""
return self._set_alias(update=True)
def _set_alias(self, update=False):
alias_name = self.parameters["AliasName"]
target_key_id = self.parameters["TargetKeyId"]
def _set_alias(self, update: bool = False) -> str:
alias_name = self._get_param("AliasName")
target_key_id = self._get_param("TargetKeyId")
if not alias_name.startswith("alias/"):
raise ValidationException("Invalid identifier")
@ -243,9 +243,9 @@ class KmsResponse(BaseResponse):
return json.dumps(None)
def delete_alias(self):
def delete_alias(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_DeleteAlias.html"""
alias_name = self.parameters["AliasName"]
alias_name = self._get_param("AliasName")
if not alias_name.startswith("alias/"):
raise ValidationException("Invalid identifier")
@ -256,10 +256,10 @@ class KmsResponse(BaseResponse):
return json.dumps(None)
def list_aliases(self):
def list_aliases(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ListAliases.html"""
region = self.region
key_id = self.parameters.get("KeyId")
key_id = self._get_param("KeyId")
if key_id is not None:
self._validate_key_id(key_id)
key_id = self.kms_backend.get_key_id(key_id)
@ -298,13 +298,13 @@ class KmsResponse(BaseResponse):
return json.dumps({"Truncated": False, "Aliases": response_aliases})
def create_grant(self):
key_id = self.parameters.get("KeyId")
grantee_principal = self.parameters.get("GranteePrincipal")
retiring_principal = self.parameters.get("RetiringPrincipal")
operations = self.parameters.get("Operations")
name = self.parameters.get("Name")
constraints = self.parameters.get("Constraints")
def create_grant(self) -> str:
key_id = self._get_param("KeyId")
grantee_principal = self._get_param("GranteePrincipal")
retiring_principal = self._get_param("RetiringPrincipal")
operations = self._get_param("Operations")
name = self._get_param("Name")
constraints = self._get_param("Constraints")
grant_id, grant_token = self.kms_backend.create_grant(
key_id,
@ -316,9 +316,9 @@ class KmsResponse(BaseResponse):
)
return json.dumps({"GrantId": grant_id, "GrantToken": grant_token})
def list_grants(self):
key_id = self.parameters.get("KeyId")
grant_id = self.parameters.get("GrantId")
def list_grants(self) -> str:
key_id = self._get_param("KeyId")
grant_id = self._get_param("GrantId")
grants = self.kms_backend.list_grants(key_id=key_id, grant_id=grant_id)
return json.dumps(
@ -329,8 +329,8 @@ class KmsResponse(BaseResponse):
}
)
def list_retirable_grants(self):
retiring_principal = self.parameters.get("RetiringPrincipal")
def list_retirable_grants(self) -> str:
retiring_principal = self._get_param("RetiringPrincipal")
grants = self.kms_backend.list_retirable_grants(retiring_principal)
return json.dumps(
@ -341,24 +341,24 @@ class KmsResponse(BaseResponse):
}
)
def revoke_grant(self):
key_id = self.parameters.get("KeyId")
grant_id = self.parameters.get("GrantId")
def revoke_grant(self) -> str:
key_id = self._get_param("KeyId")
grant_id = self._get_param("GrantId")
self.kms_backend.revoke_grant(key_id, grant_id)
return "{}"
def retire_grant(self):
key_id = self.parameters.get("KeyId")
grant_id = self.parameters.get("GrantId")
grant_token = self.parameters.get("GrantToken")
def retire_grant(self) -> str:
key_id = self._get_param("KeyId")
grant_id = self._get_param("GrantId")
grant_token = self._get_param("GrantToken")
self.kms_backend.retire_grant(key_id, grant_id, grant_token)
return "{}"
def enable_key_rotation(self):
def enable_key_rotation(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKeyRotation.html"""
key_id = self.parameters.get("KeyId")
key_id = self._get_param("KeyId")
self._validate_cmk_id(key_id)
@ -366,9 +366,9 @@ class KmsResponse(BaseResponse):
return json.dumps(None)
def disable_key_rotation(self):
def disable_key_rotation(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKeyRotation.html"""
key_id = self.parameters.get("KeyId")
key_id = self._get_param("KeyId")
self._validate_cmk_id(key_id)
@ -376,9 +376,9 @@ class KmsResponse(BaseResponse):
return json.dumps(None)
def get_key_rotation_status(self):
def get_key_rotation_status(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_GetKeyRotationStatus.html"""
key_id = self.parameters.get("KeyId")
key_id = self._get_param("KeyId")
self._validate_cmk_id(key_id)
@ -386,11 +386,11 @@ class KmsResponse(BaseResponse):
return json.dumps({"KeyRotationEnabled": rotation_enabled})
def put_key_policy(self):
def put_key_policy(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_PutKeyPolicy.html"""
key_id = self.parameters.get("KeyId")
policy_name = self.parameters.get("PolicyName")
policy = self.parameters.get("Policy")
key_id = self._get_param("KeyId")
policy_name = self._get_param("PolicyName")
policy = self._get_param("Policy")
_assert_default_policy(policy_name)
self._validate_cmk_id(key_id)
@ -399,10 +399,10 @@ class KmsResponse(BaseResponse):
return json.dumps(None)
def get_key_policy(self):
def get_key_policy(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_GetKeyPolicy.html"""
key_id = self.parameters.get("KeyId")
policy_name = self.parameters.get("PolicyName")
key_id = self._get_param("KeyId")
policy_name = self._get_param("PolicyName")
_assert_default_policy(policy_name)
self._validate_cmk_id(key_id)
@ -410,9 +410,9 @@ class KmsResponse(BaseResponse):
policy = self.kms_backend.get_key_policy(key_id) or "{}"
return json.dumps({"Policy": policy})
def list_key_policies(self):
def list_key_policies(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ListKeyPolicies.html"""
key_id = self.parameters.get("KeyId")
key_id = self._get_param("KeyId")
self._validate_cmk_id(key_id)
@ -420,11 +420,11 @@ class KmsResponse(BaseResponse):
return json.dumps({"Truncated": False, "PolicyNames": ["default"]})
def encrypt(self):
def encrypt(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_Encrypt.html"""
key_id = self.parameters.get("KeyId")
encryption_context = self.parameters.get("EncryptionContext", {})
plaintext = self.parameters.get("Plaintext")
key_id = self._get_param("KeyId")
encryption_context = self._get_param("EncryptionContext", {})
plaintext = self._get_param("Plaintext")
self._validate_key_id(key_id)
@ -438,10 +438,10 @@ class KmsResponse(BaseResponse):
return json.dumps({"CiphertextBlob": ciphertext_blob_response, "KeyId": arn})
def decrypt(self):
def decrypt(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_Decrypt.html"""
ciphertext_blob = self.parameters.get("CiphertextBlob")
encryption_context = self.parameters.get("EncryptionContext", {})
ciphertext_blob = self._get_param("CiphertextBlob")
encryption_context = self._get_param("EncryptionContext", {})
plaintext, arn = self.kms_backend.decrypt(
ciphertext_blob=ciphertext_blob, encryption_context=encryption_context
@ -451,12 +451,12 @@ class KmsResponse(BaseResponse):
return json.dumps({"Plaintext": plaintext_response, "KeyId": arn})
def re_encrypt(self):
def re_encrypt(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ReEncrypt.html"""
ciphertext_blob = self.parameters.get("CiphertextBlob")
source_encryption_context = self.parameters.get("SourceEncryptionContext", {})
destination_key_id = self.parameters.get("DestinationKeyId")
destination_encryption_context = self.parameters.get(
ciphertext_blob = self._get_param("CiphertextBlob")
source_encryption_context = self._get_param("SourceEncryptionContext", {})
destination_key_id = self._get_param("DestinationKeyId")
destination_encryption_context = self._get_param(
"DestinationEncryptionContext", {}
)
@ -483,9 +483,9 @@ class KmsResponse(BaseResponse):
}
)
def disable_key(self):
def disable_key(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_DisableKey.html"""
key_id = self.parameters.get("KeyId")
key_id = self._get_param("KeyId")
self._validate_cmk_id(key_id)
@ -493,9 +493,9 @@ class KmsResponse(BaseResponse):
return json.dumps(None)
def enable_key(self):
def enable_key(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKey.html"""
key_id = self.parameters.get("KeyId")
key_id = self._get_param("KeyId")
self._validate_cmk_id(key_id)
@ -503,9 +503,9 @@ class KmsResponse(BaseResponse):
return json.dumps(None)
def cancel_key_deletion(self):
def cancel_key_deletion(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_CancelKeyDeletion.html"""
key_id = self.parameters.get("KeyId")
key_id = self._get_param("KeyId")
self._validate_cmk_id(key_id)
@ -513,13 +513,13 @@ class KmsResponse(BaseResponse):
return json.dumps({"KeyId": key_id})
def schedule_key_deletion(self):
def schedule_key_deletion(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ScheduleKeyDeletion.html"""
key_id = self.parameters.get("KeyId")
if self.parameters.get("PendingWindowInDays") is None:
key_id = self._get_param("KeyId")
if self._get_param("PendingWindowInDays") is None:
pending_window_in_days = 30
else:
pending_window_in_days = self.parameters.get("PendingWindowInDays")
pending_window_in_days = self._get_param("PendingWindowInDays")
self._validate_cmk_id(key_id)
@ -532,12 +532,12 @@ class KmsResponse(BaseResponse):
}
)
def generate_data_key(self):
def generate_data_key(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateDataKey.html"""
key_id = self.parameters.get("KeyId")
encryption_context = self.parameters.get("EncryptionContext", {})
number_of_bytes = self.parameters.get("NumberOfBytes")
key_spec = self.parameters.get("KeySpec")
key_id = self._get_param("KeyId")
encryption_context = self._get_param("EncryptionContext", {})
number_of_bytes = self._get_param("NumberOfBytes")
key_spec = self._get_param("KeySpec")
# Param validation
self._validate_key_id(key_id)
@ -587,16 +587,16 @@ class KmsResponse(BaseResponse):
}
)
def generate_data_key_without_plaintext(self):
def generate_data_key_without_plaintext(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateDataKeyWithoutPlaintext.html"""
result = json.loads(self.generate_data_key())
del result["Plaintext"]
return json.dumps(result)
def generate_random(self):
def generate_random(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateRandom.html"""
number_of_bytes = self.parameters.get("NumberOfBytes")
number_of_bytes = self._get_param("NumberOfBytes")
if number_of_bytes and (number_of_bytes > 1024 or number_of_bytes < 1):
raise ValidationException(
@ -613,13 +613,13 @@ class KmsResponse(BaseResponse):
return json.dumps({"Plaintext": response_entropy})
def sign(self):
def sign(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_Sign.html"""
key_id = self.parameters.get("KeyId")
message = self.parameters.get("Message")
message_type = self.parameters.get("MessageType")
grant_tokens = self.parameters.get("GrantTokens")
signing_algorithm = self.parameters.get("SigningAlgorithm")
key_id = self._get_param("KeyId")
message = self._get_param("Message")
message_type = self._get_param("MessageType")
grant_tokens = self._get_param("GrantTokens")
signing_algorithm = self._get_param("SigningAlgorithm")
self._validate_key_id(key_id)
@ -660,14 +660,14 @@ class KmsResponse(BaseResponse):
}
)
def verify(self):
def verify(self) -> str:
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_Verify.html"""
key_id = self.parameters.get("KeyId")
message = self.parameters.get("Message")
message_type = self.parameters.get("MessageType")
signature = self.parameters.get("Signature")
signing_algorithm = self.parameters.get("SigningAlgorithm")
grant_tokens = self.parameters.get("GrantTokens")
key_id = self._get_param("KeyId")
message = self._get_param("Message")
message_type = self._get_param("MessageType")
signature = self._get_param("Signature")
signing_algorithm = self._get_param("SigningAlgorithm")
grant_tokens = self._get_param("GrantTokens")
self._validate_key_id(key_id)
@ -722,6 +722,6 @@ class KmsResponse(BaseResponse):
)
def _assert_default_policy(policy_name):
def _assert_default_policy(policy_name: str) -> None:
if policy_name != "default":
raise NotFoundException("No such policy exists")

View File

@ -1,4 +1,5 @@
from collections import namedtuple
from typing import Any, Dict, Tuple
import io
import os
import struct
@ -47,7 +48,7 @@ RESERVED_ALIASE_TARGET_KEY_IDS = {
RESERVED_ALIASES = list(RESERVED_ALIASE_TARGET_KEY_IDS.keys())
def generate_key_id(multi_region=False):
def generate_key_id(multi_region: bool = False) -> str:
key = str(mock_random.uuid4())
# https://docs.aws.amazon.com/kms/latest/developerguide/multi-region-keys-overview.html
# "Notice that multi-Region keys have a distinctive key ID that begins with mrk-. You can use the mrk- prefix to
@ -58,17 +59,17 @@ def generate_key_id(multi_region=False):
return key
def generate_data_key(number_of_bytes):
def generate_data_key(number_of_bytes: int) -> bytes:
"""Generate a data key."""
return os.urandom(number_of_bytes)
def generate_master_key():
def generate_master_key() -> bytes:
"""Generate a master key."""
return generate_data_key(MASTER_KEY_LEN)
def generate_private_key():
def generate_private_key() -> rsa.RSAPrivateKey:
"""Generate a private key to be used on asymmetric sign/verify.
NOTE: KeySpec is not taken into consideration and the key is always RSA_2048
@ -80,7 +81,7 @@ def generate_private_key():
)
def _serialize_ciphertext_blob(ciphertext):
def _serialize_ciphertext_blob(ciphertext: Ciphertext) -> bytes:
"""Serialize Ciphertext object into a ciphertext blob.
NOTE: This is just a simple binary format. It is not what KMS actually does.
@ -94,7 +95,7 @@ def _serialize_ciphertext_blob(ciphertext):
return header + ciphertext.ciphertext
def _deserialize_ciphertext_blob(ciphertext_blob):
def _deserialize_ciphertext_blob(ciphertext_blob: bytes) -> Ciphertext:
"""Deserialize ciphertext blob into a Ciphertext object.
NOTE: This is just a simple binary format. It is not what KMS actually does.
@ -107,7 +108,7 @@ def _deserialize_ciphertext_blob(ciphertext_blob):
)
def _serialize_encryption_context(encryption_context):
def _serialize_encryption_context(encryption_context: Dict[str, str]) -> bytes:
"""Serialize encryption context for use a AAD.
NOTE: This is not necessarily what KMS does, but it retains the same properties.
@ -119,7 +120,12 @@ def _serialize_encryption_context(encryption_context):
return aad.getvalue()
def encrypt(master_keys, key_id, plaintext, encryption_context):
def encrypt(
master_keys: Dict[str, Any],
key_id: str,
plaintext: bytes,
encryption_context: Dict[str, str],
) -> bytes:
"""Encrypt data using a master key material.
NOTE: This is not necessarily what KMS does, but it retains the same properties.
@ -159,7 +165,11 @@ def encrypt(master_keys, key_id, plaintext, encryption_context):
)
def decrypt(master_keys, ciphertext_blob, encryption_context):
def decrypt(
master_keys: Dict[str, Any],
ciphertext_blob: bytes,
encryption_context: Dict[str, str],
) -> Tuple[bytes, str]:
"""Decrypt a ciphertext blob using a master key material.
NOTE: This is not necessarily what KMS does, but it retains the same properties.

View File

@ -9,7 +9,7 @@ SESSION_TOKEN_PREFIX = "FQoGZXIvYXdzEBYaD"
DEFAULT_STS_SESSION_DURATION = 3600
def random_session_token():
def random_session_token() -> str:
return (
SESSION_TOKEN_PREFIX
+ base64.b64encode(os.urandom(266))[len(SESSION_TOKEN_PREFIX) :].decode()

View File

@ -235,7 +235,7 @@ disable = W,C,R,E
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
[mypy]
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/moto_api,moto/neptune
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/moto_api,moto/neptune
show_column_numbers=True
show_error_codes = True
disable_error_code=abstract