Techdebt: MyPy K (#6111)
This commit is contained in:
parent
64abff588f
commit
c3460b8a1a
@ -1,46 +1,47 @@
|
||||
from moto.core.exceptions import JsonRESTError
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ResourceNotFoundError(JsonRESTError):
|
||||
def __init__(self, message):
|
||||
def __init__(self, message: str):
|
||||
super().__init__(error_type="ResourceNotFoundException", message=message)
|
||||
|
||||
|
||||
class ResourceInUseError(JsonRESTError):
|
||||
def __init__(self, message):
|
||||
def __init__(self, message: str):
|
||||
super().__init__(error_type="ResourceInUseException", message=message)
|
||||
|
||||
|
||||
class StreamNotFoundError(ResourceNotFoundError):
|
||||
def __init__(self, stream_name, account_id):
|
||||
def __init__(self, stream_name: str, account_id: str):
|
||||
super().__init__(f"Stream {stream_name} under account {account_id} not found.")
|
||||
|
||||
|
||||
class StreamCannotBeUpdatedError(JsonRESTError):
|
||||
def __init__(self, stream_name, account_id):
|
||||
def __init__(self, stream_name: str, account_id: str):
|
||||
message = f"Request is invalid. Stream {stream_name} under account {account_id} is in On-Demand mode."
|
||||
super().__init__(error_type="ValidationException", message=message)
|
||||
|
||||
|
||||
class ShardNotFoundError(ResourceNotFoundError):
|
||||
def __init__(self, shard_id, stream, account_id):
|
||||
def __init__(self, shard_id: str, stream: str, account_id: str):
|
||||
super().__init__(
|
||||
f"Could not find shard {shard_id} in stream {stream} under account {account_id}."
|
||||
)
|
||||
|
||||
|
||||
class ConsumerNotFound(ResourceNotFoundError):
|
||||
def __init__(self, consumer, account_id):
|
||||
def __init__(self, consumer: str, account_id: str):
|
||||
super().__init__(f"Consumer {consumer}, account {account_id} not found.")
|
||||
|
||||
|
||||
class InvalidArgumentError(JsonRESTError):
|
||||
def __init__(self, message):
|
||||
def __init__(self, message: str):
|
||||
super().__init__(error_type="InvalidArgumentException", message=message)
|
||||
|
||||
|
||||
class InvalidRetentionPeriod(InvalidArgumentError):
|
||||
def __init__(self, hours, too_short):
|
||||
def __init__(self, hours: int, too_short: bool):
|
||||
if too_short:
|
||||
msg = f"Minimum allowed retention period is 24 hours. Requested retention period ({hours} hours) is too short."
|
||||
else:
|
||||
@ -49,31 +50,31 @@ class InvalidRetentionPeriod(InvalidArgumentError):
|
||||
|
||||
|
||||
class InvalidDecreaseRetention(InvalidArgumentError):
|
||||
def __init__(self, name, requested, existing):
|
||||
def __init__(self, name: Optional[str], requested: int, existing: int):
|
||||
msg = f"Requested retention period ({requested} hours) for stream {name} can not be longer than existing retention period ({existing} hours). Use IncreaseRetentionPeriod API."
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class InvalidIncreaseRetention(InvalidArgumentError):
|
||||
def __init__(self, name, requested, existing):
|
||||
def __init__(self, name: Optional[str], requested: int, existing: int):
|
||||
msg = f"Requested retention period ({requested} hours) for stream {name} can not be shorter than existing retention period ({existing} hours). Use DecreaseRetentionPeriod API."
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class ValidationException(JsonRESTError):
|
||||
def __init__(self, value, position, regex_to_match):
|
||||
def __init__(self, value: str, position: str, regex_to_match: str):
|
||||
msg = f"1 validation error detected: Value '{value}' at '{position}' failed to satisfy constraint: Member must satisfy regular expression pattern: {regex_to_match}"
|
||||
super().__init__(error_type="ValidationException", message=msg)
|
||||
|
||||
|
||||
class RecordSizeExceedsLimit(JsonRESTError):
|
||||
def __init__(self, position):
|
||||
def __init__(self, position: int):
|
||||
msg = f"1 validation error detected: Value at 'records.{position}.member.data' failed to satisfy constraint: Member must have length less than or equal to 1048576"
|
||||
super().__init__(error_type="ValidationException", message=msg)
|
||||
|
||||
|
||||
class TotalRecordsSizeExceedsLimit(JsonRESTError):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
error_type="InvalidArgumentException",
|
||||
message="Records size exceeds 5 MB limit",
|
||||
@ -81,6 +82,6 @@ class TotalRecordsSizeExceedsLimit(JsonRESTError):
|
||||
|
||||
|
||||
class TooManyRecords(JsonRESTError):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
msg = "1 validation error detected: Value at 'records' failed to satisfy constraint: Member must have length less than or equal to 500"
|
||||
super().__init__(error_type="ValidationException", message=msg)
|
||||
|
@ -4,7 +4,7 @@ import re
|
||||
import itertools
|
||||
|
||||
from operator import attrgetter
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, Iterable
|
||||
|
||||
from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel
|
||||
from moto.core.utils import unix_time
|
||||
@ -35,14 +35,16 @@ from .utils import (
|
||||
|
||||
|
||||
class Consumer(BaseModel):
|
||||
def __init__(self, consumer_name, account_id, region_name, stream_arn):
|
||||
def __init__(
|
||||
self, consumer_name: str, account_id: str, region_name: str, stream_arn: str
|
||||
):
|
||||
self.consumer_name = consumer_name
|
||||
self.created = unix_time()
|
||||
self.stream_arn = stream_arn
|
||||
stream_name = stream_arn.split("/")[-1]
|
||||
self.consumer_arn = f"arn:aws:kinesis:{region_name}:{account_id}:stream/{stream_name}/consumer/{consumer_name}"
|
||||
|
||||
def to_json(self, include_stream_arn=False):
|
||||
def to_json(self, include_stream_arn: bool = False) -> Dict[str, Any]:
|
||||
resp = {
|
||||
"ConsumerName": self.consumer_name,
|
||||
"ConsumerARN": self.consumer_arn,
|
||||
@ -55,7 +57,13 @@ class Consumer(BaseModel):
|
||||
|
||||
|
||||
class Record(BaseModel):
|
||||
def __init__(self, partition_key, data, sequence_number, explicit_hash_key):
|
||||
def __init__(
|
||||
self,
|
||||
partition_key: str,
|
||||
data: str,
|
||||
sequence_number: int,
|
||||
explicit_hash_key: str,
|
||||
):
|
||||
self.partition_key = partition_key
|
||||
self.data = data
|
||||
self.sequence_number = sequence_number
|
||||
@ -63,7 +71,7 @@ class Record(BaseModel):
|
||||
self.created_at_datetime = datetime.datetime.utcnow()
|
||||
self.created_at = unix_time(self.created_at_datetime)
|
||||
|
||||
def to_json(self):
|
||||
def to_json(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"Data": self.data,
|
||||
"PartitionKey": self.partition_key,
|
||||
@ -74,29 +82,36 @@ class Record(BaseModel):
|
||||
|
||||
class Shard(BaseModel):
|
||||
def __init__(
|
||||
self, shard_id, starting_hash, ending_hash, parent=None, adjacent_parent=None
|
||||
self,
|
||||
shard_id: int,
|
||||
starting_hash: int,
|
||||
ending_hash: int,
|
||||
parent: Optional[str] = None,
|
||||
adjacent_parent: Optional[str] = None,
|
||||
):
|
||||
self._shard_id = shard_id
|
||||
self.starting_hash = starting_hash
|
||||
self.ending_hash = ending_hash
|
||||
self.records = OrderedDict()
|
||||
self.records: Dict[int, Record] = OrderedDict()
|
||||
self.is_open = True
|
||||
self.parent = parent
|
||||
self.adjacent_parent = adjacent_parent
|
||||
|
||||
@property
|
||||
def shard_id(self):
|
||||
def shard_id(self) -> str:
|
||||
return f"shardId-{str(self._shard_id).zfill(12)}"
|
||||
|
||||
def get_records(self, last_sequence_id, limit):
|
||||
last_sequence_id = int(last_sequence_id)
|
||||
def get_records(
|
||||
self, last_sequence_id: str, limit: Optional[int]
|
||||
) -> Tuple[List[Record], int, int]:
|
||||
last_sequence_int = int(last_sequence_id)
|
||||
results = []
|
||||
secs_behind_latest = 0
|
||||
secs_behind_latest = 0.0
|
||||
|
||||
for sequence_number, record in self.records.items():
|
||||
if sequence_number > last_sequence_id:
|
||||
if sequence_number > last_sequence_int:
|
||||
results.append(record)
|
||||
last_sequence_id = sequence_number
|
||||
last_sequence_int = sequence_number
|
||||
|
||||
very_last_record = self.records[next(reversed(self.records))]
|
||||
secs_behind_latest = very_last_record.created_at - record.created_at
|
||||
@ -105,9 +120,9 @@ class Shard(BaseModel):
|
||||
break
|
||||
|
||||
millis_behind_latest = int(secs_behind_latest * 1000)
|
||||
return results, last_sequence_id, millis_behind_latest
|
||||
return results, last_sequence_int, millis_behind_latest
|
||||
|
||||
def put_record(self, partition_key, data, explicit_hash_key):
|
||||
def put_record(self, partition_key: str, data: str, explicit_hash_key: str) -> str:
|
||||
# Note: this function is not safe for concurrency
|
||||
if self.records:
|
||||
last_sequence_number = self.get_max_sequence_number()
|
||||
@ -119,17 +134,17 @@ class Shard(BaseModel):
|
||||
)
|
||||
return str(sequence_number)
|
||||
|
||||
def get_min_sequence_number(self):
|
||||
def get_min_sequence_number(self) -> int:
|
||||
if self.records:
|
||||
return list(self.records.keys())[0]
|
||||
return 0
|
||||
|
||||
def get_max_sequence_number(self):
|
||||
def get_max_sequence_number(self) -> int:
|
||||
if self.records:
|
||||
return list(self.records.keys())[-1]
|
||||
return 0
|
||||
|
||||
def get_sequence_number_at(self, at_timestamp):
|
||||
def get_sequence_number_at(self, at_timestamp: float) -> int:
|
||||
if not self.records or at_timestamp < list(self.records.values())[0].created_at:
|
||||
return 0
|
||||
else:
|
||||
@ -143,10 +158,10 @@ class Shard(BaseModel):
|
||||
),
|
||||
None,
|
||||
)
|
||||
return r.sequence_number
|
||||
return r.sequence_number # type: ignore
|
||||
|
||||
def to_json(self):
|
||||
response = {
|
||||
def to_json(self) -> Dict[str, Any]:
|
||||
response: Dict[str, Any] = {
|
||||
"HashKeyRange": {
|
||||
"EndingHashKey": str(self.ending_hash),
|
||||
"StartingHashKey": str(self.starting_hash),
|
||||
@ -170,12 +185,12 @@ class Shard(BaseModel):
|
||||
class Stream(CloudFormationModel):
|
||||
def __init__(
|
||||
self,
|
||||
stream_name,
|
||||
shard_count,
|
||||
stream_mode,
|
||||
retention_period_hours,
|
||||
account_id,
|
||||
region_name,
|
||||
stream_name: str,
|
||||
shard_count: int,
|
||||
stream_mode: Optional[Dict[str, str]],
|
||||
retention_period_hours: Optional[int],
|
||||
account_id: str,
|
||||
region_name: str,
|
||||
):
|
||||
self.stream_name = stream_name
|
||||
self.creation_datetime = datetime.datetime.now().strftime(
|
||||
@ -184,27 +199,27 @@ class Stream(CloudFormationModel):
|
||||
self.region = region_name
|
||||
self.account_id = account_id
|
||||
self.arn = f"arn:aws:kinesis:{region_name}:{account_id}:stream/{stream_name}"
|
||||
self.shards = {}
|
||||
self.tags = {}
|
||||
self.shards: Dict[str, Shard] = {}
|
||||
self.tags: Dict[str, str] = {}
|
||||
self.status = "ACTIVE"
|
||||
self.shard_count = None
|
||||
self.shard_count: Optional[int] = None
|
||||
self.stream_mode = stream_mode or {"StreamMode": "PROVISIONED"}
|
||||
if self.stream_mode.get("StreamMode", "") == "ON_DEMAND":
|
||||
shard_count = 4
|
||||
self.init_shards(shard_count)
|
||||
self.retention_period_hours = retention_period_hours or 24
|
||||
self.shard_level_metrics = []
|
||||
self.shard_level_metrics: List[str] = []
|
||||
self.encryption_type = "NONE"
|
||||
self.key_id = None
|
||||
self.consumers = []
|
||||
self.key_id: Optional[str] = None
|
||||
self.consumers: List[Consumer] = []
|
||||
|
||||
def delete_consumer(self, consumer_arn):
|
||||
def delete_consumer(self, consumer_arn: str) -> None:
|
||||
self.consumers = [c for c in self.consumers if c.consumer_arn != consumer_arn]
|
||||
|
||||
def get_consumer_by_arn(self, consumer_arn):
|
||||
def get_consumer_by_arn(self, consumer_arn: str) -> Optional[Consumer]:
|
||||
return next((c for c in self.consumers if c.consumer_arn == consumer_arn), None)
|
||||
|
||||
def init_shards(self, shard_count):
|
||||
def init_shards(self, shard_count: int) -> None:
|
||||
self.shard_count = shard_count
|
||||
|
||||
step = 2**128 // shard_count
|
||||
@ -216,16 +231,16 @@ class Stream(CloudFormationModel):
|
||||
shard = Shard(index, start, end)
|
||||
self.shards[shard.shard_id] = shard
|
||||
|
||||
def split_shard(self, shard_to_split, new_starting_hash_key):
|
||||
new_starting_hash_key = int(new_starting_hash_key)
|
||||
def split_shard(self, shard_to_split: str, new_starting_hash_key: str) -> None:
|
||||
new_starting_hash_int = int(new_starting_hash_key)
|
||||
|
||||
shard = self.shards[shard_to_split]
|
||||
|
||||
if shard.starting_hash < new_starting_hash_key < shard.ending_hash:
|
||||
if shard.starting_hash < new_starting_hash_int < shard.ending_hash:
|
||||
pass
|
||||
else:
|
||||
raise InvalidArgumentError(
|
||||
message=f"NewStartingHashKey {new_starting_hash_key} used in SplitShard() on shard {shard_to_split} in stream {self.stream_name} under account {self.account_id} is not both greater than one plus the shard's StartingHashKey {shard.starting_hash} and less than the shard's EndingHashKey {(shard.ending_hash - 1)}."
|
||||
message=f"NewStartingHashKey {new_starting_hash_int} used in SplitShard() on shard {shard_to_split} in stream {self.stream_name} under account {self.account_id} is not both greater than one plus the shard's StartingHashKey {shard.starting_hash} and less than the shard's EndingHashKey {(shard.ending_hash - 1)}."
|
||||
)
|
||||
|
||||
if not shard.is_open:
|
||||
@ -241,12 +256,12 @@ class Stream(CloudFormationModel):
|
||||
new_shard_1 = Shard(
|
||||
last_id + 1,
|
||||
starting_hash=shard.starting_hash,
|
||||
ending_hash=new_starting_hash_key - 1,
|
||||
ending_hash=new_starting_hash_int - 1,
|
||||
parent=shard.shard_id,
|
||||
)
|
||||
new_shard_2 = Shard(
|
||||
last_id + 2,
|
||||
starting_hash=new_starting_hash_key,
|
||||
starting_hash=new_starting_hash_int,
|
||||
ending_hash=shard.ending_hash,
|
||||
parent=shard.shard_id,
|
||||
)
|
||||
@ -261,7 +276,7 @@ class Stream(CloudFormationModel):
|
||||
record = records[index]
|
||||
self.put_record(record.partition_key, record.explicit_hash_key, record.data)
|
||||
|
||||
def merge_shards(self, shard_to_merge, adjacent_shard_to_merge):
|
||||
def merge_shards(self, shard_to_merge: str, adjacent_shard_to_merge: str) -> None:
|
||||
shard1 = self.shards[shard_to_merge]
|
||||
shard2 = self.shards[adjacent_shard_to_merge]
|
||||
|
||||
@ -300,8 +315,8 @@ class Stream(CloudFormationModel):
|
||||
record.partition_key, record.data, record.explicit_hash_key
|
||||
)
|
||||
|
||||
def update_shard_count(self, target_shard_count):
|
||||
if self.stream_mode.get("StreamMode", "") == "ON_DEMAND":
|
||||
def update_shard_count(self, target_shard_count: int) -> None:
|
||||
if self.stream_mode.get("StreamMode", "") == "ON_DEMAND": # type: ignore
|
||||
raise StreamCannotBeUpdatedError(
|
||||
stream_name=self.stream_name, account_id=self.account_id
|
||||
)
|
||||
@ -351,13 +366,15 @@ class Stream(CloudFormationModel):
|
||||
|
||||
self.shard_count = target_shard_count
|
||||
|
||||
def get_shard(self, shard_id):
|
||||
def get_shard(self, shard_id: str) -> Shard:
|
||||
if shard_id in self.shards:
|
||||
return self.shards[shard_id]
|
||||
else:
|
||||
raise ShardNotFoundError(shard_id, stream="", account_id=self.account_id)
|
||||
|
||||
def get_shard_for_key(self, partition_key, explicit_hash_key):
|
||||
def get_shard_for_key(
|
||||
self, partition_key: str, explicit_hash_key: str
|
||||
) -> Optional[Shard]:
|
||||
if not isinstance(partition_key, str):
|
||||
raise InvalidArgumentError("partition_key")
|
||||
if len(partition_key) > 256:
|
||||
@ -367,25 +384,28 @@ class Stream(CloudFormationModel):
|
||||
if not isinstance(explicit_hash_key, str):
|
||||
raise InvalidArgumentError("explicit_hash_key")
|
||||
|
||||
key = int(explicit_hash_key)
|
||||
int_key = int(explicit_hash_key)
|
||||
|
||||
if key >= 2**128:
|
||||
if int_key >= 2**128:
|
||||
raise InvalidArgumentError("explicit_hash_key")
|
||||
|
||||
else:
|
||||
key = int(md5_hash(partition_key.encode("utf-8")).hexdigest(), 16)
|
||||
int_key = int(md5_hash(partition_key.encode("utf-8")).hexdigest(), 16)
|
||||
|
||||
for shard in self.shards.values():
|
||||
if shard.starting_hash <= key < shard.ending_hash:
|
||||
if shard.starting_hash <= int_key < shard.ending_hash:
|
||||
return shard
|
||||
return None
|
||||
|
||||
def put_record(self, partition_key, explicit_hash_key, data):
|
||||
def put_record(
|
||||
self, partition_key: str, explicit_hash_key: str, data: str
|
||||
) -> Tuple[str, str]:
|
||||
shard = self.get_shard_for_key(partition_key, explicit_hash_key)
|
||||
|
||||
sequence_number = shard.put_record(partition_key, data, explicit_hash_key)
|
||||
return sequence_number, shard.shard_id
|
||||
sequence_number = shard.put_record(partition_key, data, explicit_hash_key) # type: ignore
|
||||
return sequence_number, shard.shard_id # type: ignore
|
||||
|
||||
def to_json(self, shard_limit=None):
|
||||
def to_json(self, shard_limit: Optional[int] = None) -> Dict[str, Any]:
|
||||
all_shards = list(self.shards.values())
|
||||
requested_shards = all_shards[0 : shard_limit or len(all_shards)]
|
||||
return {
|
||||
@ -403,7 +423,7 @@ class Stream(CloudFormationModel):
|
||||
}
|
||||
}
|
||||
|
||||
def to_json_summary(self):
|
||||
def to_json_summary(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"StreamDescriptionSummary": {
|
||||
"StreamARN": self.arn,
|
||||
@ -420,18 +440,23 @@ class Stream(CloudFormationModel):
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def cloudformation_name_type():
|
||||
def cloudformation_name_type() -> str:
|
||||
return "Name"
|
||||
|
||||
@staticmethod
|
||||
def cloudformation_type():
|
||||
def cloudformation_type() -> str:
|
||||
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-kinesis-stream.html
|
||||
return "AWS::Kinesis::Stream"
|
||||
|
||||
@classmethod
|
||||
def create_from_cloudformation_json(
|
||||
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs
|
||||
):
|
||||
def create_from_cloudformation_json( # type: ignore[misc]
|
||||
cls,
|
||||
resource_name: str,
|
||||
cloudformation_json: Any,
|
||||
account_id: str,
|
||||
region_name: str,
|
||||
**kwargs: Any,
|
||||
) -> "Stream":
|
||||
properties = cloudformation_json.get("Properties", {})
|
||||
shard_count = properties.get("ShardCount", 1)
|
||||
retention_period_hours = properties.get("RetentionPeriodHours", resource_name)
|
||||
@ -451,14 +476,14 @@ class Stream(CloudFormationModel):
|
||||
return stream
|
||||
|
||||
@classmethod
|
||||
def update_from_cloudformation_json(
|
||||
def update_from_cloudformation_json( # type: ignore[misc]
|
||||
cls,
|
||||
original_resource,
|
||||
new_resource_name,
|
||||
cloudformation_json,
|
||||
account_id,
|
||||
region_name,
|
||||
):
|
||||
original_resource: Any,
|
||||
new_resource_name: str,
|
||||
cloudformation_json: Any,
|
||||
account_id: str,
|
||||
region_name: str,
|
||||
) -> "Stream":
|
||||
properties = cloudformation_json["Properties"]
|
||||
|
||||
if Stream.is_replacement_update(properties):
|
||||
@ -466,11 +491,17 @@ class Stream(CloudFormationModel):
|
||||
if resource_name_property not in properties:
|
||||
properties[resource_name_property] = new_resource_name
|
||||
new_resource = cls.create_from_cloudformation_json(
|
||||
properties[resource_name_property], cloudformation_json, region_name
|
||||
resource_name=properties[resource_name_property],
|
||||
cloudformation_json=cloudformation_json,
|
||||
account_id=account_id,
|
||||
region_name=region_name,
|
||||
)
|
||||
properties[resource_name_property] = original_resource.name
|
||||
cls.delete_from_cloudformation_json(
|
||||
original_resource.name, cloudformation_json, region_name
|
||||
resource_name=original_resource.name,
|
||||
cloudformation_json=cloudformation_json,
|
||||
account_id=account_id,
|
||||
region_name=region_name,
|
||||
)
|
||||
return new_resource
|
||||
|
||||
@ -489,14 +520,18 @@ class Stream(CloudFormationModel):
|
||||
return original_resource
|
||||
|
||||
@classmethod
|
||||
def delete_from_cloudformation_json(
|
||||
cls, resource_name, cloudformation_json, account_id, region_name
|
||||
):
|
||||
def delete_from_cloudformation_json( # type: ignore[misc]
|
||||
cls,
|
||||
resource_name: str,
|
||||
cloudformation_json: Any,
|
||||
account_id: str,
|
||||
region_name: str,
|
||||
) -> None:
|
||||
backend: KinesisBackend = kinesis_backends[account_id][region_name]
|
||||
backend.delete_stream(stream_arn=None, stream_name=resource_name)
|
||||
|
||||
@staticmethod
|
||||
def is_replacement_update(properties):
|
||||
def is_replacement_update(properties: List[str]) -> bool:
|
||||
properties_requiring_replacement_update = ["BucketName", "ObjectLockEnabled"]
|
||||
return any(
|
||||
[
|
||||
@ -506,10 +541,10 @@ class Stream(CloudFormationModel):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def has_cfn_attr(cls, attr):
|
||||
def has_cfn_attr(cls, attr: str) -> bool:
|
||||
return attr in ["Arn"]
|
||||
|
||||
def get_cfn_attribute(self, attribute_name):
|
||||
def get_cfn_attribute(self, attribute_name: str) -> str:
|
||||
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
|
||||
|
||||
if attribute_name == "Arn":
|
||||
@ -517,25 +552,31 @@ class Stream(CloudFormationModel):
|
||||
raise UnformattedGetAttTemplateException()
|
||||
|
||||
@property
|
||||
def physical_resource_id(self):
|
||||
def physical_resource_id(self) -> str:
|
||||
return self.stream_name
|
||||
|
||||
|
||||
class KinesisBackend(BaseBackend):
|
||||
def __init__(self, region_name, account_id):
|
||||
def __init__(self, region_name: str, account_id: str):
|
||||
super().__init__(region_name, account_id)
|
||||
self.streams: Dict[str, Stream] = OrderedDict()
|
||||
|
||||
@staticmethod
|
||||
def default_vpc_endpoint_service(service_region, zones):
|
||||
def default_vpc_endpoint_service(
|
||||
service_region: str, zones: List[str]
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Default VPC endpoint service."""
|
||||
return BaseBackend.default_vpc_endpoint_service_factory(
|
||||
service_region, zones, "kinesis", special_service_name="kinesis-streams"
|
||||
)
|
||||
|
||||
def create_stream(
|
||||
self, stream_name, shard_count, stream_mode=None, retention_period_hours=None
|
||||
):
|
||||
self,
|
||||
stream_name: str,
|
||||
shard_count: int,
|
||||
stream_mode: Optional[Dict[str, str]] = None,
|
||||
retention_period_hours: Optional[int] = None,
|
||||
) -> Stream:
|
||||
if stream_name in self.streams:
|
||||
raise ResourceInUseError(stream_name)
|
||||
stream = Stream(
|
||||
@ -560,14 +601,14 @@ class KinesisBackend(BaseBackend):
|
||||
return stream
|
||||
if stream_arn:
|
||||
stream_name = stream_arn.split("/")[1]
|
||||
raise StreamNotFoundError(stream_name, self.account_id)
|
||||
raise StreamNotFoundError(stream_name, self.account_id) # type: ignore
|
||||
|
||||
def describe_stream_summary(
|
||||
self, stream_arn: Optional[str], stream_name: Optional[str]
|
||||
) -> Stream:
|
||||
return self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
|
||||
|
||||
def list_streams(self):
|
||||
def list_streams(self) -> Iterable[Stream]:
|
||||
return self.streams.values()
|
||||
|
||||
def delete_stream(
|
||||
@ -582,9 +623,9 @@ class KinesisBackend(BaseBackend):
|
||||
stream_name: Optional[str],
|
||||
shard_id: str,
|
||||
shard_iterator_type: str,
|
||||
starting_sequence_number: str,
|
||||
at_timestamp: str,
|
||||
):
|
||||
starting_sequence_number: int,
|
||||
at_timestamp: datetime.datetime,
|
||||
) -> str:
|
||||
# Validate params
|
||||
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
|
||||
try:
|
||||
@ -605,31 +646,31 @@ class KinesisBackend(BaseBackend):
|
||||
|
||||
def get_records(
|
||||
self, stream_arn: Optional[str], shard_iterator: str, limit: Optional[int]
|
||||
):
|
||||
) -> Tuple[str, List[Record], int]:
|
||||
decomposed = decompose_shard_iterator(shard_iterator)
|
||||
stream_name, shard_id, last_sequence_id = decomposed
|
||||
|
||||
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
|
||||
shard = stream.get_shard(shard_id)
|
||||
|
||||
records, last_sequence_id, millis_behind_latest = shard.get_records(
|
||||
records, last_sequence_id, millis_behind_latest = shard.get_records( # type: ignore
|
||||
last_sequence_id, limit
|
||||
)
|
||||
|
||||
next_shard_iterator = compose_shard_iterator(
|
||||
stream_name, shard, last_sequence_id
|
||||
stream_name, shard, last_sequence_id # type: ignore
|
||||
)
|
||||
|
||||
return next_shard_iterator, records, millis_behind_latest
|
||||
|
||||
def put_record(
|
||||
self,
|
||||
stream_arn,
|
||||
stream_name,
|
||||
partition_key,
|
||||
explicit_hash_key,
|
||||
data,
|
||||
):
|
||||
stream_arn: str,
|
||||
stream_name: str,
|
||||
partition_key: str,
|
||||
explicit_hash_key: str,
|
||||
data: str,
|
||||
) -> Tuple[str, str]:
|
||||
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
|
||||
|
||||
sequence_number, shard_id = stream.put_record(
|
||||
@ -638,10 +679,12 @@ class KinesisBackend(BaseBackend):
|
||||
|
||||
return sequence_number, shard_id
|
||||
|
||||
def put_records(self, stream_arn, stream_name, records):
|
||||
def put_records(
|
||||
self, stream_arn: str, stream_name: str, records: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
|
||||
|
||||
response = {"FailedRecordCount": 0, "Records": []}
|
||||
response: Dict[str, Any] = {"FailedRecordCount": 0, "Records": []}
|
||||
|
||||
if len(records) > 500:
|
||||
raise TooManyRecords
|
||||
@ -660,7 +703,7 @@ class KinesisBackend(BaseBackend):
|
||||
data = record.get("Data")
|
||||
|
||||
sequence_number, shard_id = stream.put_record(
|
||||
partition_key, explicit_hash_key, data
|
||||
partition_key, explicit_hash_key, data # type: ignore[arg-type]
|
||||
)
|
||||
response["Records"].append(
|
||||
{"SequenceNumber": sequence_number, "ShardId": shard_id}
|
||||
@ -669,8 +712,12 @@ class KinesisBackend(BaseBackend):
|
||||
return response
|
||||
|
||||
def split_shard(
|
||||
self, stream_arn, stream_name, shard_to_split, new_starting_hash_key
|
||||
):
|
||||
self,
|
||||
stream_arn: str,
|
||||
stream_name: str,
|
||||
shard_to_split: str,
|
||||
new_starting_hash_key: str,
|
||||
) -> None:
|
||||
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
|
||||
|
||||
if not re.match("[a-zA-Z0-9_.-]+", shard_to_split):
|
||||
@ -695,8 +742,12 @@ class KinesisBackend(BaseBackend):
|
||||
stream.split_shard(shard_to_split, new_starting_hash_key)
|
||||
|
||||
def merge_shards(
|
||||
self, stream_arn, stream_name, shard_to_merge, adjacent_shard_to_merge
|
||||
):
|
||||
self,
|
||||
stream_arn: str,
|
||||
stream_name: str,
|
||||
shard_to_merge: str,
|
||||
adjacent_shard_to_merge: str,
|
||||
) -> None:
|
||||
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
|
||||
|
||||
if shard_to_merge not in stream.shards:
|
||||
@ -713,7 +764,9 @@ class KinesisBackend(BaseBackend):
|
||||
|
||||
stream.merge_shards(shard_to_merge, adjacent_shard_to_merge)
|
||||
|
||||
def update_shard_count(self, stream_arn, stream_name, target_shard_count):
|
||||
def update_shard_count(
|
||||
self, stream_arn: str, stream_name: str, target_shard_count: int
|
||||
) -> int:
|
||||
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
|
||||
current_shard_count = len([s for s in stream.shards.values() if s.is_open])
|
||||
|
||||
@ -722,7 +775,7 @@ class KinesisBackend(BaseBackend):
|
||||
return current_shard_count
|
||||
|
||||
@paginate(pagination_model=PAGINATION_MODEL)
|
||||
def list_shards(self, stream_arn: Optional[str], stream_name: Optional[str]):
|
||||
def list_shards(self, stream_arn: Optional[str], stream_name: Optional[str]) -> List[Dict[str, Any]]: # type: ignore
|
||||
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
|
||||
shards = sorted(stream.shards.values(), key=lambda x: x.shard_id)
|
||||
return [shard.to_json() for shard in shards]
|
||||
@ -766,12 +819,16 @@ class KinesisBackend(BaseBackend):
|
||||
stream.retention_period_hours = retention_period_hours
|
||||
|
||||
def list_tags_for_stream(
|
||||
self, stream_arn, stream_name, exclusive_start_tag_key=None, limit=None
|
||||
):
|
||||
self,
|
||||
stream_arn: str,
|
||||
stream_name: str,
|
||||
exclusive_start_tag_key: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
|
||||
|
||||
tags = []
|
||||
result = {"HasMoreTags": False, "Tags": tags}
|
||||
tags: List[Dict[str, str]] = []
|
||||
result: Dict[str, Any] = {"HasMoreTags": False, "Tags": tags}
|
||||
for key, val in sorted(stream.tags.items(), key=lambda x: x[0]):
|
||||
if limit and len(tags) >= limit:
|
||||
result["HasMoreTags"] = True
|
||||
@ -805,7 +862,7 @@ class KinesisBackend(BaseBackend):
|
||||
stream_arn: Optional[str],
|
||||
stream_name: Optional[str],
|
||||
shard_level_metrics: List[str],
|
||||
) -> Tuple[str, str, Dict[str, Any], Dict[str, Any]]:
|
||||
) -> Tuple[str, str, List[str], List[str]]:
|
||||
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
|
||||
current_shard_level_metrics = stream.shard_level_metrics
|
||||
desired_metrics = list(set(current_shard_level_metrics + shard_level_metrics))
|
||||
@ -822,7 +879,7 @@ class KinesisBackend(BaseBackend):
|
||||
stream_arn: Optional[str],
|
||||
stream_name: Optional[str],
|
||||
to_be_disabled: List[str],
|
||||
) -> Tuple[str, str, Dict[str, Any], Dict[str, Any]]:
|
||||
) -> Tuple[str, str, List[str], List[str]]:
|
||||
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
|
||||
current_metrics = stream.shard_level_metrics
|
||||
if "ALL" in to_be_disabled:
|
||||
@ -834,19 +891,19 @@ class KinesisBackend(BaseBackend):
|
||||
stream.shard_level_metrics = desired_metrics
|
||||
return stream.arn, stream.stream_name, current_metrics, desired_metrics
|
||||
|
||||
def _find_stream_by_arn(self, stream_arn):
|
||||
def _find_stream_by_arn(self, stream_arn: str) -> Stream: # type: ignore[return]
|
||||
for stream in self.streams.values():
|
||||
if stream.arn == stream_arn:
|
||||
return stream
|
||||
|
||||
def list_stream_consumers(self, stream_arn):
|
||||
def list_stream_consumers(self, stream_arn: str) -> List[Consumer]:
|
||||
"""
|
||||
Pagination is not yet implemented
|
||||
"""
|
||||
stream = self._find_stream_by_arn(stream_arn)
|
||||
return stream.consumers
|
||||
|
||||
def register_stream_consumer(self, stream_arn, consumer_name):
|
||||
def register_stream_consumer(self, stream_arn: str, consumer_name: str) -> Consumer:
|
||||
consumer = Consumer(
|
||||
consumer_name, self.account_id, self.region_name, stream_arn
|
||||
)
|
||||
@ -854,7 +911,9 @@ class KinesisBackend(BaseBackend):
|
||||
stream.consumers.append(consumer)
|
||||
return consumer
|
||||
|
||||
def describe_stream_consumer(self, stream_arn, consumer_name, consumer_arn):
|
||||
def describe_stream_consumer(
|
||||
self, stream_arn: str, consumer_name: str, consumer_arn: str
|
||||
) -> Consumer:
|
||||
if stream_arn:
|
||||
stream = self._find_stream_by_arn(stream_arn)
|
||||
for consumer in stream.consumers:
|
||||
@ -862,14 +921,16 @@ class KinesisBackend(BaseBackend):
|
||||
return consumer
|
||||
if consumer_arn:
|
||||
for stream in self.streams.values():
|
||||
consumer = stream.get_consumer_by_arn(consumer_arn)
|
||||
if consumer:
|
||||
return consumer
|
||||
_consumer = stream.get_consumer_by_arn(consumer_arn)
|
||||
if _consumer:
|
||||
return _consumer
|
||||
raise ConsumerNotFound(
|
||||
consumer=consumer_name or consumer_arn, account_id=self.account_id
|
||||
)
|
||||
|
||||
def deregister_stream_consumer(self, stream_arn, consumer_name, consumer_arn):
|
||||
def deregister_stream_consumer(
|
||||
self, stream_arn: str, consumer_name: str, consumer_arn: str
|
||||
) -> None:
|
||||
if stream_arn:
|
||||
stream = self._find_stream_by_arn(stream_arn)
|
||||
stream.consumers = [
|
||||
@ -881,17 +942,19 @@ class KinesisBackend(BaseBackend):
|
||||
# It will be a noop for other streams
|
||||
stream.delete_consumer(consumer_arn)
|
||||
|
||||
def start_stream_encryption(self, stream_arn, stream_name, encryption_type, key_id):
|
||||
def start_stream_encryption(
|
||||
self, stream_arn: str, stream_name: str, encryption_type: str, key_id: str
|
||||
) -> None:
|
||||
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
|
||||
stream.encryption_type = encryption_type
|
||||
stream.key_id = key_id
|
||||
|
||||
def stop_stream_encryption(self, stream_arn, stream_name):
|
||||
def stop_stream_encryption(self, stream_arn: str, stream_name: str) -> None:
|
||||
stream = self.describe_stream(stream_arn=stream_arn, stream_name=stream_name)
|
||||
stream.encryption_type = "NONE"
|
||||
stream.key_id = None
|
||||
|
||||
def update_stream_mode(self, stream_arn, stream_mode):
|
||||
def update_stream_mode(self, stream_arn: str, stream_mode: Dict[str, str]) -> None:
|
||||
stream = self._find_stream_by_arn(stream_arn)
|
||||
stream.stream_mode = stream_mode
|
||||
|
||||
|
@ -5,45 +5,41 @@ from .models import kinesis_backends, KinesisBackend
|
||||
|
||||
|
||||
class KinesisResponse(BaseResponse):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(service_name="kinesis")
|
||||
|
||||
@property
|
||||
def parameters(self):
|
||||
return json.loads(self.body)
|
||||
|
||||
@property
|
||||
def kinesis_backend(self) -> KinesisBackend:
|
||||
return kinesis_backends[self.current_account][self.region]
|
||||
|
||||
def create_stream(self):
|
||||
stream_name = self.parameters.get("StreamName")
|
||||
shard_count = self.parameters.get("ShardCount")
|
||||
stream_mode = self.parameters.get("StreamModeDetails")
|
||||
def create_stream(self) -> str:
|
||||
stream_name = self._get_param("StreamName")
|
||||
shard_count = self._get_param("ShardCount")
|
||||
stream_mode = self._get_param("StreamModeDetails")
|
||||
self.kinesis_backend.create_stream(
|
||||
stream_name, shard_count, stream_mode=stream_mode
|
||||
)
|
||||
return ""
|
||||
|
||||
def describe_stream(self):
|
||||
stream_name = self.parameters.get("StreamName")
|
||||
stream_arn = self.parameters.get("StreamARN")
|
||||
limit = self.parameters.get("Limit")
|
||||
def describe_stream(self) -> str:
|
||||
stream_name = self._get_param("StreamName")
|
||||
stream_arn = self._get_param("StreamARN")
|
||||
limit = self._get_param("Limit")
|
||||
stream = self.kinesis_backend.describe_stream(stream_arn, stream_name)
|
||||
return json.dumps(stream.to_json(shard_limit=limit))
|
||||
|
||||
def describe_stream_summary(self):
|
||||
stream_arn = self.parameters.get("StreamARN")
|
||||
stream_name = self.parameters.get("StreamName")
|
||||
def describe_stream_summary(self) -> str:
|
||||
stream_arn = self._get_param("StreamARN")
|
||||
stream_name = self._get_param("StreamName")
|
||||
stream = self.kinesis_backend.describe_stream_summary(stream_arn, stream_name)
|
||||
return json.dumps(stream.to_json_summary())
|
||||
|
||||
def list_streams(self):
|
||||
def list_streams(self) -> str:
|
||||
streams = self.kinesis_backend.list_streams()
|
||||
stream_names = [stream.stream_name for stream in streams]
|
||||
max_streams = self._get_param("Limit", 10)
|
||||
try:
|
||||
token = self.parameters.get("ExclusiveStartStreamName")
|
||||
token = self._get_param("ExclusiveStartStreamName")
|
||||
except ValueError:
|
||||
token = self._get_param("ExclusiveStartStreamName")
|
||||
if token:
|
||||
@ -59,19 +55,19 @@ class KinesisResponse(BaseResponse):
|
||||
{"HasMoreStreams": has_more_streams, "StreamNames": streams_resp}
|
||||
)
|
||||
|
||||
def delete_stream(self):
|
||||
stream_arn = self.parameters.get("StreamARN")
|
||||
stream_name = self.parameters.get("StreamName")
|
||||
def delete_stream(self) -> str:
|
||||
stream_arn = self._get_param("StreamARN")
|
||||
stream_name = self._get_param("StreamName")
|
||||
self.kinesis_backend.delete_stream(stream_arn, stream_name)
|
||||
return ""
|
||||
|
||||
def get_shard_iterator(self):
|
||||
stream_arn = self.parameters.get("StreamARN")
|
||||
stream_name = self.parameters.get("StreamName")
|
||||
shard_id = self.parameters.get("ShardId")
|
||||
shard_iterator_type = self.parameters.get("ShardIteratorType")
|
||||
starting_sequence_number = self.parameters.get("StartingSequenceNumber")
|
||||
at_timestamp = self.parameters.get("Timestamp")
|
||||
def get_shard_iterator(self) -> str:
|
||||
stream_arn = self._get_param("StreamARN")
|
||||
stream_name = self._get_param("StreamName")
|
||||
shard_id = self._get_param("ShardId")
|
||||
shard_iterator_type = self._get_param("ShardIteratorType")
|
||||
starting_sequence_number = self._get_param("StartingSequenceNumber")
|
||||
at_timestamp = self._get_param("Timestamp")
|
||||
|
||||
shard_iterator = self.kinesis_backend.get_shard_iterator(
|
||||
stream_arn,
|
||||
@ -84,10 +80,10 @@ class KinesisResponse(BaseResponse):
|
||||
|
||||
return json.dumps({"ShardIterator": shard_iterator})
|
||||
|
||||
def get_records(self):
|
||||
stream_arn = self.parameters.get("StreamARN")
|
||||
shard_iterator = self.parameters.get("ShardIterator")
|
||||
limit = self.parameters.get("Limit")
|
||||
def get_records(self) -> str:
|
||||
stream_arn = self._get_param("StreamARN")
|
||||
shard_iterator = self._get_param("ShardIterator")
|
||||
limit = self._get_param("Limit")
|
||||
|
||||
(
|
||||
next_shard_iterator,
|
||||
@ -103,12 +99,12 @@ class KinesisResponse(BaseResponse):
|
||||
}
|
||||
)
|
||||
|
||||
def put_record(self):
|
||||
stream_arn = self.parameters.get("StreamARN")
|
||||
stream_name = self.parameters.get("StreamName")
|
||||
partition_key = self.parameters.get("PartitionKey")
|
||||
explicit_hash_key = self.parameters.get("ExplicitHashKey")
|
||||
data = self.parameters.get("Data")
|
||||
def put_record(self) -> str:
|
||||
stream_arn = self._get_param("StreamARN")
|
||||
stream_name = self._get_param("StreamName")
|
||||
partition_key = self._get_param("PartitionKey")
|
||||
explicit_hash_key = self._get_param("ExplicitHashKey")
|
||||
data = self._get_param("Data")
|
||||
|
||||
sequence_number, shard_id = self.kinesis_backend.put_record(
|
||||
stream_arn,
|
||||
@ -120,40 +116,40 @@ class KinesisResponse(BaseResponse):
|
||||
|
||||
return json.dumps({"SequenceNumber": sequence_number, "ShardId": shard_id})
|
||||
|
||||
def put_records(self):
|
||||
stream_arn = self.parameters.get("StreamARN")
|
||||
stream_name = self.parameters.get("StreamName")
|
||||
records = self.parameters.get("Records")
|
||||
def put_records(self) -> str:
|
||||
stream_arn = self._get_param("StreamARN")
|
||||
stream_name = self._get_param("StreamName")
|
||||
records = self._get_param("Records")
|
||||
|
||||
response = self.kinesis_backend.put_records(stream_arn, stream_name, records)
|
||||
|
||||
return json.dumps(response)
|
||||
|
||||
def split_shard(self):
|
||||
stream_arn = self.parameters.get("StreamARN")
|
||||
stream_name = self.parameters.get("StreamName")
|
||||
shard_to_split = self.parameters.get("ShardToSplit")
|
||||
new_starting_hash_key = self.parameters.get("NewStartingHashKey")
|
||||
def split_shard(self) -> str:
|
||||
stream_arn = self._get_param("StreamARN")
|
||||
stream_name = self._get_param("StreamName")
|
||||
shard_to_split = self._get_param("ShardToSplit")
|
||||
new_starting_hash_key = self._get_param("NewStartingHashKey")
|
||||
self.kinesis_backend.split_shard(
|
||||
stream_arn, stream_name, shard_to_split, new_starting_hash_key
|
||||
)
|
||||
return ""
|
||||
|
||||
def merge_shards(self):
|
||||
stream_arn = self.parameters.get("StreamARN")
|
||||
stream_name = self.parameters.get("StreamName")
|
||||
shard_to_merge = self.parameters.get("ShardToMerge")
|
||||
adjacent_shard_to_merge = self.parameters.get("AdjacentShardToMerge")
|
||||
def merge_shards(self) -> str:
|
||||
stream_arn = self._get_param("StreamARN")
|
||||
stream_name = self._get_param("StreamName")
|
||||
shard_to_merge = self._get_param("ShardToMerge")
|
||||
adjacent_shard_to_merge = self._get_param("AdjacentShardToMerge")
|
||||
self.kinesis_backend.merge_shards(
|
||||
stream_arn, stream_name, shard_to_merge, adjacent_shard_to_merge
|
||||
)
|
||||
return ""
|
||||
|
||||
def list_shards(self):
|
||||
stream_arn = self.parameters.get("StreamARN")
|
||||
stream_name = self.parameters.get("StreamName")
|
||||
next_token = self.parameters.get("NextToken")
|
||||
max_results = self.parameters.get("MaxResults", 10000)
|
||||
def list_shards(self) -> str:
|
||||
stream_arn = self._get_param("StreamARN")
|
||||
stream_name = self._get_param("StreamName")
|
||||
next_token = self._get_param("NextToken")
|
||||
max_results = self._get_param("MaxResults", 10000)
|
||||
shards, token = self.kinesis_backend.list_shards(
|
||||
stream_arn=stream_arn,
|
||||
stream_name=stream_name,
|
||||
@ -165,10 +161,10 @@ class KinesisResponse(BaseResponse):
|
||||
res["NextToken"] = token
|
||||
return json.dumps(res)
|
||||
|
||||
def update_shard_count(self):
|
||||
stream_arn = self.parameters.get("StreamARN")
|
||||
stream_name = self.parameters.get("StreamName")
|
||||
target_shard_count = self.parameters.get("TargetShardCount")
|
||||
def update_shard_count(self) -> str:
|
||||
stream_arn = self._get_param("StreamARN")
|
||||
stream_name = self._get_param("StreamName")
|
||||
target_shard_count = self._get_param("TargetShardCount")
|
||||
current_shard_count = self.kinesis_backend.update_shard_count(
|
||||
stream_arn=stream_arn,
|
||||
stream_name=stream_name,
|
||||
@ -182,52 +178,52 @@ class KinesisResponse(BaseResponse):
|
||||
)
|
||||
)
|
||||
|
||||
def increase_stream_retention_period(self):
|
||||
stream_arn = self.parameters.get("StreamARN")
|
||||
stream_name = self.parameters.get("StreamName")
|
||||
retention_period_hours = self.parameters.get("RetentionPeriodHours")
|
||||
def increase_stream_retention_period(self) -> str:
|
||||
stream_arn = self._get_param("StreamARN")
|
||||
stream_name = self._get_param("StreamName")
|
||||
retention_period_hours = self._get_param("RetentionPeriodHours")
|
||||
self.kinesis_backend.increase_stream_retention_period(
|
||||
stream_arn, stream_name, retention_period_hours
|
||||
)
|
||||
return ""
|
||||
|
||||
def decrease_stream_retention_period(self):
|
||||
stream_arn = self.parameters.get("StreamARN")
|
||||
stream_name = self.parameters.get("StreamName")
|
||||
retention_period_hours = self.parameters.get("RetentionPeriodHours")
|
||||
def decrease_stream_retention_period(self) -> str:
|
||||
stream_arn = self._get_param("StreamARN")
|
||||
stream_name = self._get_param("StreamName")
|
||||
retention_period_hours = self._get_param("RetentionPeriodHours")
|
||||
self.kinesis_backend.decrease_stream_retention_period(
|
||||
stream_arn, stream_name, retention_period_hours
|
||||
)
|
||||
return ""
|
||||
|
||||
def add_tags_to_stream(self):
|
||||
stream_arn = self.parameters.get("StreamARN")
|
||||
stream_name = self.parameters.get("StreamName")
|
||||
tags = self.parameters.get("Tags")
|
||||
def add_tags_to_stream(self) -> str:
|
||||
stream_arn = self._get_param("StreamARN")
|
||||
stream_name = self._get_param("StreamName")
|
||||
tags = self._get_param("Tags")
|
||||
self.kinesis_backend.add_tags_to_stream(stream_arn, stream_name, tags)
|
||||
return json.dumps({})
|
||||
|
||||
def list_tags_for_stream(self):
|
||||
stream_arn = self.parameters.get("StreamARN")
|
||||
stream_name = self.parameters.get("StreamName")
|
||||
exclusive_start_tag_key = self.parameters.get("ExclusiveStartTagKey")
|
||||
limit = self.parameters.get("Limit")
|
||||
def list_tags_for_stream(self) -> str:
|
||||
stream_arn = self._get_param("StreamARN")
|
||||
stream_name = self._get_param("StreamName")
|
||||
exclusive_start_tag_key = self._get_param("ExclusiveStartTagKey")
|
||||
limit = self._get_param("Limit")
|
||||
response = self.kinesis_backend.list_tags_for_stream(
|
||||
stream_arn, stream_name, exclusive_start_tag_key, limit
|
||||
)
|
||||
return json.dumps(response)
|
||||
|
||||
def remove_tags_from_stream(self):
|
||||
stream_arn = self.parameters.get("StreamARN")
|
||||
stream_name = self.parameters.get("StreamName")
|
||||
tag_keys = self.parameters.get("TagKeys")
|
||||
def remove_tags_from_stream(self) -> str:
|
||||
stream_arn = self._get_param("StreamARN")
|
||||
stream_name = self._get_param("StreamName")
|
||||
tag_keys = self._get_param("TagKeys")
|
||||
self.kinesis_backend.remove_tags_from_stream(stream_arn, stream_name, tag_keys)
|
||||
return json.dumps({})
|
||||
|
||||
def enable_enhanced_monitoring(self):
|
||||
stream_arn = self.parameters.get("StreamARN")
|
||||
stream_name = self.parameters.get("StreamName")
|
||||
shard_level_metrics = self.parameters.get("ShardLevelMetrics")
|
||||
def enable_enhanced_monitoring(self) -> str:
|
||||
stream_arn = self._get_param("StreamARN")
|
||||
stream_name = self._get_param("StreamName")
|
||||
shard_level_metrics = self._get_param("ShardLevelMetrics")
|
||||
arn, name, current, desired = self.kinesis_backend.enable_enhanced_monitoring(
|
||||
stream_arn=stream_arn,
|
||||
stream_name=stream_name,
|
||||
@ -242,10 +238,10 @@ class KinesisResponse(BaseResponse):
|
||||
)
|
||||
)
|
||||
|
||||
def disable_enhanced_monitoring(self):
|
||||
stream_arn = self.parameters.get("StreamARN")
|
||||
stream_name = self.parameters.get("StreamName")
|
||||
shard_level_metrics = self.parameters.get("ShardLevelMetrics")
|
||||
def disable_enhanced_monitoring(self) -> str:
|
||||
stream_arn = self._get_param("StreamARN")
|
||||
stream_name = self._get_param("StreamName")
|
||||
shard_level_metrics = self._get_param("ShardLevelMetrics")
|
||||
arn, name, current, desired = self.kinesis_backend.disable_enhanced_monitoring(
|
||||
stream_arn=stream_arn,
|
||||
stream_name=stream_name,
|
||||
@ -260,23 +256,23 @@ class KinesisResponse(BaseResponse):
|
||||
)
|
||||
)
|
||||
|
||||
def list_stream_consumers(self):
|
||||
stream_arn = self.parameters.get("StreamARN")
|
||||
def list_stream_consumers(self) -> str:
|
||||
stream_arn = self._get_param("StreamARN")
|
||||
consumers = self.kinesis_backend.list_stream_consumers(stream_arn=stream_arn)
|
||||
return json.dumps(dict(Consumers=[c.to_json() for c in consumers]))
|
||||
|
||||
def register_stream_consumer(self):
|
||||
stream_arn = self.parameters.get("StreamARN")
|
||||
consumer_name = self.parameters.get("ConsumerName")
|
||||
def register_stream_consumer(self) -> str:
|
||||
stream_arn = self._get_param("StreamARN")
|
||||
consumer_name = self._get_param("ConsumerName")
|
||||
consumer = self.kinesis_backend.register_stream_consumer(
|
||||
stream_arn=stream_arn, consumer_name=consumer_name
|
||||
)
|
||||
return json.dumps(dict(Consumer=consumer.to_json()))
|
||||
|
||||
def describe_stream_consumer(self):
|
||||
stream_arn = self.parameters.get("StreamARN")
|
||||
consumer_name = self.parameters.get("ConsumerName")
|
||||
consumer_arn = self.parameters.get("ConsumerARN")
|
||||
def describe_stream_consumer(self) -> str:
|
||||
stream_arn = self._get_param("StreamARN")
|
||||
consumer_name = self._get_param("ConsumerName")
|
||||
consumer_arn = self._get_param("ConsumerARN")
|
||||
consumer = self.kinesis_backend.describe_stream_consumer(
|
||||
stream_arn=stream_arn,
|
||||
consumer_name=consumer_name,
|
||||
@ -286,10 +282,10 @@ class KinesisResponse(BaseResponse):
|
||||
dict(ConsumerDescription=consumer.to_json(include_stream_arn=True))
|
||||
)
|
||||
|
||||
def deregister_stream_consumer(self):
|
||||
stream_arn = self.parameters.get("StreamARN")
|
||||
consumer_name = self.parameters.get("ConsumerName")
|
||||
consumer_arn = self.parameters.get("ConsumerARN")
|
||||
def deregister_stream_consumer(self) -> str:
|
||||
stream_arn = self._get_param("StreamARN")
|
||||
consumer_name = self._get_param("ConsumerName")
|
||||
consumer_arn = self._get_param("ConsumerARN")
|
||||
self.kinesis_backend.deregister_stream_consumer(
|
||||
stream_arn=stream_arn,
|
||||
consumer_name=consumer_name,
|
||||
@ -297,11 +293,11 @@ class KinesisResponse(BaseResponse):
|
||||
)
|
||||
return json.dumps(dict())
|
||||
|
||||
def start_stream_encryption(self):
|
||||
stream_arn = self.parameters.get("StreamARN")
|
||||
stream_name = self.parameters.get("StreamName")
|
||||
encryption_type = self.parameters.get("EncryptionType")
|
||||
key_id = self.parameters.get("KeyId")
|
||||
def start_stream_encryption(self) -> str:
|
||||
stream_arn = self._get_param("StreamARN")
|
||||
stream_name = self._get_param("StreamName")
|
||||
encryption_type = self._get_param("EncryptionType")
|
||||
key_id = self._get_param("KeyId")
|
||||
self.kinesis_backend.start_stream_encryption(
|
||||
stream_arn=stream_arn,
|
||||
stream_name=stream_name,
|
||||
@ -310,16 +306,16 @@ class KinesisResponse(BaseResponse):
|
||||
)
|
||||
return json.dumps(dict())
|
||||
|
||||
def stop_stream_encryption(self):
|
||||
stream_arn = self.parameters.get("StreamARN")
|
||||
stream_name = self.parameters.get("StreamName")
|
||||
def stop_stream_encryption(self) -> str:
|
||||
stream_arn = self._get_param("StreamARN")
|
||||
stream_name = self._get_param("StreamName")
|
||||
self.kinesis_backend.stop_stream_encryption(
|
||||
stream_arn=stream_arn, stream_name=stream_name
|
||||
)
|
||||
return json.dumps(dict())
|
||||
|
||||
def update_stream_mode(self):
|
||||
stream_arn = self.parameters.get("StreamARN")
|
||||
stream_mode = self.parameters.get("StreamModeDetails")
|
||||
def update_stream_mode(self) -> str:
|
||||
stream_arn = self._get_param("StreamARN")
|
||||
stream_mode = self._get_param("StreamModeDetails")
|
||||
self.kinesis_backend.update_stream_mode(stream_arn, stream_mode)
|
||||
return "{}"
|
||||
|
@ -1,4 +1,6 @@
|
||||
import base64
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, List
|
||||
|
||||
from .exceptions import InvalidArgumentError
|
||||
|
||||
@ -30,8 +32,12 @@ PAGINATION_MODEL = {
|
||||
|
||||
|
||||
def compose_new_shard_iterator(
|
||||
stream_name, shard, shard_iterator_type, starting_sequence_number, at_timestamp
|
||||
):
|
||||
stream_name: Optional[str],
|
||||
shard: Any,
|
||||
shard_iterator_type: str,
|
||||
starting_sequence_number: int,
|
||||
at_timestamp: datetime,
|
||||
) -> str:
|
||||
if shard_iterator_type == "AT_SEQUENCE_NUMBER":
|
||||
last_sequence_id = int(starting_sequence_number) - 1
|
||||
elif shard_iterator_type == "AFTER_SEQUENCE_NUMBER":
|
||||
@ -47,11 +53,13 @@ def compose_new_shard_iterator(
|
||||
return compose_shard_iterator(stream_name, shard, last_sequence_id)
|
||||
|
||||
|
||||
def compose_shard_iterator(stream_name, shard, last_sequence_id):
|
||||
def compose_shard_iterator(
|
||||
stream_name: Optional[str], shard: Any, last_sequence_id: int
|
||||
) -> str:
|
||||
return encode_method(
|
||||
f"{stream_name}:{shard.shard_id}:{last_sequence_id}".encode("utf-8")
|
||||
).decode("utf-8")
|
||||
|
||||
|
||||
def decompose_shard_iterator(shard_iterator):
|
||||
def decompose_shard_iterator(shard_iterator: str) -> List[str]:
|
||||
return decode_method(shard_iterator.encode("utf-8")).decode("utf-8").split(":")
|
||||
|
@ -6,7 +6,7 @@ class KinesisvideoClientError(RESTError):
|
||||
|
||||
|
||||
class ResourceNotFoundException(KinesisvideoClientError):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.code = 404
|
||||
super().__init__(
|
||||
"ResourceNotFoundException",
|
||||
@ -15,6 +15,6 @@ class ResourceNotFoundException(KinesisvideoClientError):
|
||||
|
||||
|
||||
class ResourceInUseException(KinesisvideoClientError):
|
||||
def __init__(self, message):
|
||||
def __init__(self, message: str):
|
||||
self.code = 400
|
||||
super().__init__("ResourceInUseException", message)
|
||||
|
@ -1,5 +1,6 @@
|
||||
from moto.core import BaseBackend, BackendDict, BaseModel
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
from moto.core import BaseBackend, BackendDict, BaseModel
|
||||
from .exceptions import ResourceNotFoundException, ResourceInUseException
|
||||
from moto.moto_api._internal import mock_random as random
|
||||
|
||||
@ -7,14 +8,14 @@ from moto.moto_api._internal import mock_random as random
|
||||
class Stream(BaseModel):
|
||||
def __init__(
|
||||
self,
|
||||
account_id,
|
||||
region_name,
|
||||
device_name,
|
||||
stream_name,
|
||||
media_type,
|
||||
kms_key_id,
|
||||
data_retention_in_hours,
|
||||
tags,
|
||||
account_id: str,
|
||||
region_name: str,
|
||||
device_name: str,
|
||||
stream_name: str,
|
||||
media_type: str,
|
||||
kms_key_id: str,
|
||||
data_retention_in_hours: int,
|
||||
tags: Dict[str, str],
|
||||
):
|
||||
self.region_name = region_name
|
||||
self.stream_name = stream_name
|
||||
@ -30,11 +31,11 @@ class Stream(BaseModel):
|
||||
self.data_endpoint_number = random.get_random_hex()
|
||||
self.arn = stream_arn
|
||||
|
||||
def get_data_endpoint(self, api_name):
|
||||
def get_data_endpoint(self, api_name: str) -> str:
|
||||
data_endpoint_prefix = "s-" if api_name in ("PUT_MEDIA", "GET_MEDIA") else "b-"
|
||||
return f"https://{data_endpoint_prefix}{self.data_endpoint_number}.kinesisvideo.{self.region_name}.amazonaws.com"
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"DeviceName": self.device_name,
|
||||
"StreamName": self.stream_name,
|
||||
@ -49,19 +50,19 @@ class Stream(BaseModel):
|
||||
|
||||
|
||||
class KinesisVideoBackend(BaseBackend):
|
||||
def __init__(self, region_name, account_id):
|
||||
def __init__(self, region_name: str, account_id: str):
|
||||
super().__init__(region_name, account_id)
|
||||
self.streams = {}
|
||||
self.streams: Dict[str, Stream] = {}
|
||||
|
||||
def create_stream(
|
||||
self,
|
||||
device_name,
|
||||
stream_name,
|
||||
media_type,
|
||||
kms_key_id,
|
||||
data_retention_in_hours,
|
||||
tags,
|
||||
):
|
||||
device_name: str,
|
||||
stream_name: str,
|
||||
media_type: str,
|
||||
kms_key_id: str,
|
||||
data_retention_in_hours: int,
|
||||
tags: Dict[str, str],
|
||||
) -> str:
|
||||
streams = [_ for _ in self.streams.values() if _.stream_name == stream_name]
|
||||
if len(streams) > 0:
|
||||
raise ResourceInUseException(f"The stream {stream_name} already exists.")
|
||||
@ -78,7 +79,7 @@ class KinesisVideoBackend(BaseBackend):
|
||||
self.streams[stream.arn] = stream
|
||||
return stream.arn
|
||||
|
||||
def _get_stream(self, stream_name, stream_arn):
|
||||
def _get_stream(self, stream_name: str, stream_arn: str) -> Stream:
|
||||
if stream_name:
|
||||
streams = [_ for _ in self.streams.values() if _.stream_name == stream_name]
|
||||
if len(streams) == 0:
|
||||
@ -90,20 +91,17 @@ class KinesisVideoBackend(BaseBackend):
|
||||
raise ResourceNotFoundException()
|
||||
return stream
|
||||
|
||||
def describe_stream(self, stream_name, stream_arn):
|
||||
def describe_stream(self, stream_name: str, stream_arn: str) -> Dict[str, Any]:
|
||||
stream = self._get_stream(stream_name, stream_arn)
|
||||
stream_info = stream.to_dict()
|
||||
return stream_info
|
||||
return stream.to_dict()
|
||||
|
||||
def list_streams(self):
|
||||
def list_streams(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Pagination and the StreamNameCondition-parameter are not yet implemented
|
||||
"""
|
||||
stream_info_list = [_.to_dict() for _ in self.streams.values()]
|
||||
next_token = None
|
||||
return stream_info_list, next_token
|
||||
return [_.to_dict() for _ in self.streams.values()]
|
||||
|
||||
def delete_stream(self, stream_arn):
|
||||
def delete_stream(self, stream_arn: str) -> None:
|
||||
"""
|
||||
The CurrentVersion-parameter is not yet implemented
|
||||
"""
|
||||
@ -112,11 +110,11 @@ class KinesisVideoBackend(BaseBackend):
|
||||
raise ResourceNotFoundException()
|
||||
del self.streams[stream_arn]
|
||||
|
||||
def get_data_endpoint(self, stream_name, stream_arn, api_name):
|
||||
def get_data_endpoint(
|
||||
self, stream_name: str, stream_arn: str, api_name: str
|
||||
) -> str:
|
||||
stream = self._get_stream(stream_name, stream_arn)
|
||||
return stream.get_data_endpoint(api_name)
|
||||
|
||||
# add methods from here
|
||||
|
||||
|
||||
kinesisvideo_backends = BackendDict(KinesisVideoBackend, "kinesisvideo")
|
||||
|
@ -1,17 +1,17 @@
|
||||
from moto.core.responses import BaseResponse
|
||||
from .models import kinesisvideo_backends
|
||||
from .models import kinesisvideo_backends, KinesisVideoBackend
|
||||
import json
|
||||
|
||||
|
||||
class KinesisVideoResponse(BaseResponse):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(service_name="kinesisvideo")
|
||||
|
||||
@property
|
||||
def kinesisvideo_backend(self):
|
||||
def kinesisvideo_backend(self) -> KinesisVideoBackend:
|
||||
return kinesisvideo_backends[self.current_account][self.region]
|
||||
|
||||
def create_stream(self):
|
||||
def create_stream(self) -> str:
|
||||
device_name = self._get_param("DeviceName")
|
||||
stream_name = self._get_param("StreamName")
|
||||
media_type = self._get_param("MediaType")
|
||||
@ -28,7 +28,7 @@ class KinesisVideoResponse(BaseResponse):
|
||||
)
|
||||
return json.dumps(dict(StreamARN=stream_arn))
|
||||
|
||||
def describe_stream(self):
|
||||
def describe_stream(self) -> str:
|
||||
stream_name = self._get_param("StreamName")
|
||||
stream_arn = self._get_param("StreamARN")
|
||||
stream_info = self.kinesisvideo_backend.describe_stream(
|
||||
@ -36,16 +36,16 @@ class KinesisVideoResponse(BaseResponse):
|
||||
)
|
||||
return json.dumps(dict(StreamInfo=stream_info))
|
||||
|
||||
def list_streams(self):
|
||||
stream_info_list, next_token = self.kinesisvideo_backend.list_streams()
|
||||
return json.dumps(dict(StreamInfoList=stream_info_list, NextToken=next_token))
|
||||
def list_streams(self) -> str:
|
||||
stream_info_list = self.kinesisvideo_backend.list_streams()
|
||||
return json.dumps(dict(StreamInfoList=stream_info_list, NextToken=None))
|
||||
|
||||
def delete_stream(self):
|
||||
def delete_stream(self) -> str:
|
||||
stream_arn = self._get_param("StreamARN")
|
||||
self.kinesisvideo_backend.delete_stream(stream_arn=stream_arn)
|
||||
return json.dumps(dict())
|
||||
|
||||
def get_data_endpoint(self):
|
||||
def get_data_endpoint(self) -> str:
|
||||
stream_name = self._get_param("StreamName")
|
||||
stream_arn = self._get_param("StreamARN")
|
||||
api_name = self._get_param("APIName")
|
||||
|
@ -1,14 +1,17 @@
|
||||
from typing import Tuple
|
||||
from moto.core import BaseBackend, BackendDict
|
||||
from moto.kinesisvideo import kinesisvideo_backends
|
||||
from moto.kinesisvideo.models import kinesisvideo_backends, KinesisVideoBackend
|
||||
from moto.sts.utils import random_session_token
|
||||
|
||||
|
||||
class KinesisVideoArchivedMediaBackend(BaseBackend):
|
||||
@property
|
||||
def backend(self):
|
||||
def backend(self) -> KinesisVideoBackend:
|
||||
return kinesisvideo_backends[self.account_id][self.region_name]
|
||||
|
||||
def _get_streaming_url(self, stream_name, stream_arn, api_name):
|
||||
def _get_streaming_url(
|
||||
self, stream_name: str, stream_arn: str, api_name: str
|
||||
) -> str:
|
||||
stream = self.backend._get_stream(stream_name, stream_arn)
|
||||
data_endpoint = stream.get_data_endpoint(api_name)
|
||||
session_token = random_session_token()
|
||||
@ -19,19 +22,17 @@ class KinesisVideoArchivedMediaBackend(BaseBackend):
|
||||
relative_path = api_to_relative_path[api_name]
|
||||
return f"{data_endpoint}{relative_path}?SessionToken={session_token}"
|
||||
|
||||
def get_hls_streaming_session_url(self, stream_name, stream_arn):
|
||||
# Ignore option paramters as the format of hls_url does't depends on them
|
||||
def get_hls_streaming_session_url(self, stream_name: str, stream_arn: str) -> str:
|
||||
# Ignore option paramters as the format of hls_url doesn't depend on them
|
||||
api_name = "GET_HLS_STREAMING_SESSION_URL"
|
||||
url = self._get_streaming_url(stream_name, stream_arn, api_name)
|
||||
return url
|
||||
return self._get_streaming_url(stream_name, stream_arn, api_name)
|
||||
|
||||
def get_dash_streaming_session_url(self, stream_name, stream_arn):
|
||||
# Ignore option paramters as the format of hls_url does't depends on them
|
||||
def get_dash_streaming_session_url(self, stream_name: str, stream_arn: str) -> str:
|
||||
# Ignore option paramters as the format of hls_url doesn't depend on them
|
||||
api_name = "GET_DASH_STREAMING_SESSION_URL"
|
||||
url = self._get_streaming_url(stream_name, stream_arn, api_name)
|
||||
return url
|
||||
return self._get_streaming_url(stream_name, stream_arn, api_name)
|
||||
|
||||
def get_clip(self, stream_name, stream_arn):
|
||||
def get_clip(self, stream_name: str, stream_arn: str) -> Tuple[str, bytes]:
|
||||
self.backend._get_stream(stream_name, stream_arn)
|
||||
content_type = "video/mp4" # Fixed content_type as it depends on input stream
|
||||
payload = b"sample-mp4-video"
|
||||
|
@ -1,17 +1,18 @@
|
||||
from typing import Dict, Tuple
|
||||
from moto.core.responses import BaseResponse
|
||||
from .models import kinesisvideoarchivedmedia_backends
|
||||
from .models import kinesisvideoarchivedmedia_backends, KinesisVideoArchivedMediaBackend
|
||||
import json
|
||||
|
||||
|
||||
class KinesisVideoArchivedMediaResponse(BaseResponse):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(service_name="kinesis-video-archived-media")
|
||||
|
||||
@property
|
||||
def kinesisvideoarchivedmedia_backend(self):
|
||||
def kinesisvideoarchivedmedia_backend(self) -> KinesisVideoArchivedMediaBackend:
|
||||
return kinesisvideoarchivedmedia_backends[self.current_account][self.region]
|
||||
|
||||
def get_hls_streaming_session_url(self):
|
||||
def get_hls_streaming_session_url(self) -> str:
|
||||
stream_name = self._get_param("StreamName")
|
||||
stream_arn = self._get_param("StreamARN")
|
||||
hls_streaming_session_url = (
|
||||
@ -21,7 +22,7 @@ class KinesisVideoArchivedMediaResponse(BaseResponse):
|
||||
)
|
||||
return json.dumps(dict(HLSStreamingSessionURL=hls_streaming_session_url))
|
||||
|
||||
def get_dash_streaming_session_url(self):
|
||||
def get_dash_streaming_session_url(self) -> str:
|
||||
stream_name = self._get_param("StreamName")
|
||||
stream_arn = self._get_param("StreamARN")
|
||||
dash_streaming_session_url = (
|
||||
@ -31,7 +32,7 @@ class KinesisVideoArchivedMediaResponse(BaseResponse):
|
||||
)
|
||||
return json.dumps(dict(DASHStreamingSessionURL=dash_streaming_session_url))
|
||||
|
||||
def get_clip(self):
|
||||
def get_clip(self) -> Tuple[bytes, Dict[str, str]]:
|
||||
stream_name = self._get_param("StreamName")
|
||||
stream_arn = self._get_param("StreamARN")
|
||||
content_type, payload = self.kinesisvideoarchivedmedia_backend.get_clip(
|
||||
|
@ -4,29 +4,29 @@ from moto.core.exceptions import JsonRESTError
|
||||
class NotFoundException(JsonRESTError):
|
||||
code = 400
|
||||
|
||||
def __init__(self, message):
|
||||
def __init__(self, message: str):
|
||||
super().__init__("NotFoundException", message)
|
||||
|
||||
|
||||
class ValidationException(JsonRESTError):
|
||||
code = 400
|
||||
|
||||
def __init__(self, message):
|
||||
def __init__(self, message: str):
|
||||
super().__init__("ValidationException", message)
|
||||
|
||||
|
||||
class AlreadyExistsException(JsonRESTError):
|
||||
code = 400
|
||||
|
||||
def __init__(self, message):
|
||||
def __init__(self, message: str):
|
||||
super().__init__("AlreadyExistsException", message)
|
||||
|
||||
|
||||
class NotAuthorizedException(JsonRESTError):
|
||||
code = 400
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("NotAuthorizedException", None)
|
||||
def __init__(self) -> None:
|
||||
super().__init__("NotAuthorizedException", "")
|
||||
|
||||
self.description = '{"__type":"NotAuthorizedException"}'
|
||||
|
||||
@ -34,7 +34,7 @@ class NotAuthorizedException(JsonRESTError):
|
||||
class AccessDeniedException(JsonRESTError):
|
||||
code = 400
|
||||
|
||||
def __init__(self, message):
|
||||
def __init__(self, message: str):
|
||||
super().__init__("AccessDeniedException", message)
|
||||
|
||||
self.description = '{"__type":"AccessDeniedException"}'
|
||||
@ -43,7 +43,7 @@ class AccessDeniedException(JsonRESTError):
|
||||
class InvalidCiphertextException(JsonRESTError):
|
||||
code = 400
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("InvalidCiphertextException", None)
|
||||
def __init__(self) -> None:
|
||||
super().__init__("InvalidCiphertextException", "")
|
||||
|
||||
self.description = '{"__type":"InvalidCiphertextException"}'
|
||||
|
@ -5,8 +5,8 @@ from copy import copy
|
||||
from datetime import datetime, timedelta
|
||||
from cryptography.exceptions import InvalidSignature
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
|
||||
from cryptography.hazmat.primitives.asymmetric import padding
|
||||
from typing import Any, Dict, List, Tuple, Optional, Iterable, Set
|
||||
|
||||
from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel
|
||||
from moto.core.utils import unix_time
|
||||
@ -28,12 +28,12 @@ from .utils import (
|
||||
class Grant(BaseModel):
|
||||
def __init__(
|
||||
self,
|
||||
key_id,
|
||||
name,
|
||||
grantee_principal,
|
||||
operations,
|
||||
constraints,
|
||||
retiring_principal,
|
||||
key_id: str,
|
||||
name: str,
|
||||
grantee_principal: str,
|
||||
operations: List[str],
|
||||
constraints: Dict[str, Any],
|
||||
retiring_principal: str,
|
||||
):
|
||||
self.key_id = key_id
|
||||
self.name = name
|
||||
@ -44,7 +44,7 @@ class Grant(BaseModel):
|
||||
self.id = mock_random.get_random_hex()
|
||||
self.token = mock_random.get_random_hex()
|
||||
|
||||
def to_json(self):
|
||||
def to_json(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"KeyId": self.key_id,
|
||||
"GrantId": self.id,
|
||||
@ -59,13 +59,13 @@ class Grant(BaseModel):
|
||||
class Key(CloudFormationModel):
|
||||
def __init__(
|
||||
self,
|
||||
policy,
|
||||
key_usage,
|
||||
key_spec,
|
||||
description,
|
||||
account_id,
|
||||
region,
|
||||
multi_region=False,
|
||||
policy: Optional[str],
|
||||
key_usage: str,
|
||||
key_spec: str,
|
||||
description: str,
|
||||
account_id: str,
|
||||
region: str,
|
||||
multi_region: bool = False,
|
||||
):
|
||||
self.id = generate_key_id(multi_region)
|
||||
self.creation_date = unix_time()
|
||||
@ -78,7 +78,7 @@ class Key(CloudFormationModel):
|
||||
self.region = region
|
||||
self.multi_region = multi_region
|
||||
self.key_rotation_status = False
|
||||
self.deletion_date = None
|
||||
self.deletion_date: Optional[datetime] = None
|
||||
self.key_material = generate_master_key()
|
||||
self.private_key = generate_private_key()
|
||||
self.origin = "AWS_KMS"
|
||||
@ -86,10 +86,15 @@ class Key(CloudFormationModel):
|
||||
self.key_spec = key_spec or "SYMMETRIC_DEFAULT"
|
||||
self.arn = f"arn:aws:kms:{region}:{account_id}:key/{self.id}"
|
||||
|
||||
self.grants = dict()
|
||||
self.grants: Dict[str, Grant] = dict()
|
||||
|
||||
def add_grant(
|
||||
self, name, grantee_principal, operations, constraints, retiring_principal
|
||||
self,
|
||||
name: str,
|
||||
grantee_principal: str,
|
||||
operations: List[str],
|
||||
constraints: Dict[str, Any],
|
||||
retiring_principal: str,
|
||||
) -> Grant:
|
||||
grant = Grant(
|
||||
self.id,
|
||||
@ -102,32 +107,32 @@ class Key(CloudFormationModel):
|
||||
self.grants[grant.id] = grant
|
||||
return grant
|
||||
|
||||
def list_grants(self, grant_id) -> [Grant]:
|
||||
def list_grants(self, grant_id: str) -> List[Grant]:
|
||||
grant_ids = [grant_id] if grant_id else self.grants.keys()
|
||||
return [grant for _id, grant in self.grants.items() if _id in grant_ids]
|
||||
|
||||
def list_retirable_grants(self, retiring_principal) -> [Grant]:
|
||||
def list_retirable_grants(self, retiring_principal: str) -> List[Grant]:
|
||||
return [
|
||||
grant
|
||||
for grant in self.grants.values()
|
||||
if grant.retiring_principal == retiring_principal
|
||||
]
|
||||
|
||||
def revoke_grant(self, grant_id) -> None:
|
||||
def revoke_grant(self, grant_id: str) -> None:
|
||||
if not self.grants.pop(grant_id, None):
|
||||
raise JsonRESTError("NotFoundException", f"Grant ID {grant_id} not found")
|
||||
|
||||
def retire_grant(self, grant_id) -> None:
|
||||
def retire_grant(self, grant_id: str) -> None:
|
||||
self.grants.pop(grant_id, None)
|
||||
|
||||
def retire_grant_by_token(self, grant_token) -> None:
|
||||
def retire_grant_by_token(self, grant_token: str) -> None:
|
||||
self.grants = {
|
||||
_id: grant
|
||||
for _id, grant in self.grants.items()
|
||||
if grant.token != grant_token
|
||||
}
|
||||
|
||||
def generate_default_policy(self):
|
||||
def generate_default_policy(self) -> str:
|
||||
return json.dumps(
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
@ -145,11 +150,11 @@ class Key(CloudFormationModel):
|
||||
)
|
||||
|
||||
@property
|
||||
def physical_resource_id(self):
|
||||
def physical_resource_id(self) -> str:
|
||||
return self.id
|
||||
|
||||
@property
|
||||
def encryption_algorithms(self):
|
||||
def encryption_algorithms(self) -> Optional[List[str]]:
|
||||
if self.key_usage == "SIGN_VERIFY":
|
||||
return None
|
||||
elif self.key_spec == "SYMMETRIC_DEFAULT":
|
||||
@ -158,9 +163,9 @@ class Key(CloudFormationModel):
|
||||
return ["RSAES_OAEP_SHA_1", "RSAES_OAEP_SHA_256"]
|
||||
|
||||
@property
|
||||
def signing_algorithms(self):
|
||||
def signing_algorithms(self) -> List[str]:
|
||||
if self.key_usage == "ENCRYPT_DECRYPT":
|
||||
return None
|
||||
return None # type: ignore[return-value]
|
||||
elif self.key_spec in ["ECC_NIST_P256", "ECC_SECG_P256K1"]:
|
||||
return ["ECDSA_SHA_256"]
|
||||
elif self.key_spec == "ECC_NIST_P384":
|
||||
@ -177,7 +182,7 @@ class Key(CloudFormationModel):
|
||||
"RSASSA_PSS_SHA_512",
|
||||
]
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
key_dict = {
|
||||
"KeyMetadata": {
|
||||
"AWSAccountId": self.account_id,
|
||||
@ -201,22 +206,27 @@ class Key(CloudFormationModel):
|
||||
key_dict["KeyMetadata"]["DeletionDate"] = unix_time(self.deletion_date)
|
||||
return key_dict
|
||||
|
||||
def delete(self, account_id, region_name):
|
||||
def delete(self, account_id: str, region_name: str) -> None:
|
||||
kms_backends[account_id][region_name].delete_key(self.id)
|
||||
|
||||
@staticmethod
|
||||
def cloudformation_name_type():
|
||||
return None
|
||||
def cloudformation_name_type() -> str:
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def cloudformation_type():
|
||||
def cloudformation_type() -> str:
|
||||
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-kms-key.html
|
||||
return "AWS::KMS::Key"
|
||||
|
||||
@classmethod
|
||||
def create_from_cloudformation_json(
|
||||
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs
|
||||
):
|
||||
def create_from_cloudformation_json( # type: ignore[misc]
|
||||
cls,
|
||||
resource_name: str,
|
||||
cloudformation_json: Any,
|
||||
account_id: str,
|
||||
region_name: str,
|
||||
**kwargs: Any,
|
||||
) -> "Key":
|
||||
kms_backend = kms_backends[account_id][region_name]
|
||||
properties = cloudformation_json["Properties"]
|
||||
|
||||
@ -233,10 +243,10 @@ class Key(CloudFormationModel):
|
||||
return key
|
||||
|
||||
@classmethod
|
||||
def has_cfn_attr(cls, attr):
|
||||
def has_cfn_attr(cls, attr: str) -> bool:
|
||||
return attr in ["Arn"]
|
||||
|
||||
def get_cfn_attribute(self, attribute_name):
|
||||
def get_cfn_attribute(self, attribute_name: str) -> str:
|
||||
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
|
||||
|
||||
if attribute_name == "Arn":
|
||||
@ -245,20 +255,22 @@ class Key(CloudFormationModel):
|
||||
|
||||
|
||||
class KmsBackend(BaseBackend):
|
||||
def __init__(self, region_name, account_id=None):
|
||||
super().__init__(region_name=region_name, account_id=account_id)
|
||||
self.keys = {}
|
||||
self.key_to_aliases = defaultdict(set)
|
||||
def __init__(self, region_name: str, account_id: Optional[str] = None):
|
||||
super().__init__(region_name=region_name, account_id=account_id) # type: ignore
|
||||
self.keys: Dict[str, Key] = {}
|
||||
self.key_to_aliases: Dict[str, Set[str]] = defaultdict(set)
|
||||
self.tagger = TaggingService(key_name="TagKey", value_name="TagValue")
|
||||
|
||||
@staticmethod
|
||||
def default_vpc_endpoint_service(service_region, zones):
|
||||
def default_vpc_endpoint_service(
|
||||
service_region: str, zones: List[str]
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Default VPC endpoint service."""
|
||||
return BaseBackend.default_vpc_endpoint_service_factory(
|
||||
service_region, zones, "kms"
|
||||
)
|
||||
|
||||
def _generate_default_keys(self, alias_name):
|
||||
def _generate_default_keys(self, alias_name: str) -> Optional[str]:
|
||||
"""Creates default kms keys"""
|
||||
if alias_name in RESERVED_ALIASES:
|
||||
key = self.create_key(
|
||||
@ -270,10 +282,17 @@ class KmsBackend(BaseBackend):
|
||||
)
|
||||
self.add_alias(key.id, alias_name)
|
||||
return key.id
|
||||
return None
|
||||
|
||||
def create_key(
|
||||
self, policy, key_usage, key_spec, description, tags, multi_region=False
|
||||
):
|
||||
self,
|
||||
policy: Optional[str],
|
||||
key_usage: str,
|
||||
key_spec: str,
|
||||
description: str,
|
||||
tags: Optional[List[Dict[str, str]]],
|
||||
multi_region: bool = False,
|
||||
) -> Key:
|
||||
"""
|
||||
The provided Policy currently does not need to be valid. If it is valid, Moto will perform authorization checks on key-related operations, just like AWS does.
|
||||
|
||||
@ -303,7 +322,7 @@ class KmsBackend(BaseBackend):
|
||||
#
|
||||
# In our implementation with just create a copy of all the properties once without any protection from change,
|
||||
# as the exact implementation is currently infeasible.
|
||||
def replicate_key(self, key_id, replica_region):
|
||||
def replicate_key(self, key_id: str, replica_region: str) -> None:
|
||||
# Using copy() instead of deepcopy(), as the latter results in exception:
|
||||
# TypeError: cannot pickle '_cffi_backend.FFI' object
|
||||
# Since we only update top level properties, copy() should suffice.
|
||||
@ -312,31 +331,31 @@ class KmsBackend(BaseBackend):
|
||||
to_region_backend = kms_backends[self.account_id][replica_region]
|
||||
to_region_backend.keys[replica_key.id] = replica_key
|
||||
|
||||
def update_key_description(self, key_id, description):
|
||||
def update_key_description(self, key_id: str, description: str) -> None:
|
||||
key = self.keys[self.get_key_id(key_id)]
|
||||
key.description = description
|
||||
|
||||
def delete_key(self, key_id):
|
||||
def delete_key(self, key_id: str) -> None:
|
||||
if key_id in self.keys:
|
||||
if key_id in self.key_to_aliases:
|
||||
self.key_to_aliases.pop(key_id)
|
||||
self.tagger.delete_all_tags_for_resource(key_id)
|
||||
|
||||
return self.keys.pop(key_id)
|
||||
self.keys.pop(key_id)
|
||||
|
||||
def describe_key(self, key_id) -> Key:
|
||||
def describe_key(self, key_id: str) -> Key:
|
||||
# allow the different methods (alias, ARN :key/, keyId, ARN alias) to
|
||||
# describe key not just KeyId
|
||||
key_id = self.get_key_id(key_id)
|
||||
if r"alias/" in str(key_id).lower():
|
||||
key_id = self.get_key_id_from_alias(key_id)
|
||||
key_id = self.get_key_id_from_alias(key_id) # type: ignore[assignment]
|
||||
return self.keys[self.get_key_id(key_id)]
|
||||
|
||||
def list_keys(self):
|
||||
def list_keys(self) -> Iterable[Key]:
|
||||
return self.keys.values()
|
||||
|
||||
@staticmethod
|
||||
def get_key_id(key_id):
|
||||
def get_key_id(key_id: str) -> str:
|
||||
# Allow use of ARN as well as pure KeyId
|
||||
if key_id.startswith("arn:") and ":key/" in key_id:
|
||||
return key_id.split(":key/")[1]
|
||||
@ -344,14 +363,14 @@ class KmsBackend(BaseBackend):
|
||||
return key_id
|
||||
|
||||
@staticmethod
|
||||
def get_alias_name(alias_name):
|
||||
def get_alias_name(alias_name: str) -> str:
|
||||
# Allow use of ARN as well as alias name
|
||||
if alias_name.startswith("arn:") and ":alias/" in alias_name:
|
||||
return "alias/" + alias_name.split(":alias/")[1]
|
||||
|
||||
return alias_name
|
||||
|
||||
def any_id_to_key_id(self, key_id):
|
||||
def any_id_to_key_id(self, key_id: str) -> str:
|
||||
"""Go from any valid key ID to the raw key ID.
|
||||
|
||||
Acceptable inputs:
|
||||
@ -363,66 +382,65 @@ class KmsBackend(BaseBackend):
|
||||
key_id = self.get_alias_name(key_id)
|
||||
key_id = self.get_key_id(key_id)
|
||||
if key_id.startswith("alias/"):
|
||||
key_id = self.get_key_id(self.get_key_id_from_alias(key_id))
|
||||
key_id = self.get_key_id(self.get_key_id_from_alias(key_id)) # type: ignore[arg-type]
|
||||
return key_id
|
||||
|
||||
def alias_exists(self, alias_name):
|
||||
def alias_exists(self, alias_name: str) -> bool:
|
||||
for aliases in self.key_to_aliases.values():
|
||||
if alias_name in aliases:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def add_alias(self, target_key_id, alias_name):
|
||||
def add_alias(self, target_key_id: str, alias_name: str) -> None:
|
||||
raw_key_id = self.get_key_id(target_key_id)
|
||||
self.key_to_aliases[raw_key_id].add(alias_name)
|
||||
|
||||
def delete_alias(self, alias_name):
|
||||
def delete_alias(self, alias_name: str) -> None:
|
||||
"""Delete the alias."""
|
||||
for aliases in self.key_to_aliases.values():
|
||||
if alias_name in aliases:
|
||||
aliases.remove(alias_name)
|
||||
|
||||
def get_all_aliases(self):
|
||||
def get_all_aliases(self) -> Dict[str, Set[str]]:
|
||||
return self.key_to_aliases
|
||||
|
||||
def get_key_id_from_alias(self, alias_name):
|
||||
def get_key_id_from_alias(self, alias_name: str) -> Optional[str]:
|
||||
for key_id, aliases in dict(self.key_to_aliases).items():
|
||||
if alias_name in ",".join(aliases):
|
||||
return key_id
|
||||
if alias_name in RESERVED_ALIASES:
|
||||
key_id = self._generate_default_keys(alias_name)
|
||||
return key_id
|
||||
return self._generate_default_keys(alias_name)
|
||||
return None
|
||||
|
||||
def enable_key_rotation(self, key_id):
|
||||
def enable_key_rotation(self, key_id: str) -> None:
|
||||
self.keys[self.get_key_id(key_id)].key_rotation_status = True
|
||||
|
||||
def disable_key_rotation(self, key_id):
|
||||
def disable_key_rotation(self, key_id: str) -> None:
|
||||
self.keys[self.get_key_id(key_id)].key_rotation_status = False
|
||||
|
||||
def get_key_rotation_status(self, key_id):
|
||||
def get_key_rotation_status(self, key_id: str) -> bool:
|
||||
return self.keys[self.get_key_id(key_id)].key_rotation_status
|
||||
|
||||
def put_key_policy(self, key_id, policy):
|
||||
def put_key_policy(self, key_id: str, policy: str) -> None:
|
||||
self.keys[self.get_key_id(key_id)].policy = policy
|
||||
|
||||
def get_key_policy(self, key_id):
|
||||
def get_key_policy(self, key_id: str) -> str:
|
||||
return self.keys[self.get_key_id(key_id)].policy
|
||||
|
||||
def disable_key(self, key_id):
|
||||
def disable_key(self, key_id: str) -> None:
|
||||
self.keys[key_id].enabled = False
|
||||
self.keys[key_id].key_state = "Disabled"
|
||||
|
||||
def enable_key(self, key_id):
|
||||
def enable_key(self, key_id: str) -> None:
|
||||
self.keys[key_id].enabled = True
|
||||
self.keys[key_id].key_state = "Enabled"
|
||||
|
||||
def cancel_key_deletion(self, key_id):
|
||||
def cancel_key_deletion(self, key_id: str) -> None:
|
||||
self.keys[key_id].key_state = "Disabled"
|
||||
self.keys[key_id].deletion_date = None
|
||||
|
||||
def schedule_key_deletion(self, key_id, pending_window_in_days):
|
||||
def schedule_key_deletion(self, key_id: str, pending_window_in_days: int) -> float: # type: ignore[return]
|
||||
if 7 <= pending_window_in_days <= 30:
|
||||
self.keys[key_id].enabled = False
|
||||
self.keys[key_id].key_state = "PendingDeletion"
|
||||
@ -431,7 +449,9 @@ class KmsBackend(BaseBackend):
|
||||
)
|
||||
return unix_time(self.keys[key_id].deletion_date)
|
||||
|
||||
def encrypt(self, key_id, plaintext, encryption_context):
|
||||
def encrypt(
|
||||
self, key_id: str, plaintext: bytes, encryption_context: Dict[str, str]
|
||||
) -> Tuple[bytes, str]:
|
||||
key_id = self.any_id_to_key_id(key_id)
|
||||
|
||||
ciphertext_blob = encrypt(
|
||||
@ -443,7 +463,9 @@ class KmsBackend(BaseBackend):
|
||||
arn = self.keys[key_id].arn
|
||||
return ciphertext_blob, arn
|
||||
|
||||
def decrypt(self, ciphertext_blob, encryption_context):
|
||||
def decrypt(
|
||||
self, ciphertext_blob: bytes, encryption_context: Dict[str, str]
|
||||
) -> Tuple[bytes, str]:
|
||||
plaintext, key_id = decrypt(
|
||||
master_keys=self.keys,
|
||||
ciphertext_blob=ciphertext_blob,
|
||||
@ -454,11 +476,11 @@ class KmsBackend(BaseBackend):
|
||||
|
||||
def re_encrypt(
|
||||
self,
|
||||
ciphertext_blob,
|
||||
source_encryption_context,
|
||||
destination_key_id,
|
||||
destination_encryption_context,
|
||||
):
|
||||
ciphertext_blob: bytes,
|
||||
source_encryption_context: Dict[str, str],
|
||||
destination_key_id: str,
|
||||
destination_encryption_context: Dict[str, str],
|
||||
) -> Tuple[bytes, str, str]:
|
||||
destination_key_id = self.any_id_to_key_id(destination_key_id)
|
||||
|
||||
plaintext, decrypting_arn = self.decrypt(
|
||||
@ -472,7 +494,13 @@ class KmsBackend(BaseBackend):
|
||||
)
|
||||
return new_ciphertext_blob, decrypting_arn, encrypting_arn
|
||||
|
||||
def generate_data_key(self, key_id, encryption_context, number_of_bytes, key_spec):
|
||||
def generate_data_key(
|
||||
self,
|
||||
key_id: str,
|
||||
encryption_context: Dict[str, str],
|
||||
number_of_bytes: int,
|
||||
key_spec: str,
|
||||
) -> Tuple[bytes, bytes, str]:
|
||||
key_id = self.any_id_to_key_id(key_id)
|
||||
|
||||
if key_spec:
|
||||
@ -492,7 +520,7 @@ class KmsBackend(BaseBackend):
|
||||
|
||||
return plaintext, ciphertext_blob, arn
|
||||
|
||||
def list_resource_tags(self, key_id_or_arn):
|
||||
def list_resource_tags(self, key_id_or_arn: str) -> Dict[str, List[Dict[str, str]]]:
|
||||
key_id = self.get_key_id(key_id_or_arn)
|
||||
if key_id in self.keys:
|
||||
return self.tagger.list_tags_for_resource(key_id)
|
||||
@ -501,21 +529,21 @@ class KmsBackend(BaseBackend):
|
||||
"The request was rejected because the specified entity or resource could not be found.",
|
||||
)
|
||||
|
||||
def tag_resource(self, key_id_or_arn, tags):
|
||||
def tag_resource(self, key_id_or_arn: str, tags: List[Dict[str, str]]) -> None:
|
||||
key_id = self.get_key_id(key_id_or_arn)
|
||||
if key_id in self.keys:
|
||||
self.tagger.tag_resource(key_id, tags)
|
||||
return {}
|
||||
return
|
||||
raise JsonRESTError(
|
||||
"NotFoundException",
|
||||
"The request was rejected because the specified entity or resource could not be found.",
|
||||
)
|
||||
|
||||
def untag_resource(self, key_id_or_arn, tag_names):
|
||||
def untag_resource(self, key_id_or_arn: str, tag_names: List[str]) -> None:
|
||||
key_id = self.get_key_id(key_id_or_arn)
|
||||
if key_id in self.keys:
|
||||
self.tagger.untag_resource_using_names(key_id, tag_names)
|
||||
return {}
|
||||
return
|
||||
raise JsonRESTError(
|
||||
"NotFoundException",
|
||||
"The request was rejected because the specified entity or resource could not be found.",
|
||||
@ -523,13 +551,13 @@ class KmsBackend(BaseBackend):
|
||||
|
||||
def create_grant(
|
||||
self,
|
||||
key_id,
|
||||
grantee_principal,
|
||||
operations,
|
||||
name,
|
||||
constraints,
|
||||
retiring_principal,
|
||||
):
|
||||
key_id: str,
|
||||
grantee_principal: str,
|
||||
operations: List[str],
|
||||
name: str,
|
||||
constraints: Dict[str, Any],
|
||||
retiring_principal: str,
|
||||
) -> Tuple[str, str]:
|
||||
key = self.describe_key(key_id)
|
||||
grant = key.add_grant(
|
||||
name,
|
||||
@ -540,21 +568,21 @@ class KmsBackend(BaseBackend):
|
||||
)
|
||||
return grant.id, grant.token
|
||||
|
||||
def list_grants(self, key_id, grant_id) -> [Grant]:
|
||||
def list_grants(self, key_id: str, grant_id: str) -> List[Grant]:
|
||||
key = self.describe_key(key_id)
|
||||
return key.list_grants(grant_id)
|
||||
|
||||
def list_retirable_grants(self, retiring_principal):
|
||||
def list_retirable_grants(self, retiring_principal: str) -> List[Grant]:
|
||||
grants = []
|
||||
for key in self.keys.values():
|
||||
grants.extend(key.list_retirable_grants(retiring_principal))
|
||||
return grants
|
||||
|
||||
def revoke_grant(self, key_id, grant_id) -> None:
|
||||
def revoke_grant(self, key_id: str, grant_id: str) -> None:
|
||||
key = self.describe_key(key_id)
|
||||
key.revoke_grant(grant_id)
|
||||
|
||||
def retire_grant(self, key_id, grant_id, grant_token) -> None:
|
||||
def retire_grant(self, key_id: str, grant_id: str, grant_token: str) -> None:
|
||||
if grant_token:
|
||||
for key in self.keys.values():
|
||||
key.retire_grant_by_token(grant_token)
|
||||
@ -562,7 +590,7 @@ class KmsBackend(BaseBackend):
|
||||
key = self.describe_key(key_id)
|
||||
key.retire_grant(grant_id)
|
||||
|
||||
def __ensure_valid_sign_and_verify_key(self, key: Key):
|
||||
def __ensure_valid_sign_and_verify_key(self, key: Key) -> None:
|
||||
if key.key_usage != "SIGN_VERIFY":
|
||||
raise ValidationException(
|
||||
(
|
||||
@ -571,7 +599,9 @@ class KmsBackend(BaseBackend):
|
||||
).format(key_id=key.id)
|
||||
)
|
||||
|
||||
def __ensure_valid_signing_augorithm(self, key: Key, signing_algorithm):
|
||||
def __ensure_valid_signing_augorithm(
|
||||
self, key: Key, signing_algorithm: str
|
||||
) -> None:
|
||||
if signing_algorithm not in key.signing_algorithms:
|
||||
raise ValidationException(
|
||||
(
|
||||
@ -584,7 +614,9 @@ class KmsBackend(BaseBackend):
|
||||
)
|
||||
)
|
||||
|
||||
def sign(self, key_id, message, signing_algorithm):
|
||||
def sign(
|
||||
self, key_id: str, message: bytes, signing_algorithm: str
|
||||
) -> Tuple[str, bytes, str]:
|
||||
"""Sign message using generated private key.
|
||||
|
||||
- signing_algorithm is ignored and hardcoded to RSASSA_PSS_SHA_256
|
||||
@ -607,7 +639,9 @@ class KmsBackend(BaseBackend):
|
||||
|
||||
return key.arn, signature, signing_algorithm
|
||||
|
||||
def verify(self, key_id, message, signature, signing_algorithm):
|
||||
def verify(
|
||||
self, key_id: str, message: bytes, signature: bytes, signing_algorithm: str
|
||||
) -> Tuple[str, bool, str]:
|
||||
"""Verify message using public key from generated private key.
|
||||
|
||||
- signing_algorithm is ignored and hardcoded to RSASSA_PSS_SHA_256
|
||||
|
@ -1,4 +1,5 @@
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List
|
||||
import json
|
||||
from .models import Key
|
||||
from .exceptions import AccessDeniedException
|
||||
@ -8,7 +9,7 @@ ALTERNATIVE_ACTIONS = defaultdict(list)
|
||||
ALTERNATIVE_ACTIONS["kms:DescribeKey"] = ["kms:*", "kms:Describe*", "kms:DescribeKey"]
|
||||
|
||||
|
||||
def validate_policy(key: Key, action: str):
|
||||
def validate_policy(key: Key, action: str) -> None:
|
||||
"""
|
||||
Relevant docs:
|
||||
- https://docs.aws.amazon.com/kms/latest/developerguide/key-policy-default.html
|
||||
@ -29,20 +30,22 @@ def validate_policy(key: Key, action: str):
|
||||
)
|
||||
|
||||
|
||||
def check_statement(statement, resource, action):
|
||||
def check_statement(statement: Dict[str, Any], resource: str, action: str) -> bool:
|
||||
return action_matches(statement.get("Action", []), action) and resource_matches(
|
||||
statement.get("Resource", ""), resource
|
||||
)
|
||||
|
||||
|
||||
def action_matches(applicable_actions, action):
|
||||
def action_matches(applicable_actions: List[str], action: str) -> bool:
|
||||
alternatives = ALTERNATIVE_ACTIONS[action]
|
||||
if any(alt in applicable_actions for alt in alternatives):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def resource_matches(applicable_resources, resource): # pylint: disable=unused-argument
|
||||
def resource_matches(
|
||||
applicable_resources: str, resource: str # pylint: disable=unused-argument
|
||||
) -> bool:
|
||||
if applicable_resources == "*":
|
||||
return True
|
||||
return False
|
||||
|
@ -3,6 +3,7 @@ import json
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from typing import Any, Dict
|
||||
|
||||
from moto.core.responses import BaseResponse
|
||||
from moto.kms.utils import RESERVED_ALIASES, RESERVED_ALIASE_TARGET_KEY_IDS
|
||||
@ -17,24 +18,23 @@ from .exceptions import (
|
||||
|
||||
|
||||
class KmsResponse(BaseResponse):
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(service_name="kms")
|
||||
|
||||
@property
|
||||
def parameters(self):
|
||||
def _get_param(self, param_name: str, if_none: Any = None) -> Any: # type: ignore
|
||||
params = json.loads(self.body)
|
||||
|
||||
for key in ("Plaintext", "CiphertextBlob"):
|
||||
if key in params:
|
||||
params[key] = base64.b64decode(params[key].encode("utf-8"))
|
||||
|
||||
return params
|
||||
return params.get(param_name, if_none)
|
||||
|
||||
@property
|
||||
def kms_backend(self) -> KmsBackend:
|
||||
return kms_backends[self.current_account][self.region]
|
||||
|
||||
def _display_arn(self, key_id):
|
||||
def _display_arn(self, key_id: str) -> str:
|
||||
if key_id.startswith("arn:"):
|
||||
return key_id
|
||||
|
||||
@ -45,7 +45,7 @@ class KmsResponse(BaseResponse):
|
||||
|
||||
return f"arn:aws:kms:{self.region}:{self.current_account}:{id_type}{key_id}"
|
||||
|
||||
def _validate_cmk_id(self, key_id):
|
||||
def _validate_cmk_id(self, key_id: str) -> None:
|
||||
"""Determine whether a CMK ID exists.
|
||||
|
||||
- raw key ID
|
||||
@ -69,7 +69,7 @@ class KmsResponse(BaseResponse):
|
||||
if cmk_id not in self.kms_backend.keys:
|
||||
raise NotFoundException(f"Key '{self._display_arn(key_id)}' does not exist")
|
||||
|
||||
def _validate_alias(self, key_id):
|
||||
def _validate_alias(self, key_id: str) -> None:
|
||||
"""Determine whether an alias exists.
|
||||
|
||||
- alias name
|
||||
@ -88,8 +88,8 @@ class KmsResponse(BaseResponse):
|
||||
if cmk_id is None:
|
||||
raise error
|
||||
|
||||
def _validate_key_id(self, key_id):
|
||||
"""Determine whether or not a key ID exists.
|
||||
def _validate_key_id(self, key_id: str) -> None:
|
||||
"""Determine whether a key ID exists.
|
||||
|
||||
- raw key ID
|
||||
- key ARN
|
||||
@ -105,77 +105,77 @@ class KmsResponse(BaseResponse):
|
||||
|
||||
self._validate_cmk_id(key_id)
|
||||
|
||||
def _validate_key_policy(self, key_id, action):
|
||||
def _validate_key_policy(self, key_id: str, action: str) -> None:
|
||||
"""
|
||||
Validate whether the specified action is allowed, given the key policy
|
||||
"""
|
||||
key = self.kms_backend.describe_key(self.kms_backend.get_key_id(key_id))
|
||||
validate_policy(key, action)
|
||||
|
||||
def create_key(self):
|
||||
def create_key(self) -> str:
|
||||
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_CreateKey.html"""
|
||||
policy = self.parameters.get("Policy")
|
||||
key_usage = self.parameters.get("KeyUsage")
|
||||
key_spec = self.parameters.get("KeySpec") or self.parameters.get(
|
||||
policy = self._get_param("Policy")
|
||||
key_usage = self._get_param("KeyUsage")
|
||||
key_spec = self._get_param("KeySpec") or self._get_param(
|
||||
"CustomerMasterKeySpec"
|
||||
)
|
||||
description = self.parameters.get("Description")
|
||||
tags = self.parameters.get("Tags")
|
||||
multi_region = self.parameters.get("MultiRegion")
|
||||
description = self._get_param("Description")
|
||||
tags = self._get_param("Tags")
|
||||
multi_region = self._get_param("MultiRegion")
|
||||
|
||||
key = self.kms_backend.create_key(
|
||||
policy, key_usage, key_spec, description, tags, multi_region
|
||||
)
|
||||
return json.dumps(key.to_dict())
|
||||
|
||||
def replicate_key(self):
|
||||
key_id = self.parameters.get("KeyId")
|
||||
def replicate_key(self) -> None:
|
||||
key_id = self._get_param("KeyId")
|
||||
self._validate_key_id(key_id)
|
||||
replica_region = self.parameters.get("ReplicaRegion")
|
||||
replica_region = self._get_param("ReplicaRegion")
|
||||
self.kms_backend.replicate_key(key_id, replica_region)
|
||||
|
||||
def update_key_description(self):
|
||||
def update_key_description(self) -> str:
|
||||
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_UpdateKeyDescription.html"""
|
||||
key_id = self.parameters.get("KeyId")
|
||||
description = self.parameters.get("Description")
|
||||
key_id = self._get_param("KeyId")
|
||||
description = self._get_param("Description")
|
||||
|
||||
self._validate_cmk_id(key_id)
|
||||
|
||||
self.kms_backend.update_key_description(key_id, description)
|
||||
return json.dumps(None)
|
||||
|
||||
def tag_resource(self):
|
||||
def tag_resource(self) -> str:
|
||||
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_TagResource.html"""
|
||||
key_id = self.parameters.get("KeyId")
|
||||
tags = self.parameters.get("Tags")
|
||||
key_id = self._get_param("KeyId")
|
||||
tags = self._get_param("Tags")
|
||||
|
||||
self._validate_cmk_id(key_id)
|
||||
|
||||
result = self.kms_backend.tag_resource(key_id, tags)
|
||||
return json.dumps(result)
|
||||
self.kms_backend.tag_resource(key_id, tags)
|
||||
return "{}"
|
||||
|
||||
def untag_resource(self):
|
||||
def untag_resource(self) -> str:
|
||||
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_UntagResource.html"""
|
||||
key_id = self.parameters.get("KeyId")
|
||||
tag_names = self.parameters.get("TagKeys")
|
||||
key_id = self._get_param("KeyId")
|
||||
tag_names = self._get_param("TagKeys")
|
||||
|
||||
self._validate_cmk_id(key_id)
|
||||
|
||||
result = self.kms_backend.untag_resource(key_id, tag_names)
|
||||
return json.dumps(result)
|
||||
self.kms_backend.untag_resource(key_id, tag_names)
|
||||
return "{}"
|
||||
|
||||
def list_resource_tags(self):
|
||||
def list_resource_tags(self) -> str:
|
||||
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ListResourceTags.html"""
|
||||
key_id = self.parameters.get("KeyId")
|
||||
key_id = self._get_param("KeyId")
|
||||
self._validate_cmk_id(key_id)
|
||||
|
||||
tags = self.kms_backend.list_resource_tags(key_id)
|
||||
tags: Dict[str, Any] = self.kms_backend.list_resource_tags(key_id)
|
||||
tags.update({"NextMarker": None, "Truncated": False})
|
||||
return json.dumps(tags)
|
||||
|
||||
def describe_key(self):
|
||||
def describe_key(self) -> str:
|
||||
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_DescribeKey.html"""
|
||||
key_id = self.parameters.get("KeyId")
|
||||
key_id = self._get_param("KeyId")
|
||||
|
||||
self._validate_key_id(key_id)
|
||||
self._validate_key_policy(key_id, "kms:DescribeKey")
|
||||
@ -184,7 +184,7 @@ class KmsResponse(BaseResponse):
|
||||
|
||||
return json.dumps(key.to_dict())
|
||||
|
||||
def list_keys(self):
|
||||
def list_keys(self) -> str:
|
||||
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ListKeys.html"""
|
||||
keys = self.kms_backend.list_keys()
|
||||
|
||||
@ -196,17 +196,17 @@ class KmsResponse(BaseResponse):
|
||||
}
|
||||
)
|
||||
|
||||
def create_alias(self):
|
||||
def create_alias(self) -> str:
|
||||
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_CreateAlias.html"""
|
||||
return self._set_alias()
|
||||
|
||||
def update_alias(self):
|
||||
def update_alias(self) -> str:
|
||||
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_UpdateAlias.html"""
|
||||
return self._set_alias(update=True)
|
||||
|
||||
def _set_alias(self, update=False):
|
||||
alias_name = self.parameters["AliasName"]
|
||||
target_key_id = self.parameters["TargetKeyId"]
|
||||
def _set_alias(self, update: bool = False) -> str:
|
||||
alias_name = self._get_param("AliasName")
|
||||
target_key_id = self._get_param("TargetKeyId")
|
||||
|
||||
if not alias_name.startswith("alias/"):
|
||||
raise ValidationException("Invalid identifier")
|
||||
@ -243,9 +243,9 @@ class KmsResponse(BaseResponse):
|
||||
|
||||
return json.dumps(None)
|
||||
|
||||
def delete_alias(self):
|
||||
def delete_alias(self) -> str:
|
||||
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_DeleteAlias.html"""
|
||||
alias_name = self.parameters["AliasName"]
|
||||
alias_name = self._get_param("AliasName")
|
||||
|
||||
if not alias_name.startswith("alias/"):
|
||||
raise ValidationException("Invalid identifier")
|
||||
@ -256,10 +256,10 @@ class KmsResponse(BaseResponse):
|
||||
|
||||
return json.dumps(None)
|
||||
|
||||
def list_aliases(self):
|
||||
def list_aliases(self) -> str:
|
||||
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ListAliases.html"""
|
||||
region = self.region
|
||||
key_id = self.parameters.get("KeyId")
|
||||
key_id = self._get_param("KeyId")
|
||||
if key_id is not None:
|
||||
self._validate_key_id(key_id)
|
||||
key_id = self.kms_backend.get_key_id(key_id)
|
||||
@ -298,13 +298,13 @@ class KmsResponse(BaseResponse):
|
||||
|
||||
return json.dumps({"Truncated": False, "Aliases": response_aliases})
|
||||
|
||||
def create_grant(self):
|
||||
key_id = self.parameters.get("KeyId")
|
||||
grantee_principal = self.parameters.get("GranteePrincipal")
|
||||
retiring_principal = self.parameters.get("RetiringPrincipal")
|
||||
operations = self.parameters.get("Operations")
|
||||
name = self.parameters.get("Name")
|
||||
constraints = self.parameters.get("Constraints")
|
||||
def create_grant(self) -> str:
|
||||
key_id = self._get_param("KeyId")
|
||||
grantee_principal = self._get_param("GranteePrincipal")
|
||||
retiring_principal = self._get_param("RetiringPrincipal")
|
||||
operations = self._get_param("Operations")
|
||||
name = self._get_param("Name")
|
||||
constraints = self._get_param("Constraints")
|
||||
|
||||
grant_id, grant_token = self.kms_backend.create_grant(
|
||||
key_id,
|
||||
@ -316,9 +316,9 @@ class KmsResponse(BaseResponse):
|
||||
)
|
||||
return json.dumps({"GrantId": grant_id, "GrantToken": grant_token})
|
||||
|
||||
def list_grants(self):
|
||||
key_id = self.parameters.get("KeyId")
|
||||
grant_id = self.parameters.get("GrantId")
|
||||
def list_grants(self) -> str:
|
||||
key_id = self._get_param("KeyId")
|
||||
grant_id = self._get_param("GrantId")
|
||||
|
||||
grants = self.kms_backend.list_grants(key_id=key_id, grant_id=grant_id)
|
||||
return json.dumps(
|
||||
@ -329,8 +329,8 @@ class KmsResponse(BaseResponse):
|
||||
}
|
||||
)
|
||||
|
||||
def list_retirable_grants(self):
|
||||
retiring_principal = self.parameters.get("RetiringPrincipal")
|
||||
def list_retirable_grants(self) -> str:
|
||||
retiring_principal = self._get_param("RetiringPrincipal")
|
||||
|
||||
grants = self.kms_backend.list_retirable_grants(retiring_principal)
|
||||
return json.dumps(
|
||||
@ -341,24 +341,24 @@ class KmsResponse(BaseResponse):
|
||||
}
|
||||
)
|
||||
|
||||
def revoke_grant(self):
|
||||
key_id = self.parameters.get("KeyId")
|
||||
grant_id = self.parameters.get("GrantId")
|
||||
def revoke_grant(self) -> str:
|
||||
key_id = self._get_param("KeyId")
|
||||
grant_id = self._get_param("GrantId")
|
||||
|
||||
self.kms_backend.revoke_grant(key_id, grant_id)
|
||||
return "{}"
|
||||
|
||||
def retire_grant(self):
|
||||
key_id = self.parameters.get("KeyId")
|
||||
grant_id = self.parameters.get("GrantId")
|
||||
grant_token = self.parameters.get("GrantToken")
|
||||
def retire_grant(self) -> str:
|
||||
key_id = self._get_param("KeyId")
|
||||
grant_id = self._get_param("GrantId")
|
||||
grant_token = self._get_param("GrantToken")
|
||||
|
||||
self.kms_backend.retire_grant(key_id, grant_id, grant_token)
|
||||
return "{}"
|
||||
|
||||
def enable_key_rotation(self):
|
||||
def enable_key_rotation(self) -> str:
|
||||
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKeyRotation.html"""
|
||||
key_id = self.parameters.get("KeyId")
|
||||
key_id = self._get_param("KeyId")
|
||||
|
||||
self._validate_cmk_id(key_id)
|
||||
|
||||
@ -366,9 +366,9 @@ class KmsResponse(BaseResponse):
|
||||
|
||||
return json.dumps(None)
|
||||
|
||||
def disable_key_rotation(self):
|
||||
def disable_key_rotation(self) -> str:
|
||||
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKeyRotation.html"""
|
||||
key_id = self.parameters.get("KeyId")
|
||||
key_id = self._get_param("KeyId")
|
||||
|
||||
self._validate_cmk_id(key_id)
|
||||
|
||||
@ -376,9 +376,9 @@ class KmsResponse(BaseResponse):
|
||||
|
||||
return json.dumps(None)
|
||||
|
||||
def get_key_rotation_status(self):
|
||||
def get_key_rotation_status(self) -> str:
|
||||
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_GetKeyRotationStatus.html"""
|
||||
key_id = self.parameters.get("KeyId")
|
||||
key_id = self._get_param("KeyId")
|
||||
|
||||
self._validate_cmk_id(key_id)
|
||||
|
||||
@ -386,11 +386,11 @@ class KmsResponse(BaseResponse):
|
||||
|
||||
return json.dumps({"KeyRotationEnabled": rotation_enabled})
|
||||
|
||||
def put_key_policy(self):
|
||||
def put_key_policy(self) -> str:
|
||||
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_PutKeyPolicy.html"""
|
||||
key_id = self.parameters.get("KeyId")
|
||||
policy_name = self.parameters.get("PolicyName")
|
||||
policy = self.parameters.get("Policy")
|
||||
key_id = self._get_param("KeyId")
|
||||
policy_name = self._get_param("PolicyName")
|
||||
policy = self._get_param("Policy")
|
||||
_assert_default_policy(policy_name)
|
||||
|
||||
self._validate_cmk_id(key_id)
|
||||
@ -399,10 +399,10 @@ class KmsResponse(BaseResponse):
|
||||
|
||||
return json.dumps(None)
|
||||
|
||||
def get_key_policy(self):
|
||||
def get_key_policy(self) -> str:
|
||||
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_GetKeyPolicy.html"""
|
||||
key_id = self.parameters.get("KeyId")
|
||||
policy_name = self.parameters.get("PolicyName")
|
||||
key_id = self._get_param("KeyId")
|
||||
policy_name = self._get_param("PolicyName")
|
||||
_assert_default_policy(policy_name)
|
||||
|
||||
self._validate_cmk_id(key_id)
|
||||
@ -410,9 +410,9 @@ class KmsResponse(BaseResponse):
|
||||
policy = self.kms_backend.get_key_policy(key_id) or "{}"
|
||||
return json.dumps({"Policy": policy})
|
||||
|
||||
def list_key_policies(self):
|
||||
def list_key_policies(self) -> str:
|
||||
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ListKeyPolicies.html"""
|
||||
key_id = self.parameters.get("KeyId")
|
||||
key_id = self._get_param("KeyId")
|
||||
|
||||
self._validate_cmk_id(key_id)
|
||||
|
||||
@ -420,11 +420,11 @@ class KmsResponse(BaseResponse):
|
||||
|
||||
return json.dumps({"Truncated": False, "PolicyNames": ["default"]})
|
||||
|
||||
def encrypt(self):
|
||||
def encrypt(self) -> str:
|
||||
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_Encrypt.html"""
|
||||
key_id = self.parameters.get("KeyId")
|
||||
encryption_context = self.parameters.get("EncryptionContext", {})
|
||||
plaintext = self.parameters.get("Plaintext")
|
||||
key_id = self._get_param("KeyId")
|
||||
encryption_context = self._get_param("EncryptionContext", {})
|
||||
plaintext = self._get_param("Plaintext")
|
||||
|
||||
self._validate_key_id(key_id)
|
||||
|
||||
@ -438,10 +438,10 @@ class KmsResponse(BaseResponse):
|
||||
|
||||
return json.dumps({"CiphertextBlob": ciphertext_blob_response, "KeyId": arn})
|
||||
|
||||
def decrypt(self):
|
||||
def decrypt(self) -> str:
|
||||
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_Decrypt.html"""
|
||||
ciphertext_blob = self.parameters.get("CiphertextBlob")
|
||||
encryption_context = self.parameters.get("EncryptionContext", {})
|
||||
ciphertext_blob = self._get_param("CiphertextBlob")
|
||||
encryption_context = self._get_param("EncryptionContext", {})
|
||||
|
||||
plaintext, arn = self.kms_backend.decrypt(
|
||||
ciphertext_blob=ciphertext_blob, encryption_context=encryption_context
|
||||
@ -451,12 +451,12 @@ class KmsResponse(BaseResponse):
|
||||
|
||||
return json.dumps({"Plaintext": plaintext_response, "KeyId": arn})
|
||||
|
||||
def re_encrypt(self):
|
||||
def re_encrypt(self) -> str:
|
||||
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ReEncrypt.html"""
|
||||
ciphertext_blob = self.parameters.get("CiphertextBlob")
|
||||
source_encryption_context = self.parameters.get("SourceEncryptionContext", {})
|
||||
destination_key_id = self.parameters.get("DestinationKeyId")
|
||||
destination_encryption_context = self.parameters.get(
|
||||
ciphertext_blob = self._get_param("CiphertextBlob")
|
||||
source_encryption_context = self._get_param("SourceEncryptionContext", {})
|
||||
destination_key_id = self._get_param("DestinationKeyId")
|
||||
destination_encryption_context = self._get_param(
|
||||
"DestinationEncryptionContext", {}
|
||||
)
|
||||
|
||||
@ -483,9 +483,9 @@ class KmsResponse(BaseResponse):
|
||||
}
|
||||
)
|
||||
|
||||
def disable_key(self):
|
||||
def disable_key(self) -> str:
|
||||
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_DisableKey.html"""
|
||||
key_id = self.parameters.get("KeyId")
|
||||
key_id = self._get_param("KeyId")
|
||||
|
||||
self._validate_cmk_id(key_id)
|
||||
|
||||
@ -493,9 +493,9 @@ class KmsResponse(BaseResponse):
|
||||
|
||||
return json.dumps(None)
|
||||
|
||||
def enable_key(self):
|
||||
def enable_key(self) -> str:
|
||||
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKey.html"""
|
||||
key_id = self.parameters.get("KeyId")
|
||||
key_id = self._get_param("KeyId")
|
||||
|
||||
self._validate_cmk_id(key_id)
|
||||
|
||||
@ -503,9 +503,9 @@ class KmsResponse(BaseResponse):
|
||||
|
||||
return json.dumps(None)
|
||||
|
||||
def cancel_key_deletion(self):
|
||||
def cancel_key_deletion(self) -> str:
|
||||
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_CancelKeyDeletion.html"""
|
||||
key_id = self.parameters.get("KeyId")
|
||||
key_id = self._get_param("KeyId")
|
||||
|
||||
self._validate_cmk_id(key_id)
|
||||
|
||||
@ -513,13 +513,13 @@ class KmsResponse(BaseResponse):
|
||||
|
||||
return json.dumps({"KeyId": key_id})
|
||||
|
||||
def schedule_key_deletion(self):
|
||||
def schedule_key_deletion(self) -> str:
|
||||
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ScheduleKeyDeletion.html"""
|
||||
key_id = self.parameters.get("KeyId")
|
||||
if self.parameters.get("PendingWindowInDays") is None:
|
||||
key_id = self._get_param("KeyId")
|
||||
if self._get_param("PendingWindowInDays") is None:
|
||||
pending_window_in_days = 30
|
||||
else:
|
||||
pending_window_in_days = self.parameters.get("PendingWindowInDays")
|
||||
pending_window_in_days = self._get_param("PendingWindowInDays")
|
||||
|
||||
self._validate_cmk_id(key_id)
|
||||
|
||||
@ -532,12 +532,12 @@ class KmsResponse(BaseResponse):
|
||||
}
|
||||
)
|
||||
|
||||
def generate_data_key(self):
|
||||
def generate_data_key(self) -> str:
|
||||
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateDataKey.html"""
|
||||
key_id = self.parameters.get("KeyId")
|
||||
encryption_context = self.parameters.get("EncryptionContext", {})
|
||||
number_of_bytes = self.parameters.get("NumberOfBytes")
|
||||
key_spec = self.parameters.get("KeySpec")
|
||||
key_id = self._get_param("KeyId")
|
||||
encryption_context = self._get_param("EncryptionContext", {})
|
||||
number_of_bytes = self._get_param("NumberOfBytes")
|
||||
key_spec = self._get_param("KeySpec")
|
||||
|
||||
# Param validation
|
||||
self._validate_key_id(key_id)
|
||||
@ -587,16 +587,16 @@ class KmsResponse(BaseResponse):
|
||||
}
|
||||
)
|
||||
|
||||
def generate_data_key_without_plaintext(self):
|
||||
def generate_data_key_without_plaintext(self) -> str:
|
||||
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateDataKeyWithoutPlaintext.html"""
|
||||
result = json.loads(self.generate_data_key())
|
||||
del result["Plaintext"]
|
||||
|
||||
return json.dumps(result)
|
||||
|
||||
def generate_random(self):
|
||||
def generate_random(self) -> str:
|
||||
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateRandom.html"""
|
||||
number_of_bytes = self.parameters.get("NumberOfBytes")
|
||||
number_of_bytes = self._get_param("NumberOfBytes")
|
||||
|
||||
if number_of_bytes and (number_of_bytes > 1024 or number_of_bytes < 1):
|
||||
raise ValidationException(
|
||||
@ -613,13 +613,13 @@ class KmsResponse(BaseResponse):
|
||||
|
||||
return json.dumps({"Plaintext": response_entropy})
|
||||
|
||||
def sign(self):
|
||||
def sign(self) -> str:
|
||||
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_Sign.html"""
|
||||
key_id = self.parameters.get("KeyId")
|
||||
message = self.parameters.get("Message")
|
||||
message_type = self.parameters.get("MessageType")
|
||||
grant_tokens = self.parameters.get("GrantTokens")
|
||||
signing_algorithm = self.parameters.get("SigningAlgorithm")
|
||||
key_id = self._get_param("KeyId")
|
||||
message = self._get_param("Message")
|
||||
message_type = self._get_param("MessageType")
|
||||
grant_tokens = self._get_param("GrantTokens")
|
||||
signing_algorithm = self._get_param("SigningAlgorithm")
|
||||
|
||||
self._validate_key_id(key_id)
|
||||
|
||||
@ -660,14 +660,14 @@ class KmsResponse(BaseResponse):
|
||||
}
|
||||
)
|
||||
|
||||
def verify(self):
|
||||
def verify(self) -> str:
|
||||
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_Verify.html"""
|
||||
key_id = self.parameters.get("KeyId")
|
||||
message = self.parameters.get("Message")
|
||||
message_type = self.parameters.get("MessageType")
|
||||
signature = self.parameters.get("Signature")
|
||||
signing_algorithm = self.parameters.get("SigningAlgorithm")
|
||||
grant_tokens = self.parameters.get("GrantTokens")
|
||||
key_id = self._get_param("KeyId")
|
||||
message = self._get_param("Message")
|
||||
message_type = self._get_param("MessageType")
|
||||
signature = self._get_param("Signature")
|
||||
signing_algorithm = self._get_param("SigningAlgorithm")
|
||||
grant_tokens = self._get_param("GrantTokens")
|
||||
|
||||
self._validate_key_id(key_id)
|
||||
|
||||
@ -722,6 +722,6 @@ class KmsResponse(BaseResponse):
|
||||
)
|
||||
|
||||
|
||||
def _assert_default_policy(policy_name):
|
||||
def _assert_default_policy(policy_name: str) -> None:
|
||||
if policy_name != "default":
|
||||
raise NotFoundException("No such policy exists")
|
||||
|
@ -1,4 +1,5 @@
|
||||
from collections import namedtuple
|
||||
from typing import Any, Dict, Tuple
|
||||
import io
|
||||
import os
|
||||
import struct
|
||||
@ -47,7 +48,7 @@ RESERVED_ALIASE_TARGET_KEY_IDS = {
|
||||
RESERVED_ALIASES = list(RESERVED_ALIASE_TARGET_KEY_IDS.keys())
|
||||
|
||||
|
||||
def generate_key_id(multi_region=False):
|
||||
def generate_key_id(multi_region: bool = False) -> str:
|
||||
key = str(mock_random.uuid4())
|
||||
# https://docs.aws.amazon.com/kms/latest/developerguide/multi-region-keys-overview.html
|
||||
# "Notice that multi-Region keys have a distinctive key ID that begins with mrk-. You can use the mrk- prefix to
|
||||
@ -58,17 +59,17 @@ def generate_key_id(multi_region=False):
|
||||
return key
|
||||
|
||||
|
||||
def generate_data_key(number_of_bytes):
|
||||
def generate_data_key(number_of_bytes: int) -> bytes:
|
||||
"""Generate a data key."""
|
||||
return os.urandom(number_of_bytes)
|
||||
|
||||
|
||||
def generate_master_key():
|
||||
def generate_master_key() -> bytes:
|
||||
"""Generate a master key."""
|
||||
return generate_data_key(MASTER_KEY_LEN)
|
||||
|
||||
|
||||
def generate_private_key():
|
||||
def generate_private_key() -> rsa.RSAPrivateKey:
|
||||
"""Generate a private key to be used on asymmetric sign/verify.
|
||||
|
||||
NOTE: KeySpec is not taken into consideration and the key is always RSA_2048
|
||||
@ -80,7 +81,7 @@ def generate_private_key():
|
||||
)
|
||||
|
||||
|
||||
def _serialize_ciphertext_blob(ciphertext):
|
||||
def _serialize_ciphertext_blob(ciphertext: Ciphertext) -> bytes:
|
||||
"""Serialize Ciphertext object into a ciphertext blob.
|
||||
|
||||
NOTE: This is just a simple binary format. It is not what KMS actually does.
|
||||
@ -94,7 +95,7 @@ def _serialize_ciphertext_blob(ciphertext):
|
||||
return header + ciphertext.ciphertext
|
||||
|
||||
|
||||
def _deserialize_ciphertext_blob(ciphertext_blob):
|
||||
def _deserialize_ciphertext_blob(ciphertext_blob: bytes) -> Ciphertext:
|
||||
"""Deserialize ciphertext blob into a Ciphertext object.
|
||||
|
||||
NOTE: This is just a simple binary format. It is not what KMS actually does.
|
||||
@ -107,7 +108,7 @@ def _deserialize_ciphertext_blob(ciphertext_blob):
|
||||
)
|
||||
|
||||
|
||||
def _serialize_encryption_context(encryption_context):
|
||||
def _serialize_encryption_context(encryption_context: Dict[str, str]) -> bytes:
|
||||
"""Serialize encryption context for use a AAD.
|
||||
|
||||
NOTE: This is not necessarily what KMS does, but it retains the same properties.
|
||||
@ -119,7 +120,12 @@ def _serialize_encryption_context(encryption_context):
|
||||
return aad.getvalue()
|
||||
|
||||
|
||||
def encrypt(master_keys, key_id, plaintext, encryption_context):
|
||||
def encrypt(
|
||||
master_keys: Dict[str, Any],
|
||||
key_id: str,
|
||||
plaintext: bytes,
|
||||
encryption_context: Dict[str, str],
|
||||
) -> bytes:
|
||||
"""Encrypt data using a master key material.
|
||||
|
||||
NOTE: This is not necessarily what KMS does, but it retains the same properties.
|
||||
@ -159,7 +165,11 @@ def encrypt(master_keys, key_id, plaintext, encryption_context):
|
||||
)
|
||||
|
||||
|
||||
def decrypt(master_keys, ciphertext_blob, encryption_context):
|
||||
def decrypt(
|
||||
master_keys: Dict[str, Any],
|
||||
ciphertext_blob: bytes,
|
||||
encryption_context: Dict[str, str],
|
||||
) -> Tuple[bytes, str]:
|
||||
"""Decrypt a ciphertext blob using a master key material.
|
||||
|
||||
NOTE: This is not necessarily what KMS does, but it retains the same properties.
|
||||
|
@ -9,7 +9,7 @@ SESSION_TOKEN_PREFIX = "FQoGZXIvYXdzEBYaD"
|
||||
DEFAULT_STS_SESSION_DURATION = 3600
|
||||
|
||||
|
||||
def random_session_token():
|
||||
def random_session_token() -> str:
|
||||
return (
|
||||
SESSION_TOKEN_PREFIX
|
||||
+ base64.b64encode(os.urandom(266))[len(SESSION_TOKEN_PREFIX) :].decode()
|
||||
|
@ -235,7 +235,7 @@ disable = W,C,R,E
|
||||
enable = anomalous-backslash-in-string, arguments-renamed, dangerous-default-value, deprecated-module, function-redefined, import-self, redefined-builtin, redefined-outer-name, reimported, pointless-statement, super-with-arguments, unused-argument, unused-import, unused-variable, useless-else-on-loop, wildcard-import
|
||||
|
||||
[mypy]
|
||||
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/moto_api,moto/neptune
|
||||
files= moto/a*,moto/b*,moto/c*,moto/d*,moto/e*,moto/f*,moto/g*,moto/i*,moto/k*,moto/moto_api,moto/neptune
|
||||
show_column_numbers=True
|
||||
show_error_codes = True
|
||||
disable_error_code=abstract
|
||||
|
Loading…
Reference in New Issue
Block a user