Glue: Default CatalogId should be the AccountID (#6864)
This commit is contained in:
parent
5563e62f21
commit
06982582b7
@ -131,7 +131,9 @@ class GlueBackend(BaseBackend):
|
|||||||
if database_name in self.databases:
|
if database_name in self.databases:
|
||||||
raise DatabaseAlreadyExistsException()
|
raise DatabaseAlreadyExistsException()
|
||||||
|
|
||||||
database = FakeDatabase(database_name, database_input)
|
database = FakeDatabase(
|
||||||
|
database_name, database_input, catalog_id=self.account_id
|
||||||
|
)
|
||||||
self.databases[database_name] = database
|
self.databases[database_name] = database
|
||||||
return database
|
return database
|
||||||
|
|
||||||
@ -165,7 +167,9 @@ class GlueBackend(BaseBackend):
|
|||||||
if table_name in database.tables:
|
if table_name in database.tables:
|
||||||
raise TableAlreadyExistsException()
|
raise TableAlreadyExistsException()
|
||||||
|
|
||||||
table = FakeTable(database_name, table_name, table_input)
|
table = FakeTable(
|
||||||
|
database_name, table_name, table_input, catalog_id=self.account_id
|
||||||
|
)
|
||||||
database.tables[table_name] = table
|
database.tables[table_name] = table
|
||||||
return table
|
return table
|
||||||
|
|
||||||
@ -1041,9 +1045,12 @@ class GlueBackend(BaseBackend):
|
|||||||
|
|
||||||
|
|
||||||
class FakeDatabase(BaseModel):
|
class FakeDatabase(BaseModel):
|
||||||
def __init__(self, database_name: str, database_input: Dict[str, Any]):
|
def __init__(
|
||||||
|
self, database_name: str, database_input: Dict[str, Any], catalog_id: str
|
||||||
|
):
|
||||||
self.name = database_name
|
self.name = database_name
|
||||||
self.input = database_input
|
self.input = database_input
|
||||||
|
self.catalog_id = catalog_id
|
||||||
self.created_time = utcnow()
|
self.created_time = utcnow()
|
||||||
self.tables: Dict[str, FakeTable] = OrderedDict()
|
self.tables: Dict[str, FakeTable] = OrderedDict()
|
||||||
|
|
||||||
@ -1058,16 +1065,21 @@ class FakeDatabase(BaseModel):
|
|||||||
"CreateTableDefaultPermissions"
|
"CreateTableDefaultPermissions"
|
||||||
),
|
),
|
||||||
"TargetDatabase": self.input.get("TargetDatabase"),
|
"TargetDatabase": self.input.get("TargetDatabase"),
|
||||||
"CatalogId": self.input.get("CatalogId"),
|
"CatalogId": self.input.get("CatalogId") or self.catalog_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class FakeTable(BaseModel):
|
class FakeTable(BaseModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, database_name: str, table_name: str, table_input: Dict[str, Any]
|
self,
|
||||||
|
database_name: str,
|
||||||
|
table_name: str,
|
||||||
|
table_input: Dict[str, Any],
|
||||||
|
catalog_id: str,
|
||||||
):
|
):
|
||||||
self.database_name = database_name
|
self.database_name = database_name
|
||||||
self.name = table_name
|
self.name = table_name
|
||||||
|
self.catalog_id = catalog_id
|
||||||
self.partitions: Dict[str, FakePartition] = OrderedDict()
|
self.partitions: Dict[str, FakePartition] = OrderedDict()
|
||||||
self.created_time = utcnow()
|
self.created_time = utcnow()
|
||||||
self.updated_time: Optional[datetime] = None
|
self.updated_time: Optional[datetime] = None
|
||||||
@ -1104,6 +1116,7 @@ class FakeTable(BaseModel):
|
|||||||
**self.get_version(str(version)),
|
**self.get_version(str(version)),
|
||||||
# Add VersionId after we get the version-details, just to make sure that it's a valid version (int)
|
# Add VersionId after we get the version-details, just to make sure that it's a valid version (int)
|
||||||
"VersionId": str(version),
|
"VersionId": str(version),
|
||||||
|
"CatalogId": self.catalog_id,
|
||||||
}
|
}
|
||||||
if self.updated_time is not None:
|
if self.updated_time is not None:
|
||||||
obj["UpdateTime"] = unix_time(self.updated_time)
|
obj["UpdateTime"] = unix_time(self.updated_time)
|
||||||
|
@ -27,7 +27,8 @@ def test_create_database():
|
|||||||
response = helpers.get_database(client, database_name)
|
response = helpers.get_database(client, database_name)
|
||||||
database = response["Database"]
|
database = response["Database"]
|
||||||
|
|
||||||
assert database.get("Name") == database_name
|
assert database["Name"] == database_name
|
||||||
|
assert database["CatalogId"] == ACCOUNT_ID
|
||||||
assert database.get("Description") == database_input.get("Description")
|
assert database.get("Description") == database_input.get("Description")
|
||||||
assert database.get("LocationUri") == database_input.get("LocationUri")
|
assert database.get("LocationUri") == database_input.get("LocationUri")
|
||||||
assert database.get("Parameters") == database_input.get("Parameters")
|
assert database.get("Parameters") == database_input.get("Parameters")
|
||||||
@ -67,14 +68,11 @@ def test_get_database_not_exits():
|
|||||||
|
|
||||||
|
|
||||||
@mock_glue
|
@mock_glue
|
||||||
def test_get_databases_empty():
|
def test_get_databases():
|
||||||
client = boto3.client("glue", region_name="us-east-1")
|
client = boto3.client("glue", region_name="us-east-1")
|
||||||
response = client.get_databases()
|
response = client.get_databases()
|
||||||
assert len(response["DatabaseList"]) == 0
|
assert len(response["DatabaseList"]) == 0
|
||||||
|
|
||||||
|
|
||||||
@mock_glue
|
|
||||||
def test_get_databases_several_items():
|
|
||||||
client = boto3.client("glue", region_name="us-east-1")
|
client = boto3.client("glue", region_name="us-east-1")
|
||||||
database_name_1, database_name_2 = "firstdatabase", "seconddatabase"
|
database_name_1, database_name_2 = "firstdatabase", "seconddatabase"
|
||||||
|
|
||||||
@ -86,7 +84,9 @@ def test_get_databases_several_items():
|
|||||||
)
|
)
|
||||||
assert len(database_list) == 2
|
assert len(database_list) == 2
|
||||||
assert database_list[0]["Name"] == database_name_1
|
assert database_list[0]["Name"] == database_name_1
|
||||||
|
assert database_list[0]["CatalogId"] == ACCOUNT_ID
|
||||||
assert database_list[1]["Name"] == database_name_2
|
assert database_list[1]["Name"] == database_name_2
|
||||||
|
assert database_list[1]["CatalogId"] == ACCOUNT_ID
|
||||||
|
|
||||||
|
|
||||||
@mock_glue
|
@mock_glue
|
||||||
@ -222,6 +222,7 @@ def test_get_tables():
|
|||||||
table["StorageDescriptor"] == table_inputs[table_name]["StorageDescriptor"]
|
table["StorageDescriptor"] == table_inputs[table_name]["StorageDescriptor"]
|
||||||
)
|
)
|
||||||
assert table["PartitionKeys"] == table_inputs[table_name]["PartitionKeys"]
|
assert table["PartitionKeys"] == table_inputs[table_name]["PartitionKeys"]
|
||||||
|
assert table["CatalogId"] == ACCOUNT_ID
|
||||||
|
|
||||||
|
|
||||||
@mock_glue
|
@mock_glue
|
||||||
@ -319,6 +320,7 @@ def test_get_table_versions():
|
|||||||
table = client.get_table(DatabaseName=database_name, Name=table_name)["Table"]
|
table = client.get_table(DatabaseName=database_name, Name=table_name)["Table"]
|
||||||
assert table["StorageDescriptor"]["Columns"] == []
|
assert table["StorageDescriptor"]["Columns"] == []
|
||||||
assert table["VersionId"] == "1"
|
assert table["VersionId"] == "1"
|
||||||
|
assert table["CatalogId"] == ACCOUNT_ID
|
||||||
|
|
||||||
columns = [{"Name": "country", "Type": "string"}]
|
columns = [{"Name": "country", "Type": "string"}]
|
||||||
table_input = helpers.create_table_input(database_name, table_name, columns=columns)
|
table_input = helpers.create_table_input(database_name, table_name, columns=columns)
|
||||||
|
Loading…
Reference in New Issue
Block a user