1038 lines
38 KiB
Python

from collections import defaultdict
import copy
import datetime
from typing import Any, Dict, Optional, List, Tuple, Iterator, Sequence
from moto.core import BaseModel, CloudFormationModel
from moto.core.utils import unix_time, unix_time_millis
from moto.dynamodb.comparisons import get_filter_expression, get_expected
from moto.dynamodb.exceptions import (
InvalidIndexNameError,
HashKeyTooLong,
RangeKeyTooLong,
ConditionalCheckFailed,
InvalidAttributeTypeError,
MockValidationException,
InvalidConversion,
SerializationException,
)
from moto.dynamodb.models.utilities import dynamo_json_dump
from moto.dynamodb.models.dynamo_type import DynamoType, Item
from moto.dynamodb.limits import HASH_KEY_MAX_LENGTH, RANGE_KEY_MAX_LENGTH
from moto.moto_api._internal import mock_random
class SecondaryIndex(BaseModel):
def __init__(
self,
index_name: str,
schema: List[Dict[str, str]],
projection: Dict[str, Any],
table_key_attrs: List[str],
):
self.name = index_name
self.schema = schema
self.table_key_attrs = table_key_attrs
self.projection = projection
self.schema_key_attrs = [k["AttributeName"] for k in schema]
def project(self, item: Item) -> Item:
"""
Enforces the ProjectionType of this Index (LSI/GSI)
Removes any non-wanted attributes from the item
:param item:
:return:
"""
if self.projection:
projection_type = self.projection.get("ProjectionType", None)
key_attributes = self.table_key_attrs + [
key["AttributeName"] for key in self.schema
]
if projection_type == "KEYS_ONLY":
item.filter(",".join(key_attributes))
elif projection_type == "INCLUDE":
allowed_attributes = key_attributes + self.projection.get(
"NonKeyAttributes", []
)
item.filter(",".join(allowed_attributes))
# ALL is handled implicitly by not filtering
return item
class LocalSecondaryIndex(SecondaryIndex):
def describe(self) -> Dict[str, Any]:
return {
"IndexName": self.name,
"KeySchema": self.schema,
"Projection": self.projection,
}
@staticmethod
def create(dct: Dict[str, Any], table_key_attrs: List[str]) -> "LocalSecondaryIndex": # type: ignore[misc]
return LocalSecondaryIndex(
index_name=dct["IndexName"],
schema=dct["KeySchema"],
projection=dct["Projection"],
table_key_attrs=table_key_attrs,
)
class GlobalSecondaryIndex(SecondaryIndex):
def __init__(
self,
index_name: str,
schema: List[Dict[str, str]],
projection: Dict[str, Any],
table_key_attrs: List[str],
status: str = "ACTIVE",
throughput: Optional[Dict[str, Any]] = None,
):
super().__init__(index_name, schema, projection, table_key_attrs)
self.status = status
self.throughput = throughput or {
"ReadCapacityUnits": 0,
"WriteCapacityUnits": 0,
}
def describe(self) -> Dict[str, Any]:
return {
"IndexName": self.name,
"KeySchema": self.schema,
"Projection": self.projection,
"IndexStatus": self.status,
"ProvisionedThroughput": self.throughput,
}
@staticmethod
def create(dct: Dict[str, Any], table_key_attrs: List[str]) -> "GlobalSecondaryIndex": # type: ignore[misc]
return GlobalSecondaryIndex(
index_name=dct["IndexName"],
schema=dct["KeySchema"],
projection=dct["Projection"],
table_key_attrs=table_key_attrs,
throughput=dct.get("ProvisionedThroughput", None),
)
def update(self, u: Dict[str, Any]) -> None:
self.name = u.get("IndexName", self.name)
self.schema = u.get("KeySchema", self.schema)
self.projection = u.get("Projection", self.projection)
self.throughput = u.get("ProvisionedThroughput", self.throughput)
class StreamRecord(BaseModel):
def __init__(
self,
table: "Table",
stream_type: str,
event_name: str,
old: Optional[Item],
new: Optional[Item],
seq: int,
):
old_a = old.to_json()["Attributes"] if old is not None else {}
new_a = new.to_json()["Attributes"] if new is not None else {}
rec = old if old is not None else new
keys = {table.hash_key_attr: rec.hash_key.to_json()} # type: ignore[union-attr]
if table.range_key_attr is not None and rec is not None:
keys[table.range_key_attr] = rec.range_key.to_json() # type: ignore
self.record: Dict[str, Any] = {
"eventID": mock_random.uuid4().hex,
"eventName": event_name,
"eventSource": "aws:dynamodb",
"eventVersion": "1.0",
"awsRegion": "us-east-1",
"dynamodb": {
"StreamViewType": stream_type,
"ApproximateCreationDateTime": datetime.datetime.utcnow().isoformat(),
"SequenceNumber": str(seq),
"SizeBytes": 1,
"Keys": keys,
},
}
if stream_type in ("NEW_IMAGE", "NEW_AND_OLD_IMAGES"):
self.record["dynamodb"]["NewImage"] = new_a
if stream_type in ("OLD_IMAGE", "NEW_AND_OLD_IMAGES"):
self.record["dynamodb"]["OldImage"] = old_a
# This is a substantial overestimate but it's the easiest to do now
self.record["dynamodb"]["SizeBytes"] = len(
dynamo_json_dump(self.record["dynamodb"])
)
def to_json(self) -> Dict[str, Any]:
return self.record
class StreamShard(BaseModel):
def __init__(self, account_id: str, table: "Table"):
self.account_id = account_id
self.table = table
self.id = "shardId-00000001541626099285-f35f62ef"
self.starting_sequence_number = 1100000000017454423009
self.items: List[StreamRecord] = []
self.created_on = datetime.datetime.utcnow()
def to_json(self) -> Dict[str, Any]:
return {
"ShardId": self.id,
"SequenceNumberRange": {
"StartingSequenceNumber": str(self.starting_sequence_number)
},
}
def add(self, old: Optional[Item], new: Optional[Item]) -> None:
t = self.table.stream_specification["StreamViewType"] # type: ignore
if old is None:
event_name = "INSERT"
elif new is None:
event_name = "REMOVE"
else:
event_name = "MODIFY"
seq = len(self.items) + self.starting_sequence_number
self.items.append(StreamRecord(self.table, t, event_name, old, new, seq))
result = None
from moto.awslambda import lambda_backends
for arn, esm in self.table.lambda_event_source_mappings.items():
region = arn[
len("arn:aws:lambda:") : arn.index(":", len("arn:aws:lambda:"))
]
result = lambda_backends[self.account_id][region].send_dynamodb_items(
arn, self.items, esm.event_source_arn
)
if result:
self.items = []
def get(self, start: int, quantity: int) -> List[Dict[str, Any]]:
start -= self.starting_sequence_number
assert start >= 0
end = start + quantity
return [i.to_json() for i in self.items[start:end]]
class Table(CloudFormationModel):
def __init__(
self,
table_name: str,
account_id: str,
region: str,
schema: List[Dict[str, Any]],
attr: List[Dict[str, str]],
throughput: Optional[Dict[str, int]] = None,
billing_mode: Optional[str] = None,
indexes: Optional[List[Dict[str, Any]]] = None,
global_indexes: Optional[List[Dict[str, Any]]] = None,
streams: Optional[Dict[str, Any]] = None,
sse_specification: Optional[Dict[str, Any]] = None,
tags: Optional[List[Dict[str, str]]] = None,
):
self.name = table_name
self.account_id = account_id
self.region_name = region
self.attr = attr
self.schema = schema
self.range_key_attr: Optional[str] = None
self.hash_key_attr: str = ""
self.range_key_type: Optional[str] = None
self.hash_key_type: str = ""
for elem in schema:
attr_type = [
a["AttributeType"]
for a in attr
if a["AttributeName"] == elem["AttributeName"]
][0]
if elem["KeyType"] == "HASH":
self.hash_key_attr = elem["AttributeName"]
self.hash_key_type = attr_type
else:
self.range_key_attr = elem["AttributeName"]
self.range_key_type = attr_type
self.table_key_attrs = [
key for key in (self.hash_key_attr, self.range_key_attr) if key is not None
]
self.billing_mode = billing_mode
if throughput is None:
self.throughput = {"WriteCapacityUnits": 0, "ReadCapacityUnits": 0}
else:
self.throughput = throughput
self.throughput["NumberOfDecreasesToday"] = 0
self.indexes = [
LocalSecondaryIndex.create(i, self.table_key_attrs)
for i in (indexes if indexes else [])
]
self.global_indexes = [
GlobalSecondaryIndex.create(i, self.table_key_attrs)
for i in (global_indexes if global_indexes else [])
]
self.created_at = datetime.datetime.utcnow()
self.items = defaultdict(dict) # type: ignore # [hash: DynamoType] or [hash: [range: DynamoType]]
self.table_arn = self._generate_arn(table_name)
self.tags = tags or []
self.ttl = {
"TimeToLiveStatus": "DISABLED" # One of 'ENABLING'|'DISABLING'|'ENABLED'|'DISABLED',
# 'AttributeName': 'string' # Can contain this
}
self.stream_specification: Optional[Dict[str, Any]] = {"StreamEnabled": False}
self.latest_stream_label: Optional[str] = None
self.stream_shard: Optional[StreamShard] = None
self.set_stream_specification(streams)
self.lambda_event_source_mappings: Dict[str, Any] = {}
self.continuous_backups: Dict[str, Any] = {
"ContinuousBackupsStatus": "ENABLED", # One of 'ENABLED'|'DISABLED', it's enabled by default
"PointInTimeRecoveryDescription": {
"PointInTimeRecoveryStatus": "DISABLED" # One of 'ENABLED'|'DISABLED'
},
}
self.sse_specification = sse_specification
if self.sse_specification and "KMSMasterKeyId" not in self.sse_specification:
self.sse_specification["KMSMasterKeyId"] = self._get_default_encryption_key(
account_id, region
)
def _get_default_encryption_key(self, account_id: str, region: str) -> str:
from moto.kms import kms_backends
# https://aws.amazon.com/kms/features/#AWS_Service_Integration
# An AWS managed CMK is created automatically when you first create
# an encrypted resource using an AWS service integrated with KMS.
kms = kms_backends[account_id][region]
ddb_alias = "alias/aws/dynamodb"
if not kms.alias_exists(ddb_alias):
key = kms.create_key(
policy="",
key_usage="ENCRYPT_DECRYPT",
key_spec="SYMMETRIC_DEFAULT",
description="Default master key that protects my DynamoDB table storage",
tags=None,
)
kms.add_alias(key.id, ddb_alias)
ebs_key = kms.describe_key(ddb_alias)
return ebs_key.arn
@classmethod
def has_cfn_attr(cls, attr: str) -> bool:
return attr in ["Arn", "StreamArn"]
def get_cfn_attribute(self, attribute_name: str) -> Any: # type: ignore[misc]
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == "Arn":
return self.table_arn
elif attribute_name == "StreamArn" and self.stream_specification:
return self.describe()["TableDescription"]["LatestStreamArn"]
raise UnformattedGetAttTemplateException()
@property
def physical_resource_id(self) -> str:
return self.name
@property
def attribute_keys(self) -> List[str]:
# A set of all the hash or range attributes for all indexes
def keys_from_index(idx: SecondaryIndex) -> List[str]:
schema = idx.schema
return [attr["AttributeName"] for attr in schema]
fieldnames = copy.copy(self.table_key_attrs)
for idx in self.indexes + self.global_indexes:
fieldnames += keys_from_index(idx)
return fieldnames
@staticmethod
def cloudformation_name_type() -> str:
return "TableName"
@staticmethod
def cloudformation_type() -> str:
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-dynamodb-table.html
return "AWS::DynamoDB::Table"
@classmethod
def create_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Dict[str, Any],
account_id: str,
region_name: str,
**kwargs: Any,
) -> "Table":
from moto.dynamodb.models import dynamodb_backends
properties = cloudformation_json["Properties"]
params = {}
if "KeySchema" in properties:
params["schema"] = properties["KeySchema"]
if "AttributeDefinitions" in properties:
params["attr"] = properties["AttributeDefinitions"]
if "GlobalSecondaryIndexes" in properties:
params["global_indexes"] = properties["GlobalSecondaryIndexes"]
if "ProvisionedThroughput" in properties:
params["throughput"] = properties["ProvisionedThroughput"]
if "LocalSecondaryIndexes" in properties:
params["indexes"] = properties["LocalSecondaryIndexes"]
if "StreamSpecification" in properties:
params["streams"] = properties["StreamSpecification"]
table = dynamodb_backends[account_id][region_name].create_table(
name=resource_name, **params
)
return table
@classmethod
def delete_from_cloudformation_json( # type: ignore[misc]
cls,
resource_name: str,
cloudformation_json: Dict[str, Any],
account_id: str,
region_name: str,
) -> None:
from moto.dynamodb.models import dynamodb_backends
dynamodb_backends[account_id][region_name].delete_table(name=resource_name)
def _generate_arn(self, name: str) -> str:
return f"arn:aws:dynamodb:{self.region_name}:{self.account_id}:table/{name}"
def set_stream_specification(self, streams: Optional[Dict[str, Any]]) -> None:
self.stream_specification = streams
if (
self.stream_specification
and streams
and (streams.get("StreamEnabled") or streams.get("StreamViewType"))
):
self.stream_specification["StreamEnabled"] = True
self.latest_stream_label = datetime.datetime.utcnow().isoformat()
self.stream_shard = StreamShard(self.account_id, self)
else:
self.stream_specification = {"StreamEnabled": False}
def describe(self, base_key: str = "TableDescription") -> Dict[str, Any]:
results: Dict[str, Any] = {
base_key: {
"AttributeDefinitions": self.attr,
"ProvisionedThroughput": self.throughput,
"BillingModeSummary": {"BillingMode": self.billing_mode},
"TableSizeBytes": 0,
"TableName": self.name,
"TableStatus": "ACTIVE",
"TableArn": self.table_arn,
"KeySchema": self.schema,
"ItemCount": len(self),
"CreationDateTime": unix_time(self.created_at),
"GlobalSecondaryIndexes": [
index.describe() for index in self.global_indexes
],
"LocalSecondaryIndexes": [index.describe() for index in self.indexes],
}
}
if self.latest_stream_label:
results[base_key]["LatestStreamLabel"] = self.latest_stream_label
results[base_key][
"LatestStreamArn"
] = f"{self.table_arn}/stream/{self.latest_stream_label}"
if self.stream_specification and self.stream_specification["StreamEnabled"]:
results[base_key]["StreamSpecification"] = self.stream_specification
if self.sse_specification and self.sse_specification.get("Enabled") is True:
results[base_key]["SSEDescription"] = {
"Status": "ENABLED",
"SSEType": "KMS",
"KMSMasterKeyArn": self.sse_specification.get("KMSMasterKeyId"),
}
return results
def __len__(self) -> int:
return sum(
[(len(value) if self.has_range_key else 1) for value in self.items.values()]
)
@property
def hash_key_names(self) -> List[str]:
keys = [self.hash_key_attr]
for index in self.global_indexes:
for key in index.schema:
if key["KeyType"] == "HASH":
keys.append(key["AttributeName"])
return keys
@property
def range_key_names(self) -> List[str]:
keys = [self.range_key_attr]
for index in self.global_indexes:
for key in index.schema:
if key["KeyType"] == "RANGE":
keys.append(key["AttributeName"])
return keys # type: ignore[return-value]
def _validate_key_sizes(self, item_attrs: Dict[str, Any]) -> None:
for hash_name in self.hash_key_names:
hash_value = item_attrs.get(hash_name)
if hash_value:
if DynamoType(hash_value).size() > HASH_KEY_MAX_LENGTH:
raise HashKeyTooLong
for range_name in self.range_key_names:
range_value = item_attrs.get(range_name)
if range_value:
if DynamoType(range_value).size() > RANGE_KEY_MAX_LENGTH:
raise RangeKeyTooLong
def _validate_item_types(self, item_attrs: Dict[str, Any]) -> None:
for key, value in item_attrs.items():
if type(value) == dict:
self._validate_item_types(value)
elif type(value) == int and key == "N":
raise InvalidConversion
if key == "S":
# This scenario is usually caught by boto3, but the user can disable parameter validation
# Which is why we need to catch it 'server-side' as well
if type(value) == int:
raise SerializationException(
"NUMBER_VALUE cannot be converted to String"
)
if type(value) == dict:
raise SerializationException(
"Start of structure or map found where not expected"
)
def put_item(
self,
item_attrs: Dict[str, Any],
expected: Optional[Dict[str, Any]] = None,
condition_expression: Optional[str] = None,
expression_attribute_names: Optional[Dict[str, str]] = None,
expression_attribute_values: Optional[Dict[str, Any]] = None,
overwrite: bool = False,
) -> Item:
if self.hash_key_attr not in item_attrs.keys():
raise MockValidationException(
"One or more parameter values were invalid: Missing the key "
+ self.hash_key_attr
+ " in the item"
)
hash_value = DynamoType(item_attrs[self.hash_key_attr])
if self.range_key_attr is not None:
if self.range_key_attr not in item_attrs.keys():
raise MockValidationException(
f"One or more parameter values were invalid: Missing the key {self.range_key_attr} in the item"
)
range_value = DynamoType(item_attrs[self.range_key_attr])
else:
range_value = None
if hash_value.type != self.hash_key_type:
raise InvalidAttributeTypeError(
self.hash_key_attr,
expected_type=self.hash_key_type,
actual_type=hash_value.type,
)
if range_value and range_value.type != self.range_key_type:
raise InvalidAttributeTypeError(
self.range_key_attr,
expected_type=self.range_key_type,
actual_type=range_value.type,
)
self._validate_item_types(item_attrs)
self._validate_key_sizes(item_attrs)
if expected is None:
expected = {}
lookup_range_value = range_value
else:
expected_range_value = expected.get(self.range_key_attr, {}).get("Value") # type: ignore
if expected_range_value is None:
lookup_range_value = range_value
else:
lookup_range_value = DynamoType(expected_range_value)
current = self.get_item(hash_value, lookup_range_value)
item = Item(hash_value, range_value, item_attrs)
if not overwrite:
if not get_expected(expected).expr(current):
raise ConditionalCheckFailed
condition_op = get_filter_expression(
condition_expression,
expression_attribute_names,
expression_attribute_values,
)
if not condition_op.expr(current):
raise ConditionalCheckFailed
if range_value:
self.items[hash_value][range_value] = item
else:
self.items[hash_value] = item # type: ignore[assignment]
if self.stream_shard is not None:
self.stream_shard.add(current, item)
return item
def __nonzero__(self) -> bool:
return True
def __bool__(self) -> bool:
return self.__nonzero__()
@property
def has_range_key(self) -> bool:
return self.range_key_attr is not None
def get_item(
self,
hash_key: DynamoType,
range_key: Optional[DynamoType] = None,
projection_expression: Optional[str] = None,
) -> Optional[Item]:
if self.has_range_key and not range_key:
raise MockValidationException(
"Table has a range key, but no range key was passed into get_item"
)
try:
result = None
if range_key:
result = self.items[hash_key][range_key]
elif hash_key in self.items:
result = self.items[hash_key]
if projection_expression and result:
result = copy.deepcopy(result)
result.filter(projection_expression)
if not result:
raise KeyError
return result
except KeyError:
return None
def delete_item(
self, hash_key: DynamoType, range_key: Optional[DynamoType]
) -> Optional[Item]:
try:
if range_key:
item = self.items[hash_key].pop(range_key)
else:
item = self.items.pop(hash_key)
if self.stream_shard is not None:
self.stream_shard.add(item, None)
return item
except KeyError:
return None
def query(
self,
hash_key: DynamoType,
range_comparison: Optional[str],
range_objs: List[DynamoType],
limit: int,
exclusive_start_key: Dict[str, Any],
scan_index_forward: bool,
projection_expression: str,
index_name: Optional[str] = None,
filter_expression: Any = None,
**filter_kwargs: Any,
) -> Tuple[List[Item], int, Optional[Dict[str, Any]]]:
results = []
if index_name:
all_indexes = self.all_indexes()
indexes_by_name = dict((i.name, i) for i in all_indexes)
if index_name not in indexes_by_name:
all_names = ", ".join(indexes_by_name.keys())
raise MockValidationException(
f"Invalid index: {index_name} for table: {self.name}. Available indexes are: {all_names}"
)
index = indexes_by_name[index_name]
try:
index_hash_key = [
key for key in index.schema if key["KeyType"] == "HASH"
][0]
except IndexError:
raise MockValidationException(
f"Missing Hash Key. KeySchema: {index.name}"
)
try:
index_range_key = [
key for key in index.schema if key["KeyType"] == "RANGE"
][0]
except IndexError:
index_range_key = None
possible_results = []
for item in self.all_items():
if not isinstance(item, Item):
continue
item_hash_key = item.attrs.get(index_hash_key["AttributeName"])
if index_range_key is None:
if item_hash_key and item_hash_key == hash_key:
possible_results.append(item)
else:
item_range_key = item.attrs.get(index_range_key["AttributeName"])
if item_hash_key and item_hash_key == hash_key and item_range_key:
possible_results.append(item)
else:
possible_results = [
item
for item in list(self.all_items())
if isinstance(item, Item) and item.hash_key == hash_key
]
if range_comparison:
if index_name and not index_range_key:
raise ValueError(
"Range Key comparison but no range key found for index: %s"
% index_name
)
elif index_name:
for result in possible_results:
if result.attrs.get(index_range_key["AttributeName"]).compare( # type: ignore
range_comparison, range_objs
):
results.append(result)
else:
for result in possible_results:
if result.range_key.compare(range_comparison, range_objs): # type: ignore[union-attr]
results.append(result)
if filter_kwargs:
for result in possible_results:
for field, value in filter_kwargs.items():
dynamo_types = [
DynamoType(ele) for ele in value["AttributeValueList"]
]
if result.attrs.get(field).compare( # type: ignore[union-attr]
value["ComparisonOperator"], dynamo_types
):
results.append(result)
if not range_comparison and not filter_kwargs:
# If we're not filtering on range key or on an index return all
# values
results = possible_results
if index_name:
if index_range_key:
# Convert to float if necessary to ensure proper ordering
def conv(x: DynamoType) -> Any:
return float(x.value) if x.type == "N" else x.value
results.sort(
key=lambda item: conv(item.attrs[index_range_key["AttributeName"]]) # type: ignore
if item.attrs.get(index_range_key["AttributeName"]) # type: ignore
else None
)
else:
results.sort(key=lambda item: item.range_key) # type: ignore
if scan_index_forward is False:
results.reverse()
scanned_count = len(list(self.all_items()))
results = copy.deepcopy(results)
if index_name:
index = self.get_index(index_name)
for result in results:
index.project(result)
results, last_evaluated_key = self._trim_results(
results, limit, exclusive_start_key, scanned_index=index_name
)
if filter_expression is not None:
results = [item for item in results if filter_expression.expr(item)]
if projection_expression:
for result in results:
result.filter(projection_expression)
return results, scanned_count, last_evaluated_key
def all_items(self) -> Iterator[Item]:
for hash_set in self.items.values():
if self.range_key_attr:
for item in hash_set.values():
yield item
else:
yield hash_set # type: ignore
def all_indexes(self) -> Sequence[SecondaryIndex]:
return (self.global_indexes or []) + (self.indexes or []) # type: ignore
def get_index(self, index_name: str, error_if_not: bool = False) -> SecondaryIndex:
all_indexes = self.all_indexes()
indexes_by_name = dict((i.name, i) for i in all_indexes)
if error_if_not and index_name not in indexes_by_name:
raise InvalidIndexNameError(
f"The table does not have the specified index: {index_name}"
)
return indexes_by_name[index_name]
def has_idx_items(self, index_name: str) -> Iterator[Item]:
idx = self.get_index(index_name)
idx_col_set = set([i["AttributeName"] for i in idx.schema])
for hash_set in self.items.values():
if self.range_key_attr:
for item in hash_set.values():
if idx_col_set.issubset(set(item.attrs)):
yield item
else:
if idx_col_set.issubset(set(hash_set.attrs)): # type: ignore
yield hash_set # type: ignore
def scan(
self,
filters: Dict[str, Any],
limit: int,
exclusive_start_key: Dict[str, Any],
filter_expression: Any = None,
index_name: Optional[str] = None,
projection_expression: Optional[str] = None,
) -> Tuple[List[Item], int, Optional[Dict[str, Any]]]:
results = []
scanned_count = 0
if index_name:
self.get_index(index_name, error_if_not=True)
items = self.has_idx_items(index_name)
else:
items = self.all_items()
for item in items:
scanned_count += 1
passes_all_conditions = True
for (
attribute_name,
(comparison_operator, comparison_objs),
) in filters.items():
attribute = item.attrs.get(attribute_name)
if attribute:
# Attribute found
if not attribute.compare(comparison_operator, comparison_objs):
passes_all_conditions = False
break
elif comparison_operator == "NULL":
# Comparison is NULL and we don't have the attribute
continue
else:
# No attribute found and comparison is no NULL. This item
# fails
passes_all_conditions = False
break
if passes_all_conditions:
results.append(item)
results, last_evaluated_key = self._trim_results(
results, limit, exclusive_start_key, scanned_index=index_name
)
if filter_expression is not None:
results = [item for item in results if filter_expression.expr(item)]
if projection_expression:
results = copy.deepcopy(results)
for result in results:
result.filter(projection_expression)
return results, scanned_count, last_evaluated_key
def _trim_results(
self,
results: List[Item],
limit: int,
exclusive_start_key: Optional[Dict[str, Any]],
scanned_index: Optional[str] = None,
) -> Tuple[List[Item], Optional[Dict[str, Any]]]:
if exclusive_start_key is not None:
hash_key = DynamoType(exclusive_start_key.get(self.hash_key_attr)) # type: ignore[arg-type]
range_key = (
exclusive_start_key.get(self.range_key_attr)
if self.range_key_attr
else None
)
if range_key is not None:
range_key = DynamoType(range_key)
for i in range(len(results)):
if (
results[i].hash_key == hash_key
and results[i].range_key == range_key
):
results = results[i + 1 :]
break
last_evaluated_key = None
size_limit = 1000000 # DynamoDB has a 1MB size limit
item_size = sum(res.size() for res in results)
if item_size > size_limit:
item_size = idx = 0
while item_size + results[idx].size() < size_limit:
item_size += results[idx].size()
idx += 1
limit = min(limit, idx) if limit else idx
if limit and len(results) > limit:
results = results[:limit]
last_evaluated_key = {self.hash_key_attr: results[-1].hash_key}
if self.range_key_attr is not None and results[-1].range_key is not None:
last_evaluated_key[self.range_key_attr] = results[-1].range_key
if scanned_index:
index = self.get_index(scanned_index)
idx_col_list = [i["AttributeName"] for i in index.schema]
for col in idx_col_list:
last_evaluated_key[col] = results[-1].attrs[col]
return results, last_evaluated_key
def delete(self, account_id: str, region_name: str) -> None:
from moto.dynamodb.models import dynamodb_backends
dynamodb_backends[account_id][region_name].delete_table(self.name)
class Backup:
def __init__(
self,
account_id: str,
region_name: str,
name: str,
table: Table,
status: Optional[str] = None,
type_: Optional[str] = None,
):
self.region_name = region_name
self.account_id = account_id
self.name = name
self.table = copy.deepcopy(table)
self.status = status or "AVAILABLE"
self.type = type_ or "USER"
self.creation_date_time = datetime.datetime.utcnow()
self.identifier = self._make_identifier()
def _make_identifier(self) -> str:
timestamp = int(unix_time_millis(self.creation_date_time))
timestamp_padded = str("0" + str(timestamp))[-16:16]
guid = str(mock_random.uuid4())
guid_shortened = guid[:8]
return f"{timestamp_padded}-{guid_shortened}"
@property
def arn(self) -> str:
return f"arn:aws:dynamodb:{self.region_name}:{self.account_id}:table/{self.table.name}/backup/{self.identifier}"
@property
def details(self) -> Dict[str, Any]: # type: ignore[misc]
return {
"BackupArn": self.arn,
"BackupName": self.name,
"BackupSizeBytes": 123,
"BackupStatus": self.status,
"BackupType": self.type,
"BackupCreationDateTime": unix_time(self.creation_date_time),
}
@property
def summary(self) -> Dict[str, Any]: # type: ignore[misc]
return {
"TableName": self.table.name,
# 'TableId': 'string',
"TableArn": self.table.table_arn,
"BackupArn": self.arn,
"BackupName": self.name,
"BackupCreationDateTime": unix_time(self.creation_date_time),
# 'BackupExpiryDateTime': datetime(2015, 1, 1),
"BackupStatus": self.status,
"BackupType": self.type,
"BackupSizeBytes": 123,
}
@property
def description(self) -> Dict[str, Any]: # type: ignore[misc]
source_table_details = self.table.describe()["TableDescription"]
source_table_details["TableCreationDateTime"] = source_table_details[
"CreationDateTime"
]
description = {
"BackupDetails": self.details,
"SourceTableDetails": source_table_details,
}
return description
class RestoredTable(Table):
def __init__(self, name: str, account_id: str, region: str, backup: "Backup"):
params = self._parse_params_from_backup(backup)
super().__init__(name, account_id=account_id, region=region, **params)
self.indexes = copy.deepcopy(backup.table.indexes)
self.global_indexes = copy.deepcopy(backup.table.global_indexes)
self.items = copy.deepcopy(backup.table.items)
# Restore Attrs
self.source_backup_arn = backup.arn
self.source_table_arn = backup.table.table_arn
self.restore_date_time = self.created_at
def _parse_params_from_backup(self, backup: "Backup") -> Dict[str, Any]:
return {
"schema": copy.deepcopy(backup.table.schema),
"attr": copy.deepcopy(backup.table.attr),
"throughput": copy.deepcopy(backup.table.throughput),
}
def describe(self, base_key: str = "TableDescription") -> Dict[str, Any]:
result = super().describe(base_key=base_key)
result[base_key]["RestoreSummary"] = {
"SourceBackupArn": self.source_backup_arn,
"SourceTableArn": self.source_table_arn,
"RestoreDateTime": unix_time(self.restore_date_time),
"RestoreInProgress": False,
}
return result
class RestoredPITTable(Table):
def __init__(self, name: str, account_id: str, region: str, source: Table):
params = self._parse_params_from_table(source)
super().__init__(name, account_id=account_id, region=region, **params)
self.indexes = copy.deepcopy(source.indexes)
self.global_indexes = copy.deepcopy(source.global_indexes)
self.items = copy.deepcopy(source.items)
# Restore Attrs
self.source_table_arn = source.table_arn
self.restore_date_time = self.created_at
def _parse_params_from_table(self, table: Table) -> Dict[str, Any]:
return {
"schema": copy.deepcopy(table.schema),
"attr": copy.deepcopy(table.attr),
"throughput": copy.deepcopy(table.throughput),
}
def describe(self, base_key: str = "TableDescription") -> Dict[str, Any]:
result = super().describe(base_key=base_key)
result[base_key]["RestoreSummary"] = {
"SourceTableArn": self.source_table_arn,
"RestoreDateTime": unix_time(self.restore_date_time),
"RestoreInProgress": False,
}
return result