Support tags in Glue create_database (#7317)
This commit is contained in:
parent
b98c17552d
commit
ad63e3966b
@ -123,7 +123,10 @@ class GlueBackend(BaseBackend):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def create_database(
|
def create_database(
|
||||||
self, database_name: str, database_input: Dict[str, Any]
|
self,
|
||||||
|
database_name: str,
|
||||||
|
database_input: Dict[str, Any],
|
||||||
|
tags: Optional[Dict[str, str]] = None,
|
||||||
) -> "FakeDatabase":
|
) -> "FakeDatabase":
|
||||||
if database_name in self.databases:
|
if database_name in self.databases:
|
||||||
raise DatabaseAlreadyExistsException()
|
raise DatabaseAlreadyExistsException()
|
||||||
@ -132,6 +135,8 @@ class GlueBackend(BaseBackend):
|
|||||||
database_name, database_input, catalog_id=self.account_id
|
database_name, database_input, catalog_id=self.account_id
|
||||||
)
|
)
|
||||||
self.databases[database_name] = database
|
self.databases[database_name] = database
|
||||||
|
resource_arn = f"arn:aws:glue:{self.region_name}:{self.account_id}:database/{database_name}"
|
||||||
|
self.tag_resource(resource_arn, tags)
|
||||||
return database
|
return database
|
||||||
|
|
||||||
def get_database(self, database_name: str) -> "FakeDatabase":
|
def get_database(self, database_name: str) -> "FakeDatabase":
|
||||||
@ -429,7 +434,7 @@ class GlueBackend(BaseBackend):
|
|||||||
def get_tags(self, resource_id: str) -> Dict[str, str]:
|
def get_tags(self, resource_id: str) -> Dict[str, str]:
|
||||||
return self.tagger.get_tag_dict_for_resource(resource_id)
|
return self.tagger.get_tag_dict_for_resource(resource_id)
|
||||||
|
|
||||||
def tag_resource(self, resource_arn: str, tags: Dict[str, str]) -> None:
|
def tag_resource(self, resource_arn: str, tags: Optional[Dict[str, str]]) -> None:
|
||||||
tag_list = TaggingService.convert_dict_to_tags_input(tags or {})
|
tag_list = TaggingService.convert_dict_to_tags_input(tags or {})
|
||||||
self.tagger.tag_resource(resource_arn, tag_list)
|
self.tagger.tag_resource(resource_arn, tag_list)
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ class GlueResponse(BaseResponse):
|
|||||||
database_name = database_input.get("Name") # type: ignore
|
database_name = database_input.get("Name") # type: ignore
|
||||||
if "CatalogId" in self.parameters:
|
if "CatalogId" in self.parameters:
|
||||||
database_input["CatalogId"] = self.parameters.get("CatalogId") # type: ignore
|
database_input["CatalogId"] = self.parameters.get("CatalogId") # type: ignore
|
||||||
self.glue_backend.create_database(database_name, database_input) # type: ignore[arg-type]
|
self.glue_backend.create_database(database_name, database_input, self.parameters.get("Tags")) # type: ignore[arg-type]
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def get_database(self) -> str:
|
def get_database(self) -> str:
|
||||||
|
@ -19,13 +19,17 @@ def create_database_input(database_name):
|
|||||||
return database_input
|
return database_input
|
||||||
|
|
||||||
|
|
||||||
def create_database(client, database_name, database_input=None, catalog_id=None):
|
def create_database(
|
||||||
|
client, database_name, database_input=None, catalog_id=None, tags=None
|
||||||
|
):
|
||||||
if database_input is None:
|
if database_input is None:
|
||||||
database_input = create_database_input(database_name)
|
database_input = create_database_input(database_name)
|
||||||
|
|
||||||
database_kwargs = {"DatabaseInput": database_input}
|
database_kwargs = {"DatabaseInput": database_input}
|
||||||
if catalog_id is not None:
|
if catalog_id is not None:
|
||||||
database_kwargs["CatalogId"] = catalog_id
|
database_kwargs["CatalogId"] = catalog_id
|
||||||
|
if tags is not None:
|
||||||
|
database_kwargs["Tags"] = tags
|
||||||
return client.create_database(**database_kwargs)
|
return client.create_database(**database_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,6 +40,23 @@ def test_create_database():
|
|||||||
assert database.get("CatalogId") == database_catalog_id
|
assert database.get("CatalogId") == database_catalog_id
|
||||||
|
|
||||||
|
|
||||||
|
@mock_aws
|
||||||
|
def test_create_database_with_tags():
|
||||||
|
client = boto3.client("glue", region_name="us-east-1")
|
||||||
|
database_name = "myspecialdatabase"
|
||||||
|
database_catalog_id = ACCOUNT_ID
|
||||||
|
database_input = helpers.create_database_input(database_name)
|
||||||
|
database_tags = {"key": "value"}
|
||||||
|
helpers.create_database(
|
||||||
|
client, database_name, database_input, database_catalog_id, tags=database_tags
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.get_tags(
|
||||||
|
ResourceArn=f"arn:aws:glue:us-east-1:{ACCOUNT_ID}:database/{database_name}"
|
||||||
|
)
|
||||||
|
assert response["Tags"] == database_tags
|
||||||
|
|
||||||
|
|
||||||
@mock_aws
|
@mock_aws
|
||||||
def test_create_database_already_exists():
|
def test_create_database_already_exists():
|
||||||
client = boto3.client("glue", region_name="us-east-1")
|
client = boto3.client("glue", region_name="us-east-1")
|
||||||
|
Loading…
Reference in New Issue
Block a user