Add Athena: get_query_results, list_query_executions (#5648)
This commit is contained in:
parent
15d3cdb794
commit
07a8d6f009
@ -40,8 +40,8 @@ athena
|
|||||||
- [ ] get_database
|
- [ ] get_database
|
||||||
- [X] get_named_query
|
- [X] get_named_query
|
||||||
- [ ] get_prepared_statement
|
- [ ] get_prepared_statement
|
||||||
- [ ] get_query_execution
|
- [X] get_query_execution
|
||||||
- [ ] get_query_results
|
- [X] get_query_results
|
||||||
- [ ] get_query_runtime_statistics
|
- [ ] get_query_runtime_statistics
|
||||||
- [ ] get_table_metadata
|
- [ ] get_table_metadata
|
||||||
- [X] get_work_group
|
- [X] get_work_group
|
||||||
@ -50,7 +50,7 @@ athena
|
|||||||
- [ ] list_engine_versions
|
- [ ] list_engine_versions
|
||||||
- [ ] list_named_queries
|
- [ ] list_named_queries
|
||||||
- [ ] list_prepared_statements
|
- [ ] list_prepared_statements
|
||||||
- [ ] list_query_executions
|
- [X] list_query_executions
|
||||||
- [ ] list_table_metadata
|
- [ ] list_table_metadata
|
||||||
- [ ] list_tags_for_resource
|
- [ ] list_tags_for_resource
|
||||||
- [X] list_work_groups
|
- [X] list_work_groups
|
||||||
|
@ -90,7 +90,21 @@ class Execution(BaseModel):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.workgroup = workgroup
|
self.workgroup = workgroup
|
||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
self.status = "QUEUED"
|
self.status = "SUCCEEDED"
|
||||||
|
|
||||||
|
|
||||||
|
class QueryResults(BaseModel):
|
||||||
|
def __init__(self, rows: List[Dict[str, Any]], column_info: List[str]):
|
||||||
|
self.rows = rows
|
||||||
|
self.column_info = column_info
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"ResultSet": {
|
||||||
|
"Rows": self.rows,
|
||||||
|
"ResultSetMetadata": {"ColumnInfo": self.column_info},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class NamedQuery(BaseModel):
|
class NamedQuery(BaseModel):
|
||||||
@ -117,6 +131,7 @@ class AthenaBackend(BaseBackend):
|
|||||||
self.executions: Dict[str, Execution] = {}
|
self.executions: Dict[str, Execution] = {}
|
||||||
self.named_queries: Dict[str, NamedQuery] = {}
|
self.named_queries: Dict[str, NamedQuery] = {}
|
||||||
self.data_catalogs: Dict[str, DataCatalog] = {}
|
self.data_catalogs: Dict[str, DataCatalog] = {}
|
||||||
|
self.query_results: Dict[str, QueryResults] = {}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def default_vpc_endpoint_service(
|
def default_vpc_endpoint_service(
|
||||||
@ -172,9 +187,20 @@ class AthenaBackend(BaseBackend):
|
|||||||
self.executions[execution.id] = execution
|
self.executions[execution.id] = execution
|
||||||
return execution.id
|
return execution.id
|
||||||
|
|
||||||
def get_execution(self, exec_id: str) -> Execution:
|
def get_query_execution(self, exec_id: str) -> Execution:
|
||||||
return self.executions[exec_id]
|
return self.executions[exec_id]
|
||||||
|
|
||||||
|
def list_query_executions(self) -> Dict[str, Execution]:
|
||||||
|
return self.executions
|
||||||
|
|
||||||
|
def get_query_results(self, exec_id: str) -> QueryResults:
|
||||||
|
results = (
|
||||||
|
self.query_results[exec_id]
|
||||||
|
if exec_id in self.query_results
|
||||||
|
else QueryResults(rows=[], column_info=[])
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
def stop_query_execution(self, exec_id: str) -> None:
|
def stop_query_execution(self, exec_id: str) -> None:
|
||||||
execution = self.executions[exec_id]
|
execution = self.executions[exec_id]
|
||||||
execution.status = "CANCELLED"
|
execution.status = "CANCELLED"
|
||||||
|
@ -54,7 +54,7 @@ class AthenaResponse(BaseResponse):
|
|||||||
|
|
||||||
def get_query_execution(self) -> str:
|
def get_query_execution(self) -> str:
|
||||||
exec_id = self._get_param("QueryExecutionId")
|
exec_id = self._get_param("QueryExecutionId")
|
||||||
execution = self.athena_backend.get_execution(exec_id)
|
execution = self.athena_backend.get_query_execution(exec_id)
|
||||||
result = {
|
result = {
|
||||||
"QueryExecution": {
|
"QueryExecution": {
|
||||||
"QueryExecutionId": exec_id,
|
"QueryExecutionId": exec_id,
|
||||||
@ -79,6 +79,15 @@ class AthenaResponse(BaseResponse):
|
|||||||
}
|
}
|
||||||
return json.dumps(result)
|
return json.dumps(result)
|
||||||
|
|
||||||
|
def get_query_results(self) -> str:
|
||||||
|
exec_id = self._get_param("QueryExecutionId")
|
||||||
|
result = self.athena_backend.get_query_results(exec_id)
|
||||||
|
return json.dumps(result.to_dict())
|
||||||
|
|
||||||
|
def list_query_executions(self) -> str:
|
||||||
|
executions = self.athena_backend.list_query_executions()
|
||||||
|
return json.dumps({"QueryExecutionIds": [i for i in executions.keys()]})
|
||||||
|
|
||||||
def stop_query_execution(self) -> str:
|
def stop_query_execution(self) -> str:
|
||||||
exec_id = self._get_param("QueryExecutionId")
|
exec_id = self._get_param("QueryExecutionId")
|
||||||
self.athena_backend.stop_query_execution(exec_id)
|
self.athena_backend.stop_query_execution(exec_id)
|
||||||
|
@ -134,7 +134,7 @@ def test_get_query_execution():
|
|||||||
details["StatementType"].should.equal("DDL")
|
details["StatementType"].should.equal("DDL")
|
||||||
details["ResultConfiguration"]["OutputLocation"].should.equal(location)
|
details["ResultConfiguration"]["OutputLocation"].should.equal(location)
|
||||||
details["QueryExecutionContext"]["Database"].should.equal(database)
|
details["QueryExecutionContext"]["Database"].should.equal(database)
|
||||||
details["Status"]["State"].should.equal("QUEUED")
|
details["Status"]["State"].should.equal("SUCCEEDED")
|
||||||
details["Statistics"].should.equal(
|
details["Statistics"].should.equal(
|
||||||
{
|
{
|
||||||
"EngineExecutionTimeInMillis": 0,
|
"EngineExecutionTimeInMillis": 0,
|
||||||
@ -273,3 +273,31 @@ def test_create_and_get_data_catalog():
|
|||||||
"Parameters": {"catalog-id": "AWS Test account ID"},
|
"Parameters": {"catalog-id": "AWS Test account ID"},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@mock_athena
|
||||||
|
def test_get_query_results():
|
||||||
|
client = boto3.client("athena", region_name="us-east-1")
|
||||||
|
|
||||||
|
result = client.get_query_results(QueryExecutionId="test")
|
||||||
|
|
||||||
|
result["ResultSet"]["Rows"].should.equal([])
|
||||||
|
result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"].should.equal([])
|
||||||
|
|
||||||
|
|
||||||
|
@mock_athena
|
||||||
|
def test_list_query_executions():
|
||||||
|
client = boto3.client("athena", region_name="us-east-1")
|
||||||
|
|
||||||
|
create_basic_workgroup(client=client, name="athena_workgroup")
|
||||||
|
exec_result = client.start_query_execution(
|
||||||
|
QueryString="query1",
|
||||||
|
QueryExecutionContext={"Database": "string"},
|
||||||
|
ResultConfiguration={"OutputLocation": "string"},
|
||||||
|
WorkGroup="athena_workgroup",
|
||||||
|
)
|
||||||
|
exec_id = exec_result["QueryExecutionId"]
|
||||||
|
|
||||||
|
executions = client.list_query_executions()
|
||||||
|
executions["QueryExecutionIds"].should.have.length_of(1)
|
||||||
|
executions["QueryExecutionIds"][0].should.equal(exec_id)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user