Athena fixes and additions (#6055)
This commit is contained in:
parent
18ec0c5467
commit
1257e93ec0
@ -1,7 +1,8 @@
|
||||
import time
|
||||
|
||||
from datetime import datetime
|
||||
from moto.core import BaseBackend, BackendDict, BaseModel
|
||||
from moto.moto_api._internal import mock_random
|
||||
from moto.utilities.paginator import paginate
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
@ -124,7 +125,31 @@ class NamedQuery(BaseModel):
|
||||
self.workgroup = workgroup
|
||||
|
||||
|
||||
class PreparedStatement(BaseModel):
|
||||
def __init__(
|
||||
self,
|
||||
statement_name: str,
|
||||
workgroup: WorkGroup,
|
||||
query_statement: str,
|
||||
description: str,
|
||||
):
|
||||
self.statement_name = statement_name
|
||||
self.workgroup = workgroup
|
||||
self.query_statement = query_statement
|
||||
self.description = description
|
||||
self.last_modified_time = datetime.now()
|
||||
|
||||
|
||||
class AthenaBackend(BaseBackend):
|
||||
PAGINATION_MODEL = {
|
||||
"list_named_queries": {
|
||||
"input_token": "next_token",
|
||||
"limit_key": "max_results",
|
||||
"limit_default": 50,
|
||||
"unique_attribute": "id",
|
||||
}
|
||||
}
|
||||
|
||||
def __init__(self, region_name: str, account_id: str):
|
||||
super().__init__(region_name, account_id)
|
||||
self.work_groups: Dict[str, WorkGroup] = {}
|
||||
@ -133,6 +158,10 @@ class AthenaBackend(BaseBackend):
|
||||
self.data_catalogs: Dict[str, DataCatalog] = {}
|
||||
self.query_results: Dict[str, QueryResults] = {}
|
||||
self.query_results_queue: List[QueryResults] = []
|
||||
self.prepared_statements: Dict[str, PreparedStatement] = {}
|
||||
|
||||
# Initialise with the primary workgroup
|
||||
self.create_work_group("primary", "", "", [])
|
||||
|
||||
@staticmethod
|
||||
def default_vpc_endpoint_service(
|
||||
@ -259,14 +288,14 @@ class AthenaBackend(BaseBackend):
|
||||
description: str,
|
||||
database: str,
|
||||
query_string: str,
|
||||
workgroup: WorkGroup,
|
||||
workgroup: str,
|
||||
) -> str:
|
||||
nq = NamedQuery(
|
||||
name=name,
|
||||
description=description,
|
||||
database=database,
|
||||
query_string=query_string,
|
||||
workgroup=workgroup,
|
||||
workgroup=self.work_groups[workgroup],
|
||||
)
|
||||
self.named_queries[nq.id] = nq
|
||||
return nq.id
|
||||
@ -307,5 +336,37 @@ class AthenaBackend(BaseBackend):
|
||||
self.data_catalogs[name] = data_catalog
|
||||
return data_catalog
|
||||
|
||||
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore
|
||||
def list_named_queries(self, work_group: str) -> List[str]: # type: ignore[misc]
|
||||
named_query_ids = [
|
||||
q.id for q in self.named_queries.values() if q.workgroup.name == work_group
|
||||
]
|
||||
return named_query_ids
|
||||
|
||||
def create_prepared_statement(
|
||||
self,
|
||||
statement_name: str,
|
||||
workgroup: WorkGroup,
|
||||
query_statement: str,
|
||||
description: str,
|
||||
) -> None:
|
||||
ps = PreparedStatement(
|
||||
statement_name=statement_name,
|
||||
workgroup=workgroup,
|
||||
query_statement=query_statement,
|
||||
description=description,
|
||||
)
|
||||
self.prepared_statements[ps.statement_name] = ps
|
||||
return None
|
||||
|
||||
def get_prepared_statement(
|
||||
self, statement_name: str, work_group: WorkGroup
|
||||
) -> Optional[PreparedStatement]:
|
||||
if statement_name in self.prepared_statements:
|
||||
ps = self.prepared_statements[statement_name]
|
||||
if ps.workgroup == work_group:
|
||||
return ps
|
||||
return None
|
||||
|
||||
|
||||
athena_backends = BackendDict(AthenaBackend, "athena")
|
||||
|
@ -104,8 +104,8 @@ class AthenaResponse(BaseResponse):
|
||||
description = self._get_param("Description")
|
||||
database = self._get_param("Database")
|
||||
query_string = self._get_param("QueryString")
|
||||
workgroup = self._get_param("WorkGroup")
|
||||
if workgroup and not self.athena_backend.get_work_group(workgroup):
|
||||
workgroup = self._get_param("WorkGroup") or "primary"
|
||||
if not self.athena_backend.get_work_group(workgroup):
|
||||
return self.error("WorkGroup does not exist", 400)
|
||||
query_id = self.athena_backend.create_named_query(
|
||||
name, description, database, query_string, workgroup
|
||||
@ -123,7 +123,7 @@ class AthenaResponse(BaseResponse):
|
||||
"Database": nq.database, # type: ignore[union-attr]
|
||||
"QueryString": nq.query_string, # type: ignore[union-attr]
|
||||
"NamedQueryId": nq.id, # type: ignore[union-attr]
|
||||
"WorkGroup": nq.workgroup, # type: ignore[union-attr]
|
||||
"WorkGroup": nq.workgroup.name, # type: ignore[union-attr]
|
||||
}
|
||||
}
|
||||
)
|
||||
@ -157,3 +157,46 @@ class AthenaResponse(BaseResponse):
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
def list_named_queries(self) -> str:
|
||||
next_token = self._get_param("NextToken")
|
||||
max_results = self._get_param("MaxResults")
|
||||
work_group = self._get_param("WorkGroup") or "primary"
|
||||
named_query_ids, next_token = self.athena_backend.list_named_queries(
|
||||
next_token=next_token, max_results=max_results, work_group=work_group
|
||||
)
|
||||
return json.dumps({"NamedQueryIds": named_query_ids, "NextToken": next_token})
|
||||
|
||||
def create_prepared_statement(self) -> Union[str, Tuple[str, Dict[str, int]]]:
|
||||
statement_name = self._get_param("StatementName")
|
||||
work_group = self._get_param("WorkGroup")
|
||||
query_statement = self._get_param("QueryStatement")
|
||||
description = self._get_param("Description")
|
||||
if not self.athena_backend.get_work_group(work_group):
|
||||
return self.error("WorkGroup does not exist", 400)
|
||||
self.athena_backend.create_prepared_statement(
|
||||
statement_name=statement_name,
|
||||
workgroup=work_group,
|
||||
query_statement=query_statement,
|
||||
description=description,
|
||||
)
|
||||
return json.dumps(dict())
|
||||
|
||||
def get_prepared_statement(self) -> str:
|
||||
statement_name = self._get_param("StatementName")
|
||||
work_group = self._get_param("WorkGroup")
|
||||
ps = self.athena_backend.get_prepared_statement(
|
||||
statement_name=statement_name,
|
||||
work_group=work_group,
|
||||
)
|
||||
return json.dumps(
|
||||
{
|
||||
"PreparedStatement": {
|
||||
"StatementName": ps.statement_name, # type: ignore[union-attr]
|
||||
"QueryStatement": ps.query_statement, # type: ignore[union-attr]
|
||||
"WorkGroupName": ps.workgroup, # type: ignore[union-attr]
|
||||
"Description": ps.description, # type: ignore[union-attr]
|
||||
# "LastModifiedTime": ps.last_modified_time, # type: ignore[union-attr]
|
||||
}
|
||||
}
|
||||
)
|
||||
|
@ -51,8 +51,11 @@ def test_create_work_group():
|
||||
# Then test the work group appears in the work group list
|
||||
response = client.list_work_groups()
|
||||
|
||||
response["WorkGroups"].should.have.length_of(1)
|
||||
work_group = response["WorkGroups"][0]
|
||||
work_groups = list(
|
||||
filter(lambda wg: wg["Name"] != "primary", response["WorkGroups"])
|
||||
)
|
||||
work_groups.should.have.length_of(1)
|
||||
work_group = work_groups[0]
|
||||
work_group["Name"].should.equal("athena_workgroup")
|
||||
work_group["Description"].should.equal("Test work group")
|
||||
work_group["State"].should.equal("ENABLED")
|
||||
@ -191,7 +194,7 @@ def test_get_named_query():
|
||||
database = "target_db"
|
||||
query_string = "SELECT * FROM tbl1"
|
||||
description = "description of this query"
|
||||
# craete named query
|
||||
# create named query
|
||||
res_create = client.create_named_query(
|
||||
Name=query_name,
|
||||
Database=database,
|
||||
@ -375,3 +378,48 @@ def test_list_query_executions():
|
||||
executions = client.list_query_executions()
|
||||
executions["QueryExecutionIds"].should.have.length_of(1)
|
||||
executions["QueryExecutionIds"][0].should.equal(exec_id)
|
||||
|
||||
|
||||
@mock_athena
|
||||
def test_list_named_queries():
|
||||
client = boto3.client("athena", region_name="us-east-1")
|
||||
create_basic_workgroup(client=client, name="athena_workgroup")
|
||||
query_id = client.create_named_query(
|
||||
Name="query-name",
|
||||
Database="target_db",
|
||||
QueryString="SELECT * FROM table1",
|
||||
WorkGroup="athena_workgroup",
|
||||
)
|
||||
list_athena_wg = client.list_named_queries(WorkGroup="athena_workgroup")
|
||||
assert list_athena_wg["NamedQueryIds"][0] == query_id["NamedQueryId"]
|
||||
list_primary_wg = client.list_named_queries()
|
||||
assert len(list_primary_wg["NamedQueryIds"]) == 0
|
||||
|
||||
|
||||
@mock_athena
|
||||
def test_create_prepared_statement():
|
||||
client = boto3.client("athena", region_name="us-east-1")
|
||||
create_basic_workgroup(client=client, name="athena_workgroup")
|
||||
res = client.create_prepared_statement(
|
||||
StatementName="test-statement",
|
||||
WorkGroup="athena_workgroup",
|
||||
QueryStatement="SELECT * FROM table1",
|
||||
)
|
||||
metadata = res["ResponseMetadata"]
|
||||
assert metadata["HTTPStatusCode"] == 200
|
||||
assert metadata["RetryAttempts"] == 0
|
||||
|
||||
|
||||
@mock_athena
|
||||
def test_get_prepared_statement():
|
||||
client = boto3.client("athena", region_name="us-east-1")
|
||||
create_basic_workgroup(client=client, name="athena_workgroup")
|
||||
client.create_prepared_statement(
|
||||
StatementName="stmt-name",
|
||||
WorkGroup="athena_workgroup",
|
||||
QueryStatement="SELECT * FROM table1",
|
||||
)
|
||||
resp = client.get_prepared_statement(
|
||||
StatementName="stmt-name", WorkGroup="athena_workgroup"
|
||||
)
|
||||
assert "PreparedStatement" in resp
|
||||
|
Loading…
Reference in New Issue
Block a user