Athena fixes and additions (#6055)

This commit is contained in:
Paul Martins 2023-03-12 17:14:33 +00:00 committed by GitHub
parent 18ec0c5467
commit 1257e93ec0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 161 additions and 9 deletions

View File

@ -1,7 +1,8 @@
import time import time
from datetime import datetime
from moto.core import BaseBackend, BackendDict, BaseModel from moto.core import BaseBackend, BackendDict, BaseModel
from moto.moto_api._internal import mock_random from moto.moto_api._internal import mock_random
from moto.utilities.paginator import paginate
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@ -124,7 +125,31 @@ class NamedQuery(BaseModel):
self.workgroup = workgroup self.workgroup = workgroup
class PreparedStatement(BaseModel):
def __init__(
self,
statement_name: str,
workgroup: WorkGroup,
query_statement: str,
description: str,
):
self.statement_name = statement_name
self.workgroup = workgroup
self.query_statement = query_statement
self.description = description
self.last_modified_time = datetime.now()
class AthenaBackend(BaseBackend): class AthenaBackend(BaseBackend):
PAGINATION_MODEL = {
"list_named_queries": {
"input_token": "next_token",
"limit_key": "max_results",
"limit_default": 50,
"unique_attribute": "id",
}
}
def __init__(self, region_name: str, account_id: str): def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id) super().__init__(region_name, account_id)
self.work_groups: Dict[str, WorkGroup] = {} self.work_groups: Dict[str, WorkGroup] = {}
@ -133,6 +158,10 @@ class AthenaBackend(BaseBackend):
self.data_catalogs: Dict[str, DataCatalog] = {} self.data_catalogs: Dict[str, DataCatalog] = {}
self.query_results: Dict[str, QueryResults] = {} self.query_results: Dict[str, QueryResults] = {}
self.query_results_queue: List[QueryResults] = [] self.query_results_queue: List[QueryResults] = []
self.prepared_statements: Dict[str, PreparedStatement] = {}
# Initialise with the primary workgroup
self.create_work_group("primary", "", "", [])
@staticmethod @staticmethod
def default_vpc_endpoint_service( def default_vpc_endpoint_service(
@ -259,14 +288,14 @@ class AthenaBackend(BaseBackend):
description: str, description: str,
database: str, database: str,
query_string: str, query_string: str,
workgroup: WorkGroup, workgroup: str,
) -> str: ) -> str:
nq = NamedQuery( nq = NamedQuery(
name=name, name=name,
description=description, description=description,
database=database, database=database,
query_string=query_string, query_string=query_string,
workgroup=workgroup, workgroup=self.work_groups[workgroup],
) )
self.named_queries[nq.id] = nq self.named_queries[nq.id] = nq
return nq.id return nq.id
@ -307,5 +336,37 @@ class AthenaBackend(BaseBackend):
self.data_catalogs[name] = data_catalog self.data_catalogs[name] = data_catalog
return data_catalog return data_catalog
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore
def list_named_queries(self, work_group: str) -> List[str]: # type: ignore[misc]
named_query_ids = [
q.id for q in self.named_queries.values() if q.workgroup.name == work_group
]
return named_query_ids
def create_prepared_statement(
self,
statement_name: str,
workgroup: WorkGroup,
query_statement: str,
description: str,
) -> None:
ps = PreparedStatement(
statement_name=statement_name,
workgroup=workgroup,
query_statement=query_statement,
description=description,
)
self.prepared_statements[ps.statement_name] = ps
return None
def get_prepared_statement(
self, statement_name: str, work_group: WorkGroup
) -> Optional[PreparedStatement]:
if statement_name in self.prepared_statements:
ps = self.prepared_statements[statement_name]
if ps.workgroup == work_group:
return ps
return None
athena_backends = BackendDict(AthenaBackend, "athena") athena_backends = BackendDict(AthenaBackend, "athena")

View File

@ -104,8 +104,8 @@ class AthenaResponse(BaseResponse):
description = self._get_param("Description") description = self._get_param("Description")
database = self._get_param("Database") database = self._get_param("Database")
query_string = self._get_param("QueryString") query_string = self._get_param("QueryString")
workgroup = self._get_param("WorkGroup") workgroup = self._get_param("WorkGroup") or "primary"
if workgroup and not self.athena_backend.get_work_group(workgroup): if not self.athena_backend.get_work_group(workgroup):
return self.error("WorkGroup does not exist", 400) return self.error("WorkGroup does not exist", 400)
query_id = self.athena_backend.create_named_query( query_id = self.athena_backend.create_named_query(
name, description, database, query_string, workgroup name, description, database, query_string, workgroup
@ -123,7 +123,7 @@ class AthenaResponse(BaseResponse):
"Database": nq.database, # type: ignore[union-attr] "Database": nq.database, # type: ignore[union-attr]
"QueryString": nq.query_string, # type: ignore[union-attr] "QueryString": nq.query_string, # type: ignore[union-attr]
"NamedQueryId": nq.id, # type: ignore[union-attr] "NamedQueryId": nq.id, # type: ignore[union-attr]
"WorkGroup": nq.workgroup, # type: ignore[union-attr] "WorkGroup": nq.workgroup.name, # type: ignore[union-attr]
} }
} }
) )
@ -157,3 +157,46 @@ class AthenaResponse(BaseResponse):
} }
} }
) )
def list_named_queries(self) -> str:
next_token = self._get_param("NextToken")
max_results = self._get_param("MaxResults")
work_group = self._get_param("WorkGroup") or "primary"
named_query_ids, next_token = self.athena_backend.list_named_queries(
next_token=next_token, max_results=max_results, work_group=work_group
)
return json.dumps({"NamedQueryIds": named_query_ids, "NextToken": next_token})
def create_prepared_statement(self) -> Union[str, Tuple[str, Dict[str, int]]]:
statement_name = self._get_param("StatementName")
work_group = self._get_param("WorkGroup")
query_statement = self._get_param("QueryStatement")
description = self._get_param("Description")
if not self.athena_backend.get_work_group(work_group):
return self.error("WorkGroup does not exist", 400)
self.athena_backend.create_prepared_statement(
statement_name=statement_name,
workgroup=work_group,
query_statement=query_statement,
description=description,
)
return json.dumps(dict())
def get_prepared_statement(self) -> str:
statement_name = self._get_param("StatementName")
work_group = self._get_param("WorkGroup")
ps = self.athena_backend.get_prepared_statement(
statement_name=statement_name,
work_group=work_group,
)
return json.dumps(
{
"PreparedStatement": {
"StatementName": ps.statement_name, # type: ignore[union-attr]
"QueryStatement": ps.query_statement, # type: ignore[union-attr]
"WorkGroupName": ps.workgroup, # type: ignore[union-attr]
"Description": ps.description, # type: ignore[union-attr]
# "LastModifiedTime": ps.last_modified_time, # type: ignore[union-attr]
}
}
)

View File

@ -51,8 +51,11 @@ def test_create_work_group():
# Then test the work group appears in the work group list # Then test the work group appears in the work group list
response = client.list_work_groups() response = client.list_work_groups()
response["WorkGroups"].should.have.length_of(1) work_groups = list(
work_group = response["WorkGroups"][0] filter(lambda wg: wg["Name"] != "primary", response["WorkGroups"])
)
work_groups.should.have.length_of(1)
work_group = work_groups[0]
work_group["Name"].should.equal("athena_workgroup") work_group["Name"].should.equal("athena_workgroup")
work_group["Description"].should.equal("Test work group") work_group["Description"].should.equal("Test work group")
work_group["State"].should.equal("ENABLED") work_group["State"].should.equal("ENABLED")
@ -191,7 +194,7 @@ def test_get_named_query():
database = "target_db" database = "target_db"
query_string = "SELECT * FROM tbl1" query_string = "SELECT * FROM tbl1"
description = "description of this query" description = "description of this query"
# craete named query # create named query
res_create = client.create_named_query( res_create = client.create_named_query(
Name=query_name, Name=query_name,
Database=database, Database=database,
@ -375,3 +378,48 @@ def test_list_query_executions():
executions = client.list_query_executions() executions = client.list_query_executions()
executions["QueryExecutionIds"].should.have.length_of(1) executions["QueryExecutionIds"].should.have.length_of(1)
executions["QueryExecutionIds"][0].should.equal(exec_id) executions["QueryExecutionIds"][0].should.equal(exec_id)
@mock_athena
def test_list_named_queries():
client = boto3.client("athena", region_name="us-east-1")
create_basic_workgroup(client=client, name="athena_workgroup")
query_id = client.create_named_query(
Name="query-name",
Database="target_db",
QueryString="SELECT * FROM table1",
WorkGroup="athena_workgroup",
)
list_athena_wg = client.list_named_queries(WorkGroup="athena_workgroup")
assert list_athena_wg["NamedQueryIds"][0] == query_id["NamedQueryId"]
list_primary_wg = client.list_named_queries()
assert len(list_primary_wg["NamedQueryIds"]) == 0
@mock_athena
def test_create_prepared_statement():
client = boto3.client("athena", region_name="us-east-1")
create_basic_workgroup(client=client, name="athena_workgroup")
res = client.create_prepared_statement(
StatementName="test-statement",
WorkGroup="athena_workgroup",
QueryStatement="SELECT * FROM table1",
)
metadata = res["ResponseMetadata"]
assert metadata["HTTPStatusCode"] == 200
assert metadata["RetryAttempts"] == 0
@mock_athena
def test_get_prepared_statement():
client = boto3.client("athena", region_name="us-east-1")
create_basic_workgroup(client=client, name="athena_workgroup")
client.create_prepared_statement(
StatementName="stmt-name",
WorkGroup="athena_workgroup",
QueryStatement="SELECT * FROM table1",
)
resp = client.get_prepared_statement(
StatementName="stmt-name", WorkGroup="athena_workgroup"
)
assert "PreparedStatement" in resp