moto/tests/test_athena/test_athena.py

435 lines
15 KiB
Python

from botocore.exceptions import ClientError
import pytest
import boto3
import sure # noqa # pylint: disable=unused-import
from moto import mock_athena, settings
from moto.athena.models import athena_backends, QueryResults
from moto.core import DEFAULT_ACCOUNT_ID
@mock_athena
def test_create_work_group():
client = boto3.client("athena", region_name="us-east-1")
client.create_work_group(
Name="athena_workgroup",
Description="Test work group",
Configuration={
"ResultConfiguration": {
"OutputLocation": "s3://bucket-name/prefix/",
"EncryptionConfiguration": {
"EncryptionOption": "SSE_KMS",
"KmsKey": "aws:arn:kms:1233456789:us-east-1:key/number-1",
},
}
},
)
try:
# The second time should throw an error
client.create_work_group(
Name="athena_workgroup",
Description="duplicate",
Configuration={
"ResultConfiguration": {
"OutputLocation": "s3://bucket-name/prefix/",
"EncryptionConfiguration": {
"EncryptionOption": "SSE_KMS",
"KmsKey": "aws:arn:kms:1233456789:us-east-1:key/number-1",
},
}
},
)
except ClientError as err:
err.response["Error"]["Code"].should.equal("InvalidRequestException")
err.response["Error"]["Message"].should.equal("WorkGroup already exists")
else:
raise RuntimeError("Should have raised ResourceNotFoundException")
# Then test the work group appears in the work group list
response = client.list_work_groups()
work_groups = list(
filter(lambda wg: wg["Name"] != "primary", response["WorkGroups"])
)
work_groups.should.have.length_of(1)
work_group = work_groups[0]
work_group["Name"].should.equal("athena_workgroup")
work_group["Description"].should.equal("Test work group")
work_group["State"].should.equal("ENABLED")
@mock_athena
def test_get_primary_workgroup():
client = boto3.client("athena", region_name="us-east-1")
assert len(client.list_work_groups()["WorkGroups"]) == 1
primary = client.get_work_group(WorkGroup="primary")["WorkGroup"]
assert primary["Name"] == "primary"
assert primary["Configuration"] == {}
@mock_athena
def test_create_and_get_workgroup():
client = boto3.client("athena", region_name="us-east-1")
create_basic_workgroup(client=client, name="athena_workgroup")
work_group = client.get_work_group(WorkGroup="athena_workgroup")["WorkGroup"]
del work_group["CreationTime"] # Were not testing creationtime atm
work_group.should.equal(
{
"Name": "athena_workgroup",
"State": "ENABLED",
"Configuration": {
"ResultConfiguration": {"OutputLocation": "s3://bucket-name/prefix/"}
},
"Description": "Test work group",
}
)
@mock_athena
def test_start_query_execution():
client = boto3.client("athena", region_name="us-east-1")
create_basic_workgroup(client=client, name="athena_workgroup")
response = client.start_query_execution(
QueryString="query1",
QueryExecutionContext={"Database": "string"},
ResultConfiguration={"OutputLocation": "string"},
WorkGroup="athena_workgroup",
)
assert "QueryExecutionId" in response
sec_response = client.start_query_execution(
QueryString="query2",
QueryExecutionContext={"Database": "string"},
ResultConfiguration={"OutputLocation": "string"},
)
assert "QueryExecutionId" in sec_response
response["QueryExecutionId"].shouldnt.equal(sec_response["QueryExecutionId"])
@mock_athena
def test_start_query_validate_workgroup():
client = boto3.client("athena", region_name="us-east-1")
with pytest.raises(ClientError) as err:
client.start_query_execution(
QueryString="query1",
QueryExecutionContext={"Database": "string"},
ResultConfiguration={"OutputLocation": "string"},
WorkGroup="unknown_workgroup",
)
err.value.response["Error"]["Code"].should.equal("InvalidRequestException")
err.value.response["Error"]["Message"].should.equal("WorkGroup does not exist")
@mock_athena
def test_get_query_execution():
client = boto3.client("athena", region_name="us-east-1")
query = "SELECT stuff"
location = "s3://bucket-name/prefix/"
database = "database"
# Start Query
exex_id = client.start_query_execution(
QueryString=query,
QueryExecutionContext={"Database": database},
ResultConfiguration={"OutputLocation": location},
)["QueryExecutionId"]
#
details = client.get_query_execution(QueryExecutionId=exex_id)["QueryExecution"]
#
details["QueryExecutionId"].should.equal(exex_id)
details["Query"].should.equal(query)
details["StatementType"].should.equal("DDL")
details["ResultConfiguration"]["OutputLocation"].should.equal(location)
details["QueryExecutionContext"]["Database"].should.equal(database)
details["Status"]["State"].should.equal("SUCCEEDED")
details["Statistics"].should.equal(
{
"EngineExecutionTimeInMillis": 0,
"DataScannedInBytes": 0,
"TotalExecutionTimeInMillis": 0,
"QueryQueueTimeInMillis": 0,
"QueryPlanningTimeInMillis": 0,
"ServiceProcessingTimeInMillis": 0,
}
)
assert "WorkGroup" not in details
@mock_athena
def test_stop_query_execution():
client = boto3.client("athena", region_name="us-east-1")
query = "SELECT stuff"
location = "s3://bucket-name/prefix/"
database = "database"
# Start Query
exex_id = client.start_query_execution(
QueryString=query,
QueryExecutionContext={"Database": database},
ResultConfiguration={"OutputLocation": location},
)["QueryExecutionId"]
# Stop Query
client.stop_query_execution(QueryExecutionId=exex_id)
# Verify status
details = client.get_query_execution(QueryExecutionId=exex_id)["QueryExecution"]
#
details["QueryExecutionId"].should.equal(exex_id)
details["Status"]["State"].should.equal("CANCELLED")
@mock_athena
def test_create_named_query():
client = boto3.client("athena", region_name="us-east-1")
# craete named query
res = client.create_named_query(
Name="query-name", Database="target_db", QueryString="SELECT * FROM table1"
)
assert "NamedQueryId" in res
@mock_athena
def test_get_named_query():
client = boto3.client("athena", region_name="us-east-1")
query_name = "query-name"
database = "target_db"
query_string = "SELECT * FROM tbl1"
description = "description of this query"
# create named query
res_create = client.create_named_query(
Name=query_name,
Database=database,
QueryString=query_string,
Description=description,
)
query_id = res_create["NamedQueryId"]
# get named query
res_get = client.get_named_query(NamedQueryId=query_id)["NamedQuery"]
res_get["Name"].should.equal(query_name)
res_get["Description"].should.equal(description)
res_get["Database"].should.equal(database)
res_get["QueryString"].should.equal(query_string)
res_get["NamedQueryId"].should.equal(query_id)
def create_basic_workgroup(client, name):
client.create_work_group(
Name=name,
Description="Test work group",
Configuration={
"ResultConfiguration": {"OutputLocation": "s3://bucket-name/prefix/"}
},
)
@mock_athena
def test_create_data_catalog():
client = boto3.client("athena", region_name="us-east-1")
response = client.create_data_catalog(
Name="athena_datacatalog",
Type="GLUE",
Description="Test data catalog",
Parameters={"catalog-id": "AWS Test account ID"},
Tags=[],
)
try:
# The second time should throw an error
response = client.create_data_catalog(
Name="athena_datacatalog",
Type="GLUE",
Description="Test data catalog",
Parameters={"catalog-id": "AWS Test account ID"},
Tags=[],
)
except ClientError as err:
err.response["Error"]["Code"].should.equal("InvalidRequestException")
err.response["Error"]["Message"].should.equal("DataCatalog already exists")
else:
raise RuntimeError("Should have raised ResourceNotFoundException")
# Then test the work group appears in the work group list
response = client.list_data_catalogs()
response["DataCatalogsSummary"].should.have.length_of(1)
data_catalog = response["DataCatalogsSummary"][0]
data_catalog["CatalogName"].should.equal("athena_datacatalog")
data_catalog["Type"].should.equal("GLUE")
@mock_athena
def test_create_and_get_data_catalog():
client = boto3.client("athena", region_name="us-east-1")
client.create_data_catalog(
Name="athena_datacatalog",
Type="GLUE",
Description="Test data catalog",
Parameters={"catalog-id": "AWS Test account ID"},
Tags=[],
)
data_catalog = client.get_data_catalog(Name="athena_datacatalog")
data_catalog["DataCatalog"].should.equal(
{
"Name": "athena_datacatalog",
"Description": "Test data catalog",
"Type": "GLUE",
"Parameters": {"catalog-id": "AWS Test account ID"},
}
)
@mock_athena
def test_get_query_results():
client = boto3.client("athena", region_name="us-east-1")
result = client.get_query_results(QueryExecutionId="test")
result["ResultSet"]["Rows"].should.equal([])
result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"].should.equal([])
if not settings.TEST_SERVER_MODE:
backend = athena_backends[DEFAULT_ACCOUNT_ID]["us-east-1"]
rows = [{"Data": [{"VarCharValue": ".."}]}]
column_info = [
{
"CatalogName": "string",
"SchemaName": "string",
"TableName": "string",
"Name": "string",
"Label": "string",
"Type": "string",
"Precision": 123,
"Scale": 123,
"Nullable": "NOT_NULL",
"CaseSensitive": True,
}
]
# This was the documented way to configure query results, before `moto-api/static/athena/query_results` was implemented
# We should try to keep this for backward compatibility
results = QueryResults(rows=rows, column_info=column_info)
backend.query_results["test"] = results
result = client.get_query_results(QueryExecutionId="test")
result["ResultSet"]["Rows"].should.equal(rows)
result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"].should.equal(column_info)
@mock_athena
def test_get_query_results_queue():
client = boto3.client("athena", region_name="us-east-1")
result = client.get_query_results(QueryExecutionId="test")
result["ResultSet"]["Rows"].should.equal([])
result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"].should.equal([])
if not settings.TEST_SERVER_MODE:
backend = athena_backends[DEFAULT_ACCOUNT_ID]["us-east-1"]
rows = [{"Data": [{"VarCharValue": ".."}]}]
column_info = [
{
"CatalogName": "string",
"SchemaName": "string",
"TableName": "string",
"Name": "string",
"Label": "string",
"Type": "string",
"Precision": 123,
"Scale": 123,
"Nullable": "NOT_NULL",
"CaseSensitive": True,
}
]
results = QueryResults(rows=rows, column_info=column_info)
backend.query_results_queue.append(results)
result = client.get_query_results(
QueryExecutionId="some-id-not-used-when-results-were-added-to-queue"
)
result["ResultSet"]["Rows"].should.equal(rows)
result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"].should.equal(column_info)
result = client.get_query_results(QueryExecutionId="other-id")
result["ResultSet"]["Rows"].should.equal([])
result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"].should.equal([])
result = client.get_query_results(
QueryExecutionId="some-id-not-used-when-results-were-added-to-queue"
)
result["ResultSet"]["Rows"].should.equal(rows)
result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"].should.equal(column_info)
@mock_athena
def test_list_query_executions():
client = boto3.client("athena", region_name="us-east-1")
create_basic_workgroup(client=client, name="athena_workgroup")
exec_result = client.start_query_execution(
QueryString="query1",
QueryExecutionContext={"Database": "string"},
ResultConfiguration={"OutputLocation": "string"},
WorkGroup="athena_workgroup",
)
exec_id = exec_result["QueryExecutionId"]
executions = client.list_query_executions()
executions["QueryExecutionIds"].should.have.length_of(1)
executions["QueryExecutionIds"][0].should.equal(exec_id)
@mock_athena
def test_list_named_queries():
client = boto3.client("athena", region_name="us-east-1")
create_basic_workgroup(client=client, name="athena_workgroup")
query_id = client.create_named_query(
Name="query-name",
Database="target_db",
QueryString="SELECT * FROM table1",
WorkGroup="athena_workgroup",
)
list_athena_wg = client.list_named_queries(WorkGroup="athena_workgroup")
assert list_athena_wg["NamedQueryIds"][0] == query_id["NamedQueryId"]
list_primary_wg = client.list_named_queries()
assert len(list_primary_wg["NamedQueryIds"]) == 0
@mock_athena
def test_create_prepared_statement():
client = boto3.client("athena", region_name="us-east-1")
create_basic_workgroup(client=client, name="athena_workgroup")
res = client.create_prepared_statement(
StatementName="test-statement",
WorkGroup="athena_workgroup",
QueryStatement="SELECT * FROM table1",
)
metadata = res["ResponseMetadata"]
assert metadata["HTTPStatusCode"] == 200
assert metadata["RetryAttempts"] == 0
@mock_athena
def test_get_prepared_statement():
client = boto3.client("athena", region_name="us-east-1")
create_basic_workgroup(client=client, name="athena_workgroup")
client.create_prepared_statement(
StatementName="stmt-name",
WorkGroup="athena_workgroup",
QueryStatement="SELECT * FROM table1",
)
resp = client.get_prepared_statement(
StatementName="stmt-name", WorkGroup="athena_workgroup"
)
assert "PreparedStatement" in resp