From 1257e93ec0826d85e439eb15a7f87c93bfd4803d Mon Sep 17 00:00:00 2001 From: Paul Martins Date: Sun, 12 Mar 2023 17:14:33 +0000 Subject: [PATCH] Athena fixes and additions (#6055) --- moto/athena/models.py | 67 ++++++++++++++++++++++++++++++-- moto/athena/responses.py | 49 +++++++++++++++++++++-- tests/test_athena/test_athena.py | 54 +++++++++++++++++++++++-- 3 files changed, 161 insertions(+), 9 deletions(-) diff --git a/moto/athena/models.py b/moto/athena/models.py index a7ee323c1..1b1936cc7 100644 --- a/moto/athena/models.py +++ b/moto/athena/models.py @@ -1,7 +1,8 @@ import time - +from datetime import datetime from moto.core import BaseBackend, BackendDict, BaseModel from moto.moto_api._internal import mock_random +from moto.utilities.paginator import paginate from typing import Any, Dict, List, Optional @@ -124,7 +125,31 @@ class NamedQuery(BaseModel): self.workgroup = workgroup +class PreparedStatement(BaseModel): + def __init__( + self, + statement_name: str, + workgroup: WorkGroup, + query_statement: str, + description: str, + ): + self.statement_name = statement_name + self.workgroup = workgroup + self.query_statement = query_statement + self.description = description + self.last_modified_time = datetime.now() + + class AthenaBackend(BaseBackend): + PAGINATION_MODEL = { + "list_named_queries": { + "input_token": "next_token", + "limit_key": "max_results", + "limit_default": 50, + "unique_attribute": "id", + } + } + def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) self.work_groups: Dict[str, WorkGroup] = {} @@ -133,6 +158,10 @@ class AthenaBackend(BaseBackend): self.data_catalogs: Dict[str, DataCatalog] = {} self.query_results: Dict[str, QueryResults] = {} self.query_results_queue: List[QueryResults] = [] + self.prepared_statements: Dict[str, PreparedStatement] = {} + + # Initialise with the primary workgroup + self.create_work_group("primary", "", "", []) @staticmethod def default_vpc_endpoint_service( @@ -259,14 +288,14 @@ class AthenaBackend(BaseBackend): description: str, database: str, query_string: str, - workgroup: WorkGroup, + workgroup: str, ) -> str: nq = NamedQuery( name=name, description=description, database=database, query_string=query_string, - workgroup=workgroup, + workgroup=self.work_groups[workgroup], ) self.named_queries[nq.id] = nq return nq.id @@ -307,5 +336,37 @@ class AthenaBackend(BaseBackend): self.data_catalogs[name] = data_catalog return data_catalog + @paginate(pagination_model=PAGINATION_MODEL) # type: ignore + def list_named_queries(self, work_group: str) -> List[str]: # type: ignore[misc] + named_query_ids = [ + q.id for q in self.named_queries.values() if q.workgroup.name == work_group + ] + return named_query_ids + + def create_prepared_statement( + self, + statement_name: str, + workgroup: WorkGroup, + query_statement: str, + description: str, + ) -> None: + ps = PreparedStatement( + statement_name=statement_name, + workgroup=workgroup, + query_statement=query_statement, + description=description, + ) + self.prepared_statements[ps.statement_name] = ps + return None + + def get_prepared_statement( + self, statement_name: str, work_group: WorkGroup + ) -> Optional[PreparedStatement]: + if statement_name in self.prepared_statements: + ps = self.prepared_statements[statement_name] + if ps.workgroup == work_group: + return ps + return None + athena_backends = BackendDict(AthenaBackend, "athena") diff --git a/moto/athena/responses.py b/moto/athena/responses.py index 7f7dfc94a..e8b21e188 100644 --- a/moto/athena/responses.py +++ b/moto/athena/responses.py @@ -104,8 +104,8 @@ class AthenaResponse(BaseResponse): description = self._get_param("Description") database = self._get_param("Database") query_string = self._get_param("QueryString") - workgroup = self._get_param("WorkGroup") - if workgroup and not self.athena_backend.get_work_group(workgroup): + workgroup = self._get_param("WorkGroup") or "primary" + if not self.athena_backend.get_work_group(workgroup): return self.error("WorkGroup does not exist", 400) query_id = self.athena_backend.create_named_query( name, description, database, query_string, workgroup @@ -123,7 +123,7 @@ class AthenaResponse(BaseResponse): "Database": nq.database, # type: ignore[union-attr] "QueryString": nq.query_string, # type: ignore[union-attr] "NamedQueryId": nq.id, # type: ignore[union-attr] - "WorkGroup": nq.workgroup, # type: ignore[union-attr] + "WorkGroup": nq.workgroup.name, # type: ignore[union-attr] } } ) @@ -157,3 +157,46 @@ class AthenaResponse(BaseResponse): } } ) + + def list_named_queries(self) -> str: + next_token = self._get_param("NextToken") + max_results = self._get_param("MaxResults") + work_group = self._get_param("WorkGroup") or "primary" + named_query_ids, next_token = self.athena_backend.list_named_queries( + next_token=next_token, max_results=max_results, work_group=work_group + ) + return json.dumps({"NamedQueryIds": named_query_ids, "NextToken": next_token}) + + def create_prepared_statement(self) -> Union[str, Tuple[str, Dict[str, int]]]: + statement_name = self._get_param("StatementName") + work_group = self._get_param("WorkGroup") + query_statement = self._get_param("QueryStatement") + description = self._get_param("Description") + if not self.athena_backend.get_work_group(work_group): + return self.error("WorkGroup does not exist", 400) + self.athena_backend.create_prepared_statement( + statement_name=statement_name, + workgroup=work_group, + query_statement=query_statement, + description=description, + ) + return json.dumps(dict()) + + def get_prepared_statement(self) -> str: + statement_name = self._get_param("StatementName") + work_group = self._get_param("WorkGroup") + ps = self.athena_backend.get_prepared_statement( + statement_name=statement_name, + work_group=work_group, + ) + return json.dumps( + { + "PreparedStatement": { + "StatementName": ps.statement_name, # type: ignore[union-attr] + "QueryStatement": ps.query_statement, # type: ignore[union-attr] + "WorkGroupName": ps.workgroup, # type: ignore[union-attr] + "Description": ps.description, # type: ignore[union-attr] + # "LastModifiedTime": ps.last_modified_time, # type: ignore[union-attr] + } + } + ) diff --git a/tests/test_athena/test_athena.py b/tests/test_athena/test_athena.py index 57d4b1680..1389511ef 100644 --- a/tests/test_athena/test_athena.py +++ b/tests/test_athena/test_athena.py @@ -51,8 +51,11 @@ def test_create_work_group(): # Then test the work group appears in the work group list response = client.list_work_groups() - response["WorkGroups"].should.have.length_of(1) - work_group = response["WorkGroups"][0] + work_groups = list( + filter(lambda wg: wg["Name"] != "primary", response["WorkGroups"]) + ) + work_groups.should.have.length_of(1) + work_group = work_groups[0] work_group["Name"].should.equal("athena_workgroup") work_group["Description"].should.equal("Test work group") work_group["State"].should.equal("ENABLED") @@ -191,7 +194,7 @@ def test_get_named_query(): database = "target_db" query_string = "SELECT * FROM tbl1" description = "description of this query" - # craete named query + # create named query res_create = client.create_named_query( Name=query_name, Database=database, @@ -375,3 +378,48 @@ def test_list_query_executions(): executions = client.list_query_executions() executions["QueryExecutionIds"].should.have.length_of(1) executions["QueryExecutionIds"][0].should.equal(exec_id) + + +@mock_athena +def test_list_named_queries(): + client = boto3.client("athena", region_name="us-east-1") + create_basic_workgroup(client=client, name="athena_workgroup") + query_id = client.create_named_query( + Name="query-name", + Database="target_db", + QueryString="SELECT * FROM table1", + WorkGroup="athena_workgroup", + ) + list_athena_wg = client.list_named_queries(WorkGroup="athena_workgroup") + assert list_athena_wg["NamedQueryIds"][0] == query_id["NamedQueryId"] + list_primary_wg = client.list_named_queries() + assert len(list_primary_wg["NamedQueryIds"]) == 0 + + +@mock_athena +def test_create_prepared_statement(): + client = boto3.client("athena", region_name="us-east-1") + create_basic_workgroup(client=client, name="athena_workgroup") + res = client.create_prepared_statement( + StatementName="test-statement", + WorkGroup="athena_workgroup", + QueryStatement="SELECT * FROM table1", + ) + metadata = res["ResponseMetadata"] + assert metadata["HTTPStatusCode"] == 200 + assert metadata["RetryAttempts"] == 0 + + +@mock_athena +def test_get_prepared_statement(): + client = boto3.client("athena", region_name="us-east-1") + create_basic_workgroup(client=client, name="athena_workgroup") + client.create_prepared_statement( + StatementName="stmt-name", + WorkGroup="athena_workgroup", + QueryStatement="SELECT * FROM table1", + ) + resp = client.get_prepared_statement( + StatementName="stmt-name", WorkGroup="athena_workgroup" + ) + assert "PreparedStatement" in resp