Add Athena: get_query_results, list_query_executions (#5648)
This commit is contained in:
parent
15d3cdb794
commit
07a8d6f009
@ -40,8 +40,8 @@ athena
|
||||
- [ ] get_database
|
||||
- [X] get_named_query
|
||||
- [ ] get_prepared_statement
|
||||
- [ ] get_query_execution
|
||||
- [ ] get_query_results
|
||||
- [X] get_query_execution
|
||||
- [X] get_query_results
|
||||
- [ ] get_query_runtime_statistics
|
||||
- [ ] get_table_metadata
|
||||
- [X] get_work_group
|
||||
@ -50,7 +50,7 @@ athena
|
||||
- [ ] list_engine_versions
|
||||
- [ ] list_named_queries
|
||||
- [ ] list_prepared_statements
|
||||
- [ ] list_query_executions
|
||||
- [X] list_query_executions
|
||||
- [ ] list_table_metadata
|
||||
- [ ] list_tags_for_resource
|
||||
- [X] list_work_groups
|
||||
|
@ -90,7 +90,21 @@ class Execution(BaseModel):
|
||||
self.config = config
|
||||
self.workgroup = workgroup
|
||||
self.start_time = time.time()
|
||||
self.status = "QUEUED"
|
||||
self.status = "SUCCEEDED"
|
||||
|
||||
|
||||
class QueryResults(BaseModel):
|
||||
def __init__(self, rows: List[Dict[str, Any]], column_info: List[str]):
|
||||
self.rows = rows
|
||||
self.column_info = column_info
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"ResultSet": {
|
||||
"Rows": self.rows,
|
||||
"ResultSetMetadata": {"ColumnInfo": self.column_info},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class NamedQuery(BaseModel):
|
||||
@ -117,6 +131,7 @@ class AthenaBackend(BaseBackend):
|
||||
self.executions: Dict[str, Execution] = {}
|
||||
self.named_queries: Dict[str, NamedQuery] = {}
|
||||
self.data_catalogs: Dict[str, DataCatalog] = {}
|
||||
self.query_results: Dict[str, QueryResults] = {}
|
||||
|
||||
@staticmethod
|
||||
def default_vpc_endpoint_service(
|
||||
@ -172,9 +187,20 @@ class AthenaBackend(BaseBackend):
|
||||
self.executions[execution.id] = execution
|
||||
return execution.id
|
||||
|
||||
def get_execution(self, exec_id: str) -> Execution:
|
||||
def get_query_execution(self, exec_id: str) -> Execution:
|
||||
return self.executions[exec_id]
|
||||
|
||||
def list_query_executions(self) -> Dict[str, Execution]:
|
||||
return self.executions
|
||||
|
||||
def get_query_results(self, exec_id: str) -> QueryResults:
|
||||
results = (
|
||||
self.query_results[exec_id]
|
||||
if exec_id in self.query_results
|
||||
else QueryResults(rows=[], column_info=[])
|
||||
)
|
||||
return results
|
||||
|
||||
def stop_query_execution(self, exec_id: str) -> None:
|
||||
execution = self.executions[exec_id]
|
||||
execution.status = "CANCELLED"
|
||||
|
@ -54,7 +54,7 @@ class AthenaResponse(BaseResponse):
|
||||
|
||||
def get_query_execution(self) -> str:
|
||||
exec_id = self._get_param("QueryExecutionId")
|
||||
execution = self.athena_backend.get_execution(exec_id)
|
||||
execution = self.athena_backend.get_query_execution(exec_id)
|
||||
result = {
|
||||
"QueryExecution": {
|
||||
"QueryExecutionId": exec_id,
|
||||
@ -79,6 +79,15 @@ class AthenaResponse(BaseResponse):
|
||||
}
|
||||
return json.dumps(result)
|
||||
|
||||
def get_query_results(self) -> str:
|
||||
exec_id = self._get_param("QueryExecutionId")
|
||||
result = self.athena_backend.get_query_results(exec_id)
|
||||
return json.dumps(result.to_dict())
|
||||
|
||||
def list_query_executions(self) -> str:
|
||||
executions = self.athena_backend.list_query_executions()
|
||||
return json.dumps({"QueryExecutionIds": [i for i in executions.keys()]})
|
||||
|
||||
def stop_query_execution(self) -> str:
|
||||
exec_id = self._get_param("QueryExecutionId")
|
||||
self.athena_backend.stop_query_execution(exec_id)
|
||||
|
@ -134,7 +134,7 @@ def test_get_query_execution():
|
||||
details["StatementType"].should.equal("DDL")
|
||||
details["ResultConfiguration"]["OutputLocation"].should.equal(location)
|
||||
details["QueryExecutionContext"]["Database"].should.equal(database)
|
||||
details["Status"]["State"].should.equal("QUEUED")
|
||||
details["Status"]["State"].should.equal("SUCCEEDED")
|
||||
details["Statistics"].should.equal(
|
||||
{
|
||||
"EngineExecutionTimeInMillis": 0,
|
||||
@ -273,3 +273,31 @@ def test_create_and_get_data_catalog():
|
||||
"Parameters": {"catalog-id": "AWS Test account ID"},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@mock_athena
|
||||
def test_get_query_results():
|
||||
client = boto3.client("athena", region_name="us-east-1")
|
||||
|
||||
result = client.get_query_results(QueryExecutionId="test")
|
||||
|
||||
result["ResultSet"]["Rows"].should.equal([])
|
||||
result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"].should.equal([])
|
||||
|
||||
|
||||
@mock_athena
|
||||
def test_list_query_executions():
|
||||
client = boto3.client("athena", region_name="us-east-1")
|
||||
|
||||
create_basic_workgroup(client=client, name="athena_workgroup")
|
||||
exec_result = client.start_query_execution(
|
||||
QueryString="query1",
|
||||
QueryExecutionContext={"Database": "string"},
|
||||
ResultConfiguration={"OutputLocation": "string"},
|
||||
WorkGroup="athena_workgroup",
|
||||
)
|
||||
exec_id = exec_result["QueryExecutionId"]
|
||||
|
||||
executions = client.list_query_executions()
|
||||
executions["QueryExecutionIds"].should.have.length_of(1)
|
||||
executions["QueryExecutionIds"][0].should.equal(exec_id)
|
||||
|
Loading…
Reference in New Issue
Block a user