Add Athena: get_query_results, list_query_executions (#5648)

This commit is contained in:
Greg Hinch 2022-12-15 20:49:48 +00:00 committed by GitHub
parent 15d3cdb794
commit 07a8d6f009
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 70 additions and 7 deletions

View File

@ -40,8 +40,8 @@ athena
- [ ] get_database - [ ] get_database
- [X] get_named_query - [X] get_named_query
- [ ] get_prepared_statement - [ ] get_prepared_statement
- [ ] get_query_execution - [X] get_query_execution
- [ ] get_query_results - [X] get_query_results
- [ ] get_query_runtime_statistics - [ ] get_query_runtime_statistics
- [ ] get_table_metadata - [ ] get_table_metadata
- [X] get_work_group - [X] get_work_group
@ -50,7 +50,7 @@ athena
- [ ] list_engine_versions - [ ] list_engine_versions
- [ ] list_named_queries - [ ] list_named_queries
- [ ] list_prepared_statements - [ ] list_prepared_statements
- [ ] list_query_executions - [X] list_query_executions
- [ ] list_table_metadata - [ ] list_table_metadata
- [ ] list_tags_for_resource - [ ] list_tags_for_resource
- [X] list_work_groups - [X] list_work_groups

View File

@ -90,7 +90,21 @@ class Execution(BaseModel):
self.config = config self.config = config
self.workgroup = workgroup self.workgroup = workgroup
self.start_time = time.time() self.start_time = time.time()
self.status = "QUEUED" self.status = "SUCCEEDED"
class QueryResults(BaseModel):
def __init__(self, rows: List[Dict[str, Any]], column_info: List[str]):
self.rows = rows
self.column_info = column_info
def to_dict(self) -> Dict[str, Any]:
return {
"ResultSet": {
"Rows": self.rows,
"ResultSetMetadata": {"ColumnInfo": self.column_info},
},
}
class NamedQuery(BaseModel): class NamedQuery(BaseModel):
@ -117,6 +131,7 @@ class AthenaBackend(BaseBackend):
self.executions: Dict[str, Execution] = {} self.executions: Dict[str, Execution] = {}
self.named_queries: Dict[str, NamedQuery] = {} self.named_queries: Dict[str, NamedQuery] = {}
self.data_catalogs: Dict[str, DataCatalog] = {} self.data_catalogs: Dict[str, DataCatalog] = {}
self.query_results: Dict[str, QueryResults] = {}
@staticmethod @staticmethod
def default_vpc_endpoint_service( def default_vpc_endpoint_service(
@ -172,9 +187,20 @@ class AthenaBackend(BaseBackend):
self.executions[execution.id] = execution self.executions[execution.id] = execution
return execution.id return execution.id
def get_execution(self, exec_id: str) -> Execution: def get_query_execution(self, exec_id: str) -> Execution:
return self.executions[exec_id] return self.executions[exec_id]
def list_query_executions(self) -> Dict[str, Execution]:
return self.executions
def get_query_results(self, exec_id: str) -> QueryResults:
results = (
self.query_results[exec_id]
if exec_id in self.query_results
else QueryResults(rows=[], column_info=[])
)
return results
def stop_query_execution(self, exec_id: str) -> None: def stop_query_execution(self, exec_id: str) -> None:
execution = self.executions[exec_id] execution = self.executions[exec_id]
execution.status = "CANCELLED" execution.status = "CANCELLED"

View File

@ -54,7 +54,7 @@ class AthenaResponse(BaseResponse):
def get_query_execution(self) -> str: def get_query_execution(self) -> str:
exec_id = self._get_param("QueryExecutionId") exec_id = self._get_param("QueryExecutionId")
execution = self.athena_backend.get_execution(exec_id) execution = self.athena_backend.get_query_execution(exec_id)
result = { result = {
"QueryExecution": { "QueryExecution": {
"QueryExecutionId": exec_id, "QueryExecutionId": exec_id,
@ -79,6 +79,15 @@ class AthenaResponse(BaseResponse):
} }
return json.dumps(result) return json.dumps(result)
def get_query_results(self) -> str:
exec_id = self._get_param("QueryExecutionId")
result = self.athena_backend.get_query_results(exec_id)
return json.dumps(result.to_dict())
def list_query_executions(self) -> str:
executions = self.athena_backend.list_query_executions()
return json.dumps({"QueryExecutionIds": [i for i in executions.keys()]})
def stop_query_execution(self) -> str: def stop_query_execution(self) -> str:
exec_id = self._get_param("QueryExecutionId") exec_id = self._get_param("QueryExecutionId")
self.athena_backend.stop_query_execution(exec_id) self.athena_backend.stop_query_execution(exec_id)

View File

@ -134,7 +134,7 @@ def test_get_query_execution():
details["StatementType"].should.equal("DDL") details["StatementType"].should.equal("DDL")
details["ResultConfiguration"]["OutputLocation"].should.equal(location) details["ResultConfiguration"]["OutputLocation"].should.equal(location)
details["QueryExecutionContext"]["Database"].should.equal(database) details["QueryExecutionContext"]["Database"].should.equal(database)
details["Status"]["State"].should.equal("QUEUED") details["Status"]["State"].should.equal("SUCCEEDED")
details["Statistics"].should.equal( details["Statistics"].should.equal(
{ {
"EngineExecutionTimeInMillis": 0, "EngineExecutionTimeInMillis": 0,
@ -273,3 +273,31 @@ def test_create_and_get_data_catalog():
"Parameters": {"catalog-id": "AWS Test account ID"}, "Parameters": {"catalog-id": "AWS Test account ID"},
} }
) )
@mock_athena
def test_get_query_results():
client = boto3.client("athena", region_name="us-east-1")
result = client.get_query_results(QueryExecutionId="test")
result["ResultSet"]["Rows"].should.equal([])
result["ResultSet"]["ResultSetMetadata"]["ColumnInfo"].should.equal([])
@mock_athena
def test_list_query_executions():
client = boto3.client("athena", region_name="us-east-1")
create_basic_workgroup(client=client, name="athena_workgroup")
exec_result = client.start_query_execution(
QueryString="query1",
QueryExecutionContext={"Database": "string"},
ResultConfiguration={"OutputLocation": "string"},
WorkGroup="athena_workgroup",
)
exec_id = exec_result["QueryExecutionId"]
executions = client.list_query_executions()
executions["QueryExecutionIds"].should.have.length_of(1)
executions["QueryExecutionIds"][0].should.equal(exec_id)