430 lines
14 KiB
Python
430 lines
14 KiB
Python
import boto3
|
|
import pytest
|
|
from botocore.exceptions import ClientError
|
|
|
|
from moto import mock_aws, settings
|
|
from moto.athena.models import QueryResults, athena_backends
|
|
from moto.core import DEFAULT_ACCOUNT_ID
|
|
|
|
|
|
@mock_aws
|
|
def test_create_work_group():
|
|
client = boto3.client("athena", region_name="us-east-1")
|
|
|
|
client.create_work_group(
|
|
Name="athena_workgroup",
|
|
Description="Test work group",
|
|
Configuration={
|
|
"ResultConfiguration": {
|
|
"OutputLocation": "s3://bucket-name/prefix/",
|
|
"EncryptionConfiguration": {
|
|
"EncryptionOption": "SSE_KMS",
|
|
"KmsKey": "aws:arn:kms:1233456789:us-east-1:key/number-1",
|
|
},
|
|
}
|
|
},
|
|
)
|
|
|
|
with pytest.raises(ClientError) as exc:
|
|
# The second time should throw an error
|
|
client.create_work_group(
|
|
Name="athena_workgroup",
|
|
Description="duplicate",
|
|
Configuration={
|
|
"ResultConfiguration": {
|
|
"OutputLocation": "s3://bucket-name/prefix/",
|
|
"EncryptionConfiguration": {
|
|
"EncryptionOption": "SSE_KMS",
|
|
"KmsKey": "aws:arn:kms:1233456789:us-east-1:key/number-1",
|
|
},
|
|
}
|
|
},
|
|
)
|
|
err = exc.value.response["Error"]
|
|
assert err["Code"] == "InvalidRequestException"
|
|
assert err["Message"] == "WorkGroup already exists"
|
|
|
|
# Then test the work group appears in the work group list
|
|
response = client.list_work_groups()
|
|
|
|
work_groups = list(
|
|
filter(lambda wg: wg["Name"] != "primary", response["WorkGroups"])
|
|
)
|
|
assert len(work_groups) == 1
|
|
work_group = work_groups[0]
|
|
assert work_group["Name"] == "athena_workgroup"
|
|
assert work_group["Description"] == "Test work group"
|
|
assert work_group["State"] == "ENABLED"
|
|
|
|
|
|
@mock_aws
|
|
def test_get_primary_workgroup():
|
|
client = boto3.client("athena", region_name="us-east-1")
|
|
assert len(client.list_work_groups()["WorkGroups"]) == 1
|
|
|
|
primary = client.get_work_group(WorkGroup="primary")["WorkGroup"]
|
|
assert primary["Name"] == "primary"
|
|
assert primary["Configuration"] == {}
|
|
|
|
|
|
@mock_aws
|
|
def test_create_and_get_workgroup():
|
|
client = boto3.client("athena", region_name="us-east-1")
|
|
|
|
create_basic_workgroup(client=client, name="athena_workgroup")
|
|
|
|
work_group = client.get_work_group(WorkGroup="athena_workgroup")["WorkGroup"]
|
|
del work_group["CreationTime"] # Were not testing creationtime atm
|
|
assert work_group == {
|
|
"Name": "athena_workgroup",
|
|
"State": "ENABLED",
|
|
"Configuration": {
|
|
"ResultConfiguration": {"OutputLocation": "s3://bucket-name/prefix/"}
|
|
},
|
|
"Description": "Test work group",
|
|
}
|
|
|
|
|
|
@mock_aws
|
|
def test_start_query_execution():
|
|
client = boto3.client("athena", region_name="us-east-1")
|
|
|
|
create_basic_workgroup(client=client, name="athena_workgroup")
|
|
response = client.start_query_execution(
|
|
QueryString="query1",
|
|
QueryExecutionContext={"Database": "string"},
|
|
ResultConfiguration={"OutputLocation": "string"},
|
|
WorkGroup="athena_workgroup",
|
|
)
|
|
assert "QueryExecutionId" in response
|
|
|
|
sec_response = client.start_query_execution(
|
|
QueryString="query2",
|
|
QueryExecutionContext={"Database": "string"},
|
|
ResultConfiguration={"OutputLocation": "string"},
|
|
)
|
|
assert "QueryExecutionId" in sec_response
|
|
assert response["QueryExecutionId"] != sec_response["QueryExecutionId"]
|
|
|
|
|
|
@mock_aws
|
|
def test_start_query_validate_workgroup():
|
|
client = boto3.client("athena", region_name="us-east-1")
|
|
|
|
with pytest.raises(ClientError) as err:
|
|
client.start_query_execution(
|
|
QueryString="query1",
|
|
QueryExecutionContext={"Database": "string"},
|
|
ResultConfiguration={"OutputLocation": "string"},
|
|
WorkGroup="unknown_workgroup",
|
|
)
|
|
assert err.value.response["Error"]["Code"] == "InvalidRequestException"
|
|
assert err.value.response["Error"]["Message"] == "WorkGroup does not exist"
|
|
|
|
|
|
@mock_aws
|
|
@pytest.mark.parametrize(
|
|
"location", ["s3://bucket-name/prefix/", "s3://bucket-name/prefix_wo_slash"]
|
|
)
|
|
def test_get_query_execution(location):
|
|
client = boto3.client("athena", region_name="us-east-1")
|
|
|
|
query = "SELECT stuff"
|
|
database = "database"
|
|
# Start Query
|
|
exex_id = client.start_query_execution(
|
|
QueryString=query,
|
|
QueryExecutionContext={"Database": database},
|
|
ResultConfiguration={"OutputLocation": location},
|
|
)["QueryExecutionId"]
|
|
#
|
|
details = client.get_query_execution(QueryExecutionId=exex_id)["QueryExecution"]
|
|
#
|
|
assert details["QueryExecutionId"] == exex_id
|
|
assert details["Query"] == query
|
|
assert details["StatementType"] == "DML"
|
|
result_config = details["ResultConfiguration"]
|
|
if location.endswith("/"):
|
|
assert result_config["OutputLocation"] == f"{location}{exex_id}.csv"
|
|
else:
|
|
assert result_config["OutputLocation"] == f"{location}/{exex_id}.csv"
|
|
assert details["QueryExecutionContext"]["Database"] == database
|
|
assert details["Status"]["State"] == "SUCCEEDED"
|
|
assert details["Statistics"] == {
|
|
"EngineExecutionTimeInMillis": 0,
|
|
"DataScannedInBytes": 0,
|
|
"TotalExecutionTimeInMillis": 0,
|
|
"QueryQueueTimeInMillis": 0,
|
|
"QueryPlanningTimeInMillis": 0,
|
|
"ServiceProcessingTimeInMillis": 0,
|
|
}
|
|
assert "WorkGroup" not in details
|
|
|
|
|
|
@mock_aws
|
|
def test_stop_query_execution():
|
|
client = boto3.client("athena", region_name="us-east-1")
|
|
|
|
query = "SELECT stuff"
|
|
location = "s3://bucket-name/prefix/"
|
|
database = "database"
|
|
# Start Query
|
|
exex_id = client.start_query_execution(
|
|
QueryString=query,
|
|
QueryExecutionContext={"Database": database},
|
|
ResultConfiguration={"OutputLocation": location},
|
|
)["QueryExecutionId"]
|
|
# Stop Query
|
|
client.stop_query_execution(QueryExecutionId=exex_id)
|
|
# Verify status
|
|
details = client.get_query_execution(QueryExecutionId=exex_id)["QueryExecution"]
|
|
#
|
|
assert details["QueryExecutionId"] == exex_id
|
|
assert details["Status"]["State"] == "CANCELLED"
|
|
|
|
|
|
@mock_aws
|
|
def test_create_named_query():
|
|
client = boto3.client("athena", region_name="us-east-1")
|
|
|
|
# craete named query
|
|
res = client.create_named_query(
|
|
Name="query-name", Database="target_db", QueryString="SELECT * FROM table1"
|
|
)
|
|
|
|
assert "NamedQueryId" in res
|
|
|
|
|
|
@mock_aws
|
|
def test_get_named_query():
|
|
client = boto3.client("athena", region_name="us-east-1")
|
|
query_name = "query-name"
|
|
database = "target_db"
|
|
query_string = "SELECT * FROM tbl1"
|
|
description = "description of this query"
|
|
# create named query
|
|
res_create = client.create_named_query(
|
|
Name=query_name,
|
|
Database=database,
|
|
QueryString=query_string,
|
|
Description=description,
|
|
)
|
|
query_id = res_create["NamedQueryId"]
|
|
|
|
# get named query
|
|
res_get = client.get_named_query(NamedQueryId=query_id)["NamedQuery"]
|
|
assert res_get["Name"] == query_name
|
|
assert res_get["Description"] == description
|
|
assert res_get["Database"] == database
|
|
assert res_get["QueryString"] == query_string
|
|
assert res_get["NamedQueryId"] == query_id
|
|
|
|
|
|
def create_basic_workgroup(client, name):
|
|
client.create_work_group(
|
|
Name=name,
|
|
Description="Test work group",
|
|
Configuration={
|
|
"ResultConfiguration": {"OutputLocation": "s3://bucket-name/prefix/"}
|
|
},
|
|
)
|
|
|
|
|
|
@mock_aws
|
|
def test_create_data_catalog():
|
|
client = boto3.client("athena", region_name="us-east-1")
|
|
response = client.create_data_catalog(
|
|
Name="athena_datacatalog",
|
|
Type="GLUE",
|
|
Description="Test data catalog",
|
|
Parameters={"catalog-id": "AWS Test account ID"},
|
|
Tags=[],
|
|
)
|
|
|
|
with pytest.raises(ClientError) as exc:
|
|
# The second time should throw an error
|
|
response = client.create_data_catalog(
|
|
Name="athena_datacatalog",
|
|
Type="GLUE",
|
|
Description="Test data catalog",
|
|
Parameters={"catalog-id": "AWS Test account ID"},
|
|
Tags=[],
|
|
)
|
|
err = exc.value.response["Error"]
|
|
assert err["Code"] == "InvalidRequestException"
|
|
assert err["Message"] == "DataCatalog already exists"
|
|
|
|
# Then test the work group appears in the work group list
|
|
response = client.list_data_catalogs()
|
|
|
|
assert len(response["DataCatalogsSummary"]) == 1
|
|
data_catalog = response["DataCatalogsSummary"][0]
|
|
assert data_catalog["CatalogName"] == "athena_datacatalog"
|
|
assert data_catalog["Type"] == "GLUE"
|
|
|
|
|
|
@mock_aws
|
|
def test_create_and_get_data_catalog():
|
|
client = boto3.client("athena", region_name="us-east-1")
|
|
|
|
client.create_data_catalog(
|
|
Name="athena_datacatalog",
|
|
Type="GLUE",
|
|
Description="Test data catalog",
|
|
Parameters={"catalog-id": "AWS Test account ID"},
|
|
Tags=[],
|
|
)
|
|
|
|
data_catalog = client.get_data_catalog(Name="athena_datacatalog")
|
|
assert data_catalog["DataCatalog"] == {
|
|
"Name": "athena_datacatalog",
|
|
"Description": "Test data catalog",
|
|
"Type": "GLUE",
|
|
"Parameters": {"catalog-id": "AWS Test account ID"},
|
|
}
|
|
|
|
|
|
@mock_aws
|
|
def test_get_query_results():
|
|
client = boto3.client("athena", region_name="us-east-1")
|
|
|
|
result = client.get_query_results(QueryExecutionId="test")
|
|
|
|
assert result["ResultSet"]["Rows"] == []
|
|
assert result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"] == []
|
|
|
|
if settings.TEST_DECORATOR_MODE:
|
|
backend = athena_backends[DEFAULT_ACCOUNT_ID]["us-east-1"]
|
|
rows = [{"Data": [{"VarCharValue": ".."}]}]
|
|
column_info = [
|
|
{
|
|
"CatalogName": "string",
|
|
"SchemaName": "string",
|
|
"TableName": "string",
|
|
"Name": "string",
|
|
"Label": "string",
|
|
"Type": "string",
|
|
"Precision": 123,
|
|
"Scale": 123,
|
|
"Nullable": "NOT_NULL",
|
|
"CaseSensitive": True,
|
|
}
|
|
]
|
|
# This was the documented way to configure query results, before `moto-api/static/athena/query_results` was implemented
|
|
# We should try to keep this for backward compatibility
|
|
results = QueryResults(rows=rows, column_info=column_info)
|
|
backend.query_results["test"] = results
|
|
|
|
result = client.get_query_results(QueryExecutionId="test")
|
|
assert result["ResultSet"]["Rows"] == rows
|
|
assert result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"] == column_info
|
|
|
|
|
|
@mock_aws
|
|
def test_get_query_results_queue():
|
|
client = boto3.client("athena", region_name="us-east-1")
|
|
|
|
result = client.get_query_results(QueryExecutionId="test")
|
|
|
|
assert result["ResultSet"]["Rows"] == []
|
|
assert result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"] == []
|
|
|
|
if settings.TEST_DECORATOR_MODE:
|
|
backend = athena_backends[DEFAULT_ACCOUNT_ID]["us-east-1"]
|
|
rows = [{"Data": [{"VarCharValue": ".."}]}]
|
|
column_info = [
|
|
{
|
|
"CatalogName": "string",
|
|
"SchemaName": "string",
|
|
"TableName": "string",
|
|
"Name": "string",
|
|
"Label": "string",
|
|
"Type": "string",
|
|
"Precision": 123,
|
|
"Scale": 123,
|
|
"Nullable": "NOT_NULL",
|
|
"CaseSensitive": True,
|
|
}
|
|
]
|
|
results = QueryResults(rows=rows, column_info=column_info)
|
|
backend.query_results_queue.append(results)
|
|
|
|
result = client.get_query_results(
|
|
QueryExecutionId="some-id-not-used-when-results-were-added-to-queue"
|
|
)
|
|
assert result["ResultSet"]["Rows"] == rows
|
|
assert result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"] == column_info
|
|
|
|
result = client.get_query_results(QueryExecutionId="other-id")
|
|
assert result["ResultSet"]["Rows"] == []
|
|
assert result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"] == []
|
|
|
|
result = client.get_query_results(
|
|
QueryExecutionId="some-id-not-used-when-results-were-added-to-queue"
|
|
)
|
|
assert result["ResultSet"]["Rows"] == rows
|
|
assert result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"] == column_info
|
|
|
|
|
|
@mock_aws
|
|
def test_list_query_executions():
|
|
client = boto3.client("athena", region_name="us-east-1")
|
|
|
|
create_basic_workgroup(client=client, name="athena_workgroup")
|
|
exec_result = client.start_query_execution(
|
|
QueryString="query1",
|
|
QueryExecutionContext={"Database": "string"},
|
|
ResultConfiguration={"OutputLocation": "string"},
|
|
WorkGroup="athena_workgroup",
|
|
)
|
|
exec_id = exec_result["QueryExecutionId"]
|
|
|
|
executions = client.list_query_executions()
|
|
assert len(executions["QueryExecutionIds"]) == 1
|
|
assert executions["QueryExecutionIds"][0] == exec_id
|
|
|
|
|
|
@mock_aws
|
|
def test_list_named_queries():
|
|
client = boto3.client("athena", region_name="us-east-1")
|
|
create_basic_workgroup(client=client, name="athena_workgroup")
|
|
query_id = client.create_named_query(
|
|
Name="query-name",
|
|
Database="target_db",
|
|
QueryString="SELECT * FROM table1",
|
|
WorkGroup="athena_workgroup",
|
|
)
|
|
list_athena_wg = client.list_named_queries(WorkGroup="athena_workgroup")
|
|
assert list_athena_wg["NamedQueryIds"][0] == query_id["NamedQueryId"]
|
|
list_primary_wg = client.list_named_queries()
|
|
assert len(list_primary_wg["NamedQueryIds"]) == 0
|
|
|
|
|
|
@mock_aws
|
|
def test_create_prepared_statement():
|
|
client = boto3.client("athena", region_name="us-east-1")
|
|
create_basic_workgroup(client=client, name="athena_workgroup")
|
|
res = client.create_prepared_statement(
|
|
StatementName="test-statement",
|
|
WorkGroup="athena_workgroup",
|
|
QueryStatement="SELECT * FROM table1",
|
|
)
|
|
metadata = res["ResponseMetadata"]
|
|
assert metadata["HTTPStatusCode"] == 200
|
|
assert metadata["RetryAttempts"] == 0
|
|
|
|
|
|
@mock_aws
|
|
def test_get_prepared_statement():
|
|
client = boto3.client("athena", region_name="us-east-1")
|
|
create_basic_workgroup(client=client, name="athena_workgroup")
|
|
client.create_prepared_statement(
|
|
StatementName="stmt-name",
|
|
WorkGroup="athena_workgroup",
|
|
QueryStatement="SELECT * FROM table1",
|
|
)
|
|
resp = client.get_prepared_statement(
|
|
StatementName="stmt-name", WorkGroup="athena_workgroup"
|
|
)
|
|
assert "PreparedStatement" in resp
|