This commit is contained in:
Bert Blommers 2020-05-16 15:03:26 +01:00
parent dd20fec9f3
commit ffb521f86b
3 changed files with 78 additions and 68 deletions

View File

@ -50,7 +50,6 @@ class WorkGroup(TaggableResourceMixin, BaseModel):
class Execution(BaseModel): class Execution(BaseModel):
def __init__(self, query, context, config, workgroup): def __init__(self, query, context, config, workgroup):
self.id = str(uuid4()) self.id = str(uuid4())
self.query = query self.query = query
@ -97,11 +96,13 @@ class AthenaBackend(BaseBackend):
"State": wg.state, "State": wg.state,
"Configuration": wg.configuration, "Configuration": wg.configuration,
"Description": wg.description, "Description": wg.description,
"CreationTime": time.time() "CreationTime": time.time(),
} }
def start_query_execution(self, query, context, config, workgroup): def start_query_execution(self, query, context, config, workgroup):
execution = Execution(query=query, context=context, config=config, workgroup=workgroup) execution = Execution(
query=query, context=context, config=config, workgroup=workgroup
)
self.executions[execution.id] = execution self.executions[execution.id] = execution
return execution.id return execution.id

View File

@ -43,32 +43,34 @@ class AthenaResponse(BaseResponse):
workgroup = self._get_param("WorkGroup") workgroup = self._get_param("WorkGroup")
if workgroup and not self.athena_backend.get_work_group(workgroup): if workgroup and not self.athena_backend.get_work_group(workgroup):
return self.error("WorkGroup does not exist", 400) return self.error("WorkGroup does not exist", 400)
id = self.athena_backend.start_query_execution(query=query, context=context, config=config, workgroup=workgroup) id = self.athena_backend.start_query_execution(
query=query, context=context, config=config, workgroup=workgroup
)
return json.dumps({"QueryExecutionId": id}) return json.dumps({"QueryExecutionId": id})
def get_query_execution(self): def get_query_execution(self):
exec_id = self._get_param("QueryExecutionId") exec_id = self._get_param("QueryExecutionId")
execution = self.athena_backend.get_execution(exec_id) execution = self.athena_backend.get_execution(exec_id)
result = { result = {
'QueryExecution': { "QueryExecution": {
'QueryExecutionId': exec_id, "QueryExecutionId": exec_id,
'Query': execution.query, "Query": execution.query,
'StatementType': 'DDL', "StatementType": "DDL",
'ResultConfiguration': execution.config, "ResultConfiguration": execution.config,
'QueryExecutionContext': execution.context, "QueryExecutionContext": execution.context,
'Status': { "Status": {
'State': execution.status, "State": execution.status,
'SubmissionDateTime': execution.start_time "SubmissionDateTime": execution.start_time,
}, },
'Statistics': { "Statistics": {
'EngineExecutionTimeInMillis': 0, "EngineExecutionTimeInMillis": 0,
'DataScannedInBytes': 0, "DataScannedInBytes": 0,
'TotalExecutionTimeInMillis': 0, "TotalExecutionTimeInMillis": 0,
'QueryQueueTimeInMillis': 0, "QueryQueueTimeInMillis": 0,
'QueryPlanningTimeInMillis': 0, "QueryPlanningTimeInMillis": 0,
'ServiceProcessingTimeInMillis': 0 "ServiceProcessingTimeInMillis": 0,
}, },
'WorkGroup': execution.workgroup "WorkGroup": execution.workgroup,
} }
} }
return json.dumps(result) return json.dumps(result)
@ -80,11 +82,6 @@ class AthenaResponse(BaseResponse):
def error(self, msg, status): def error(self, msg, status):
return ( return (
json.dumps( json.dumps({"__type": "InvalidRequestException", "Message": msg,}),
{
"__type": "InvalidRequestException",
"Message": msg,
}
),
dict(status=status), dict(status=status),
) )

View File

@ -64,18 +64,18 @@ def test_create_and_get_workgroup():
create_basic_workgroup(client=client, name="athena_workgroup") create_basic_workgroup(client=client, name="athena_workgroup")
work_group = client.get_work_group(WorkGroup='athena_workgroup')['WorkGroup'] work_group = client.get_work_group(WorkGroup="athena_workgroup")["WorkGroup"]
del work_group["CreationTime"] # Were not testing creationtime atm del work_group["CreationTime"] # Were not testing creationtime atm
work_group.should.equal({ work_group.should.equal(
'Name': 'athena_workgroup', {
'State': 'ENABLED', "Name": "athena_workgroup",
'Configuration': { "State": "ENABLED",
'ResultConfiguration': { "Configuration": {
'OutputLocation': 's3://bucket-name/prefix/' "ResultConfiguration": {"OutputLocation": "s3://bucket-name/prefix/"}
} },
}, "Description": "Test work group",
'Description': 'Test work group' }
}) )
@mock_athena @mock_athena
@ -83,16 +83,20 @@ def test_start_query_execution():
client = boto3.client("athena", region_name="us-east-1") client = boto3.client("athena", region_name="us-east-1")
create_basic_workgroup(client=client, name="athena_workgroup") create_basic_workgroup(client=client, name="athena_workgroup")
response = client.start_query_execution(QueryString='query1', response = client.start_query_execution(
QueryExecutionContext={'Database': 'string'}, QueryString="query1",
ResultConfiguration={'OutputLocation': 'string'}, QueryExecutionContext={"Database": "string"},
WorkGroup='athena_workgroup') ResultConfiguration={"OutputLocation": "string"},
assert 'QueryExecutionId' in response WorkGroup="athena_workgroup",
)
assert "QueryExecutionId" in response
sec_response = client.start_query_execution(QueryString='query2', sec_response = client.start_query_execution(
QueryExecutionContext={'Database': 'string'}, QueryString="query2",
ResultConfiguration={'OutputLocation': 'string'}) QueryExecutionContext={"Database": "string"},
assert 'QueryExecutionId' in sec_response ResultConfiguration={"OutputLocation": "string"},
)
assert "QueryExecutionId" in sec_response
response["QueryExecutionId"].shouldnt.equal(sec_response["QueryExecutionId"]) response["QueryExecutionId"].shouldnt.equal(sec_response["QueryExecutionId"])
@ -101,10 +105,12 @@ def test_start_query_validate_workgroup():
client = boto3.client("athena", region_name="us-east-1") client = boto3.client("athena", region_name="us-east-1")
with assert_raises(ClientError) as err: with assert_raises(ClientError) as err:
client.start_query_execution(QueryString='query1', client.start_query_execution(
QueryExecutionContext={'Database': 'string'}, QueryString="query1",
ResultConfiguration={'OutputLocation': 'string'}, QueryExecutionContext={"Database": "string"},
WorkGroup='unknown_workgroup') ResultConfiguration={"OutputLocation": "string"},
WorkGroup="unknown_workgroup",
)
err.exception.response["Error"]["Code"].should.equal("InvalidRequestException") err.exception.response["Error"]["Code"].should.equal("InvalidRequestException")
err.exception.response["Error"]["Message"].should.equal("WorkGroup does not exist") err.exception.response["Error"]["Message"].should.equal("WorkGroup does not exist")
@ -117,9 +123,11 @@ def test_get_query_execution():
location = "s3://bucket-name/prefix/" location = "s3://bucket-name/prefix/"
database = "database" database = "database"
# Start Query # Start Query
exex_id = client.start_query_execution(QueryString=query, exex_id = client.start_query_execution(
QueryExecutionContext={'Database': database}, QueryString=query,
ResultConfiguration={'OutputLocation': location})["QueryExecutionId"] QueryExecutionContext={"Database": database},
ResultConfiguration={"OutputLocation": location},
)["QueryExecutionId"]
# #
details = client.get_query_execution(QueryExecutionId=exex_id)["QueryExecution"] details = client.get_query_execution(QueryExecutionId=exex_id)["QueryExecution"]
# #
@ -129,12 +137,16 @@ def test_get_query_execution():
details["ResultConfiguration"]["OutputLocation"].should.equal(location) details["ResultConfiguration"]["OutputLocation"].should.equal(location)
details["QueryExecutionContext"]["Database"].should.equal(database) details["QueryExecutionContext"]["Database"].should.equal(database)
details["Status"]["State"].should.equal("QUEUED") details["Status"]["State"].should.equal("QUEUED")
details["Statistics"].should.equal({'EngineExecutionTimeInMillis': 0, details["Statistics"].should.equal(
'DataScannedInBytes': 0, {
'TotalExecutionTimeInMillis': 0, "EngineExecutionTimeInMillis": 0,
'QueryQueueTimeInMillis': 0, "DataScannedInBytes": 0,
'QueryPlanningTimeInMillis': 0, "TotalExecutionTimeInMillis": 0,
'ServiceProcessingTimeInMillis': 0}) "QueryQueueTimeInMillis": 0,
"QueryPlanningTimeInMillis": 0,
"ServiceProcessingTimeInMillis": 0,
}
)
assert "WorkGroup" not in details assert "WorkGroup" not in details
@ -146,9 +158,11 @@ def test_stop_query_execution():
location = "s3://bucket-name/prefix/" location = "s3://bucket-name/prefix/"
database = "database" database = "database"
# Start Query # Start Query
exex_id = client.start_query_execution(QueryString=query, exex_id = client.start_query_execution(
QueryExecutionContext={'Database': database}, QueryString=query,
ResultConfiguration={'OutputLocation': location})["QueryExecutionId"] QueryExecutionContext={"Database": database},
ResultConfiguration={"OutputLocation": location},
)["QueryExecutionId"]
# Stop Query # Stop Query
client.stop_query_execution(QueryExecutionId=exex_id) client.stop_query_execution(QueryExecutionId=exex_id)
# Verify status # Verify status
@ -163,8 +177,6 @@ def create_basic_workgroup(client, name):
Name=name, Name=name,
Description="Test work group", Description="Test work group",
Configuration={ Configuration={
"ResultConfiguration": { "ResultConfiguration": {"OutputLocation": "s3://bucket-name/prefix/",}
"OutputLocation": "s3://bucket-name/prefix/", },
}
}
) )