Support tags in Glue create_database (#7317)

This commit is contained in:
MartinAltmayerTMH 2024-02-09 23:21:39 +01:00 committed by GitHub
parent b98c17552d
commit ad63e3966b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 30 additions and 4 deletions

View File

@ -123,7 +123,10 @@ class GlueBackend(BaseBackend):
) )
def create_database( def create_database(
self, database_name: str, database_input: Dict[str, Any] self,
database_name: str,
database_input: Dict[str, Any],
tags: Optional[Dict[str, str]] = None,
) -> "FakeDatabase": ) -> "FakeDatabase":
if database_name in self.databases: if database_name in self.databases:
raise DatabaseAlreadyExistsException() raise DatabaseAlreadyExistsException()
@ -132,6 +135,8 @@ class GlueBackend(BaseBackend):
database_name, database_input, catalog_id=self.account_id database_name, database_input, catalog_id=self.account_id
) )
self.databases[database_name] = database self.databases[database_name] = database
resource_arn = f"arn:aws:glue:{self.region_name}:{self.account_id}:database/{database_name}"
self.tag_resource(resource_arn, tags)
return database return database
def get_database(self, database_name: str) -> "FakeDatabase": def get_database(self, database_name: str) -> "FakeDatabase":
@ -429,7 +434,7 @@ class GlueBackend(BaseBackend):
def get_tags(self, resource_id: str) -> Dict[str, str]: def get_tags(self, resource_id: str) -> Dict[str, str]:
return self.tagger.get_tag_dict_for_resource(resource_id) return self.tagger.get_tag_dict_for_resource(resource_id)
def tag_resource(self, resource_arn: str, tags: Dict[str, str]) -> None: def tag_resource(self, resource_arn: str, tags: Optional[Dict[str, str]]) -> None:
tag_list = TaggingService.convert_dict_to_tags_input(tags or {}) tag_list = TaggingService.convert_dict_to_tags_input(tags or {})
self.tagger.tag_resource(resource_arn, tag_list) self.tagger.tag_resource(resource_arn, tag_list)

View File

@ -31,7 +31,7 @@ class GlueResponse(BaseResponse):
database_name = database_input.get("Name") # type: ignore database_name = database_input.get("Name") # type: ignore
if "CatalogId" in self.parameters: if "CatalogId" in self.parameters:
database_input["CatalogId"] = self.parameters.get("CatalogId") # type: ignore database_input["CatalogId"] = self.parameters.get("CatalogId") # type: ignore
self.glue_backend.create_database(database_name, database_input) # type: ignore[arg-type] self.glue_backend.create_database(database_name, database_input, self.parameters.get("Tags")) # type: ignore[arg-type]
return "" return ""
def get_database(self) -> str: def get_database(self) -> str:

View File

@ -19,13 +19,17 @@ def create_database_input(database_name):
return database_input return database_input
def create_database(client, database_name, database_input=None, catalog_id=None): def create_database(
client, database_name, database_input=None, catalog_id=None, tags=None
):
if database_input is None: if database_input is None:
database_input = create_database_input(database_name) database_input = create_database_input(database_name)
database_kwargs = {"DatabaseInput": database_input} database_kwargs = {"DatabaseInput": database_input}
if catalog_id is not None: if catalog_id is not None:
database_kwargs["CatalogId"] = catalog_id database_kwargs["CatalogId"] = catalog_id
if tags is not None:
database_kwargs["Tags"] = tags
return client.create_database(**database_kwargs) return client.create_database(**database_kwargs)

View File

@ -40,6 +40,23 @@ def test_create_database():
assert database.get("CatalogId") == database_catalog_id assert database.get("CatalogId") == database_catalog_id
@mock_aws
def test_create_database_with_tags():
client = boto3.client("glue", region_name="us-east-1")
database_name = "myspecialdatabase"
database_catalog_id = ACCOUNT_ID
database_input = helpers.create_database_input(database_name)
database_tags = {"key": "value"}
helpers.create_database(
client, database_name, database_input, database_catalog_id, tags=database_tags
)
response = client.get_tags(
ResourceArn=f"arn:aws:glue:us-east-1:{ACCOUNT_ID}:database/{database_name}"
)
assert response["Tags"] == database_tags
@mock_aws @mock_aws
def test_create_database_already_exists(): def test_create_database_already_exists():
client = boto3.client("glue", region_name="us-east-1") client = boto3.client("glue", region_name="us-east-1")