Glue: Default CatalogId should be the AccountID (#6864)

This commit is contained in:
Bert Blommers 2023-09-29 12:08:22 +00:00 committed by GitHub
parent 5563e62f21
commit 06982582b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 10 deletions

View File

@ -131,7 +131,9 @@ class GlueBackend(BaseBackend):
if database_name in self.databases:
raise DatabaseAlreadyExistsException()
database = FakeDatabase(database_name, database_input)
database = FakeDatabase(
database_name, database_input, catalog_id=self.account_id
)
self.databases[database_name] = database
return database
@ -165,7 +167,9 @@ class GlueBackend(BaseBackend):
if table_name in database.tables:
raise TableAlreadyExistsException()
table = FakeTable(database_name, table_name, table_input)
table = FakeTable(
database_name, table_name, table_input, catalog_id=self.account_id
)
database.tables[table_name] = table
return table
@ -1041,9 +1045,12 @@ class GlueBackend(BaseBackend):
class FakeDatabase(BaseModel):
def __init__(self, database_name: str, database_input: Dict[str, Any]):
def __init__(
self, database_name: str, database_input: Dict[str, Any], catalog_id: str
):
self.name = database_name
self.input = database_input
self.catalog_id = catalog_id
self.created_time = utcnow()
self.tables: Dict[str, FakeTable] = OrderedDict()
@ -1058,16 +1065,21 @@ class FakeDatabase(BaseModel):
"CreateTableDefaultPermissions"
),
"TargetDatabase": self.input.get("TargetDatabase"),
"CatalogId": self.input.get("CatalogId"),
"CatalogId": self.input.get("CatalogId") or self.catalog_id,
}
class FakeTable(BaseModel):
def __init__(
self, database_name: str, table_name: str, table_input: Dict[str, Any]
self,
database_name: str,
table_name: str,
table_input: Dict[str, Any],
catalog_id: str,
):
self.database_name = database_name
self.name = table_name
self.catalog_id = catalog_id
self.partitions: Dict[str, FakePartition] = OrderedDict()
self.created_time = utcnow()
self.updated_time: Optional[datetime] = None
@ -1104,6 +1116,7 @@ class FakeTable(BaseModel):
**self.get_version(str(version)),
# Add VersionId after we get the version-details, just to make sure that it's a valid version (int)
"VersionId": str(version),
"CatalogId": self.catalog_id,
}
if self.updated_time is not None:
obj["UpdateTime"] = unix_time(self.updated_time)

View File

@ -27,7 +27,8 @@ def test_create_database():
response = helpers.get_database(client, database_name)
database = response["Database"]
assert database.get("Name") == database_name
assert database["Name"] == database_name
assert database["CatalogId"] == ACCOUNT_ID
assert database.get("Description") == database_input.get("Description")
assert database.get("LocationUri") == database_input.get("LocationUri")
assert database.get("Parameters") == database_input.get("Parameters")
@ -67,14 +68,11 @@ def test_get_database_not_exits():
@mock_glue
def test_get_databases_empty():
def test_get_databases():
client = boto3.client("glue", region_name="us-east-1")
response = client.get_databases()
assert len(response["DatabaseList"]) == 0
@mock_glue
def test_get_databases_several_items():
client = boto3.client("glue", region_name="us-east-1")
database_name_1, database_name_2 = "firstdatabase", "seconddatabase"
@ -86,7 +84,9 @@ def test_get_databases_several_items():
)
assert len(database_list) == 2
assert database_list[0]["Name"] == database_name_1
assert database_list[0]["CatalogId"] == ACCOUNT_ID
assert database_list[1]["Name"] == database_name_2
assert database_list[1]["CatalogId"] == ACCOUNT_ID
@mock_glue
@ -222,6 +222,7 @@ def test_get_tables():
table["StorageDescriptor"] == table_inputs[table_name]["StorageDescriptor"]
)
assert table["PartitionKeys"] == table_inputs[table_name]["PartitionKeys"]
assert table["CatalogId"] == ACCOUNT_ID
@mock_glue
@ -319,6 +320,7 @@ def test_get_table_versions():
table = client.get_table(DatabaseName=database_name, Name=table_name)["Table"]
assert table["StorageDescriptor"]["Columns"] == []
assert table["VersionId"] == "1"
assert table["CatalogId"] == ACCOUNT_ID
columns = [{"Name": "country", "Type": "string"}]
table_input = helpers.create_table_input(database_name, table_name, columns=columns)