moto/moto/dynamodb/models/__init__.py

1872 lines
67 KiB
Python
Raw Normal View History

2018-11-08 11:08:24 -05:00
from collections import defaultdict
import copy
2013-12-05 13:16:56 +02:00
import datetime
2016-01-14 15:44:28 -07:00
import decimal
2013-12-05 13:16:56 +02:00
import json
import re
import uuid
2013-12-05 13:16:56 +02:00
from collections import OrderedDict
from moto.core import BaseBackend, BaseModel, CloudFormationModel
2021-12-24 20:02:45 -01:00
from moto.core.utils import unix_time, unix_time_millis, BackendDict
2017-10-29 16:06:09 +00:00
from moto.core.exceptions import JsonRESTError
2022-03-09 16:57:25 -01:00
from moto.dynamodb.comparisons import get_filter_expression
from moto.dynamodb.comparisons import get_expected
from moto.dynamodb.exceptions import (
Enable AST Validation This commit puts AST validation on the execution path. This means updates get validated prior to being executed. There were quite a few tests that were not working against Amazon DDB. These tests I considered broken and as such this commit adapts them such that they pass against Amazon DDB. test_update_item_on_map() => One of the SET actions would try to set a nested element by specifying the nesting on the path rather than by putting a map as a value for a non-existent key. This got changed. test_item_size_is_under_400KB => Used the keyword "item" which DDB doesn't like. Change to cont in order to keep the same sizings. => Secondly the size error messages differs a bit depending whether it is part of the update or part of a put_item. For an update it should be: Item size to update has exceeded the maximum allowed size otherwise it is Item size has exceeded the maximum allowed size' test_remove_top_level_attribute => Used a keyword item. Use ExpressionAttributeNames test_update_item_double_nested_remove => Used keywords name & first. Migrated to non-deprecated API and use ExpressionAttributeNames test_update_item_set & test_boto3_update_item_conditions_pass & test_boto3_update_item_conditions_pass_because_expect_not_exists & test_boto3_update_item_conditions_pass_because_expect_not_exists_by_compare_to_null & test_boto3_update_item_conditions_pass_because_expect_exists_by_compare_to_not_null & test_boto3_update_item_conditions_fail & test_boto3_update_item_conditions_fail_because_expect_not_exists & test_boto3_update_item_conditions_fail_because_expect_not_exists_by_compare_to_null => Were broken tests which had string literal instead of value placeholder
2020-04-19 16:50:53 +01:00
InvalidIndexNameError,
ItemSizeTooLarge,
ItemSizeToUpdateTooLarge,
HashKeyTooLong,
RangeKeyTooLong,
ConditionalCheckFailed,
TransactionCanceledException,
EmptyKeyAttributeException,
InvalidAttributeTypeError,
MultipleTransactionsException,
2022-02-10 19:09:45 -01:00
TooManyTransactionsException,
TableNotFoundException,
ResourceNotFoundException,
SourceTableNotFoundException,
TableAlreadyExistsException,
BackupNotFoundException,
ResourceInUseException,
StreamAlreadyEnabledException,
MockValidationException,
InvalidConversion,
Enable AST Validation This commit puts AST validation on the execution path. This means updates get validated prior to being executed. There were quite a few tests that were not working against Amazon DDB. These tests I considered broken and as such this commit adapts them such that they pass against Amazon DDB. test_update_item_on_map() => One of the SET actions would try to set a nested element by specifying the nesting on the path rather than by putting a map as a value for a non-existent key. This got changed. test_item_size_is_under_400KB => Used the keyword "item" which DDB doesn't like. Change to cont in order to keep the same sizings. => Secondly the size error messages differs a bit depending whether it is part of the update or part of a put_item. For an update it should be: Item size to update has exceeded the maximum allowed size otherwise it is Item size has exceeded the maximum allowed size' test_remove_top_level_attribute => Used a keyword item. Use ExpressionAttributeNames test_update_item_double_nested_remove => Used keywords name & first. Migrated to non-deprecated API and use ExpressionAttributeNames test_update_item_set & test_boto3_update_item_conditions_pass & test_boto3_update_item_conditions_pass_because_expect_not_exists & test_boto3_update_item_conditions_pass_because_expect_not_exists_by_compare_to_null & test_boto3_update_item_conditions_pass_because_expect_exists_by_compare_to_not_null & test_boto3_update_item_conditions_fail & test_boto3_update_item_conditions_fail_because_expect_not_exists & test_boto3_update_item_conditions_fail_because_expect_not_exists_by_compare_to_null => Were broken tests which had string literal instead of value placeholder
2020-04-19 16:50:53 +01:00
)
2022-03-09 16:57:25 -01:00
from moto.dynamodb.models.utilities import bytesize
from moto.dynamodb.models.dynamo_type import DynamoType
from moto.dynamodb.parsing.executors import UpdateExpressionExecutor
from moto.dynamodb.parsing.expressions import UpdateExpressionParser
from moto.dynamodb.parsing.validators import UpdateExpressionValidator
from moto.dynamodb.limits import HASH_KEY_MAX_LENGTH, RANGE_KEY_MAX_LENGTH
2013-12-05 13:16:56 +02:00
class DynamoJsonEncoder(json.JSONEncoder):
def default(self, o):
if hasattr(o, "to_json"):
return o.to_json()
2013-12-05 13:16:56 +02:00
def dynamo_json_dump(dynamo_object):
return json.dumps(dynamo_object, cls=DynamoJsonEncoder)
# https://github.com/spulec/moto/issues/1874
# Ensure that the total size of an item does not exceed 400kb
class LimitedSizeDict(dict):
def __init__(self, *args, **kwargs):
self.update(*args, **kwargs)
def __setitem__(self, key, value):
2019-10-31 08:44:26 -07:00
current_item_size = sum(
[
item.size() if type(item) == DynamoType else bytesize(str(item))
for item in (list(self.keys()) + list(self.values()))
]
)
new_item_size = bytesize(key) + (
value.size() if type(value) == DynamoType else bytesize(str(value))
)
# Official limit is set to 400000 (400KB)
# Manual testing confirms that the actual limit is between 409 and 410KB
# We'll set the limit to something in between to be safe
if (current_item_size + new_item_size) > 405000:
raise ItemSizeTooLarge
super().__setitem__(key, value)
2017-03-11 23:41:12 -05:00
class Item(BaseModel):
def __init__(self, hash_key, range_key, attrs):
2013-12-05 13:16:56 +02:00
self.hash_key = hash_key
self.range_key = range_key
self.attrs = LimitedSizeDict()
2014-08-26 13:25:50 -04:00
for key, value in attrs.items():
2013-12-05 13:16:56 +02:00
self.attrs[key] = DynamoType(value)
def __eq__(self, other):
return all(
[
self.hash_key == other.hash_key,
self.range_key == other.range_key,
self.attrs == other.attrs,
]
)
2013-12-05 13:16:56 +02:00
def __repr__(self):
return "Item: {0}".format(self.to_json())
2020-03-21 12:20:09 +00:00
def size(self):
return sum(bytesize(key) + value.size() for key, value in self.attrs.items())
2020-03-21 12:20:09 +00:00
2013-12-05 13:16:56 +02:00
def to_json(self):
attributes = {}
2014-08-26 13:25:50 -04:00
for attribute_key, attribute in self.attrs.items():
2019-10-31 08:44:26 -07:00
attributes[attribute_key] = {attribute.type: attribute.value}
2013-12-05 13:16:56 +02:00
2019-10-31 08:44:26 -07:00
return {"Attributes": attributes}
2013-12-05 13:16:56 +02:00
def describe_attrs(self, attributes):
if attributes:
included = {}
2014-08-26 13:25:50 -04:00
for key, value in self.attrs.items():
2013-12-05 13:16:56 +02:00
if key in attributes:
included[key] = value
else:
included = self.attrs
2019-10-31 08:44:26 -07:00
return {"Item": included}
2013-12-05 13:16:56 +02:00
def validate_no_empty_key_values(self, attribute_updates, key_attributes):
for attribute_name, update_action in attribute_updates.items():
action = update_action.get("Action") or "PUT" # PUT is default
if action == "DELETE":
continue
new_value = next(iter(update_action["Value"].values()))
if action == "PUT" and new_value == "" and attribute_name in key_attributes:
raise EmptyKeyAttributeException
def update_with_attribute_updates(self, attribute_updates):
for attribute_name, update_action in attribute_updates.items():
# Use default Action value, if no explicit Action is passed.
# Default value is 'Put', according to
# Boto3 DynamoDB.Client.update_item documentation.
action = update_action.get("Action", "PUT")
2019-10-31 08:44:26 -07:00
if action == "DELETE" and "Value" not in update_action:
if attribute_name in self.attrs:
del self.attrs[attribute_name]
2016-01-04 16:29:02 -07:00
continue
2019-10-31 08:44:26 -07:00
new_value = list(update_action["Value"].values())[0]
if action == "PUT":
# TODO deal with other types
if set(update_action["Value"].keys()) == set(["SS"]):
self.attrs[attribute_name] = DynamoType({"SS": new_value})
elif isinstance(new_value, list):
self.attrs[attribute_name] = DynamoType({"L": new_value})
elif isinstance(new_value, dict):
self.attrs[attribute_name] = DynamoType({"M": new_value})
2019-10-31 08:44:26 -07:00
elif set(update_action["Value"].keys()) == set(["N"]):
self.attrs[attribute_name] = DynamoType({"N": new_value})
2019-10-31 08:44:26 -07:00
elif set(update_action["Value"].keys()) == set(["NULL"]):
if attribute_name in self.attrs:
del self.attrs[attribute_name]
else:
self.attrs[attribute_name] = DynamoType({"S": new_value})
2019-10-31 08:44:26 -07:00
elif action == "ADD":
if set(update_action["Value"].keys()) == set(["N"]):
existing = self.attrs.get(attribute_name, DynamoType({"N": "0"}))
self.attrs[attribute_name] = DynamoType(
{
"N": str(
decimal.Decimal(existing.value)
+ decimal.Decimal(new_value)
)
}
)
elif set(update_action["Value"].keys()) == set(["SS"]):
existing = self.attrs.get(attribute_name, DynamoType({"SS": {}}))
new_set = set(existing.value).union(set(new_value))
2019-10-31 08:44:26 -07:00
self.attrs[attribute_name] = DynamoType({"SS": list(new_set)})
elif set(update_action["Value"].keys()) == {"L"}:
existing = self.attrs.get(attribute_name, DynamoType({"L": []}))
new_list = existing.value + new_value
self.attrs[attribute_name] = DynamoType({"L": new_list})
2016-01-14 15:44:28 -07:00
else:
# TODO: implement other data types
2017-02-23 21:37:43 -05:00
raise NotImplementedError(
2019-10-31 08:44:26 -07:00
"ADD not supported for %s"
% ", ".join(update_action["Value"].keys())
)
elif action == "DELETE":
if set(update_action["Value"].keys()) == set(["SS"]):
existing = self.attrs.get(attribute_name, DynamoType({"SS": {}}))
new_set = set(existing.value).difference(set(new_value))
2019-10-31 08:44:26 -07:00
self.attrs[attribute_name] = DynamoType({"SS": list(new_set)})
else:
raise NotImplementedError(
2019-10-31 08:44:26 -07:00
"ADD not supported for %s"
% ", ".join(update_action["Value"].keys())
)
else:
raise NotImplementedError(
2019-10-31 08:44:26 -07:00
"%s action not support for update_with_attribute_updates" % action
)
# Filter using projection_expression
# Ensure a deep copy is used to filter, otherwise actual data will be removed
def filter(self, projection_expression):
2019-11-03 07:33:27 -08:00
expressions = [x.strip() for x in projection_expression.split(",")]
top_level_expressions = [
expr[0 : expr.index(".")] for expr in expressions if "." in expr
]
for attr in list(self.attrs):
if attr not in expressions and attr not in top_level_expressions:
self.attrs.pop(attr)
if attr in top_level_expressions:
2019-11-03 07:33:27 -08:00
relevant_expressions = [
expr[len(attr + ".") :]
for expr in expressions
if expr.startswith(attr + ".")
]
self.attrs[attr].filter(relevant_expressions)
2014-11-15 09:35:52 -05:00
class StreamRecord(BaseModel):
def __init__(self, table, stream_type, event_name, old, new, seq):
2019-10-31 08:44:26 -07:00
old_a = old.to_json()["Attributes"] if old is not None else {}
new_a = new.to_json()["Attributes"] if new is not None else {}
rec = old if old is not None else new
keys = {table.hash_key_attr: rec.hash_key.to_json()}
if table.range_key_attr is not None:
keys[table.range_key_attr] = rec.range_key.to_json()
2018-11-08 11:08:24 -05:00
self.record = {
2019-10-31 08:44:26 -07:00
"eventID": uuid.uuid4().hex,
"eventName": event_name,
"eventSource": "aws:dynamodb",
"eventVersion": "1.0",
"awsRegion": "us-east-1",
"dynamodb": {
"StreamViewType": stream_type,
"ApproximateCreationDateTime": datetime.datetime.utcnow().isoformat(),
"SequenceNumber": str(seq),
"SizeBytes": 1,
"Keys": keys,
},
}
2018-11-08 11:08:24 -05:00
2019-10-31 08:44:26 -07:00
if stream_type in ("NEW_IMAGE", "NEW_AND_OLD_IMAGES"):
self.record["dynamodb"]["NewImage"] = new_a
if stream_type in ("OLD_IMAGE", "NEW_AND_OLD_IMAGES"):
self.record["dynamodb"]["OldImage"] = old_a
# This is a substantial overestimate but it's the easiest to do now
2019-12-20 11:30:36 -08:00
self.record["dynamodb"]["SizeBytes"] = len(
dynamo_json_dump(self.record["dynamodb"])
)
def to_json(self):
return self.record
class StreamShard(BaseModel):
2022-08-13 09:49:43 +00:00
def __init__(self, account_id, table):
self.account_id = account_id
self.table = table
2019-10-31 08:44:26 -07:00
self.id = "shardId-00000001541626099285-f35f62ef"
self.starting_sequence_number = 1100000000017454423009
self.items = []
self.created_on = datetime.datetime.utcnow()
def to_json(self):
return {
2019-10-31 08:44:26 -07:00
"ShardId": self.id,
"SequenceNumberRange": {
"StartingSequenceNumber": str(self.starting_sequence_number)
},
}
def add(self, old, new):
2019-10-31 08:44:26 -07:00
t = self.table.stream_specification["StreamViewType"]
if old is None:
2019-10-31 08:44:26 -07:00
event_name = "INSERT"
elif new is None:
event_name = "REMOVE"
else:
2019-10-31 08:44:26 -07:00
event_name = "MODIFY"
seq = len(self.items) + self.starting_sequence_number
2019-10-31 08:44:26 -07:00
self.items.append(StreamRecord(self.table, t, event_name, old, new, seq))
result = None
from moto.awslambda import lambda_backends
2019-10-31 08:44:26 -07:00
for arn, esm in self.table.lambda_event_source_mappings.items():
2019-10-31 08:44:26 -07:00
region = arn[
len("arn:aws:lambda:") : arn.index(":", len("arn:aws:lambda:"))
]
2022-08-13 09:49:43 +00:00
result = lambda_backends[self.account_id][region].send_dynamodb_items(
2019-10-31 08:44:26 -07:00
arn, self.items, esm.event_source_arn
)
if result:
self.items = []
def get(self, start, quantity):
start -= self.starting_sequence_number
assert start >= 0
end = start + quantity
return [i.to_json() for i in self.items[start:end]]
2018-11-08 11:08:24 -05:00
class SecondaryIndex(BaseModel):
def project(self, item):
"""
Enforces the ProjectionType of this Index (LSI/GSI)
Removes any non-wanted attributes from the item
:param item:
:return:
"""
if self.projection:
projection_type = self.projection.get("ProjectionType", None)
key_attributes = self.table_key_attrs + [
key["AttributeName"] for key in self.schema
]
if projection_type == "KEYS_ONLY":
item.filter(",".join(key_attributes))
elif projection_type == "INCLUDE":
allowed_attributes = key_attributes + self.projection.get(
"NonKeyAttributes", []
)
item.filter(",".join(allowed_attributes))
# ALL is handled implicitly by not filtering
return item
class LocalSecondaryIndex(SecondaryIndex):
def __init__(self, index_name, schema, projection, table_key_attrs):
2020-06-13 20:27:05 +01:00
self.name = index_name
self.schema = schema
self.projection = projection
self.table_key_attrs = table_key_attrs
2020-06-13 20:27:05 +01:00
def describe(self):
return {
"IndexName": self.name,
"KeySchema": self.schema,
"Projection": self.projection,
}
@staticmethod
def create(dct, table_key_attrs):
2020-06-13 20:27:05 +01:00
return LocalSecondaryIndex(
index_name=dct["IndexName"],
schema=dct["KeySchema"],
projection=dct["Projection"],
table_key_attrs=table_key_attrs,
2020-06-13 20:27:05 +01:00
)
class GlobalSecondaryIndex(SecondaryIndex):
2020-06-13 20:27:05 +01:00
def __init__(
self,
index_name,
schema,
projection,
table_key_attrs,
status="ACTIVE",
throughput=None,
2020-06-13 20:27:05 +01:00
):
self.name = index_name
self.schema = schema
self.projection = projection
self.table_key_attrs = table_key_attrs
2020-06-13 20:27:05 +01:00
self.status = status
self.throughput = throughput or {
"ReadCapacityUnits": 0,
"WriteCapacityUnits": 0,
}
def describe(self):
return {
"IndexName": self.name,
"KeySchema": self.schema,
"Projection": self.projection,
"IndexStatus": self.status,
"ProvisionedThroughput": self.throughput,
}
@staticmethod
def create(dct, table_key_attrs):
2020-06-13 20:27:05 +01:00
return GlobalSecondaryIndex(
index_name=dct["IndexName"],
schema=dct["KeySchema"],
projection=dct["Projection"],
table_key_attrs=table_key_attrs,
2020-06-13 20:27:05 +01:00
throughput=dct.get("ProvisionedThroughput", None),
)
def update(self, u):
self.name = u.get("IndexName", self.name)
self.schema = u.get("KeySchema", self.schema)
self.projection = u.get("Projection", self.projection)
self.throughput = u.get("ProvisionedThroughput", self.throughput)
class Table(CloudFormationModel):
2019-10-31 08:44:26 -07:00
def __init__(
self,
table_name,
2022-08-13 09:49:43 +00:00
account_id,
2022-02-10 19:09:45 -01:00
region,
2019-10-31 08:44:26 -07:00
schema=None,
attr=None,
throughput=None,
2022-02-10 19:09:45 -01:00
billing_mode=None,
2019-10-31 08:44:26 -07:00
indexes=None,
global_indexes=None,
streams=None,
2022-02-10 19:09:45 -01:00
sse_specification=None,
tags=None,
2019-10-31 08:44:26 -07:00
):
2013-12-05 13:16:56 +02:00
self.name = table_name
2022-08-13 09:49:43 +00:00
self.account_id = account_id
2013-12-05 13:16:56 +02:00
self.attr = attr
self.schema = schema
self.range_key_attr = None
self.hash_key_attr = None
self.range_key_type = None
self.hash_key_type = None
for elem in schema:
attr_type = [
a["AttributeType"]
for a in attr
if a["AttributeName"] == elem["AttributeName"]
][0]
2013-12-05 13:16:56 +02:00
if elem["KeyType"] == "HASH":
self.hash_key_attr = elem["AttributeName"]
self.hash_key_type = attr_type
2013-12-05 13:16:56 +02:00
else:
self.range_key_attr = elem["AttributeName"]
self.range_key_type = attr_type
self.table_key_attrs = [
key for key in (self.hash_key_attr, self.range_key_attr) if key
]
2022-02-10 19:09:45 -01:00
self.billing_mode = billing_mode
2013-12-05 13:16:56 +02:00
if throughput is None:
2022-02-10 19:09:45 -01:00
self.throughput = {"WriteCapacityUnits": 0, "ReadCapacityUnits": 0}
2013-12-05 13:16:56 +02:00
else:
self.throughput = throughput
self.throughput["NumberOfDecreasesToday"] = 0
2020-06-13 20:27:05 +01:00
self.indexes = [
LocalSecondaryIndex.create(i, self.table_key_attrs)
for i in (indexes if indexes else [])
2020-06-13 20:27:05 +01:00
]
self.global_indexes = [
GlobalSecondaryIndex.create(i, self.table_key_attrs)
2020-06-13 20:27:05 +01:00
for i in (global_indexes if global_indexes else [])
]
2016-09-07 14:40:52 -04:00
self.created_at = datetime.datetime.utcnow()
2013-12-05 13:16:56 +02:00
self.items = defaultdict(dict)
2017-05-10 21:58:42 -04:00
self.table_arn = self._generate_arn(table_name)
2022-02-10 19:09:45 -01:00
self.tags = tags or []
2017-10-29 16:06:09 +00:00
self.ttl = {
2019-10-31 08:44:26 -07:00
"TimeToLiveStatus": "DISABLED" # One of 'ENABLING'|'DISABLING'|'ENABLED'|'DISABLED',
2017-10-29 16:06:09 +00:00
# 'AttributeName': 'string' # Can contain this
}
2022-02-10 19:09:45 -01:00
self.stream_specification = {"StreamEnabled": False}
self.latest_stream_label = None
self.stream_shard = None
self.set_stream_specification(streams)
self.lambda_event_source_mappings = {}
self.continuous_backups = {
"ContinuousBackupsStatus": "ENABLED", # One of 'ENABLED'|'DISABLED', it's enabled by default
"PointInTimeRecoveryDescription": {
"PointInTimeRecoveryStatus": "DISABLED" # One of 'ENABLED'|'DISABLED'
},
}
2022-02-10 19:09:45 -01:00
self.sse_specification = sse_specification
if sse_specification and "KMSMasterKeyId" not in self.sse_specification:
self.sse_specification["KMSMasterKeyId"] = self._get_default_encryption_key(
2022-08-13 09:49:43 +00:00
account_id, region
2022-02-10 19:09:45 -01:00
)
2022-08-13 09:49:43 +00:00
def _get_default_encryption_key(self, account_id, region):
2022-02-10 19:09:45 -01:00
from moto.kms import kms_backends
# https://aws.amazon.com/kms/features/#AWS_Service_Integration
# An AWS managed CMK is created automatically when you first create
# an encrypted resource using an AWS service integrated with KMS.
2022-08-13 09:49:43 +00:00
kms = kms_backends[account_id][region]
2022-02-10 19:09:45 -01:00
ddb_alias = "alias/aws/dynamodb"
if not kms.alias_exists(ddb_alias):
key = kms.create_key(
policy="",
key_usage="ENCRYPT_DECRYPT",
2022-06-22 05:18:51 +12:00
key_spec="SYMMETRIC_DEFAULT",
2022-02-10 19:09:45 -01:00
description="Default master key that protects my DynamoDB table storage",
tags=None,
)
kms.add_alias(key.id, ddb_alias)
ebs_key = kms.describe_key(ddb_alias)
return ebs_key.arn
2017-05-10 21:58:42 -04:00
@classmethod
def has_cfn_attr(cls, attr):
return attr in ["Arn", "StreamArn"]
def get_cfn_attribute(self, attribute_name):
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == "Arn":
return self.table_arn
elif attribute_name == "StreamArn" and self.stream_specification:
return self.describe()["TableDescription"]["LatestStreamArn"]
raise UnformattedGetAttTemplateException()
@property
def physical_resource_id(self):
return self.name
@property
2022-02-10 19:09:45 -01:00
def attribute_keys(self):
# A set of all the hash or range attributes for all indexes
def keys_from_index(idx):
schema = idx.schema
return [attr["AttributeName"] for attr in schema]
fieldnames = copy.copy(self.table_key_attrs)
for idx in self.indexes + self.global_indexes:
fieldnames += keys_from_index(idx)
return fieldnames
@staticmethod
def cloudformation_name_type():
return "TableName"
@staticmethod
def cloudformation_type():
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-dynamodb-table.html
return "AWS::DynamoDB::Table"
@classmethod
2019-10-31 08:44:26 -07:00
def create_from_cloudformation_json(
2022-08-13 09:49:43 +00:00
cls, resource_name, cloudformation_json, account_id, region_name, **kwargs
2019-10-31 08:44:26 -07:00
):
properties = cloudformation_json["Properties"]
params = {}
2019-10-31 08:44:26 -07:00
if "KeySchema" in properties:
params["schema"] = properties["KeySchema"]
if "AttributeDefinitions" in properties:
params["attr"] = properties["AttributeDefinitions"]
if "GlobalSecondaryIndexes" in properties:
params["global_indexes"] = properties["GlobalSecondaryIndexes"]
if "ProvisionedThroughput" in properties:
params["throughput"] = properties["ProvisionedThroughput"]
if "LocalSecondaryIndexes" in properties:
params["indexes"] = properties["LocalSecondaryIndexes"]
if "StreamSpecification" in properties:
params["streams"] = properties["StreamSpecification"]
2019-10-31 08:44:26 -07:00
2022-08-13 09:49:43 +00:00
table = dynamodb_backends[account_id][region_name].create_table(
name=resource_name, **params
2019-10-31 08:44:26 -07:00
)
return table
@classmethod
def delete_from_cloudformation_json(
2022-08-13 09:49:43 +00:00
cls, resource_name, cloudformation_json, account_id, region_name
):
2022-08-13 09:49:43 +00:00
table = dynamodb_backends[account_id][region_name].delete_table(
name=resource_name
)
return table
2017-05-10 21:58:42 -04:00
def _generate_arn(self, name):
2022-08-13 09:49:43 +00:00
return f"arn:aws:dynamodb:us-east-1:{self.account_id}:table/{name}"
2014-08-26 13:25:50 -04:00
def set_stream_specification(self, streams):
self.stream_specification = streams
2019-10-31 08:44:26 -07:00
if streams and (streams.get("StreamEnabled") or streams.get("StreamViewType")):
self.stream_specification["StreamEnabled"] = True
self.latest_stream_label = datetime.datetime.utcnow().isoformat()
2022-08-13 09:49:43 +00:00
self.stream_shard = StreamShard(self.account_id, self)
else:
2019-10-31 08:44:26 -07:00
self.stream_specification = {"StreamEnabled": False}
2019-10-31 08:44:26 -07:00
def describe(self, base_key="TableDescription"):
2013-12-05 13:16:56 +02:00
results = {
base_key: {
2019-10-31 08:44:26 -07:00
"AttributeDefinitions": self.attr,
"ProvisionedThroughput": self.throughput,
2022-02-10 19:09:45 -01:00
"BillingModeSummary": {"BillingMode": self.billing_mode},
2019-10-31 08:44:26 -07:00
"TableSizeBytes": 0,
"TableName": self.name,
"TableStatus": "ACTIVE",
"TableArn": self.table_arn,
"KeySchema": self.schema,
"ItemCount": len(self),
"CreationDateTime": unix_time(self.created_at),
2020-06-13 20:27:05 +01:00
"GlobalSecondaryIndexes": [
index.describe() for index in self.global_indexes
],
"LocalSecondaryIndexes": [index.describe() for index in self.indexes],
2013-12-05 13:16:56 +02:00
}
}
2022-02-10 19:09:45 -01:00
if self.latest_stream_label:
results[base_key]["LatestStreamLabel"] = self.latest_stream_label
results[base_key][
"LatestStreamArn"
] = f"{self.table_arn}/stream/{self.latest_stream_label}"
2019-10-31 08:44:26 -07:00
if self.stream_specification and self.stream_specification["StreamEnabled"]:
results[base_key]["StreamSpecification"] = self.stream_specification
2022-02-10 19:09:45 -01:00
if self.sse_specification and self.sse_specification.get("Enabled") is True:
results[base_key]["SSEDescription"] = {
"Status": "ENABLED",
"SSEType": "KMS",
"KMSMasterKeyArn": self.sse_specification.get("KMSMasterKeyId"),
}
2013-12-05 13:16:56 +02:00
return results
2014-08-26 13:25:50 -04:00
2013-12-05 13:16:56 +02:00
def __len__(self):
return sum(
[(len(value) if self.has_range_key else 1) for value in self.items.values()]
)
2014-08-26 13:25:50 -04:00
@property
def hash_key_names(self):
keys = [self.hash_key_attr]
for index in self.global_indexes:
hash_key = None
2020-06-13 20:27:05 +01:00
for key in index.schema:
2019-10-31 08:44:26 -07:00
if key["KeyType"] == "HASH":
hash_key = key["AttributeName"]
keys.append(hash_key)
return keys
@property
def range_key_names(self):
keys = [self.range_key_attr]
for index in self.global_indexes:
range_key = None
2020-06-13 20:27:05 +01:00
for key in index.schema:
2019-10-31 08:44:26 -07:00
if key["KeyType"] == "RANGE":
range_key = keys.append(key["AttributeName"])
keys.append(range_key)
return keys
def _validate_key_sizes(self, item_attrs):
for hash_name in self.hash_key_names:
hash_value = item_attrs.get(hash_name)
if hash_value:
if DynamoType(hash_value).size() > HASH_KEY_MAX_LENGTH:
raise HashKeyTooLong
for range_name in self.range_key_names:
range_value = item_attrs.get(range_name)
if range_value:
if DynamoType(range_value).size() > RANGE_KEY_MAX_LENGTH:
raise RangeKeyTooLong
def _validate_item_types(self, item_attrs):
for key, value in item_attrs.items():
if type(value) == dict:
self._validate_item_types(value)
elif type(value) == int and key == "N":
raise InvalidConversion
2019-10-31 08:44:26 -07:00
def put_item(
self,
item_attrs,
expected=None,
condition_expression=None,
expression_attribute_names=None,
expression_attribute_values=None,
overwrite=False,
):
if self.hash_key_attr not in item_attrs.keys():
raise MockValidationException(
"One or more parameter values were invalid: Missing the key "
+ self.hash_key_attr
+ " in the item"
)
2013-12-05 13:16:56 +02:00
hash_value = DynamoType(item_attrs.get(self.hash_key_attr))
if self.has_range_key:
if self.range_key_attr not in item_attrs.keys():
raise MockValidationException(
"One or more parameter values were invalid: Missing the key "
+ self.range_key_attr
+ " in the item"
)
2013-12-05 13:16:56 +02:00
range_value = DynamoType(item_attrs.get(self.range_key_attr))
else:
range_value = None
if hash_value.type != self.hash_key_type:
raise InvalidAttributeTypeError(
self.hash_key_attr,
expected_type=self.hash_key_type,
actual_type=hash_value.type,
)
if range_value and range_value.type != self.range_key_type:
raise InvalidAttributeTypeError(
self.range_key_attr,
expected_type=self.range_key_type,
actual_type=range_value.type,
)
self._validate_key_sizes(item_attrs)
self._validate_item_types(item_attrs)
if expected is None:
expected = {}
lookup_range_value = range_value
else:
2019-10-31 08:44:26 -07:00
expected_range_value = expected.get(self.range_key_attr, {}).get("Value")
if expected_range_value is None:
lookup_range_value = range_value
else:
lookup_range_value = DynamoType(expected_range_value)
current = self.get_item(hash_value, lookup_range_value)
item = Item(hash_value, range_value, item_attrs)
2013-12-05 13:16:56 +02:00
if not overwrite:
2019-04-01 16:23:49 -04:00
if not get_expected(expected).expr(current):
raise ConditionalCheckFailed
condition_op = get_filter_expression(
condition_expression,
expression_attribute_names,
2019-10-31 08:44:26 -07:00
expression_attribute_values,
)
if not condition_op.expr(current):
raise ConditionalCheckFailed
2019-04-01 16:23:49 -04:00
2013-12-05 13:16:56 +02:00
if range_value:
self.items[hash_value][range_value] = item
else:
self.items[hash_value] = item
2018-11-08 11:08:24 -05:00
if self.stream_shard is not None:
self.stream_shard.add(current, item)
2018-11-08 11:08:24 -05:00
2013-12-05 13:16:56 +02:00
return item
2014-08-26 13:25:50 -04:00
2013-12-05 13:16:56 +02:00
def __nonzero__(self):
return True
2014-08-26 13:25:50 -04:00
2014-10-26 21:11:03 -04:00
def __bool__(self):
return self.__nonzero__()
2013-12-05 13:16:56 +02:00
@property
def has_range_key(self):
return self.range_key_attr is not None
2014-08-26 13:25:50 -04:00
def get_item(self, hash_key, range_key=None, projection_expression=None):
2013-12-05 13:16:56 +02:00
if self.has_range_key and not range_key:
raise MockValidationException(
2019-10-31 08:44:26 -07:00
"Table has a range key, but no range key was passed into get_item"
)
2013-12-05 13:16:56 +02:00
try:
result = None
2013-12-05 13:16:56 +02:00
if range_key:
result = self.items[hash_key][range_key]
elif hash_key in self.items:
result = self.items[hash_key]
if projection_expression and result:
result = copy.deepcopy(result)
result.filter(projection_expression)
if not result:
raise KeyError
return result
2013-12-05 13:16:56 +02:00
except KeyError:
return None
2014-08-26 13:25:50 -04:00
2013-12-05 13:16:56 +02:00
def delete_item(self, hash_key, range_key):
try:
if range_key:
item = self.items[hash_key].pop(range_key)
2013-12-05 13:16:56 +02:00
else:
item = self.items.pop(hash_key)
if self.stream_shard is not None:
self.stream_shard.add(item, None)
2018-11-08 11:08:24 -05:00
return item
2013-12-05 13:16:56 +02:00
except KeyError:
return None
2014-08-26 13:25:50 -04:00
2019-10-31 08:44:26 -07:00
def query(
self,
hash_key,
range_comparison,
range_objs,
limit,
exclusive_start_key,
scan_index_forward,
projection_expression,
index_name=None,
filter_expression=None,
**filter_kwargs,
2019-10-31 08:44:26 -07:00
):
2013-12-05 13:16:56 +02:00
results = []
2015-12-13 23:12:11 -07:00
if index_name:
all_indexes = self.all_indexes()
2020-06-13 20:27:05 +01:00
indexes_by_name = dict((i.name, i) for i in all_indexes)
2015-12-13 23:12:11 -07:00
if index_name not in indexes_by_name:
raise MockValidationException(
2019-10-31 08:44:26 -07:00
"Invalid index: %s for table: %s. Available indexes are: %s"
% (index_name, self.name, ", ".join(indexes_by_name.keys()))
)
2015-12-13 23:12:11 -07:00
index = indexes_by_name[index_name]
try:
2019-10-31 08:44:26 -07:00
index_hash_key = [
2020-06-13 20:27:05 +01:00
key for key in index.schema if key["KeyType"] == "HASH"
2019-10-31 08:44:26 -07:00
][0]
2015-12-13 23:12:11 -07:00
except IndexError:
raise MockValidationException(
"Missing Hash Key. KeySchema: %s" % index.name
)
2015-12-13 23:12:11 -07:00
try:
2019-10-31 08:44:26 -07:00
index_range_key = [
2020-06-13 20:27:05 +01:00
key for key in index.schema if key["KeyType"] == "RANGE"
2019-10-31 08:44:26 -07:00
][0]
except IndexError:
index_range_key = None
2015-12-13 23:12:11 -07:00
possible_results = []
for item in self.all_items():
if not isinstance(item, Item):
continue
2019-10-31 08:44:26 -07:00
item_hash_key = item.attrs.get(index_hash_key["AttributeName"])
if index_range_key is None:
if item_hash_key and item_hash_key == hash_key:
possible_results.append(item)
else:
2019-10-31 08:44:26 -07:00
item_range_key = item.attrs.get(index_range_key["AttributeName"])
if item_hash_key and item_hash_key == hash_key and item_range_key:
possible_results.append(item)
2015-12-13 23:12:11 -07:00
else:
2019-10-31 08:44:26 -07:00
possible_results = [
item
for item in list(self.all_items())
if isinstance(item, Item) and item.hash_key == hash_key
]
2013-12-05 13:16:56 +02:00
if range_comparison:
2015-12-13 23:12:11 -07:00
if index_name and not index_range_key:
2017-02-23 21:37:43 -05:00
raise ValueError(
2019-10-31 08:44:26 -07:00
"Range Key comparison but no range key found for index: %s"
% index_name
)
2015-12-13 23:12:11 -07:00
elif index_name:
for result in possible_results:
2019-10-31 08:44:26 -07:00
if result.attrs.get(index_range_key["AttributeName"]).compare(
range_comparison, range_objs
):
2015-12-13 23:12:11 -07:00
results.append(result)
else:
for result in possible_results:
if result.range_key.compare(range_comparison, range_objs):
results.append(result)
2016-03-16 21:30:51 -07:00
if filter_kwargs:
for result in possible_results:
for field, value in filter_kwargs.items():
2019-10-31 08:44:26 -07:00
dynamo_types = [
DynamoType(ele) for ele in value["AttributeValueList"]
]
if result.attrs.get(field).compare(
value["ComparisonOperator"], dynamo_types
):
2016-03-16 21:30:51 -07:00
results.append(result)
if not range_comparison and not filter_kwargs:
2017-02-23 21:37:43 -05:00
# If we're not filtering on range key or on an index return all
# values
2013-12-05 13:16:56 +02:00
results = possible_results
2015-12-13 23:12:11 -07:00
if index_name:
if index_range_key:
# Convert to float if necessary to ensure proper ordering
def conv(x):
return float(x.value) if x.type == "N" else x.value
2019-10-31 08:44:26 -07:00
results.sort(
key=lambda item: conv(item.attrs[index_range_key["AttributeName"]])
2019-10-31 08:44:26 -07:00
if item.attrs.get(index_range_key["AttributeName"])
else None
)
2015-12-13 23:12:11 -07:00
else:
results.sort(key=lambda item: item.range_key)
2016-03-01 11:30:35 -08:00
if scan_index_forward is False:
results.reverse()
2016-03-16 16:17:41 -06:00
scanned_count = len(list(self.all_items()))
results = copy.deepcopy(results)
if index_name:
index = self.get_index(index_name)
for result in results:
index.project(result)
2019-10-31 08:44:26 -07:00
results, last_evaluated_key = self._trim_results(
results, limit, exclusive_start_key, scanned_index=index_name
2019-10-31 08:44:26 -07:00
)
if filter_expression is not None:
results = [item for item in results if filter_expression.expr(item)]
if projection_expression:
for result in results:
result.filter(projection_expression)
return results, scanned_count, last_evaluated_key
2013-12-05 13:16:56 +02:00
def all_items(self):
for hash_set in self.items.values():
if self.range_key_attr:
for item in hash_set.values():
yield item
else:
yield hash_set
2014-08-26 13:25:50 -04:00
def all_indexes(self):
return (self.global_indexes or []) + (self.indexes or [])
def get_index(self, index_name, error_if_not=False):
all_indexes = self.all_indexes()
2020-06-13 20:27:05 +01:00
indexes_by_name = dict((i.name, i) for i in all_indexes)
if error_if_not and index_name not in indexes_by_name:
raise InvalidIndexNameError(
"The table does not have the specified index: %s" % index_name
)
return indexes_by_name[index_name]
def has_idx_items(self, index_name):
idx = self.get_index(index_name)
2020-06-13 20:27:05 +01:00
idx_col_set = set([i["AttributeName"] for i in idx.schema])
for hash_set in self.items.values():
if self.range_key_attr:
for item in hash_set.values():
if idx_col_set.issubset(set(item.attrs)):
yield item
else:
if idx_col_set.issubset(set(hash_set.attrs)):
yield hash_set
2019-10-31 08:44:26 -07:00
def scan(
self,
filters,
limit,
exclusive_start_key,
filter_expression=None,
index_name=None,
projection_expression=None,
):
2013-12-05 13:16:56 +02:00
results = []
scanned_count = 0
if index_name:
self.get_index(index_name, error_if_not=True)
items = self.has_idx_items(index_name)
else:
items = self.all_items()
for item in items:
2013-12-05 13:16:56 +02:00
scanned_count += 1
passes_all_conditions = True
2019-10-31 08:44:26 -07:00
for (
attribute_name,
(comparison_operator, comparison_objs),
) in filters.items():
attribute = item.attrs.get(attribute_name)
2013-12-05 13:16:56 +02:00
if attribute:
# Attribute found
if not attribute.compare(comparison_operator, comparison_objs):
passes_all_conditions = False
break
2019-10-31 08:44:26 -07:00
elif comparison_operator == "NULL":
2013-12-05 13:16:56 +02:00
# Comparison is NULL and we don't have the attribute
continue
else:
2017-02-23 21:37:43 -05:00
# No attribute found and comparison is no NULL. This item
# fails
2013-12-05 13:16:56 +02:00
passes_all_conditions = False
break
if passes_all_conditions:
results.append(item)
2016-03-01 11:30:35 -08:00
results, last_evaluated_key = self._trim_results(
results, limit, exclusive_start_key, scanned_index=index_name
)
if filter_expression is not None:
results = [item for item in results if filter_expression.expr(item)]
if projection_expression:
results = copy.deepcopy(results)
for result in results:
result.filter(projection_expression)
2016-03-01 11:30:35 -08:00
return results, scanned_count, last_evaluated_key
2019-10-10 09:14:22 +01:00
def _trim_results(self, results, limit, exclusive_start_key, scanned_index=None):
2016-03-01 11:30:35 -08:00
if exclusive_start_key is not None:
hash_key = DynamoType(exclusive_start_key.get(self.hash_key_attr))
range_key = exclusive_start_key.get(self.range_key_attr)
if range_key is not None:
range_key = DynamoType(range_key)
for i in range(len(results)):
2019-10-31 08:44:26 -07:00
if (
results[i].hash_key == hash_key
and results[i].range_key == range_key
):
results = results[i + 1 :]
2016-03-01 11:30:35 -08:00
break
last_evaluated_key = None
2020-03-21 12:20:09 +00:00
size_limit = 1000000 # DynamoDB has a 1MB size limit
item_size = sum(res.size() for res in results)
2020-03-21 12:20:09 +00:00
if item_size > size_limit:
item_size = idx = 0
while item_size + results[idx].size() < size_limit:
item_size += results[idx].size()
idx += 1
limit = min(limit, idx) if limit else idx
2016-03-01 11:30:35 -08:00
if limit and len(results) > limit:
results = results[:limit]
2019-10-31 08:44:26 -07:00
last_evaluated_key = {self.hash_key_attr: results[-1].hash_key}
2016-03-01 11:30:35 -08:00
if results[-1].range_key is not None:
last_evaluated_key[self.range_key_attr] = results[-1].range_key
2019-10-10 09:14:22 +01:00
if scanned_index:
idx = self.get_index(scanned_index)
2020-06-13 20:27:05 +01:00
idx_col_list = [i["AttributeName"] for i in idx.schema]
for col in idx_col_list:
last_evaluated_key[col] = results[-1].attrs[col]
2016-03-01 11:30:35 -08:00
return results, last_evaluated_key
2022-08-13 09:49:43 +00:00
def delete(self, account_id, region_name):
dynamodb_backends[account_id][region_name].delete_table(self.name)
2013-12-05 13:16:56 +02:00
class RestoredTable(Table):
2022-08-13 09:49:43 +00:00
def __init__(self, name, account_id, region, backup):
params = self._parse_params_from_backup(backup)
2022-08-13 09:49:43 +00:00
super().__init__(name, account_id=account_id, region=region, **params)
self.indexes = copy.deepcopy(backup.table.indexes)
self.global_indexes = copy.deepcopy(backup.table.global_indexes)
self.items = copy.deepcopy(backup.table.items)
# Restore Attrs
self.source_backup_arn = backup.arn
self.source_table_arn = backup.table.table_arn
self.restore_date_time = self.created_at
@staticmethod
def _parse_params_from_backup(backup):
params = {
"schema": copy.deepcopy(backup.table.schema),
"attr": copy.deepcopy(backup.table.attr),
"throughput": copy.deepcopy(backup.table.throughput),
}
return params
def describe(self, base_key="TableDescription"):
result = super().describe(base_key=base_key)
result[base_key]["RestoreSummary"] = {
"SourceBackupArn": self.source_backup_arn,
"SourceTableArn": self.source_table_arn,
"RestoreDateTime": unix_time(self.restore_date_time),
"RestoreInProgress": False,
}
return result
class RestoredPITTable(Table):
2022-08-13 09:49:43 +00:00
def __init__(self, name, account_id, region, source):
params = self._parse_params_from_table(source)
2022-08-13 09:49:43 +00:00
super().__init__(name, account_id=account_id, region=region, **params)
self.indexes = copy.deepcopy(source.indexes)
self.global_indexes = copy.deepcopy(source.global_indexes)
self.items = copy.deepcopy(source.items)
# Restore Attrs
self.source_table_arn = source.table_arn
self.restore_date_time = self.created_at
@staticmethod
def _parse_params_from_table(table):
params = {
"schema": copy.deepcopy(table.schema),
"attr": copy.deepcopy(table.attr),
"throughput": copy.deepcopy(table.throughput),
}
return params
def describe(self, base_key="TableDescription"):
result = super().describe(base_key=base_key)
result[base_key]["RestoreSummary"] = {
"SourceTableArn": self.source_table_arn,
"RestoreDateTime": unix_time(self.restore_date_time),
"RestoreInProgress": False,
}
return result
class Backup(object):
2022-03-10 13:39:59 -01:00
def __init__(self, backend, name, table, status=None, type_=None):
self.backend = backend
self.name = name
self.table = copy.deepcopy(table)
self.status = status or "AVAILABLE"
self.type = type_ or "USER"
self.creation_date_time = datetime.datetime.utcnow()
self.identifier = self._make_identifier()
def _make_identifier(self):
timestamp = int(unix_time_millis(self.creation_date_time))
timestamp_padded = str("0" + str(timestamp))[-16:16]
guid = str(uuid.uuid4())
guid_shortened = guid[:8]
return "{}-{}".format(timestamp_padded, guid_shortened)
@property
def arn(self):
return "arn:aws:dynamodb:{region}:{account}:table/{table_name}/backup/{identifier}".format(
region=self.backend.region_name,
2022-08-13 09:49:43 +00:00
account=self.backend.account_id,
table_name=self.table.name,
identifier=self.identifier,
)
@property
def details(self):
details = {
"BackupArn": self.arn,
"BackupName": self.name,
"BackupSizeBytes": 123,
"BackupStatus": self.status,
"BackupType": self.type,
"BackupCreationDateTime": unix_time(self.creation_date_time),
}
return details
@property
def summary(self):
summary = {
"TableName": self.table.name,
# 'TableId': 'string',
"TableArn": self.table.table_arn,
"BackupArn": self.arn,
"BackupName": self.name,
"BackupCreationDateTime": unix_time(self.creation_date_time),
# 'BackupExpiryDateTime': datetime(2015, 1, 1),
"BackupStatus": self.status,
"BackupType": self.type,
"BackupSizeBytes": 123,
}
return summary
@property
def description(self):
source_table_details = self.table.describe()["TableDescription"]
source_table_details["TableCreationDateTime"] = source_table_details[
"CreationDateTime"
]
description = {
"BackupDetails": self.details,
"SourceTableDetails": source_table_details,
}
return description
2013-12-05 13:16:56 +02:00
class DynamoDBBackend(BaseBackend):
def __init__(self, region_name, account_id):
super().__init__(region_name, account_id)
2013-12-05 13:16:56 +02:00
self.tables = OrderedDict()
self.backups = OrderedDict()
2013-12-05 13:16:56 +02:00
@staticmethod
def default_vpc_endpoint_service(service_region, zones):
"""Default VPC endpoint service."""
# No 'vpce' in the base endpoint DNS name
return BaseBackend.default_vpc_endpoint_service_factory(
service_region,
zones,
"dynamodb",
"Gateway",
private_dns_names=False,
base_endpoint_dns_names=[f"dynamodb.{service_region}.amazonaws.com"],
)
2013-12-05 13:16:56 +02:00
def create_table(self, name, **params):
if name in self.tables:
raise ResourceInUseException
2022-08-13 09:49:43 +00:00
table = Table(
name, account_id=self.account_id, region=self.region_name, **params
)
2013-12-05 13:16:56 +02:00
self.tables[name] = table
return table
def delete_table(self, name):
if name not in self.tables:
raise ResourceNotFoundException
2013-12-05 13:16:56 +02:00
return self.tables.pop(name, None)
2021-08-28 07:25:06 +01:00
def describe_endpoints(self):
return [
{
"Address": "dynamodb.{}.amazonaws.com".format(self.region_name),
"CachePeriodInMinutes": 1440,
}
]
2017-05-10 21:58:42 -04:00
def tag_resource(self, table_arn, tags):
for table in self.tables:
if self.tables[table].table_arn == table_arn:
self.tables[table].tags.extend(tags)
2017-10-29 16:06:09 +00:00
def untag_resource(self, table_arn, tag_keys):
for table in self.tables:
if self.tables[table].table_arn == table_arn:
2019-10-31 08:44:26 -07:00
self.tables[table].tags = [
tag for tag in self.tables[table].tags if tag["Key"] not in tag_keys
]
2017-10-29 16:06:09 +00:00
2017-05-10 21:58:42 -04:00
def list_tags_of_resource(self, table_arn):
for table in self.tables:
if self.tables[table].table_arn == table_arn:
return self.tables[table].tags
raise ResourceNotFoundException
2017-05-10 21:58:42 -04:00
def list_tables(self, limit, exclusive_start_table_name):
all_tables = list(self.tables.keys())
if exclusive_start_table_name:
try:
last_table_index = all_tables.index(exclusive_start_table_name)
except ValueError:
start = len(all_tables)
else:
start = last_table_index + 1
else:
start = 0
if limit:
tables = all_tables[start : start + limit]
else:
tables = all_tables[start:]
if limit and len(all_tables) > start + limit:
return tables, tables[-1]
return tables, None
def describe_table(self, name):
table = self.get_table(name)
return table.describe(base_key="Table")
2022-02-10 19:09:45 -01:00
def update_table(
self,
name,
attr_definitions,
global_index,
throughput,
billing_mode,
stream_spec,
):
table = self.get_table(name)
2022-02-10 19:09:45 -01:00
if attr_definitions:
table.attr = attr_definitions
if global_index:
table = self.update_table_global_indexes(name, global_index)
if throughput:
table = self.update_table_throughput(name, throughput)
2022-02-10 19:09:45 -01:00
if billing_mode:
table = self.update_table_billing_mode(name, billing_mode)
if stream_spec:
table = self.update_table_streams(name, stream_spec)
return table
2013-12-05 13:16:56 +02:00
def update_table_throughput(self, name, throughput):
table = self.tables[name]
table.throughput = throughput
return table
2022-02-10 19:09:45 -01:00
def update_table_billing_mode(self, name, billing_mode):
table = self.tables[name]
table.billing_mode = billing_mode
return table
def update_table_streams(self, name, stream_specification):
table = self.tables[name]
2019-10-31 08:44:26 -07:00
if (
stream_specification.get("StreamEnabled")
or stream_specification.get("StreamViewType")
) and table.latest_stream_label:
raise StreamAlreadyEnabledException
table.set_stream_specification(stream_specification)
return table
def update_table_global_indexes(self, name, global_index_updates):
table = self.tables[name]
2020-06-13 20:27:05 +01:00
gsis_by_name = dict((i.name, i) for i in table.global_indexes)
for gsi_update in global_index_updates:
2019-10-31 08:44:26 -07:00
gsi_to_create = gsi_update.get("Create")
gsi_to_update = gsi_update.get("Update")
gsi_to_delete = gsi_update.get("Delete")
if gsi_to_delete:
2019-10-31 08:44:26 -07:00
index_name = gsi_to_delete["IndexName"]
if index_name not in gsis_by_name:
2019-10-31 08:44:26 -07:00
raise ValueError(
"Global Secondary Index does not exist, but tried to delete: %s"
% gsi_to_delete["IndexName"]
)
del gsis_by_name[index_name]
if gsi_to_update:
2019-10-31 08:44:26 -07:00
index_name = gsi_to_update["IndexName"]
if index_name not in gsis_by_name:
2019-10-31 08:44:26 -07:00
raise ValueError(
"Global Secondary Index does not exist, but tried to update: %s"
2020-06-13 20:27:05 +01:00
% index_name
2019-10-31 08:44:26 -07:00
)
gsis_by_name[index_name].update(gsi_to_update)
if gsi_to_create:
2019-10-31 08:44:26 -07:00
if gsi_to_create["IndexName"] in gsis_by_name:
2017-02-23 21:37:43 -05:00
raise ValueError(
2019-10-31 08:44:26 -07:00
"Global Secondary Index already exists: %s"
% gsi_to_create["IndexName"]
)
2020-06-13 20:27:05 +01:00
gsis_by_name[gsi_to_create["IndexName"]] = GlobalSecondaryIndex.create(
2022-03-10 13:39:59 -01:00
gsi_to_create, table.table_key_attrs
2020-06-13 20:27:05 +01:00
)
# in python 3.6, dict.values() returns a dict_values object, but we expect it to be a list in other
# parts of the codebase
table.global_indexes = list(gsis_by_name.values())
return table
2019-10-31 08:44:26 -07:00
def put_item(
self,
table_name,
item_attrs,
expected=None,
condition_expression=None,
expression_attribute_names=None,
expression_attribute_values=None,
overwrite=False,
):
table = self.get_table(table_name)
2019-10-31 08:44:26 -07:00
return table.put_item(
item_attrs,
expected,
condition_expression,
expression_attribute_names,
expression_attribute_values,
overwrite,
)
2014-08-26 13:25:50 -04:00
def get_table_keys_name(self, table_name, keys):
"""
Given a set of keys, extracts the key and range key
"""
2013-12-05 13:16:56 +02:00
table = self.tables.get(table_name)
if not table:
return None, None
else:
if len(keys) == 1:
for key in keys:
if key in table.hash_key_names:
return key, None
2016-03-16 21:30:51 -07:00
# for potential_hash, potential_range in zip(table.hash_key_names, table.range_key_names):
# if set([potential_hash, potential_range]) == set(keys):
# return potential_hash, potential_range
potential_hash, potential_range = None, None
for key in set(keys):
if key in table.hash_key_names:
potential_hash = key
elif key in table.range_key_names:
potential_range = key
return potential_hash, potential_range
2014-08-26 13:25:50 -04:00
2013-12-05 13:16:56 +02:00
def get_keys_value(self, table, keys):
2019-10-31 08:44:26 -07:00
if table.hash_key_attr not in keys or (
table.has_range_key and table.range_key_attr not in keys
):
# "Table has a range key, but no range key was passed into get_item"
raise MockValidationException("Validation Exception")
2014-08-26 13:25:50 -04:00
hash_key = DynamoType(keys[table.hash_key_attr])
2019-10-31 08:44:26 -07:00
range_key = (
DynamoType(keys[table.range_key_attr]) if table.has_range_key else None
)
2014-11-15 09:35:52 -05:00
return hash_key, range_key
2013-12-05 13:16:56 +02:00
def get_schema(self, table_name, index_name):
table = self.get_table(table_name)
if index_name:
all_indexes = (table.global_indexes or []) + (table.indexes or [])
indexes_by_name = dict((i.name, i) for i in all_indexes)
if index_name not in indexes_by_name:
raise ResourceNotFoundException(
"Invalid index: {} for table: {}. Available indexes are: {}".format(
index_name, table_name, ", ".join(indexes_by_name.keys())
)
)
return indexes_by_name[index_name].schema
else:
return table.schema
def get_table(self, table_name):
if table_name not in self.tables:
raise ResourceNotFoundException()
return self.tables.get(table_name)
def get_item(self, table_name, keys, projection_expression=None):
table = self.get_table(table_name)
2014-11-15 09:35:52 -05:00
hash_key, range_key = self.get_keys_value(table, keys)
return table.get_item(hash_key, range_key, projection_expression)
2013-12-05 13:16:56 +02:00
2019-10-31 08:44:26 -07:00
def query(
self,
table_name,
hash_key_dict,
range_comparison,
range_value_dicts,
limit,
exclusive_start_key,
scan_index_forward,
projection_expression,
index_name=None,
expr_names=None,
expr_values=None,
filter_expression=None,
**filter_kwargs,
2019-10-31 08:44:26 -07:00
):
table = self.get_table(table_name)
2013-12-05 13:16:56 +02:00
hash_key = DynamoType(hash_key_dict)
2019-10-31 08:44:26 -07:00
range_values = [DynamoType(range_value) for range_value in range_value_dicts]
2013-12-05 13:16:56 +02:00
2019-10-31 08:44:26 -07:00
filter_expression = get_filter_expression(
filter_expression, expr_names, expr_values
)
2019-10-31 08:44:26 -07:00
return table.query(
hash_key,
range_comparison,
range_values,
limit,
exclusive_start_key,
scan_index_forward,
projection_expression,
index_name,
filter_expression,
**filter_kwargs,
2019-10-31 08:44:26 -07:00
)
2014-08-26 13:25:50 -04:00
2019-10-31 08:44:26 -07:00
def scan(
self,
table_name,
filters,
limit,
exclusive_start_key,
filter_expression,
expr_names,
expr_values,
index_name,
projection_expression,
):
table = self.get_table(table_name)
2013-12-05 13:16:56 +02:00
scan_filters = {}
2014-08-26 13:25:50 -04:00
for key, (comparison_operator, comparison_values) in filters.items():
2013-12-05 13:16:56 +02:00
dynamo_types = [DynamoType(value) for value in comparison_values]
scan_filters[key] = (comparison_operator, dynamo_types)
2019-10-31 08:44:26 -07:00
filter_expression = get_filter_expression(
filter_expression, expr_names, expr_values
)
2019-10-31 08:44:26 -07:00
return table.scan(
scan_filters,
limit,
exclusive_start_key,
filter_expression,
index_name,
projection_expression,
)
2014-08-26 13:25:50 -04:00
2019-10-31 08:44:26 -07:00
def update_item(
self,
table_name,
key,
update_expression,
expression_attribute_names,
expression_attribute_values,
attribute_updates=None,
2019-10-31 08:44:26 -07:00
expected=None,
condition_expression=None,
):
table = self.get_table(table_name)
# Support spaces between operators in an update expression
# E.g. `a = b + c` -> `a=b+c`
if update_expression:
# Parse expression to get validation errors
Enable AST Validation This commit puts AST validation on the execution path. This means updates get validated prior to being executed. There were quite a few tests that were not working against Amazon DDB. These tests I considered broken and as such this commit adapts them such that they pass against Amazon DDB. test_update_item_on_map() => One of the SET actions would try to set a nested element by specifying the nesting on the path rather than by putting a map as a value for a non-existent key. This got changed. test_item_size_is_under_400KB => Used the keyword "item" which DDB doesn't like. Change to cont in order to keep the same sizings. => Secondly the size error messages differs a bit depending whether it is part of the update or part of a put_item. For an update it should be: Item size to update has exceeded the maximum allowed size otherwise it is Item size has exceeded the maximum allowed size' test_remove_top_level_attribute => Used a keyword item. Use ExpressionAttributeNames test_update_item_double_nested_remove => Used keywords name & first. Migrated to non-deprecated API and use ExpressionAttributeNames test_update_item_set & test_boto3_update_item_conditions_pass & test_boto3_update_item_conditions_pass_because_expect_not_exists & test_boto3_update_item_conditions_pass_because_expect_not_exists_by_compare_to_null & test_boto3_update_item_conditions_pass_because_expect_exists_by_compare_to_not_null & test_boto3_update_item_conditions_fail & test_boto3_update_item_conditions_fail_because_expect_not_exists & test_boto3_update_item_conditions_fail_because_expect_not_exists_by_compare_to_null => Were broken tests which had string literal instead of value placeholder
2020-04-19 16:50:53 +01:00
update_expression_ast = UpdateExpressionParser.make(update_expression)
update_expression = re.sub(r"\s*([=\+-])\s*", "\\1", update_expression)
update_expression_ast.validate()
if all([table.hash_key_attr in key, table.range_key_attr in key]):
2017-02-23 21:37:43 -05:00
# Covers cases where table has hash and range keys, ``key`` param
# will be a dict
hash_value = DynamoType(key[table.hash_key_attr])
range_value = DynamoType(key[table.range_key_attr])
elif table.hash_key_attr in key:
2015-12-31 10:46:54 -08:00
# Covers tables that have a range key where ``key`` param is a dict
hash_value = DynamoType(key[table.hash_key_attr])
range_value = None
else:
2015-12-31 10:46:54 -08:00
# Covers other cases
hash_value = DynamoType(key)
range_value = None
item = table.get_item(hash_value, range_value)
orig_item = copy.deepcopy(item)
if not expected:
expected = {}
2019-04-01 16:23:49 -04:00
if not get_expected(expected).expr(item):
raise ConditionalCheckFailed
2019-04-12 10:13:36 -04:00
condition_op = get_filter_expression(
condition_expression,
expression_attribute_names,
2019-10-31 08:44:26 -07:00
expression_attribute_values,
)
2019-04-12 10:13:36 -04:00
if not condition_op.expr(item):
raise ConditionalCheckFailed
# Update does not fail on new items, so create one
if item is None:
if update_expression:
# Validate AST before creating anything
2022-03-10 13:39:59 -01:00
item = Item(hash_value, range_value, attrs={})
UpdateExpressionValidator(
update_expression_ast,
expression_attribute_names=expression_attribute_names,
expression_attribute_values=expression_attribute_values,
item=item,
table=table,
).validate()
2019-10-31 08:44:26 -07:00
data = {table.hash_key_attr: {hash_value.type: hash_value.value}}
if range_value:
2019-10-31 08:44:26 -07:00
data.update(
{table.range_key_attr: {range_value.type: range_value.value}}
)
table.put_item(data)
item = table.get_item(hash_value, range_value)
if attribute_updates:
2022-02-10 19:09:45 -01:00
item.validate_no_empty_key_values(attribute_updates, table.attribute_keys)
if update_expression:
validator = UpdateExpressionValidator(
Enable AST Validation This commit puts AST validation on the execution path. This means updates get validated prior to being executed. There were quite a few tests that were not working against Amazon DDB. These tests I considered broken and as such this commit adapts them such that they pass against Amazon DDB. test_update_item_on_map() => One of the SET actions would try to set a nested element by specifying the nesting on the path rather than by putting a map as a value for a non-existent key. This got changed. test_item_size_is_under_400KB => Used the keyword "item" which DDB doesn't like. Change to cont in order to keep the same sizings. => Secondly the size error messages differs a bit depending whether it is part of the update or part of a put_item. For an update it should be: Item size to update has exceeded the maximum allowed size otherwise it is Item size has exceeded the maximum allowed size' test_remove_top_level_attribute => Used a keyword item. Use ExpressionAttributeNames test_update_item_double_nested_remove => Used keywords name & first. Migrated to non-deprecated API and use ExpressionAttributeNames test_update_item_set & test_boto3_update_item_conditions_pass & test_boto3_update_item_conditions_pass_because_expect_not_exists & test_boto3_update_item_conditions_pass_because_expect_not_exists_by_compare_to_null & test_boto3_update_item_conditions_pass_because_expect_exists_by_compare_to_not_null & test_boto3_update_item_conditions_fail & test_boto3_update_item_conditions_fail_because_expect_not_exists & test_boto3_update_item_conditions_fail_because_expect_not_exists_by_compare_to_null => Were broken tests which had string literal instead of value placeholder
2020-04-19 16:50:53 +01:00
update_expression_ast,
expression_attribute_names=expression_attribute_names,
expression_attribute_values=expression_attribute_values,
item=item,
table=table,
)
validated_ast = validator.validate()
try:
UpdateExpressionExecutor(
validated_ast, item, expression_attribute_names
).execute()
except ItemSizeTooLarge:
raise ItemSizeToUpdateTooLarge()
else:
item.update_with_attribute_updates(attribute_updates)
if table.stream_shard is not None:
table.stream_shard.add(orig_item, item)
return item
2019-10-31 08:44:26 -07:00
def delete_item(
self,
table_name,
key,
expression_attribute_names=None,
expression_attribute_values=None,
condition_expression=None,
):
table = self.get_table(table_name)
hash_value, range_value = self.get_keys_value(table, key)
item = table.get_item(hash_value, range_value)
condition_op = get_filter_expression(
condition_expression,
expression_attribute_names,
2019-10-31 08:44:26 -07:00
expression_attribute_values,
)
if not condition_op.expr(item):
raise ConditionalCheckFailed
return table.delete_item(hash_value, range_value)
2013-12-05 13:16:56 +02:00
def update_time_to_live(self, table_name, ttl_spec):
2017-10-29 16:06:09 +00:00
table = self.tables.get(table_name)
if table is None:
2019-10-31 08:44:26 -07:00
raise JsonRESTError("ResourceNotFound", "Table not found")
2017-10-29 16:06:09 +00:00
2019-10-31 08:44:26 -07:00
if "Enabled" not in ttl_spec or "AttributeName" not in ttl_spec:
raise JsonRESTError(
"InvalidParameterValue",
"TimeToLiveSpecification does not contain Enabled and AttributeName",
)
2017-10-29 16:06:09 +00:00
2019-10-31 08:44:26 -07:00
if ttl_spec["Enabled"]:
table.ttl["TimeToLiveStatus"] = "ENABLED"
2017-10-29 16:06:09 +00:00
else:
2019-10-31 08:44:26 -07:00
table.ttl["TimeToLiveStatus"] = "DISABLED"
table.ttl["AttributeName"] = ttl_spec["AttributeName"]
2017-10-29 16:06:09 +00:00
def describe_time_to_live(self, table_name):
2017-10-29 16:06:09 +00:00
table = self.tables.get(table_name)
if table is None:
2019-10-31 08:44:26 -07:00
raise JsonRESTError("ResourceNotFound", "Table not found")
2017-10-29 16:06:09 +00:00
return table.ttl
def transact_write_items(self, transact_items):
2022-02-10 19:09:45 -01:00
if len(transact_items) > 25:
raise TooManyTransactionsException()
# Create a backup in case any of the transactions fail
original_table_state = copy.deepcopy(self.tables)
target_items = set()
def check_unicity(table_name, key):
item = (str(table_name), str(key))
if item in target_items:
raise MultipleTransactionsException()
target_items.add(item)
errors = []
for item in transact_items:
try:
if "ConditionCheck" in item:
item = item["ConditionCheck"]
key = item["Key"]
table_name = item["TableName"]
check_unicity(table_name, key)
condition_expression = item.get("ConditionExpression", None)
expression_attribute_names = item.get(
"ExpressionAttributeNames", None
)
expression_attribute_values = item.get(
"ExpressionAttributeValues", None
)
current = self.get_item(table_name, key)
condition_op = get_filter_expression(
condition_expression,
expression_attribute_names,
expression_attribute_values,
)
if not condition_op.expr(current):
raise ConditionalCheckFailed()
elif "Put" in item:
item = item["Put"]
attrs = item["Item"]
table_name = item["TableName"]
condition_expression = item.get("ConditionExpression", None)
expression_attribute_names = item.get(
"ExpressionAttributeNames", None
)
expression_attribute_values = item.get(
"ExpressionAttributeValues", None
)
self.put_item(
table_name,
attrs,
condition_expression=condition_expression,
expression_attribute_names=expression_attribute_names,
expression_attribute_values=expression_attribute_values,
)
elif "Delete" in item:
item = item["Delete"]
key = item["Key"]
table_name = item["TableName"]
check_unicity(table_name, key)
condition_expression = item.get("ConditionExpression", None)
expression_attribute_names = item.get(
"ExpressionAttributeNames", None
)
expression_attribute_values = item.get(
"ExpressionAttributeValues", None
)
self.delete_item(
table_name,
key,
condition_expression=condition_expression,
expression_attribute_names=expression_attribute_names,
expression_attribute_values=expression_attribute_values,
)
elif "Update" in item:
item = item["Update"]
key = item["Key"]
table_name = item["TableName"]
check_unicity(table_name, key)
update_expression = item["UpdateExpression"]
condition_expression = item.get("ConditionExpression", None)
expression_attribute_names = item.get(
"ExpressionAttributeNames", None
)
expression_attribute_values = item.get(
"ExpressionAttributeValues", None
)
self.update_item(
table_name,
key,
update_expression=update_expression,
condition_expression=condition_expression,
expression_attribute_names=expression_attribute_names,
expression_attribute_values=expression_attribute_values,
)
else:
raise ValueError
errors.append((None, None))
except MultipleTransactionsException:
# Rollback to the original state, and reraise the error
self.tables = original_table_state
raise MultipleTransactionsException()
except Exception as e: # noqa: E722 Do not use bare except
errors.append((type(e).__name__, e.message))
if set(errors) != set([(None, None)]):
# Rollback to the original state, and reraise the errors
self.tables = original_table_state
raise TransactionCanceledException(errors)
def describe_continuous_backups(self, table_name):
try:
table = self.get_table(table_name)
except ResourceNotFoundException:
raise TableNotFoundException(table_name)
return table.continuous_backups
def update_continuous_backups(self, table_name, point_in_time_spec):
try:
table = self.get_table(table_name)
except ResourceNotFoundException:
raise TableNotFoundException(table_name)
if (
point_in_time_spec["PointInTimeRecoveryEnabled"]
and table.continuous_backups["PointInTimeRecoveryDescription"][
"PointInTimeRecoveryStatus"
]
== "DISABLED"
):
table.continuous_backups["PointInTimeRecoveryDescription"] = {
"PointInTimeRecoveryStatus": "ENABLED",
"EarliestRestorableDateTime": unix_time(),
"LatestRestorableDateTime": unix_time(),
}
elif not point_in_time_spec["PointInTimeRecoveryEnabled"]:
table.continuous_backups["PointInTimeRecoveryDescription"] = {
"PointInTimeRecoveryStatus": "DISABLED"
}
return table.continuous_backups
def get_backup(self, backup_arn):
if backup_arn not in self.backups:
raise BackupNotFoundException(backup_arn)
return self.backups.get(backup_arn)
def list_backups(self, table_name):
backups = list(self.backups.values())
if table_name is not None:
backups = [backup for backup in backups if backup.table.name == table_name]
return backups
def create_backup(self, table_name, backup_name):
try:
table = self.get_table(table_name)
except ResourceNotFoundException:
raise TableNotFoundException(table_name)
backup = Backup(self, backup_name, table)
self.backups[backup.arn] = backup
return backup
def delete_backup(self, backup_arn):
backup = self.get_backup(backup_arn)
if backup is None:
raise KeyError()
backup_deleted = self.backups.pop(backup_arn)
backup_deleted.status = "DELETED"
return backup_deleted
def describe_backup(self, backup_arn):
backup = self.get_backup(backup_arn)
if backup is None:
raise KeyError()
return backup
def restore_table_from_backup(self, target_table_name, backup_arn):
backup = self.get_backup(backup_arn)
if target_table_name in self.tables:
raise TableAlreadyExistsException(target_table_name)
2022-02-10 19:09:45 -01:00
new_table = RestoredTable(
2022-08-13 09:49:43 +00:00
target_table_name,
account_id=self.account_id,
region=self.region_name,
backup=backup,
2022-02-10 19:09:45 -01:00
)
self.tables[target_table_name] = new_table
return new_table
def restore_table_to_point_in_time(self, target_table_name, source_table_name):
2021-11-29 19:35:18 -01:00
"""
Currently this only accepts the source and target table elements, and will
copy all items from the source without respect to other arguments.
"""
try:
source = self.get_table(source_table_name)
except ResourceNotFoundException:
raise SourceTableNotFoundException(source_table_name)
if target_table_name in self.tables:
raise TableAlreadyExistsException(target_table_name)
2022-02-10 19:09:45 -01:00
new_table = RestoredPITTable(
2022-08-13 09:49:43 +00:00
target_table_name,
account_id=self.account_id,
region=self.region_name,
source=source,
2022-02-10 19:09:45 -01:00
)
self.tables[target_table_name] = new_table
return new_table
######################
# LIST of methods where the logic completely resides in responses.py
# Duplicated here so that the implementation coverage script is aware
# TODO: Move logic here
######################
def batch_get_item(self):
pass
def batch_write_item(self):
pass
def transact_get_items(self):
pass
2013-12-05 13:16:56 +02:00
2021-12-24 20:02:45 -01:00
dynamodb_backends = BackendDict(DynamoDBBackend, "dynamodb")