Techdebt: Replace sure with regular asserts in Athena tests (#6382)

This commit is contained in:
Bert Blommers 2023-06-09 10:11:05 +00:00 committed by GitHub
parent 56895eae17
commit 3697019f01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 83 additions and 94 deletions

View File

@ -1,7 +1,6 @@
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
import pytest import pytest
import boto3 import boto3
import sure # noqa # pylint: disable=unused-import
from moto import mock_athena, settings from moto import mock_athena, settings
from moto.athena.models import athena_backends, QueryResults from moto.athena.models import athena_backends, QueryResults
@ -26,7 +25,7 @@ def test_create_work_group():
}, },
) )
try: with pytest.raises(ClientError) as exc:
# The second time should throw an error # The second time should throw an error
client.create_work_group( client.create_work_group(
Name="athena_workgroup", Name="athena_workgroup",
@ -41,11 +40,9 @@ def test_create_work_group():
} }
}, },
) )
except ClientError as err: err = exc.value.response["Error"]
err.response["Error"]["Code"].should.equal("InvalidRequestException") assert err["Code"] == "InvalidRequestException"
err.response["Error"]["Message"].should.equal("WorkGroup already exists") assert err["Message"] == "WorkGroup already exists"
else:
raise RuntimeError("Should have raised ResourceNotFoundException")
# Then test the work group appears in the work group list # Then test the work group appears in the work group list
response = client.list_work_groups() response = client.list_work_groups()
@ -53,11 +50,11 @@ def test_create_work_group():
work_groups = list( work_groups = list(
filter(lambda wg: wg["Name"] != "primary", response["WorkGroups"]) filter(lambda wg: wg["Name"] != "primary", response["WorkGroups"])
) )
work_groups.should.have.length_of(1) assert len(work_groups) == 1
work_group = work_groups[0] work_group = work_groups[0]
work_group["Name"].should.equal("athena_workgroup") assert work_group["Name"] == "athena_workgroup"
work_group["Description"].should.equal("Test work group") assert work_group["Description"] == "Test work group"
work_group["State"].should.equal("ENABLED") assert work_group["State"] == "ENABLED"
@mock_athena @mock_athena
@ -78,16 +75,14 @@ def test_create_and_get_workgroup():
work_group = client.get_work_group(WorkGroup="athena_workgroup")["WorkGroup"] work_group = client.get_work_group(WorkGroup="athena_workgroup")["WorkGroup"]
del work_group["CreationTime"] # Were not testing creationtime atm del work_group["CreationTime"] # Were not testing creationtime atm
work_group.should.equal( assert work_group == {
{ "Name": "athena_workgroup",
"Name": "athena_workgroup", "State": "ENABLED",
"State": "ENABLED", "Configuration": {
"Configuration": { "ResultConfiguration": {"OutputLocation": "s3://bucket-name/prefix/"}
"ResultConfiguration": {"OutputLocation": "s3://bucket-name/prefix/"} },
}, "Description": "Test work group",
"Description": "Test work group", }
}
)
@mock_athena @mock_athena
@ -109,7 +104,7 @@ def test_start_query_execution():
ResultConfiguration={"OutputLocation": "string"}, ResultConfiguration={"OutputLocation": "string"},
) )
assert "QueryExecutionId" in sec_response assert "QueryExecutionId" in sec_response
response["QueryExecutionId"].shouldnt.equal(sec_response["QueryExecutionId"]) assert response["QueryExecutionId"] != sec_response["QueryExecutionId"]
@mock_athena @mock_athena
@ -123,8 +118,8 @@ def test_start_query_validate_workgroup():
ResultConfiguration={"OutputLocation": "string"}, ResultConfiguration={"OutputLocation": "string"},
WorkGroup="unknown_workgroup", WorkGroup="unknown_workgroup",
) )
err.value.response["Error"]["Code"].should.equal("InvalidRequestException") assert err.value.response["Error"]["Code"] == "InvalidRequestException"
err.value.response["Error"]["Message"].should.equal("WorkGroup does not exist") assert err.value.response["Error"]["Message"] == "WorkGroup does not exist"
@mock_athena @mock_athena
@ -143,22 +138,20 @@ def test_get_query_execution():
# #
details = client.get_query_execution(QueryExecutionId=exex_id)["QueryExecution"] details = client.get_query_execution(QueryExecutionId=exex_id)["QueryExecution"]
# #
details["QueryExecutionId"].should.equal(exex_id) assert details["QueryExecutionId"] == exex_id
details["Query"].should.equal(query) assert details["Query"] == query
details["StatementType"].should.equal("DDL") assert details["StatementType"] == "DDL"
details["ResultConfiguration"]["OutputLocation"].should.equal(location) assert details["ResultConfiguration"]["OutputLocation"] == location
details["QueryExecutionContext"]["Database"].should.equal(database) assert details["QueryExecutionContext"]["Database"] == database
details["Status"]["State"].should.equal("SUCCEEDED") assert details["Status"]["State"] == "SUCCEEDED"
details["Statistics"].should.equal( assert details["Statistics"] == {
{ "EngineExecutionTimeInMillis": 0,
"EngineExecutionTimeInMillis": 0, "DataScannedInBytes": 0,
"DataScannedInBytes": 0, "TotalExecutionTimeInMillis": 0,
"TotalExecutionTimeInMillis": 0, "QueryQueueTimeInMillis": 0,
"QueryQueueTimeInMillis": 0, "QueryPlanningTimeInMillis": 0,
"QueryPlanningTimeInMillis": 0, "ServiceProcessingTimeInMillis": 0,
"ServiceProcessingTimeInMillis": 0, }
}
)
assert "WorkGroup" not in details assert "WorkGroup" not in details
@ -180,8 +173,8 @@ def test_stop_query_execution():
# Verify status # Verify status
details = client.get_query_execution(QueryExecutionId=exex_id)["QueryExecution"] details = client.get_query_execution(QueryExecutionId=exex_id)["QueryExecution"]
# #
details["QueryExecutionId"].should.equal(exex_id) assert details["QueryExecutionId"] == exex_id
details["Status"]["State"].should.equal("CANCELLED") assert details["Status"]["State"] == "CANCELLED"
@mock_athena @mock_athena
@ -214,11 +207,11 @@ def test_get_named_query():
# get named query # get named query
res_get = client.get_named_query(NamedQueryId=query_id)["NamedQuery"] res_get = client.get_named_query(NamedQueryId=query_id)["NamedQuery"]
res_get["Name"].should.equal(query_name) assert res_get["Name"] == query_name
res_get["Description"].should.equal(description) assert res_get["Description"] == description
res_get["Database"].should.equal(database) assert res_get["Database"] == database
res_get["QueryString"].should.equal(query_string) assert res_get["QueryString"] == query_string
res_get["NamedQueryId"].should.equal(query_id) assert res_get["NamedQueryId"] == query_id
def create_basic_workgroup(client, name): def create_basic_workgroup(client, name):
@ -242,7 +235,7 @@ def test_create_data_catalog():
Tags=[], Tags=[],
) )
try: with pytest.raises(ClientError) as exc:
# The second time should throw an error # The second time should throw an error
response = client.create_data_catalog( response = client.create_data_catalog(
Name="athena_datacatalog", Name="athena_datacatalog",
@ -251,19 +244,17 @@ def test_create_data_catalog():
Parameters={"catalog-id": "AWS Test account ID"}, Parameters={"catalog-id": "AWS Test account ID"},
Tags=[], Tags=[],
) )
except ClientError as err: err = exc.value.response["Error"]
err.response["Error"]["Code"].should.equal("InvalidRequestException") assert err["Code"] == "InvalidRequestException"
err.response["Error"]["Message"].should.equal("DataCatalog already exists") assert err["Message"] == "DataCatalog already exists"
else:
raise RuntimeError("Should have raised ResourceNotFoundException")
# Then test the work group appears in the work group list # Then test the work group appears in the work group list
response = client.list_data_catalogs() response = client.list_data_catalogs()
response["DataCatalogsSummary"].should.have.length_of(1) assert len(response["DataCatalogsSummary"]) == 1
data_catalog = response["DataCatalogsSummary"][0] data_catalog = response["DataCatalogsSummary"][0]
data_catalog["CatalogName"].should.equal("athena_datacatalog") assert data_catalog["CatalogName"] == "athena_datacatalog"
data_catalog["Type"].should.equal("GLUE") assert data_catalog["Type"] == "GLUE"
@mock_athena @mock_athena
@ -279,14 +270,12 @@ def test_create_and_get_data_catalog():
) )
data_catalog = client.get_data_catalog(Name="athena_datacatalog") data_catalog = client.get_data_catalog(Name="athena_datacatalog")
data_catalog["DataCatalog"].should.equal( assert data_catalog["DataCatalog"] == {
{ "Name": "athena_datacatalog",
"Name": "athena_datacatalog", "Description": "Test data catalog",
"Description": "Test data catalog", "Type": "GLUE",
"Type": "GLUE", "Parameters": {"catalog-id": "AWS Test account ID"},
"Parameters": {"catalog-id": "AWS Test account ID"}, }
}
)
@mock_athena @mock_athena
@ -295,8 +284,8 @@ def test_get_query_results():
result = client.get_query_results(QueryExecutionId="test") result = client.get_query_results(QueryExecutionId="test")
result["ResultSet"]["Rows"].should.equal([]) assert result["ResultSet"]["Rows"] == []
result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"].should.equal([]) assert result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"] == []
if not settings.TEST_SERVER_MODE: if not settings.TEST_SERVER_MODE:
backend = athena_backends[DEFAULT_ACCOUNT_ID]["us-east-1"] backend = athena_backends[DEFAULT_ACCOUNT_ID]["us-east-1"]
@ -321,8 +310,8 @@ def test_get_query_results():
backend.query_results["test"] = results backend.query_results["test"] = results
result = client.get_query_results(QueryExecutionId="test") result = client.get_query_results(QueryExecutionId="test")
result["ResultSet"]["Rows"].should.equal(rows) assert result["ResultSet"]["Rows"] == rows
result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"].should.equal(column_info) assert result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"] == column_info
@mock_athena @mock_athena
@ -331,8 +320,8 @@ def test_get_query_results_queue():
result = client.get_query_results(QueryExecutionId="test") result = client.get_query_results(QueryExecutionId="test")
result["ResultSet"]["Rows"].should.equal([]) assert result["ResultSet"]["Rows"] == []
result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"].should.equal([]) assert result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"] == []
if not settings.TEST_SERVER_MODE: if not settings.TEST_SERVER_MODE:
backend = athena_backends[DEFAULT_ACCOUNT_ID]["us-east-1"] backend = athena_backends[DEFAULT_ACCOUNT_ID]["us-east-1"]
@ -357,18 +346,18 @@ def test_get_query_results_queue():
result = client.get_query_results( result = client.get_query_results(
QueryExecutionId="some-id-not-used-when-results-were-added-to-queue" QueryExecutionId="some-id-not-used-when-results-were-added-to-queue"
) )
result["ResultSet"]["Rows"].should.equal(rows) assert result["ResultSet"]["Rows"] == rows
result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"].should.equal(column_info) assert result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"] == column_info
result = client.get_query_results(QueryExecutionId="other-id") result = client.get_query_results(QueryExecutionId="other-id")
result["ResultSet"]["Rows"].should.equal([]) assert result["ResultSet"]["Rows"] == []
result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"].should.equal([]) assert result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"] == []
result = client.get_query_results( result = client.get_query_results(
QueryExecutionId="some-id-not-used-when-results-were-added-to-queue" QueryExecutionId="some-id-not-used-when-results-were-added-to-queue"
) )
result["ResultSet"]["Rows"].should.equal(rows) assert result["ResultSet"]["Rows"] == rows
result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"].should.equal(column_info) assert result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"] == column_info
@mock_athena @mock_athena
@ -385,8 +374,8 @@ def test_list_query_executions():
exec_id = exec_result["QueryExecutionId"] exec_id = exec_result["QueryExecutionId"]
executions = client.list_query_executions() executions = client.list_query_executions()
executions["QueryExecutionIds"].should.have.length_of(1) assert len(executions["QueryExecutionIds"]) == 1
executions["QueryExecutionIds"][0].should.equal(exec_id) assert executions["QueryExecutionIds"][0] == exec_id
@mock_athena @mock_athena

View File

@ -40,20 +40,20 @@ def test_set_athena_result():
f"http://{base_url}/moto-api/static/athena/query-results", f"http://{base_url}/moto-api/static/athena/query-results",
json=athena_result, json=athena_result,
) )
resp.status_code.should.equal(201) assert resp.status_code == 201
client = boto3.client("athena", region_name="us-east-1") client = boto3.client("athena", region_name="us-east-1")
details = client.get_query_results(QueryExecutionId="anyid")["ResultSet"] details = client.get_query_results(QueryExecutionId="anyid")["ResultSet"]
details["Rows"].should.equal(athena_result["results"][0]["rows"]) assert details["Rows"] == athena_result["results"][0]["rows"]
details["ResultSetMetadata"]["ColumnInfo"].should.equal(DEFAULT_COLUMN_INFO) assert details["ResultSetMetadata"]["ColumnInfo"] == DEFAULT_COLUMN_INFO
# Operation should be idempotent # Operation should be idempotent
details = client.get_query_results(QueryExecutionId="anyid")["ResultSet"] details = client.get_query_results(QueryExecutionId="anyid")["ResultSet"]
details["Rows"].should.equal(athena_result["results"][0]["rows"]) assert details["Rows"] == athena_result["results"][0]["rows"]
# Different ID should still return different (default) results though # Different ID should still return different (default) results though
details = client.get_query_results(QueryExecutionId="otherid")["ResultSet"] details = client.get_query_results(QueryExecutionId="otherid")["ResultSet"]
details["Rows"].should.equal([]) assert details["Rows"] == []
@mock_athena @mock_athena
@ -73,27 +73,27 @@ def test_set_multiple_athena_result():
f"http://{base_url}/moto-api/static/athena/query-results", f"http://{base_url}/moto-api/static/athena/query-results",
json=athena_result, json=athena_result,
) )
resp.status_code.should.equal(201) assert resp.status_code == 201
client = boto3.client("athena", region_name="us-east-1") client = boto3.client("athena", region_name="us-east-1")
details = client.get_query_results(QueryExecutionId="first_id")["ResultSet"] details = client.get_query_results(QueryExecutionId="first_id")["ResultSet"]
details["Rows"].should.equal([{"Data": [{"VarCharValue": "1"}]}]) assert details["Rows"] == [{"Data": [{"VarCharValue": "1"}]}]
# The same ID should return the same data # The same ID should return the same data
details = client.get_query_results(QueryExecutionId="first_id")["ResultSet"] details = client.get_query_results(QueryExecutionId="first_id")["ResultSet"]
details["Rows"].should.equal([{"Data": [{"VarCharValue": "1"}]}]) assert details["Rows"] == [{"Data": [{"VarCharValue": "1"}]}]
# The next ID should return different data # The next ID should return different data
details = client.get_query_results(QueryExecutionId="second_id")["ResultSet"] details = client.get_query_results(QueryExecutionId="second_id")["ResultSet"]
details["Rows"].should.equal([{"Data": [{"VarCharValue": "2"}]}]) assert details["Rows"] == [{"Data": [{"VarCharValue": "2"}]}]
# The last ID should return even different data # The last ID should return even different data
details = client.get_query_results(QueryExecutionId="third_id")["ResultSet"] details = client.get_query_results(QueryExecutionId="third_id")["ResultSet"]
details["Rows"].should.equal([{"Data": [{"VarCharValue": "3"}]}]) assert details["Rows"] == [{"Data": [{"VarCharValue": "3"}]}]
# Any other calls should return the default data # Any other calls should return the default data
details = client.get_query_results(QueryExecutionId="other_id")["ResultSet"] details = client.get_query_results(QueryExecutionId="other_id")["ResultSet"]
details["Rows"].should.equal([]) assert details["Rows"] == []
@mock_athena @mock_athena
@ -119,7 +119,7 @@ def test_set_athena_result_with_custom_region_account():
f"http://{base_url}/moto-api/static/athena/query-results", f"http://{base_url}/moto-api/static/athena/query-results",
json=athena_result, json=athena_result,
) )
resp.status_code.should.equal(201) assert resp.status_code == 201
sts = boto3.client("sts", "us-east-1") sts = boto3.client("sts", "us-east-1")
cross_account_creds = sts.assume_role( cross_account_creds = sts.assume_role(
@ -139,8 +139,8 @@ def test_set_athena_result_with_custom_region_account():
details = athena_in_other_account.get_query_results(QueryExecutionId="anyid")[ details = athena_in_other_account.get_query_results(QueryExecutionId="anyid")[
"ResultSet" "ResultSet"
] ]
details["Rows"].should.equal(athena_result["results"][0]["rows"]) assert details["Rows"] == athena_result["results"][0]["rows"]
details["ResultSetMetadata"]["ColumnInfo"].should.equal(DEFAULT_COLUMN_INFO) assert details["ResultSetMetadata"]["ColumnInfo"] == DEFAULT_COLUMN_INFO
# query results from other regions do not match # query results from other regions do not match
athena_in_diff_region = boto3.client( athena_in_diff_region = boto3.client(
@ -153,9 +153,9 @@ def test_set_athena_result_with_custom_region_account():
details = athena_in_diff_region.get_query_results(QueryExecutionId="anyid")[ details = athena_in_diff_region.get_query_results(QueryExecutionId="anyid")[
"ResultSet" "ResultSet"
] ]
details["Rows"].should.equal([]) assert details["Rows"] == []
# query results from default account does not match # query results from default account does not match
client = boto3.client("athena", region_name="eu-west-1") client = boto3.client("athena", region_name="eu-west-1")
details = client.get_query_results(QueryExecutionId="anyid")["ResultSet"] details = client.get_query_results(QueryExecutionId="anyid")["ResultSet"]
details["Rows"].should.equal([]) assert details["Rows"] == []