From a9372c2fbce8b12419bfad203824fae68e980e4c Mon Sep 17 00:00:00 2001 From: tim Date: Wed, 27 Jul 2022 14:38:08 +0200 Subject: [PATCH] Support CatalogId in create_database (#5339) --- moto/glue/responses.py | 2 ++ tests/test_glue/helpers.py | 8 ++++++-- tests/test_glue/test_datacatalog.py | 6 ++++-- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/moto/glue/responses.py b/moto/glue/responses.py index ea71c7c60..8b9266876 100644 --- a/moto/glue/responses.py +++ b/moto/glue/responses.py @@ -21,6 +21,8 @@ class GlueResponse(BaseResponse): def create_database(self): database_input = self.parameters.get("DatabaseInput") database_name = database_input.get("Name") + if "CatalogId" in self.parameters: + database_input["CatalogId"] = self.parameters.get("CatalogId") self.glue_backend.create_database(database_name, database_input) return "" diff --git a/tests/test_glue/helpers.py b/tests/test_glue/helpers.py index 7293b99bf..2e7c2de34 100644 --- a/tests/test_glue/helpers.py +++ b/tests/test_glue/helpers.py @@ -12,10 +12,14 @@ def create_database_input(database_name): return database_input -def create_database(client, database_name, database_input=None): +def create_database(client, database_name, database_input=None, catalog_id=None): if database_input is None: database_input = create_database_input(database_name) - return client.create_database(DatabaseInput=database_input) + + database_kwargs = {"DatabaseInput": database_input} + if catalog_id is not None: + database_kwargs["CatalogId"] = catalog_id + return client.create_database(**database_kwargs) def get_database(client, database_name): diff --git a/tests/test_glue/test_datacatalog.py b/tests/test_glue/test_datacatalog.py index 81c8f13e2..d0ce39741 100644 --- a/tests/test_glue/test_datacatalog.py +++ b/tests/test_glue/test_datacatalog.py @@ -11,6 +11,7 @@ import pytz from freezegun import freeze_time from moto import mock_glue, settings +from moto.core import ACCOUNT_ID from . import helpers @@ -22,8 +23,9 @@ FROZEN_CREATE_TIME = datetime(2015, 1, 1, 0, 0, 0) def test_create_database(): client = boto3.client("glue", region_name="us-east-1") database_name = "myspecialdatabase" + database_catalog_id = ACCOUNT_ID database_input = helpers.create_database_input(database_name) - helpers.create_database(client, database_name, database_input) + helpers.create_database(client, database_name, database_input, database_catalog_id) response = helpers.get_database(client, database_name) database = response["Database"] @@ -38,7 +40,7 @@ def test_create_database(): database_input.get("CreateTableDefaultPermissions") ) database.get("TargetDatabase").should.equal(database_input.get("TargetDatabase")) - database.get("CatalogId").should.equal(database_input.get("CatalogId")) + database.get("CatalogId").should.equal(database_catalog_id) @mock_glue