diff --git a/moto/dynamodb2/models/__init__.py b/moto/dynamodb2/models/__init__.py index a5277800f..60bc1b2fe 100644 --- a/moto/dynamodb2/models/__init__.py +++ b/moto/dynamodb2/models/__init__.py @@ -272,6 +272,66 @@ class StreamShard(BaseModel): return [i.to_json() for i in self.items[start:end]] +class LocalSecondaryIndex(BaseModel): + def __init__(self, index_name, schema, projection): + self.name = index_name + self.schema = schema + self.projection = projection + + def describe(self): + return { + "IndexName": self.name, + "KeySchema": self.schema, + "Projection": self.projection, + } + + @staticmethod + def create(dct): + return LocalSecondaryIndex( + index_name=dct["IndexName"], + schema=dct["KeySchema"], + projection=dct["Projection"], + ) + + +class GlobalSecondaryIndex(BaseModel): + def __init__( + self, index_name, schema, projection, status="ACTIVE", throughput=None + ): + self.name = index_name + self.schema = schema + self.projection = projection + self.status = status + self.throughput = throughput or { + "ReadCapacityUnits": 0, + "WriteCapacityUnits": 0, + } + + def describe(self): + return { + "IndexName": self.name, + "KeySchema": self.schema, + "Projection": self.projection, + "IndexStatus": self.status, + "ProvisionedThroughput": self.throughput, + } + + @staticmethod + def create(dct): + return GlobalSecondaryIndex( + index_name=dct["IndexName"], + schema=dct["KeySchema"], + projection=dct["Projection"], + throughput=dct.get("ProvisionedThroughput", None), + ) + + def update(self, u): + self.name = u.get("IndexName", self.name) + self.schema = u.get("KeySchema", self.schema) + self.projection = u.get("Projection", self.projection) + self.throughput = u.get("ProvisionedThroughput", self.throughput) + + class Table(BaseModel): def __init__( self, @@ -302,12 +362,13 @@ class Table(BaseModel): else: self.throughput = throughput self.throughput["NumberOfDecreasesToday"] = 0 - self.indexes = indexes - self.global_indexes = global_indexes if global_indexes else [] - for index in self.global_indexes: - index[ - "IndexStatus" - ] = "ACTIVE" # One of 'CREATING'|'UPDATING'|'DELETING'|'ACTIVE' + self.indexes = [ + LocalSecondaryIndex.create(i) for i in (indexes if indexes else []) + ] + self.global_indexes = [ + GlobalSecondaryIndex.create(i) + for i in (global_indexes if global_indexes else []) + ] self.created_at = datetime.datetime.utcnow() self.items = defaultdict(dict) self.table_arn = self._generate_arn(table_name) @@ -374,8 +435,10 @@ class Table(BaseModel): "KeySchema": self.schema, "ItemCount": len(self), "CreationDateTime": unix_time(self.created_at), - "GlobalSecondaryIndexes": [index for index in self.global_indexes], - "LocalSecondaryIndexes": [index for index in self.indexes], + "GlobalSecondaryIndexes": [ + index.describe() for index in self.global_indexes + ], + "LocalSecondaryIndexes": [index.describe() for index in self.indexes], } } if self.stream_specification and self.stream_specification["StreamEnabled"]: @@ -401,7 +464,7 @@ class Table(BaseModel): keys = [self.hash_key_attr] for index in self.global_indexes: hash_key = None - for key in index["KeySchema"]: + for key in index.schema: if key["KeyType"] == "HASH": hash_key = key["AttributeName"] keys.append(hash_key) @@ -412,7 +475,7 @@ class Table(BaseModel): keys = [self.range_key_attr] for index in self.global_indexes: range_key = None - for key in index["KeySchema"]: + for key in index.schema: if key["KeyType"] == "RANGE": range_key = keys.append(key["AttributeName"]) keys.append(range_key) @@ -545,7 +608,7 @@ class Table(BaseModel): if index_name: all_indexes = self.all_indexes() - indexes_by_name = dict((i["IndexName"], i) for i in all_indexes) + indexes_by_name = dict((i.name, i) for i in all_indexes) if index_name not in indexes_by_name: raise ValueError( "Invalid index: %s for table: %s. Available indexes are: %s" @@ -555,14 +618,14 @@ class Table(BaseModel): index = indexes_by_name[index_name] try: index_hash_key = [ - key for key in index["KeySchema"] if key["KeyType"] == "HASH" + key for key in index.schema if key["KeyType"] == "HASH" ][0] except IndexError: - raise ValueError("Missing Hash Key. KeySchema: %s" % index["KeySchema"]) + raise ValueError("Missing Hash Key. KeySchema: %s" % index.name) try: index_range_key = [ - key for key in index["KeySchema"] if key["KeyType"] == "RANGE" + key for key in index.schema if key["KeyType"] == "RANGE" ][0] except IndexError: index_range_key = None @@ -667,9 +730,9 @@ class Table(BaseModel): def has_idx_items(self, index_name): all_indexes = self.all_indexes() - indexes_by_name = dict((i["IndexName"], i) for i in all_indexes) + indexes_by_name = dict((i.name, i) for i in all_indexes) idx = indexes_by_name[index_name] - idx_col_set = set([i["AttributeName"] for i in idx["KeySchema"]]) + idx_col_set = set([i["AttributeName"] for i in idx.schema]) for hash_set in self.items.values(): if self.range_key_attr: @@ -692,7 +755,7 @@ class Table(BaseModel): results = [] scanned_count = 0 all_indexes = self.all_indexes() - indexes_by_name = dict((i["IndexName"], i) for i in all_indexes) + indexes_by_name = dict((i.name, i) for i in all_indexes) if index_name: if index_name not in indexes_by_name: @@ -773,9 +836,9 @@ class Table(BaseModel): if scanned_index: all_indexes = self.all_indexes() - indexes_by_name = dict((i["IndexName"], i) for i in all_indexes) + indexes_by_name = dict((i.name, i) for i in all_indexes) idx = indexes_by_name[scanned_index] - idx_col_list = [i["AttributeName"] for i in idx["KeySchema"]] + idx_col_list = [i["AttributeName"] for i in idx.schema] for col in idx_col_list: last_evaluated_key[col] = results[-1].attrs[col] @@ -885,7 +948,7 @@ class DynamoDBBackend(BaseBackend): def update_table_global_indexes(self, name, global_index_updates): table = self.tables[name] - gsis_by_name = dict((i["IndexName"], i) for i in table.global_indexes) + gsis_by_name = dict((i.name, i) for i in table.global_indexes) for gsi_update in global_index_updates: gsi_to_create = gsi_update.get("Create") gsi_to_update = gsi_update.get("Update") @@ -906,7 +969,7 @@ class DynamoDBBackend(BaseBackend): if index_name not in gsis_by_name: raise ValueError( "Global Secondary Index does not exist, but tried to update: %s" - % gsi_to_update["IndexName"] + % index_name ) gsis_by_name[index_name].update(gsi_to_update) @@ -917,7 +980,9 @@ class DynamoDBBackend(BaseBackend): % gsi_to_create["IndexName"] ) - gsis_by_name[gsi_to_create["IndexName"]] = gsi_to_create + gsis_by_name[gsi_to_create["IndexName"]] = GlobalSecondaryIndex.create( + gsi_to_create + ) # in python 3.6, dict.values() returns a dict_values object, but we expect it to be a list in other # parts of the codebase diff --git a/moto/dynamodb2/responses.py b/moto/dynamodb2/responses.py index aec7c7560..6500a0a63 100644 --- a/moto/dynamodb2/responses.py +++ b/moto/dynamodb2/responses.py @@ -411,7 +411,6 @@ class DynamoHandler(BaseResponse): def query(self): name = self.body["TableName"] - # {u'KeyConditionExpression': u'#n0 = :v0', u'ExpressionAttributeValues': {u':v0': {u'S': u'johndoe'}}, u'ExpressionAttributeNames': {u'#n0': u'username'}} key_condition_expression = self.body.get("KeyConditionExpression") projection_expression = self.body.get("ProjectionExpression") expression_attribute_names = self.body.get("ExpressionAttributeNames", {}) @@ -439,7 +438,7 @@ class DynamoHandler(BaseResponse): index_name = self.body.get("IndexName") if index_name: all_indexes = (table.global_indexes or []) + (table.indexes or []) - indexes_by_name = dict((i["IndexName"], i) for i in all_indexes) + indexes_by_name = dict((i.name, i) for i in all_indexes) if index_name not in indexes_by_name: er = "com.amazonaws.dynamodb.v20120810#ResourceNotFoundException" return self.error( @@ -449,7 +448,7 @@ class DynamoHandler(BaseResponse): ), ) - index = indexes_by_name[index_name]["KeySchema"] + index = indexes_by_name[index_name].schema else: index = table.schema diff --git a/tests/test_dynamodb2/test_dynamodb_table_with_range_key.py b/tests/test_dynamodb2/test_dynamodb_table_with_range_key.py index 33f65d5ec..12e75a73e 100644 --- a/tests/test_dynamodb2/test_dynamodb_table_with_range_key.py +++ b/tests/test_dynamodb2/test_dynamodb_table_with_range_key.py @@ -931,6 +931,83 @@ boto3 """ +@mock_dynamodb2 +def test_boto3_create_table_with_gsi(): + dynamodb = boto3.client("dynamodb", region_name="us-east-1") + + table = dynamodb.create_table( + TableName="users", + KeySchema=[ + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, + ], + AttributeDefinitions=[ + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, + ], + BillingMode="PAY_PER_REQUEST", + GlobalSecondaryIndexes=[ + { + "IndexName": "test_gsi", + "KeySchema": [{"AttributeName": "subject", "KeyType": "HASH"}], + "Projection": {"ProjectionType": "ALL"}, + } + ], + ) + table["TableDescription"]["GlobalSecondaryIndexes"].should.equal( + [ + { + "KeySchema": [{"KeyType": "HASH", "AttributeName": "subject"}], + "IndexName": "test_gsi", + "Projection": {"ProjectionType": "ALL"}, + "IndexStatus": "ACTIVE", + "ProvisionedThroughput": { + "ReadCapacityUnits": 0, + "WriteCapacityUnits": 0, + }, + } + ] + ) + + table = dynamodb.create_table( + TableName="users2", + KeySchema=[ + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, + ], + AttributeDefinitions=[ + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, + ], + BillingMode="PAY_PER_REQUEST", + GlobalSecondaryIndexes=[ + { + "IndexName": "test_gsi", + "KeySchema": [{"AttributeName": "subject", "KeyType": "HASH"}], + "Projection": {"ProjectionType": "ALL"}, + "ProvisionedThroughput": { + "ReadCapacityUnits": 3, + "WriteCapacityUnits": 5, + }, + } + ], + ) + table["TableDescription"]["GlobalSecondaryIndexes"].should.equal( + [ + { + "KeySchema": [{"KeyType": "HASH", "AttributeName": "subject"}], + "IndexName": "test_gsi", + "Projection": {"ProjectionType": "ALL"}, + "IndexStatus": "ACTIVE", + "ProvisionedThroughput": { + "ReadCapacityUnits": 3, + "WriteCapacityUnits": 5, + }, + } + ] + ) + + @mock_dynamodb2 def test_boto3_conditions(): dynamodb = boto3.resource("dynamodb", region_name="us-east-1")