From ad63e3966bbcb65dc1a7c9ef1c90b31c6d6100f4 Mon Sep 17 00:00:00 2001 From: MartinAltmayerTMH <141163644+MartinAltmayerTMH@users.noreply.github.com> Date: Fri, 9 Feb 2024 23:21:39 +0100 Subject: [PATCH] Support tags in Glue create_database (#7317) --- moto/glue/models.py | 9 +++++++-- moto/glue/responses.py | 2 +- tests/test_glue/helpers.py | 6 +++++- tests/test_glue/test_datacatalog.py | 17 +++++++++++++++++ 4 files changed, 30 insertions(+), 4 deletions(-) diff --git a/moto/glue/models.py b/moto/glue/models.py index c350d8c32..74155237b 100644 --- a/moto/glue/models.py +++ b/moto/glue/models.py @@ -123,7 +123,10 @@ class GlueBackend(BaseBackend): ) def create_database( - self, database_name: str, database_input: Dict[str, Any] + self, + database_name: str, + database_input: Dict[str, Any], + tags: Optional[Dict[str, str]] = None, ) -> "FakeDatabase": if database_name in self.databases: raise DatabaseAlreadyExistsException() @@ -132,6 +135,8 @@ class GlueBackend(BaseBackend): database_name, database_input, catalog_id=self.account_id ) self.databases[database_name] = database + resource_arn = f"arn:aws:glue:{self.region_name}:{self.account_id}:database/{database_name}" + self.tag_resource(resource_arn, tags) return database def get_database(self, database_name: str) -> "FakeDatabase": @@ -429,7 +434,7 @@ class GlueBackend(BaseBackend): def get_tags(self, resource_id: str) -> Dict[str, str]: return self.tagger.get_tag_dict_for_resource(resource_id) - def tag_resource(self, resource_arn: str, tags: Dict[str, str]) -> None: + def tag_resource(self, resource_arn: str, tags: Optional[Dict[str, str]]) -> None: tag_list = TaggingService.convert_dict_to_tags_input(tags or {}) self.tagger.tag_resource(resource_arn, tag_list) diff --git a/moto/glue/responses.py b/moto/glue/responses.py index a47ac7012..78a23a0b8 100644 --- a/moto/glue/responses.py +++ b/moto/glue/responses.py @@ -31,7 +31,7 @@ class GlueResponse(BaseResponse): database_name = database_input.get("Name") # type: ignore if "CatalogId" in self.parameters: database_input["CatalogId"] = self.parameters.get("CatalogId") # type: ignore - self.glue_backend.create_database(database_name, database_input) # type: ignore[arg-type] + self.glue_backend.create_database(database_name, database_input, self.parameters.get("Tags")) # type: ignore[arg-type] return "" def get_database(self) -> str: diff --git a/tests/test_glue/helpers.py b/tests/test_glue/helpers.py index e900a5195..0ef161601 100644 --- a/tests/test_glue/helpers.py +++ b/tests/test_glue/helpers.py @@ -19,13 +19,17 @@ def create_database_input(database_name): return database_input -def create_database(client, database_name, database_input=None, catalog_id=None): +def create_database( + client, database_name, database_input=None, catalog_id=None, tags=None +): if database_input is None: database_input = create_database_input(database_name) database_kwargs = {"DatabaseInput": database_input} if catalog_id is not None: database_kwargs["CatalogId"] = catalog_id + if tags is not None: + database_kwargs["Tags"] = tags return client.create_database(**database_kwargs) diff --git a/tests/test_glue/test_datacatalog.py b/tests/test_glue/test_datacatalog.py index 739e44513..c3efd2c98 100644 --- a/tests/test_glue/test_datacatalog.py +++ b/tests/test_glue/test_datacatalog.py @@ -40,6 +40,23 @@ def test_create_database(): assert database.get("CatalogId") == database_catalog_id +@mock_aws +def test_create_database_with_tags(): + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + database_catalog_id = ACCOUNT_ID + database_input = helpers.create_database_input(database_name) + database_tags = {"key": "value"} + helpers.create_database( + client, database_name, database_input, database_catalog_id, tags=database_tags + ) + + response = client.get_tags( + ResourceArn=f"arn:aws:glue:us-east-1:{ACCOUNT_ID}:database/{database_name}" + ) + assert response["Tags"] == database_tags + + @mock_aws def test_create_database_already_exists(): client = boto3.client("glue", region_name="us-east-1")