Athena fixes and additions (#6055)
This commit is contained in:
parent
18ec0c5467
commit
1257e93ec0
@ -1,7 +1,8 @@
|
|||||||
import time
|
import time
|
||||||
|
from datetime import datetime
|
||||||
from moto.core import BaseBackend, BackendDict, BaseModel
|
from moto.core import BaseBackend, BackendDict, BaseModel
|
||||||
from moto.moto_api._internal import mock_random
|
from moto.moto_api._internal import mock_random
|
||||||
|
from moto.utilities.paginator import paginate
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
@ -124,7 +125,31 @@ class NamedQuery(BaseModel):
|
|||||||
self.workgroup = workgroup
|
self.workgroup = workgroup
|
||||||
|
|
||||||
|
|
||||||
|
class PreparedStatement(BaseModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
statement_name: str,
|
||||||
|
workgroup: WorkGroup,
|
||||||
|
query_statement: str,
|
||||||
|
description: str,
|
||||||
|
):
|
||||||
|
self.statement_name = statement_name
|
||||||
|
self.workgroup = workgroup
|
||||||
|
self.query_statement = query_statement
|
||||||
|
self.description = description
|
||||||
|
self.last_modified_time = datetime.now()
|
||||||
|
|
||||||
|
|
||||||
class AthenaBackend(BaseBackend):
|
class AthenaBackend(BaseBackend):
|
||||||
|
PAGINATION_MODEL = {
|
||||||
|
"list_named_queries": {
|
||||||
|
"input_token": "next_token",
|
||||||
|
"limit_key": "max_results",
|
||||||
|
"limit_default": 50,
|
||||||
|
"unique_attribute": "id",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self, region_name: str, account_id: str):
|
def __init__(self, region_name: str, account_id: str):
|
||||||
super().__init__(region_name, account_id)
|
super().__init__(region_name, account_id)
|
||||||
self.work_groups: Dict[str, WorkGroup] = {}
|
self.work_groups: Dict[str, WorkGroup] = {}
|
||||||
@ -133,6 +158,10 @@ class AthenaBackend(BaseBackend):
|
|||||||
self.data_catalogs: Dict[str, DataCatalog] = {}
|
self.data_catalogs: Dict[str, DataCatalog] = {}
|
||||||
self.query_results: Dict[str, QueryResults] = {}
|
self.query_results: Dict[str, QueryResults] = {}
|
||||||
self.query_results_queue: List[QueryResults] = []
|
self.query_results_queue: List[QueryResults] = []
|
||||||
|
self.prepared_statements: Dict[str, PreparedStatement] = {}
|
||||||
|
|
||||||
|
# Initialise with the primary workgroup
|
||||||
|
self.create_work_group("primary", "", "", [])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def default_vpc_endpoint_service(
|
def default_vpc_endpoint_service(
|
||||||
@ -259,14 +288,14 @@ class AthenaBackend(BaseBackend):
|
|||||||
description: str,
|
description: str,
|
||||||
database: str,
|
database: str,
|
||||||
query_string: str,
|
query_string: str,
|
||||||
workgroup: WorkGroup,
|
workgroup: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
nq = NamedQuery(
|
nq = NamedQuery(
|
||||||
name=name,
|
name=name,
|
||||||
description=description,
|
description=description,
|
||||||
database=database,
|
database=database,
|
||||||
query_string=query_string,
|
query_string=query_string,
|
||||||
workgroup=workgroup,
|
workgroup=self.work_groups[workgroup],
|
||||||
)
|
)
|
||||||
self.named_queries[nq.id] = nq
|
self.named_queries[nq.id] = nq
|
||||||
return nq.id
|
return nq.id
|
||||||
@ -307,5 +336,37 @@ class AthenaBackend(BaseBackend):
|
|||||||
self.data_catalogs[name] = data_catalog
|
self.data_catalogs[name] = data_catalog
|
||||||
return data_catalog
|
return data_catalog
|
||||||
|
|
||||||
|
@paginate(pagination_model=PAGINATION_MODEL) # type: ignore
|
||||||
|
def list_named_queries(self, work_group: str) -> List[str]: # type: ignore[misc]
|
||||||
|
named_query_ids = [
|
||||||
|
q.id for q in self.named_queries.values() if q.workgroup.name == work_group
|
||||||
|
]
|
||||||
|
return named_query_ids
|
||||||
|
|
||||||
|
def create_prepared_statement(
|
||||||
|
self,
|
||||||
|
statement_name: str,
|
||||||
|
workgroup: WorkGroup,
|
||||||
|
query_statement: str,
|
||||||
|
description: str,
|
||||||
|
) -> None:
|
||||||
|
ps = PreparedStatement(
|
||||||
|
statement_name=statement_name,
|
||||||
|
workgroup=workgroup,
|
||||||
|
query_statement=query_statement,
|
||||||
|
description=description,
|
||||||
|
)
|
||||||
|
self.prepared_statements[ps.statement_name] = ps
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_prepared_statement(
|
||||||
|
self, statement_name: str, work_group: WorkGroup
|
||||||
|
) -> Optional[PreparedStatement]:
|
||||||
|
if statement_name in self.prepared_statements:
|
||||||
|
ps = self.prepared_statements[statement_name]
|
||||||
|
if ps.workgroup == work_group:
|
||||||
|
return ps
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
athena_backends = BackendDict(AthenaBackend, "athena")
|
athena_backends = BackendDict(AthenaBackend, "athena")
|
||||||
|
@ -104,8 +104,8 @@ class AthenaResponse(BaseResponse):
|
|||||||
description = self._get_param("Description")
|
description = self._get_param("Description")
|
||||||
database = self._get_param("Database")
|
database = self._get_param("Database")
|
||||||
query_string = self._get_param("QueryString")
|
query_string = self._get_param("QueryString")
|
||||||
workgroup = self._get_param("WorkGroup")
|
workgroup = self._get_param("WorkGroup") or "primary"
|
||||||
if workgroup and not self.athena_backend.get_work_group(workgroup):
|
if not self.athena_backend.get_work_group(workgroup):
|
||||||
return self.error("WorkGroup does not exist", 400)
|
return self.error("WorkGroup does not exist", 400)
|
||||||
query_id = self.athena_backend.create_named_query(
|
query_id = self.athena_backend.create_named_query(
|
||||||
name, description, database, query_string, workgroup
|
name, description, database, query_string, workgroup
|
||||||
@ -123,7 +123,7 @@ class AthenaResponse(BaseResponse):
|
|||||||
"Database": nq.database, # type: ignore[union-attr]
|
"Database": nq.database, # type: ignore[union-attr]
|
||||||
"QueryString": nq.query_string, # type: ignore[union-attr]
|
"QueryString": nq.query_string, # type: ignore[union-attr]
|
||||||
"NamedQueryId": nq.id, # type: ignore[union-attr]
|
"NamedQueryId": nq.id, # type: ignore[union-attr]
|
||||||
"WorkGroup": nq.workgroup, # type: ignore[union-attr]
|
"WorkGroup": nq.workgroup.name, # type: ignore[union-attr]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -157,3 +157,46 @@ class AthenaResponse(BaseResponse):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def list_named_queries(self) -> str:
|
||||||
|
next_token = self._get_param("NextToken")
|
||||||
|
max_results = self._get_param("MaxResults")
|
||||||
|
work_group = self._get_param("WorkGroup") or "primary"
|
||||||
|
named_query_ids, next_token = self.athena_backend.list_named_queries(
|
||||||
|
next_token=next_token, max_results=max_results, work_group=work_group
|
||||||
|
)
|
||||||
|
return json.dumps({"NamedQueryIds": named_query_ids, "NextToken": next_token})
|
||||||
|
|
||||||
|
def create_prepared_statement(self) -> Union[str, Tuple[str, Dict[str, int]]]:
|
||||||
|
statement_name = self._get_param("StatementName")
|
||||||
|
work_group = self._get_param("WorkGroup")
|
||||||
|
query_statement = self._get_param("QueryStatement")
|
||||||
|
description = self._get_param("Description")
|
||||||
|
if not self.athena_backend.get_work_group(work_group):
|
||||||
|
return self.error("WorkGroup does not exist", 400)
|
||||||
|
self.athena_backend.create_prepared_statement(
|
||||||
|
statement_name=statement_name,
|
||||||
|
workgroup=work_group,
|
||||||
|
query_statement=query_statement,
|
||||||
|
description=description,
|
||||||
|
)
|
||||||
|
return json.dumps(dict())
|
||||||
|
|
||||||
|
def get_prepared_statement(self) -> str:
|
||||||
|
statement_name = self._get_param("StatementName")
|
||||||
|
work_group = self._get_param("WorkGroup")
|
||||||
|
ps = self.athena_backend.get_prepared_statement(
|
||||||
|
statement_name=statement_name,
|
||||||
|
work_group=work_group,
|
||||||
|
)
|
||||||
|
return json.dumps(
|
||||||
|
{
|
||||||
|
"PreparedStatement": {
|
||||||
|
"StatementName": ps.statement_name, # type: ignore[union-attr]
|
||||||
|
"QueryStatement": ps.query_statement, # type: ignore[union-attr]
|
||||||
|
"WorkGroupName": ps.workgroup, # type: ignore[union-attr]
|
||||||
|
"Description": ps.description, # type: ignore[union-attr]
|
||||||
|
# "LastModifiedTime": ps.last_modified_time, # type: ignore[union-attr]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@ -51,8 +51,11 @@ def test_create_work_group():
|
|||||||
# Then test the work group appears in the work group list
|
# Then test the work group appears in the work group list
|
||||||
response = client.list_work_groups()
|
response = client.list_work_groups()
|
||||||
|
|
||||||
response["WorkGroups"].should.have.length_of(1)
|
work_groups = list(
|
||||||
work_group = response["WorkGroups"][0]
|
filter(lambda wg: wg["Name"] != "primary", response["WorkGroups"])
|
||||||
|
)
|
||||||
|
work_groups.should.have.length_of(1)
|
||||||
|
work_group = work_groups[0]
|
||||||
work_group["Name"].should.equal("athena_workgroup")
|
work_group["Name"].should.equal("athena_workgroup")
|
||||||
work_group["Description"].should.equal("Test work group")
|
work_group["Description"].should.equal("Test work group")
|
||||||
work_group["State"].should.equal("ENABLED")
|
work_group["State"].should.equal("ENABLED")
|
||||||
@ -191,7 +194,7 @@ def test_get_named_query():
|
|||||||
database = "target_db"
|
database = "target_db"
|
||||||
query_string = "SELECT * FROM tbl1"
|
query_string = "SELECT * FROM tbl1"
|
||||||
description = "description of this query"
|
description = "description of this query"
|
||||||
# craete named query
|
# create named query
|
||||||
res_create = client.create_named_query(
|
res_create = client.create_named_query(
|
||||||
Name=query_name,
|
Name=query_name,
|
||||||
Database=database,
|
Database=database,
|
||||||
@ -375,3 +378,48 @@ def test_list_query_executions():
|
|||||||
executions = client.list_query_executions()
|
executions = client.list_query_executions()
|
||||||
executions["QueryExecutionIds"].should.have.length_of(1)
|
executions["QueryExecutionIds"].should.have.length_of(1)
|
||||||
executions["QueryExecutionIds"][0].should.equal(exec_id)
|
executions["QueryExecutionIds"][0].should.equal(exec_id)
|
||||||
|
|
||||||
|
|
||||||
|
@mock_athena
|
||||||
|
def test_list_named_queries():
|
||||||
|
client = boto3.client("athena", region_name="us-east-1")
|
||||||
|
create_basic_workgroup(client=client, name="athena_workgroup")
|
||||||
|
query_id = client.create_named_query(
|
||||||
|
Name="query-name",
|
||||||
|
Database="target_db",
|
||||||
|
QueryString="SELECT * FROM table1",
|
||||||
|
WorkGroup="athena_workgroup",
|
||||||
|
)
|
||||||
|
list_athena_wg = client.list_named_queries(WorkGroup="athena_workgroup")
|
||||||
|
assert list_athena_wg["NamedQueryIds"][0] == query_id["NamedQueryId"]
|
||||||
|
list_primary_wg = client.list_named_queries()
|
||||||
|
assert len(list_primary_wg["NamedQueryIds"]) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@mock_athena
|
||||||
|
def test_create_prepared_statement():
|
||||||
|
client = boto3.client("athena", region_name="us-east-1")
|
||||||
|
create_basic_workgroup(client=client, name="athena_workgroup")
|
||||||
|
res = client.create_prepared_statement(
|
||||||
|
StatementName="test-statement",
|
||||||
|
WorkGroup="athena_workgroup",
|
||||||
|
QueryStatement="SELECT * FROM table1",
|
||||||
|
)
|
||||||
|
metadata = res["ResponseMetadata"]
|
||||||
|
assert metadata["HTTPStatusCode"] == 200
|
||||||
|
assert metadata["RetryAttempts"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
@mock_athena
|
||||||
|
def test_get_prepared_statement():
|
||||||
|
client = boto3.client("athena", region_name="us-east-1")
|
||||||
|
create_basic_workgroup(client=client, name="athena_workgroup")
|
||||||
|
client.create_prepared_statement(
|
||||||
|
StatementName="stmt-name",
|
||||||
|
WorkGroup="athena_workgroup",
|
||||||
|
QueryStatement="SELECT * FROM table1",
|
||||||
|
)
|
||||||
|
resp = client.get_prepared_statement(
|
||||||
|
StatementName="stmt-name", WorkGroup="athena_workgroup"
|
||||||
|
)
|
||||||
|
assert "PreparedStatement" in resp
|
||||||
|
Loading…
Reference in New Issue
Block a user