From 07a8d6f0099f8365d0709a2800236aed7e02dcdc Mon Sep 17 00:00:00 2001 From: Greg Hinch Date: Thu, 15 Dec 2022 20:49:48 +0000 Subject: [PATCH] Add Athena: get_query_results, list_query_executions (#5648) --- docs/docs/services/athena.rst | 6 +++--- moto/athena/models.py | 30 ++++++++++++++++++++++++++++-- moto/athena/responses.py | 11 ++++++++++- tests/test_athena/test_athena.py | 30 +++++++++++++++++++++++++++++- 4 files changed, 70 insertions(+), 7 deletions(-) diff --git a/docs/docs/services/athena.rst b/docs/docs/services/athena.rst index 5fe5ccf36..259728923 100644 --- a/docs/docs/services/athena.rst +++ b/docs/docs/services/athena.rst @@ -40,8 +40,8 @@ athena - [ ] get_database - [X] get_named_query - [ ] get_prepared_statement -- [ ] get_query_execution -- [ ] get_query_results +- [X] get_query_execution +- [X] get_query_results - [ ] get_query_runtime_statistics - [ ] get_table_metadata - [X] get_work_group @@ -50,7 +50,7 @@ athena - [ ] list_engine_versions - [ ] list_named_queries - [ ] list_prepared_statements -- [ ] list_query_executions +- [X] list_query_executions - [ ] list_table_metadata - [ ] list_tags_for_resource - [X] list_work_groups diff --git a/moto/athena/models.py b/moto/athena/models.py index 83ee2344b..b8fccbeec 100644 --- a/moto/athena/models.py +++ b/moto/athena/models.py @@ -90,7 +90,21 @@ class Execution(BaseModel): self.config = config self.workgroup = workgroup self.start_time = time.time() - self.status = "QUEUED" + self.status = "SUCCEEDED" + + +class QueryResults(BaseModel): + def __init__(self, rows: List[Dict[str, Any]], column_info: List[str]): + self.rows = rows + self.column_info = column_info + + def to_dict(self) -> Dict[str, Any]: + return { + "ResultSet": { + "Rows": self.rows, + "ResultSetMetadata": {"ColumnInfo": self.column_info}, + }, + } class NamedQuery(BaseModel): @@ -117,6 +131,7 @@ class AthenaBackend(BaseBackend): self.executions: Dict[str, Execution] = {} self.named_queries: Dict[str, NamedQuery] = {} self.data_catalogs: Dict[str, DataCatalog] = {} + self.query_results: Dict[str, QueryResults] = {} @staticmethod def default_vpc_endpoint_service( @@ -172,9 +187,20 @@ class AthenaBackend(BaseBackend): self.executions[execution.id] = execution return execution.id - def get_execution(self, exec_id: str) -> Execution: + def get_query_execution(self, exec_id: str) -> Execution: return self.executions[exec_id] + def list_query_executions(self) -> Dict[str, Execution]: + return self.executions + + def get_query_results(self, exec_id: str) -> QueryResults: + results = ( + self.query_results[exec_id] + if exec_id in self.query_results + else QueryResults(rows=[], column_info=[]) + ) + return results + def stop_query_execution(self, exec_id: str) -> None: execution = self.executions[exec_id] execution.status = "CANCELLED" diff --git a/moto/athena/responses.py b/moto/athena/responses.py index 09000727b..7f7dfc94a 100644 --- a/moto/athena/responses.py +++ b/moto/athena/responses.py @@ -54,7 +54,7 @@ class AthenaResponse(BaseResponse): def get_query_execution(self) -> str: exec_id = self._get_param("QueryExecutionId") - execution = self.athena_backend.get_execution(exec_id) + execution = self.athena_backend.get_query_execution(exec_id) result = { "QueryExecution": { "QueryExecutionId": exec_id, @@ -79,6 +79,15 @@ class AthenaResponse(BaseResponse): } return json.dumps(result) + def get_query_results(self) -> str: + exec_id = self._get_param("QueryExecutionId") + result = self.athena_backend.get_query_results(exec_id) + return json.dumps(result.to_dict()) + + def list_query_executions(self) -> str: + executions = self.athena_backend.list_query_executions() + return json.dumps({"QueryExecutionIds": [i for i in executions.keys()]}) + def stop_query_execution(self) -> str: exec_id = self._get_param("QueryExecutionId") self.athena_backend.stop_query_execution(exec_id) diff --git a/tests/test_athena/test_athena.py b/tests/test_athena/test_athena.py index 4b2650325..fd9c5be62 100644 --- a/tests/test_athena/test_athena.py +++ b/tests/test_athena/test_athena.py @@ -134,7 +134,7 @@ def test_get_query_execution(): details["StatementType"].should.equal("DDL") details["ResultConfiguration"]["OutputLocation"].should.equal(location) details["QueryExecutionContext"]["Database"].should.equal(database) - details["Status"]["State"].should.equal("QUEUED") + details["Status"]["State"].should.equal("SUCCEEDED") details["Statistics"].should.equal( { "EngineExecutionTimeInMillis": 0, @@ -273,3 +273,31 @@ def test_create_and_get_data_catalog(): "Parameters": {"catalog-id": "AWS Test account ID"}, } ) + + +@mock_athena +def test_get_query_results(): + client = boto3.client("athena", region_name="us-east-1") + + result = client.get_query_results(QueryExecutionId="test") + + result["ResultSet"]["Rows"].should.equal([]) + result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"].should.equal([]) + + +@mock_athena +def test_list_query_executions(): + client = boto3.client("athena", region_name="us-east-1") + + create_basic_workgroup(client=client, name="athena_workgroup") + exec_result = client.start_query_execution( + QueryString="query1", + QueryExecutionContext={"Database": "string"}, + ResultConfiguration={"OutputLocation": "string"}, + WorkGroup="athena_workgroup", + ) + exec_id = exec_result["QueryExecutionId"] + + executions = client.list_query_executions() + executions["QueryExecutionIds"].should.have.length_of(1) + executions["QueryExecutionIds"][0].should.equal(exec_id)