Glue: Default CatalogId should be the AccountID (#6864)

This commit is contained in:
Bert Blommers 2023-09-29 12:08:22 +00:00 committed by GitHub
parent 5563e62f21
commit 06982582b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 10 deletions

View File

@ -131,7 +131,9 @@ class GlueBackend(BaseBackend):
if database_name in self.databases: if database_name in self.databases:
raise DatabaseAlreadyExistsException() raise DatabaseAlreadyExistsException()
database = FakeDatabase(database_name, database_input) database = FakeDatabase(
database_name, database_input, catalog_id=self.account_id
)
self.databases[database_name] = database self.databases[database_name] = database
return database return database
@ -165,7 +167,9 @@ class GlueBackend(BaseBackend):
if table_name in database.tables: if table_name in database.tables:
raise TableAlreadyExistsException() raise TableAlreadyExistsException()
table = FakeTable(database_name, table_name, table_input) table = FakeTable(
database_name, table_name, table_input, catalog_id=self.account_id
)
database.tables[table_name] = table database.tables[table_name] = table
return table return table
@ -1041,9 +1045,12 @@ class GlueBackend(BaseBackend):
class FakeDatabase(BaseModel): class FakeDatabase(BaseModel):
def __init__(self, database_name: str, database_input: Dict[str, Any]): def __init__(
self, database_name: str, database_input: Dict[str, Any], catalog_id: str
):
self.name = database_name self.name = database_name
self.input = database_input self.input = database_input
self.catalog_id = catalog_id
self.created_time = utcnow() self.created_time = utcnow()
self.tables: Dict[str, FakeTable] = OrderedDict() self.tables: Dict[str, FakeTable] = OrderedDict()
@ -1058,16 +1065,21 @@ class FakeDatabase(BaseModel):
"CreateTableDefaultPermissions" "CreateTableDefaultPermissions"
), ),
"TargetDatabase": self.input.get("TargetDatabase"), "TargetDatabase": self.input.get("TargetDatabase"),
"CatalogId": self.input.get("CatalogId"), "CatalogId": self.input.get("CatalogId") or self.catalog_id,
} }
class FakeTable(BaseModel): class FakeTable(BaseModel):
def __init__( def __init__(
self, database_name: str, table_name: str, table_input: Dict[str, Any] self,
database_name: str,
table_name: str,
table_input: Dict[str, Any],
catalog_id: str,
): ):
self.database_name = database_name self.database_name = database_name
self.name = table_name self.name = table_name
self.catalog_id = catalog_id
self.partitions: Dict[str, FakePartition] = OrderedDict() self.partitions: Dict[str, FakePartition] = OrderedDict()
self.created_time = utcnow() self.created_time = utcnow()
self.updated_time: Optional[datetime] = None self.updated_time: Optional[datetime] = None
@ -1104,6 +1116,7 @@ class FakeTable(BaseModel):
**self.get_version(str(version)), **self.get_version(str(version)),
# Add VersionId after we get the version-details, just to make sure that it's a valid version (int) # Add VersionId after we get the version-details, just to make sure that it's a valid version (int)
"VersionId": str(version), "VersionId": str(version),
"CatalogId": self.catalog_id,
} }
if self.updated_time is not None: if self.updated_time is not None:
obj["UpdateTime"] = unix_time(self.updated_time) obj["UpdateTime"] = unix_time(self.updated_time)

View File

@ -27,7 +27,8 @@ def test_create_database():
response = helpers.get_database(client, database_name) response = helpers.get_database(client, database_name)
database = response["Database"] database = response["Database"]
assert database.get("Name") == database_name assert database["Name"] == database_name
assert database["CatalogId"] == ACCOUNT_ID
assert database.get("Description") == database_input.get("Description") assert database.get("Description") == database_input.get("Description")
assert database.get("LocationUri") == database_input.get("LocationUri") assert database.get("LocationUri") == database_input.get("LocationUri")
assert database.get("Parameters") == database_input.get("Parameters") assert database.get("Parameters") == database_input.get("Parameters")
@ -67,14 +68,11 @@ def test_get_database_not_exits():
@mock_glue @mock_glue
def test_get_databases_empty(): def test_get_databases():
client = boto3.client("glue", region_name="us-east-1") client = boto3.client("glue", region_name="us-east-1")
response = client.get_databases() response = client.get_databases()
assert len(response["DatabaseList"]) == 0 assert len(response["DatabaseList"]) == 0
@mock_glue
def test_get_databases_several_items():
client = boto3.client("glue", region_name="us-east-1") client = boto3.client("glue", region_name="us-east-1")
database_name_1, database_name_2 = "firstdatabase", "seconddatabase" database_name_1, database_name_2 = "firstdatabase", "seconddatabase"
@ -86,7 +84,9 @@ def test_get_databases_several_items():
) )
assert len(database_list) == 2 assert len(database_list) == 2
assert database_list[0]["Name"] == database_name_1 assert database_list[0]["Name"] == database_name_1
assert database_list[0]["CatalogId"] == ACCOUNT_ID
assert database_list[1]["Name"] == database_name_2 assert database_list[1]["Name"] == database_name_2
assert database_list[1]["CatalogId"] == ACCOUNT_ID
@mock_glue @mock_glue
@ -222,6 +222,7 @@ def test_get_tables():
table["StorageDescriptor"] == table_inputs[table_name]["StorageDescriptor"] table["StorageDescriptor"] == table_inputs[table_name]["StorageDescriptor"]
) )
assert table["PartitionKeys"] == table_inputs[table_name]["PartitionKeys"] assert table["PartitionKeys"] == table_inputs[table_name]["PartitionKeys"]
assert table["CatalogId"] == ACCOUNT_ID
@mock_glue @mock_glue
@ -319,6 +320,7 @@ def test_get_table_versions():
table = client.get_table(DatabaseName=database_name, Name=table_name)["Table"] table = client.get_table(DatabaseName=database_name, Name=table_name)["Table"]
assert table["StorageDescriptor"]["Columns"] == [] assert table["StorageDescriptor"]["Columns"] == []
assert table["VersionId"] == "1" assert table["VersionId"] == "1"
assert table["CatalogId"] == ACCOUNT_ID
columns = [{"Name": "country", "Type": "string"}] columns = [{"Name": "country", "Type": "string"}]
table_input = helpers.create_table_input(database_name, table_name, columns=columns) table_input = helpers.create_table_input(database_name, table_name, columns=columns)