moto/tests/test_glue/helpers.py

196 lines
6.4 KiB
Python

import copy
from .fixtures.datacatalog import DATABASE_INPUT, PARTITION_INPUT, TABLE_INPUT
from .fixtures.schema_registry import (
TEST_AVRO_DATA_FORMAT,
TEST_AVRO_SCHEMA_DEFINITION,
TEST_BACKWARD_COMPATIBILITY,
TEST_NEW_AVRO_SCHEMA_DEFINITION,
TEST_REGISTRY_NAME,
TEST_SCHEMA_ID,
TEST_SCHEMA_NAME,
)
def create_database_input(database_name):
database_input = copy.deepcopy(DATABASE_INPUT)
database_input["Name"] = database_name
database_input["LocationUri"] = f"s3://my-bucket/{database_name}"
return database_input
def create_database(
client, database_name, database_input=None, catalog_id=None, tags=None
):
if database_input is None:
database_input = create_database_input(database_name)
database_kwargs = {"DatabaseInput": database_input}
if catalog_id is not None:
database_kwargs["CatalogId"] = catalog_id
if tags is not None:
database_kwargs["Tags"] = tags
return client.create_database(**database_kwargs)
def get_database(client, database_name):
return client.get_database(Name=database_name)
def create_table_input(database_name, table_name, columns=None, partition_keys=None):
table_input = copy.deepcopy(TABLE_INPUT)
table_input["Name"] = table_name
table_input["PartitionKeys"] = partition_keys or []
table_input["StorageDescriptor"]["Columns"] = columns or []
table_input["StorageDescriptor"]["Location"] = (
f"s3://my-bucket/{database_name}/{table_name}"
)
return table_input
def create_table(client, database_name, table_name, table_input=None, **kwargs):
if table_input is None:
table_input = create_table_input(database_name, table_name, **kwargs)
return client.create_table(DatabaseName=database_name, TableInput=table_input)
def update_table(client, database_name, table_name, table_input=None, **kwargs):
if table_input is None:
table_input = create_table_input(database_name, table_name, **kwargs)
return client.update_table(DatabaseName=database_name, TableInput=table_input)
def get_table(client, database_name, table_name):
return client.get_table(DatabaseName=database_name, Name=table_name)
def get_tables(client, database_name, expression=None):
if expression:
return client.get_tables(DatabaseName=database_name, Expression=expression)
else:
return client.get_tables(DatabaseName=database_name)
def get_table_versions(client, database_name, table_name):
return client.get_table_versions(DatabaseName=database_name, TableName=table_name)
def get_table_version(client, database_name, table_name, version_id):
return client.get_table_version(
DatabaseName=database_name, TableName=table_name, VersionId=version_id
)
def create_column(name, type_, comment=None, parameters=None):
column = {"Name": name, "Type": type_}
if comment is not None:
column["Comment"] = comment
if parameters is not None:
column["Parameters"] = parameters
return column
def create_partition_input(database_name, table_name, values=None, columns=None):
root_path = f"s3://my-bucket/{database_name}/{table_name}"
part_input = copy.deepcopy(PARTITION_INPUT)
part_input["Values"] = values or []
part_input["StorageDescriptor"]["Columns"] = columns or []
part_input["StorageDescriptor"]["SerdeInfo"]["Parameters"]["path"] = root_path
return part_input
def create_partition(client, database_name, table_name, partiton_input=None, **kwargs):
if partiton_input is None:
partiton_input = create_partition_input(database_name, table_name, **kwargs)
return client.create_partition(
DatabaseName=database_name, TableName=table_name, PartitionInput=partiton_input
)
def update_partition(
client, database_name, table_name, old_values=None, partiton_input=None, **kwargs
):
if partiton_input is None:
partiton_input = create_partition_input(database_name, table_name, **kwargs)
return client.update_partition(
DatabaseName=database_name,
TableName=table_name,
PartitionInput=partiton_input,
PartitionValueList=old_values or [],
)
def get_partition(client, database_name, table_name, values):
return client.get_partition(
DatabaseName=database_name, TableName=table_name, PartitionValues=values
)
def create_crawler(
client, crawler_name, crawler_role=None, crawler_targets=None, **kwargs
):
optional_param_map = {
"database_name": "DatabaseName",
"description": "Description",
"schedule": "Schedule",
"classifiers": "Classifiers",
"table_prefix": "TablePrefix",
"schema_change_policy": "SchemaChangePolicy",
"recrawl_policy": "RecrawlPolicy",
"lineage_configuration": "LineageConfiguration",
"configuration": "Configuration",
"crawler_security_configuration": "CrawlerSecurityConfiguration",
"tags": "Tags",
}
params = {
boto3_key: kwargs.get(key)
for key, boto3_key in optional_param_map.items()
if kwargs.get(key) is not None
}
if crawler_role is None:
crawler_role = "arn:aws:iam::123456789012:role/Glue/Role"
if crawler_targets is None:
crawler_targets = {
"S3Targets": [],
"JdbcTargets": [],
"MongoDBTargets": [],
"DynamoDBTargets": [],
"CatalogTargets": [],
}
return client.create_crawler(
Name=crawler_name, Role=crawler_role, Targets=crawler_targets, **params
)
def create_registry(client, registry_name=TEST_REGISTRY_NAME):
return client.create_registry(RegistryName=registry_name)
def create_schema(
client,
registry_id,
data_format=TEST_AVRO_DATA_FORMAT,
compatibility=TEST_BACKWARD_COMPATIBILITY,
schema_definition=TEST_AVRO_SCHEMA_DEFINITION,
):
return client.create_schema(
RegistryId=registry_id,
SchemaName=TEST_SCHEMA_NAME,
DataFormat=data_format,
Compatibility=compatibility,
SchemaDefinition=schema_definition,
)
def register_schema_version(client):
return client.register_schema_version(
SchemaId=TEST_SCHEMA_ID, SchemaDefinition=TEST_NEW_AVRO_SCHEMA_DEFINITION
)