From a2246df1a47d9c98d6139b7179cf01c8cbb44585 Mon Sep 17 00:00:00 2001 From: aaronfowles-emis <69793448+aaronfowles-emis@users.noreply.github.com> Date: Mon, 11 Jan 2021 13:10:18 +0000 Subject: [PATCH] Implement full Database object for Glue get_database() - fix for #3571. (#3572) * implement potential fix for #3571. * freeze_time decorator not used in TEST_SERVER_MODE --- moto/glue/models.py | 23 ++++++++++++++++--- moto/glue/responses.py | 9 ++++---- tests/test_glue/fixtures/datacatalog.py | 13 +++++++++++ tests/test_glue/helpers.py | 17 +++++++++++--- tests/test_glue/test_datacatalog.py | 30 +++++++++++++++++++------ 5 files changed, 75 insertions(+), 17 deletions(-) diff --git a/moto/glue/models.py b/moto/glue/models.py index cf930cfb2..a434628b7 100644 --- a/moto/glue/models.py +++ b/moto/glue/models.py @@ -1,6 +1,7 @@ from __future__ import unicode_literals import time +from datetime import datetime from moto.core import BaseBackend, BaseModel from moto.compat import OrderedDict @@ -20,11 +21,11 @@ class GlueBackend(BaseBackend): def __init__(self): self.databases = OrderedDict() - def create_database(self, database_name): + def create_database(self, database_name, database_input): if database_name in self.databases: raise DatabaseAlreadyExistsException() - database = FakeDatabase(database_name) + database = FakeDatabase(database_name, database_input) self.databases[database_name] = database return database @@ -68,10 +69,26 @@ class GlueBackend(BaseBackend): class FakeDatabase(BaseModel): - def __init__(self, database_name): + def __init__(self, database_name, database_input): self.name = database_name + self.input = database_input + self.created_time = datetime.utcnow() self.tables = OrderedDict() + def as_dict(self): + return { + "Name": self.name, + "Description": self.input.get("Description"), + "LocationUri": self.input.get("LocationUri"), + "Parameters": self.input.get("Parameters"), + "CreateTime": self.created_time.isoformat(), + "CreateTableDefaultPermissions": self.input.get( + "CreateTableDefaultPermissions" + ), + "TargetDatabase": self.input.get("TargetDatabase"), + "CatalogId": self.input.get("CatalogId"), + } + class FakeTable(BaseModel): def __init__(self, database_name, table_name, table_input): diff --git a/moto/glue/responses.py b/moto/glue/responses.py index ba9cee8fc..e3ec08dee 100644 --- a/moto/glue/responses.py +++ b/moto/glue/responses.py @@ -21,19 +21,20 @@ class GlueResponse(BaseResponse): return json.loads(self.body) def create_database(self): - database_name = self.parameters["DatabaseInput"]["Name"] - self.glue_backend.create_database(database_name) + database_input = self.parameters.get("DatabaseInput") + database_name = database_input.get("Name") + self.glue_backend.create_database(database_name, database_input) return "" def get_database(self): database_name = self.parameters.get("Name") database = self.glue_backend.get_database(database_name) - return json.dumps({"Database": {"Name": database.name}}) + return json.dumps({"Database": database.as_dict()}) def get_databases(self): database_list = self.glue_backend.get_databases() return json.dumps( - {"DatabaseList": [{"Name": database.name} for database in database_list]} + {"DatabaseList": [database.as_dict() for database in database_list]} ) def create_table(self): diff --git a/tests/test_glue/fixtures/datacatalog.py b/tests/test_glue/fixtures/datacatalog.py index 11cb30ca9..9e5d8d229 100644 --- a/tests/test_glue/fixtures/datacatalog.py +++ b/tests/test_glue/fixtures/datacatalog.py @@ -53,3 +53,16 @@ PARTITION_INPUT = { # 'TableName': 'source_table', # 'Values': ['2018-06-26'], } + +DATABASE_INPUT = { + "Name": "testdatabase", + "Description": "a testdatabase", + "LocationUri": "", + "Parameters": {}, + "CreateTableDefaultPermissions": [ + { + "Principal": {"DataLakePrincipalIdentifier": "a_fake_owner"}, + "Permissions": ["ALL"], + }, + ], +} diff --git a/tests/test_glue/helpers.py b/tests/test_glue/helpers.py index 9003a1358..b0a602c75 100644 --- a/tests/test_glue/helpers.py +++ b/tests/test_glue/helpers.py @@ -2,11 +2,22 @@ from __future__ import unicode_literals import copy -from .fixtures.datacatalog import TABLE_INPUT, PARTITION_INPUT +from .fixtures.datacatalog import TABLE_INPUT, PARTITION_INPUT, DATABASE_INPUT -def create_database(client, database_name): - return client.create_database(DatabaseInput={"Name": database_name}) +def create_database_input(database_name): + database_input = copy.deepcopy(DATABASE_INPUT) + database_input["Name"] = database_name + database_input["LocationUri"] = "s3://my-bucket/{database_name}".format( + database_name=database_name + ) + return database_input + + +def create_database(client, database_name, database_input=None): + if database_input is None: + database_input = create_database_input(database_name) + return client.create_database(DatabaseInput=database_input) def get_database(client, database_name): diff --git a/tests/test_glue/test_datacatalog.py b/tests/test_glue/test_datacatalog.py index 46ef910b7..62b5cc443 100644 --- a/tests/test_glue/test_datacatalog.py +++ b/tests/test_glue/test_datacatalog.py @@ -9,21 +9,37 @@ from botocore.client import ClientError from datetime import datetime import pytz +from freezegun import freeze_time -from moto import mock_glue +from moto import mock_glue, settings from . import helpers +FROZEN_CREATE_TIME = datetime(2015, 1, 1, 0, 0, 0) + + @mock_glue +@freeze_time(FROZEN_CREATE_TIME) def test_create_database(): client = boto3.client("glue", region_name="us-east-1") database_name = "myspecialdatabase" - helpers.create_database(client, database_name) + database_input = helpers.create_database_input(database_name) + helpers.create_database(client, database_name, database_input) response = helpers.get_database(client, database_name) database = response["Database"] - database.should.equal({"Name": database_name}) + database.get("Name").should.equal(database_name) + database.get("Description").should.equal(database_input.get("Description")) + database.get("LocationUri").should.equal(database_input.get("LocationUri")) + database.get("Parameters").should.equal(database_input.get("Parameters")) + if not settings.TEST_SERVER_MODE: + database.get("CreateTime").should.equal(FROZEN_CREATE_TIME) + database.get("CreateTableDefaultPermissions").should.equal( + database_input.get("CreateTableDefaultPermissions") + ) + database.get("TargetDatabase").should.equal(database_input.get("TargetDatabase")) + database.get("CatalogId").should.equal(database_input.get("CatalogId")) @mock_glue @@ -64,15 +80,15 @@ def test_get_databases_several_items(): client = boto3.client("glue", region_name="us-east-1") database_name_1, database_name_2 = "firstdatabase", "seconddatabase" - helpers.create_database(client, database_name_1) - helpers.create_database(client, database_name_2) + helpers.create_database(client, database_name_1, {"Name": database_name_1}) + helpers.create_database(client, database_name_2, {"Name": database_name_2}) database_list = sorted( client.get_databases()["DatabaseList"], key=lambda x: x["Name"] ) database_list.should.have.length_of(2) - database_list[0].should.equal({"Name": database_name_1}) - database_list[1].should.equal({"Name": database_name_2}) + database_list[0]["Name"].should.equal(database_name_1) + database_list[1]["Name"].should.equal(database_name_2) @mock_glue