From 5dd155de4b5ba046185726b5602eb013e8ca3e02 Mon Sep 17 00:00:00 2001 From: Tuukka Ikkala <10980802+ikkala@users.noreply.github.com> Date: Fri, 24 Feb 2023 23:25:54 +0200 Subject: [PATCH] Athena - Make query results configurable (#5928) --- docs/docs/services/athena.rst | 58 ++++--- moto/athena/models.py | 61 +++++--- moto/moto_api/_internal/models.py | 15 +- moto/moto_api/_internal/responses.py | 26 ++++ moto/moto_api/_internal/urls.py | 1 + tests/test_athena/test_athena.py | 48 ++++++ tests/test_athena/test_athena_server_api.py | 161 ++++++++++++++++++++ 7 files changed, 327 insertions(+), 43 deletions(-) create mode 100644 tests/test_athena/test_athena_server_api.py diff --git a/docs/docs/services/athena.rst b/docs/docs/services/athena.rst index 7e99f7980..5d8849150 100644 --- a/docs/docs/services/athena.rst +++ b/docs/docs/services/athena.rst @@ -51,33 +51,49 @@ athena - [X] get_query_execution - [X] get_query_results - Queries are not executed, so this call will always return 0 rows by default. + Queries are not executed by Moto, so this call will always return 0 rows by default. - When using decorators, you can use the internal API to manually set results: + You can use a dedicated API to configure this. Moto has a queue that can be filled with the expected results. + + A request to `get_query_results` will take the first result from that queue, and assign it to the provided QueryExecutionId. Subsequent requests using the same QueryExecutionId will return the same result. Other requests using a different QueryExecutionId will take the next result from the queue, or return an empty result if the queue is empty. + + Configuring this queue by making a HTTP request to `/moto-api/static/athena/query-results`. An example invocation looks like this: .. sourcecode:: python - from moto.athena.models import athena_backends, QueryResults - from moto.core import DEFAULT_ACCOUNT_ID + expected_results = { + "account_id": "123456789012", # This is the default - can be omitted + "region": "us-east-1", # This is the default - can be omitted + "results": [ + { + "rows": [{"Data": [{"VarCharValue": "1"}]}], + "column_info": [{ + "CatalogName": "string", + "SchemaName": "string", + "TableName": "string", + "Name": "string", + "Label": "string", + "Type": "string", + "Precision": 123, + "Scale": 123, + "Nullable": "NOT_NULL", + "CaseSensitive": True, + }], + }, + # other results as required + ], + } + resp = requests.post( + "http://motoapi.amazonaws.com:5000/moto-api/static/athena/query-results", + json=athena_result, + ) + resp.status_code.should.equal(201) - backend = athena_backends[DEFAULT_ACCOUNT_ID]["us-east-1"] - rows = [{'Data': [{'VarCharValue': '..'}]}] - column_info = [{ - 'CatalogName': 'string', - 'SchemaName': 'string', - 'TableName': 'string', - 'Name': 'string', - 'Label': 'string', - 'Type': 'string', - 'Precision': 123, - 'Scale': 123, - 'Nullable': 'NOT_NULL', - 'CaseSensitive': True - }] - results = QueryResults(rows=rows, column_info=column_info) - backend.query_results["test"] = results + client = boto3.client("athena", region_name="us-east-1") + details = client.get_query_execution(QueryExecutionId="any_id")["QueryExecution"] + + .. note:: The exact QueryExecutionId is not relevant here, but will likely be whatever value is returned by start_query_execution - result = client.get_query_results(QueryExecutionId="test") - [ ] get_query_runtime_statistics diff --git a/moto/athena/models.py b/moto/athena/models.py index bfd2d9205..975bddb3a 100644 --- a/moto/athena/models.py +++ b/moto/athena/models.py @@ -132,6 +132,7 @@ class AthenaBackend(BaseBackend): self.named_queries: Dict[str, NamedQuery] = {} self.data_catalogs: Dict[str, DataCatalog] = {} self.query_results: Dict[str, QueryResults] = {} + self.query_results_queue: List[QueryResults] = [] @staticmethod def default_vpc_endpoint_service( @@ -195,34 +196,52 @@ class AthenaBackend(BaseBackend): def get_query_results(self, exec_id: str) -> QueryResults: """ - Queries are not executed, so this call will always return 0 rows by default. + Queries are not executed by Moto, so this call will always return 0 rows by default. - When using decorators, you can use the internal API to manually set results: + You can use a dedicated API to configure this. Moto has a queue that can be filled with the expected results. + + A request to `get_query_results` will take the first result from that queue, and assign it to the provided QueryExecutionId. Subsequent requests using the same QueryExecutionId will return the same result. Other requests using a different QueryExecutionId will take the next result from the queue, or return an empty result if the queue is empty. + + Configuring this queue by making a HTTP request to `/moto-api/static/athena/query-results`. An example invocation looks like this: .. sourcecode:: python - from moto.athena.models import athena_backends, QueryResults - from moto.core import DEFAULT_ACCOUNT_ID + expected_results = { + "account_id": "123456789012", # This is the default - can be omitted + "region": "us-east-1", # This is the default - can be omitted + "results": [ + { + "rows": [{"Data": [{"VarCharValue": "1"}]}], + "column_info": [{ + "CatalogName": "string", + "SchemaName": "string", + "TableName": "string", + "Name": "string", + "Label": "string", + "Type": "string", + "Precision": 123, + "Scale": 123, + "Nullable": "NOT_NULL", + "CaseSensitive": True, + }], + }, + # other results as required + ], + } + resp = requests.post( + "http://motoapi.amazonaws.com:5000/moto-api/static/athena/query-results", + json=athena_result, + ) + resp.status_code.should.equal(201) - backend = athena_backends[DEFAULT_ACCOUNT_ID]["us-east-1"] - rows = [{'Data': [{'VarCharValue': '..'}]}] - column_info = [{ - 'CatalogName': 'string', - 'SchemaName': 'string', - 'TableName': 'string', - 'Name': 'string', - 'Label': 'string', - 'Type': 'string', - 'Precision': 123, - 'Scale': 123, - 'Nullable': 'NOT_NULL', - 'CaseSensitive': True - }] - results = QueryResults(rows=rows, column_info=column_info) - backend.query_results["test"] = results + client = boto3.client("athena", region_name="us-east-1") + details = client.get_query_execution(QueryExecutionId="any_id")["QueryExecution"] + + .. note:: The exact QueryExecutionId is not relevant here, but will likely be whatever value is returned by start_query_execution - result = client.get_query_results(QueryExecutionId="test") """ + if exec_id not in self.query_results and self.query_results_queue: + self.query_results[exec_id] = self.query_results_queue.pop(0) results = ( self.query_results[exec_id] if exec_id in self.query_results diff --git a/moto/moto_api/_internal/models.py b/moto/moto_api/_internal/models.py index 3738b3fc7..373c00938 100644 --- a/moto/moto_api/_internal/models.py +++ b/moto/moto_api/_internal/models.py @@ -1,6 +1,6 @@ from moto.core import BaseBackend, DEFAULT_ACCOUNT_ID from moto.core.model_instances import reset_model_data -from typing import Any, Dict +from typing import Any, Dict, List class MotoAPIBackend(BaseBackend): @@ -33,5 +33,18 @@ class MotoAPIBackend(BaseBackend): state_manager.unset_transition(model_name) + def set_athena_result( + self, + rows: List[Dict[str, Any]], + column_info: List[Dict[str, str]], + account_id: str, + region: str, + ) -> None: + from moto.athena.models import athena_backends, QueryResults + + backend = athena_backends[account_id][region] + results = QueryResults(rows=rows, column_info=column_info) + backend.query_results_queue.append(results) + moto_api_backend = MotoAPIBackend(region_name="global", account_id=DEFAULT_ACCOUNT_ID) diff --git a/moto/moto_api/_internal/responses.py b/moto/moto_api/_internal/responses.py index ae52c6e47..03a0ac684 100644 --- a/moto/moto_api/_internal/responses.py +++ b/moto/moto_api/_internal/responses.py @@ -1,6 +1,7 @@ import json from moto import settings +from moto.core import DEFAULT_ACCOUNT_ID from moto.core.common_types import TYPE_RESPONSE from moto.core.responses import ActionAuthenticatorMixin, BaseResponse from typing import Any, Dict, List @@ -141,3 +142,28 @@ class MotoAPIResponse(BaseResponse): a = self._get_param("a") mock_random.seed(int(a)) return 200, {}, "" + + def set_athena_result( + self, + request: Any, + full_url: str, # pylint: disable=unused-argument + headers: Any, + ) -> TYPE_RESPONSE: + from .models import moto_api_backend + + request_body_size = int(headers["Content-Length"]) + body = request.environ["wsgi.input"].read(request_body_size).decode("utf-8") + body = json.loads(body) + account_id = body.get("account_id", DEFAULT_ACCOUNT_ID) + region = body.get("region", "us-east-1") + + for result in body.get("results", []): + rows = result["rows"] + column_info = result.get("column_info", []) + moto_api_backend.set_athena_result( + rows=rows, + column_info=column_info, + account_id=account_id, + region=region, + ) + return 201, {}, "" diff --git a/moto/moto_api/_internal/urls.py b/moto/moto_api/_internal/urls.py index da5ea6cef..0d1dd46a2 100644 --- a/moto/moto_api/_internal/urls.py +++ b/moto/moto_api/_internal/urls.py @@ -12,6 +12,7 @@ url_paths = { "{0}/moto-api/reset": response_instance.reset_response, "{0}/moto-api/reset-auth": response_instance.reset_auth_response, "{0}/moto-api/seed": response_instance.seed, + "{0}/moto-api/static/athena/query-results": response_instance.set_athena_result, "{0}/moto-api/state-manager/get-transition": response_instance.get_transition, "{0}/moto-api/state-manager/set-transition": response_instance.set_transition, "{0}/moto-api/state-manager/unset-transition": response_instance.unset_transition, diff --git a/tests/test_athena/test_athena.py b/tests/test_athena/test_athena.py index 24ce64c96..57d4b1680 100644 --- a/tests/test_athena/test_athena.py +++ b/tests/test_athena/test_athena.py @@ -303,6 +303,8 @@ def test_get_query_results(): "CaseSensitive": True, } ] + # This was the documented way to configure query results, before `moto-api/static/athena/query_results` was implemented + # We should try to keep this for backward compatibility results = QueryResults(rows=rows, column_info=column_info) backend.query_results["test"] = results @@ -311,6 +313,52 @@ def test_get_query_results(): result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"].should.equal(column_info) +@mock_athena +def test_get_query_results_queue(): + client = boto3.client("athena", region_name="us-east-1") + + result = client.get_query_results(QueryExecutionId="test") + + result["ResultSet"]["Rows"].should.equal([]) + result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"].should.equal([]) + + if not settings.TEST_SERVER_MODE: + backend = athena_backends[DEFAULT_ACCOUNT_ID]["us-east-1"] + rows = [{"Data": [{"VarCharValue": ".."}]}] + column_info = [ + { + "CatalogName": "string", + "SchemaName": "string", + "TableName": "string", + "Name": "string", + "Label": "string", + "Type": "string", + "Precision": 123, + "Scale": 123, + "Nullable": "NOT_NULL", + "CaseSensitive": True, + } + ] + results = QueryResults(rows=rows, column_info=column_info) + backend.query_results_queue.append(results) + + result = client.get_query_results( + QueryExecutionId="some-id-not-used-when-results-were-added-to-queue" + ) + result["ResultSet"]["Rows"].should.equal(rows) + result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"].should.equal(column_info) + + result = client.get_query_results(QueryExecutionId="other-id") + result["ResultSet"]["Rows"].should.equal([]) + result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"].should.equal([]) + + result = client.get_query_results( + QueryExecutionId="some-id-not-used-when-results-were-added-to-queue" + ) + result["ResultSet"]["Rows"].should.equal(rows) + result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"].should.equal(column_info) + + @mock_athena def test_list_query_executions(): client = boto3.client("athena", region_name="us-east-1") diff --git a/tests/test_athena/test_athena_server_api.py b/tests/test_athena/test_athena_server_api.py new file mode 100644 index 000000000..2ebed930c --- /dev/null +++ b/tests/test_athena/test_athena_server_api.py @@ -0,0 +1,161 @@ +import boto3 +import requests + +from moto import mock_athena, mock_sts, settings + + +DEFAULT_COLUMN_INFO = [ + { + "CatalogName": "string", + "SchemaName": "string", + "TableName": "string", + "Name": "string", + "Label": "string", + "Type": "string", + "Precision": 123, + "Scale": 123, + "Nullable": "NOT_NULL", + "CaseSensitive": True, + } +] + + +@mock_athena +def test_set_athena_result(): + base_url = ( + "localhost:5000" if settings.TEST_SERVER_MODE else "motoapi.amazonaws.com" + ) + + athena_result = { + "results": [ + { + "rows": [ + {"Data": [{"VarCharValue": "1"}]}, + ], + "column_info": DEFAULT_COLUMN_INFO, + } + ] + } + resp = requests.post( + f"http://{base_url}/moto-api/static/athena/query-results", + json=athena_result, + ) + resp.status_code.should.equal(201) + + client = boto3.client("athena", region_name="us-east-1") + details = client.get_query_results(QueryExecutionId="anyid")["ResultSet"] + details["Rows"].should.equal(athena_result["results"][0]["rows"]) + details["ResultSetMetadata"]["ColumnInfo"].should.equal(DEFAULT_COLUMN_INFO) + + # Operation should be idempotent + details = client.get_query_results(QueryExecutionId="anyid")["ResultSet"] + details["Rows"].should.equal(athena_result["results"][0]["rows"]) + + # Different ID should still return different (default) results though + details = client.get_query_results(QueryExecutionId="otherid")["ResultSet"] + details["Rows"].should.equal([]) + + +@mock_athena +def test_set_multiple_athena_result(): + base_url = ( + "localhost:5000" if settings.TEST_SERVER_MODE else "motoapi.amazonaws.com" + ) + + athena_result = { + "results": [ + {"rows": [{"Data": [{"VarCharValue": "1"}]}]}, + {"rows": [{"Data": [{"VarCharValue": "2"}]}]}, + {"rows": [{"Data": [{"VarCharValue": "3"}]}]}, + ] + } + resp = requests.post( + f"http://{base_url}/moto-api/static/athena/query-results", + json=athena_result, + ) + resp.status_code.should.equal(201) + + client = boto3.client("athena", region_name="us-east-1") + details = client.get_query_results(QueryExecutionId="first_id")["ResultSet"] + details["Rows"].should.equal([{"Data": [{"VarCharValue": "1"}]}]) + + # The same ID should return the same data + details = client.get_query_results(QueryExecutionId="first_id")["ResultSet"] + details["Rows"].should.equal([{"Data": [{"VarCharValue": "1"}]}]) + + # The next ID should return different data + details = client.get_query_results(QueryExecutionId="second_id")["ResultSet"] + details["Rows"].should.equal([{"Data": [{"VarCharValue": "2"}]}]) + + # The last ID should return even different data + details = client.get_query_results(QueryExecutionId="third_id")["ResultSet"] + details["Rows"].should.equal([{"Data": [{"VarCharValue": "3"}]}]) + + # Any other calls should return the default data + details = client.get_query_results(QueryExecutionId="other_id")["ResultSet"] + details["Rows"].should.equal([]) + + +@mock_athena +@mock_sts +def test_set_athena_result_with_custom_region_account(): + base_url = ( + "localhost:5000" if settings.TEST_SERVER_MODE else "motoapi.amazonaws.com" + ) + + athena_result = { + "account_id": "222233334444", + "region": "eu-west-1", + "results": [ + { + "rows": [ + {"Data": [{"VarCharValue": "1"}]}, + ], + "column_info": DEFAULT_COLUMN_INFO, + } + ], + } + resp = requests.post( + f"http://{base_url}/moto-api/static/athena/query-results", + json=athena_result, + ) + resp.status_code.should.equal(201) + + sts = boto3.client("sts", "us-east-1") + cross_account_creds = sts.assume_role( + RoleArn="arn:aws:iam::222233334444:role/role-in-another-account", + RoleSessionName="test-session-name", + ExternalId="test-external-id", + )["Credentials"] + + athena_in_other_account = boto3.client( + "athena", + aws_access_key_id=cross_account_creds["AccessKeyId"], + aws_secret_access_key=cross_account_creds["SecretAccessKey"], + aws_session_token=cross_account_creds["SessionToken"], + region_name="eu-west-1", + ) + + details = athena_in_other_account.get_query_results(QueryExecutionId="anyid")[ + "ResultSet" + ] + details["Rows"].should.equal(athena_result["results"][0]["rows"]) + details["ResultSetMetadata"]["ColumnInfo"].should.equal(DEFAULT_COLUMN_INFO) + + # query results from other regions do not match + athena_in_diff_region = boto3.client( + "athena", + aws_access_key_id=cross_account_creds["AccessKeyId"], + aws_secret_access_key=cross_account_creds["SecretAccessKey"], + aws_session_token=cross_account_creds["SessionToken"], + region_name="eu-west-2", + ) + details = athena_in_diff_region.get_query_results(QueryExecutionId="anyid")[ + "ResultSet" + ] + details["Rows"].should.equal([]) + + # query results from default account does not match + client = boto3.client("athena", region_name="eu-west-1") + details = client.get_query_results(QueryExecutionId="anyid")["ResultSet"] + details["Rows"].should.equal([])