Add Athena: get_query_results, list_query_executions (#5648)

This commit is contained in:
Greg Hinch 2022-12-15 20:49:48 +00:00 committed by GitHub
parent 15d3cdb794
commit 07a8d6f009
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 70 additions and 7 deletions

View File

@ -40,8 +40,8 @@ athena
- [ ] get_database
- [X] get_named_query
- [ ] get_prepared_statement
- [ ] get_query_execution
- [ ] get_query_results
- [X] get_query_execution
- [X] get_query_results
- [ ] get_query_runtime_statistics
- [ ] get_table_metadata
- [X] get_work_group
@ -50,7 +50,7 @@ athena
- [ ] list_engine_versions
- [ ] list_named_queries
- [ ] list_prepared_statements
- [ ] list_query_executions
- [X] list_query_executions
- [ ] list_table_metadata
- [ ] list_tags_for_resource
- [X] list_work_groups

View File

@ -90,7 +90,21 @@ class Execution(BaseModel):
self.config = config
self.workgroup = workgroup
self.start_time = time.time()
self.status = "QUEUED"
self.status = "SUCCEEDED"
class QueryResults(BaseModel):
def __init__(self, rows: List[Dict[str, Any]], column_info: List[str]):
self.rows = rows
self.column_info = column_info
def to_dict(self) -> Dict[str, Any]:
return {
"ResultSet": {
"Rows": self.rows,
"ResultSetMetadata": {"ColumnInfo": self.column_info},
},
}
class NamedQuery(BaseModel):
@ -117,6 +131,7 @@ class AthenaBackend(BaseBackend):
self.executions: Dict[str, Execution] = {}
self.named_queries: Dict[str, NamedQuery] = {}
self.data_catalogs: Dict[str, DataCatalog] = {}
self.query_results: Dict[str, QueryResults] = {}
@staticmethod
def default_vpc_endpoint_service(
@ -172,9 +187,20 @@ class AthenaBackend(BaseBackend):
self.executions[execution.id] = execution
return execution.id
def get_execution(self, exec_id: str) -> Execution:
def get_query_execution(self, exec_id: str) -> Execution:
return self.executions[exec_id]
def list_query_executions(self) -> Dict[str, Execution]:
return self.executions
def get_query_results(self, exec_id: str) -> QueryResults:
results = (
self.query_results[exec_id]
if exec_id in self.query_results
else QueryResults(rows=[], column_info=[])
)
return results
def stop_query_execution(self, exec_id: str) -> None:
execution = self.executions[exec_id]
execution.status = "CANCELLED"

View File

@ -54,7 +54,7 @@ class AthenaResponse(BaseResponse):
def get_query_execution(self) -> str:
exec_id = self._get_param("QueryExecutionId")
execution = self.athena_backend.get_execution(exec_id)
execution = self.athena_backend.get_query_execution(exec_id)
result = {
"QueryExecution": {
"QueryExecutionId": exec_id,
@ -79,6 +79,15 @@ class AthenaResponse(BaseResponse):
}
return json.dumps(result)
def get_query_results(self) -> str:
exec_id = self._get_param("QueryExecutionId")
result = self.athena_backend.get_query_results(exec_id)
return json.dumps(result.to_dict())
def list_query_executions(self) -> str:
executions = self.athena_backend.list_query_executions()
return json.dumps({"QueryExecutionIds": [i for i in executions.keys()]})
def stop_query_execution(self) -> str:
exec_id = self._get_param("QueryExecutionId")
self.athena_backend.stop_query_execution(exec_id)

View File

@ -134,7 +134,7 @@ def test_get_query_execution():
details["StatementType"].should.equal("DDL")
details["ResultConfiguration"]["OutputLocation"].should.equal(location)
details["QueryExecutionContext"]["Database"].should.equal(database)
details["Status"]["State"].should.equal("QUEUED")
details["Status"]["State"].should.equal("SUCCEEDED")
details["Statistics"].should.equal(
{
"EngineExecutionTimeInMillis": 0,
@ -273,3 +273,31 @@ def test_create_and_get_data_catalog():
"Parameters": {"catalog-id": "AWS Test account ID"},
}
)
@mock_athena
def test_get_query_results():
client = boto3.client("athena", region_name="us-east-1")
result = client.get_query_results(QueryExecutionId="test")
result["ResultSet"]["Rows"].should.equal([])
result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"].should.equal([])
@mock_athena
def test_list_query_executions():
client = boto3.client("athena", region_name="us-east-1")
create_basic_workgroup(client=client, name="athena_workgroup")
exec_result = client.start_query_execution(
QueryString="query1",
QueryExecutionContext={"Database": "string"},
ResultConfiguration={"OutputLocation": "string"},
WorkGroup="athena_workgroup",
)
exec_id = exec_result["QueryExecutionId"]
executions = client.list_query_executions()
executions["QueryExecutionIds"].should.have.length_of(1)
executions["QueryExecutionIds"][0].should.equal(exec_id)