From 8a092c91ae9dcc4961754596b4398a7a2d1cf2ed Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Sat, 27 Jun 2020 11:07:15 +0100 Subject: [PATCH] DynamoDB - Add support for GSI's ProjectionType: KEYS_ONLY --- moto/dynamodb2/models/__init__.py | 44 +++++++++++++++++++-------- tests/test_dynamodb2/test_dynamodb.py | 44 +++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 12 deletions(-) diff --git a/moto/dynamodb2/models/__init__.py b/moto/dynamodb2/models/__init__.py index 13ee94948..7e288bb9d 100644 --- a/moto/dynamodb2/models/__init__.py +++ b/moto/dynamodb2/models/__init__.py @@ -331,6 +331,21 @@ class GlobalSecondaryIndex(BaseModel): self.projection = u.get("Projection", self.projection) self.throughput = u.get("ProvisionedThroughput", self.throughput) + def project(self, item): + """ + Enforces the ProjectionType of this GSI + Removes any non-wanted attributes from the item + :param item: + :return: + """ + if self.projection: + if self.projection.get("ProjectionType", None) == "KEYS_ONLY": + allowed_attributes = ",".join( + [key["AttributeName"] for key in self.schema] + ) + item.filter(allowed_attributes) + return item + class Table(BaseModel): def __init__( @@ -719,6 +734,10 @@ class Table(BaseModel): results = [item for item in results if filter_expression.expr(item)] results = copy.deepcopy(results) + if index_name: + index = self.get_index(index_name) + for result in results: + index.project(result) if projection_expression: for result in results: result.filter(projection_expression) @@ -739,11 +758,16 @@ class Table(BaseModel): def all_indexes(self): return (self.global_indexes or []) + (self.indexes or []) - def has_idx_items(self, index_name): - + def get_index(self, index_name, err=None): all_indexes = self.all_indexes() indexes_by_name = dict((i.name, i) for i in all_indexes) - idx = indexes_by_name[index_name] + if err and index_name not in indexes_by_name: + raise err + return indexes_by_name[index_name] + + def has_idx_items(self, index_name): + + idx = self.get_index(index_name) idx_col_set = set([i["AttributeName"] for i in idx.schema]) for hash_set in self.items.values(): @@ -766,14 +790,12 @@ class Table(BaseModel): ): results = [] scanned_count = 0 - all_indexes = self.all_indexes() - indexes_by_name = dict((i.name, i) for i in all_indexes) if index_name: - if index_name not in indexes_by_name: - raise InvalidIndexNameError( - "The table does not have the specified index: %s" % index_name - ) + err = InvalidIndexNameError( + "The table does not have the specified index: %s" % index_name + ) + self.get_index(index_name, err) items = self.has_idx_items(index_name) else: items = self.all_items() @@ -847,9 +869,7 @@ class Table(BaseModel): last_evaluated_key[self.range_key_attr] = results[-1].range_key if scanned_index: - all_indexes = self.all_indexes() - indexes_by_name = dict((i.name, i) for i in all_indexes) - idx = indexes_by_name[scanned_index] + idx = self.get_index(scanned_index) idx_col_list = [i["AttributeName"] for i in idx.schema] for col in idx_col_list: last_evaluated_key[col] = results[-1].attrs[col] diff --git a/tests/test_dynamodb2/test_dynamodb.py b/tests/test_dynamodb2/test_dynamodb.py index 370999116..cf1548e03 100644 --- a/tests/test_dynamodb2/test_dynamodb.py +++ b/tests/test_dynamodb2/test_dynamodb.py @@ -5316,3 +5316,47 @@ def test_transact_write_items_fails_with_transaction_canceled_exception(): ex.exception.response["Error"]["Message"].should.equal( "Transaction cancelled, please refer cancellation reasons for specific reasons [None, ConditionalCheckFailed]" ) + + +@mock_dynamodb2 +def test_gsi_projection_type_keys_only(): + table_schema = { + "KeySchema": [{"AttributeName": "partitionKey", "KeyType": "HASH"}], + "GlobalSecondaryIndexes": [ + { + "IndexName": "GSI-K1", + "KeySchema": [ + {"AttributeName": "gsiK1PartitionKey", "KeyType": "HASH"}, + {"AttributeName": "gsiK1SortKey", "KeyType": "RANGE"}, + ], + "Projection": {"ProjectionType": "KEYS_ONLY",}, + } + ], + "AttributeDefinitions": [ + {"AttributeName": "partitionKey", "AttributeType": "S"}, + {"AttributeName": "gsiK1PartitionKey", "AttributeType": "S"}, + {"AttributeName": "gsiK1SortKey", "AttributeType": "S"}, + ], + } + + item = { + "partitionKey": "pk-1", + "gsiK1PartitionKey": "gsi-pk", + "gsiK1SortKey": "gsi-sk", + "someAttribute": "lore ipsum", + } + + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") + dynamodb.create_table( + TableName="test-table", BillingMode="PAY_PER_REQUEST", **table_schema + ) + table = dynamodb.Table("test-table") + table.put_item(Item=item) + + items = table.query( + KeyConditionExpression=Key("gsiK1PartitionKey").eq("gsi-pk"), + IndexName="GSI-K1", + )["Items"] + items.should.have.length_of(1) + # Item should only include GSI Keys, as per the ProjectionType + items[0].should.equal({"gsiK1PartitionKey": "gsi-pk", "gsiK1SortKey": "gsi-sk"})