Glue: Default CatalogId should be the AccountID (#6864)
This commit is contained in:
parent
5563e62f21
commit
06982582b7
@ -131,7 +131,9 @@ class GlueBackend(BaseBackend):
|
||||
if database_name in self.databases:
|
||||
raise DatabaseAlreadyExistsException()
|
||||
|
||||
database = FakeDatabase(database_name, database_input)
|
||||
database = FakeDatabase(
|
||||
database_name, database_input, catalog_id=self.account_id
|
||||
)
|
||||
self.databases[database_name] = database
|
||||
return database
|
||||
|
||||
@ -165,7 +167,9 @@ class GlueBackend(BaseBackend):
|
||||
if table_name in database.tables:
|
||||
raise TableAlreadyExistsException()
|
||||
|
||||
table = FakeTable(database_name, table_name, table_input)
|
||||
table = FakeTable(
|
||||
database_name, table_name, table_input, catalog_id=self.account_id
|
||||
)
|
||||
database.tables[table_name] = table
|
||||
return table
|
||||
|
||||
@ -1041,9 +1045,12 @@ class GlueBackend(BaseBackend):
|
||||
|
||||
|
||||
class FakeDatabase(BaseModel):
|
||||
def __init__(self, database_name: str, database_input: Dict[str, Any]):
|
||||
def __init__(
|
||||
self, database_name: str, database_input: Dict[str, Any], catalog_id: str
|
||||
):
|
||||
self.name = database_name
|
||||
self.input = database_input
|
||||
self.catalog_id = catalog_id
|
||||
self.created_time = utcnow()
|
||||
self.tables: Dict[str, FakeTable] = OrderedDict()
|
||||
|
||||
@ -1058,16 +1065,21 @@ class FakeDatabase(BaseModel):
|
||||
"CreateTableDefaultPermissions"
|
||||
),
|
||||
"TargetDatabase": self.input.get("TargetDatabase"),
|
||||
"CatalogId": self.input.get("CatalogId"),
|
||||
"CatalogId": self.input.get("CatalogId") or self.catalog_id,
|
||||
}
|
||||
|
||||
|
||||
class FakeTable(BaseModel):
|
||||
def __init__(
|
||||
self, database_name: str, table_name: str, table_input: Dict[str, Any]
|
||||
self,
|
||||
database_name: str,
|
||||
table_name: str,
|
||||
table_input: Dict[str, Any],
|
||||
catalog_id: str,
|
||||
):
|
||||
self.database_name = database_name
|
||||
self.name = table_name
|
||||
self.catalog_id = catalog_id
|
||||
self.partitions: Dict[str, FakePartition] = OrderedDict()
|
||||
self.created_time = utcnow()
|
||||
self.updated_time: Optional[datetime] = None
|
||||
@ -1104,6 +1116,7 @@ class FakeTable(BaseModel):
|
||||
**self.get_version(str(version)),
|
||||
# Add VersionId after we get the version-details, just to make sure that it's a valid version (int)
|
||||
"VersionId": str(version),
|
||||
"CatalogId": self.catalog_id,
|
||||
}
|
||||
if self.updated_time is not None:
|
||||
obj["UpdateTime"] = unix_time(self.updated_time)
|
||||
|
@ -27,7 +27,8 @@ def test_create_database():
|
||||
response = helpers.get_database(client, database_name)
|
||||
database = response["Database"]
|
||||
|
||||
assert database.get("Name") == database_name
|
||||
assert database["Name"] == database_name
|
||||
assert database["CatalogId"] == ACCOUNT_ID
|
||||
assert database.get("Description") == database_input.get("Description")
|
||||
assert database.get("LocationUri") == database_input.get("LocationUri")
|
||||
assert database.get("Parameters") == database_input.get("Parameters")
|
||||
@ -67,14 +68,11 @@ def test_get_database_not_exits():
|
||||
|
||||
|
||||
@mock_glue
|
||||
def test_get_databases_empty():
|
||||
def test_get_databases():
|
||||
client = boto3.client("glue", region_name="us-east-1")
|
||||
response = client.get_databases()
|
||||
assert len(response["DatabaseList"]) == 0
|
||||
|
||||
|
||||
@mock_glue
|
||||
def test_get_databases_several_items():
|
||||
client = boto3.client("glue", region_name="us-east-1")
|
||||
database_name_1, database_name_2 = "firstdatabase", "seconddatabase"
|
||||
|
||||
@ -86,7 +84,9 @@ def test_get_databases_several_items():
|
||||
)
|
||||
assert len(database_list) == 2
|
||||
assert database_list[0]["Name"] == database_name_1
|
||||
assert database_list[0]["CatalogId"] == ACCOUNT_ID
|
||||
assert database_list[1]["Name"] == database_name_2
|
||||
assert database_list[1]["CatalogId"] == ACCOUNT_ID
|
||||
|
||||
|
||||
@mock_glue
|
||||
@ -222,6 +222,7 @@ def test_get_tables():
|
||||
table["StorageDescriptor"] == table_inputs[table_name]["StorageDescriptor"]
|
||||
)
|
||||
assert table["PartitionKeys"] == table_inputs[table_name]["PartitionKeys"]
|
||||
assert table["CatalogId"] == ACCOUNT_ID
|
||||
|
||||
|
||||
@mock_glue
|
||||
@ -319,6 +320,7 @@ def test_get_table_versions():
|
||||
table = client.get_table(DatabaseName=database_name, Name=table_name)["Table"]
|
||||
assert table["StorageDescriptor"]["Columns"] == []
|
||||
assert table["VersionId"] == "1"
|
||||
assert table["CatalogId"] == ACCOUNT_ID
|
||||
|
||||
columns = [{"Name": "country", "Type": "string"}]
|
||||
table_input = helpers.create_table_input(database_name, table_name, columns=columns)
|
||||
|
Loading…
Reference in New Issue
Block a user