From dd20fec9f35ed508ca6f4a2ecacc6c0c95acbedf Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Sat, 16 May 2020 15:00:06 +0100 Subject: [PATCH 1/2] Athena - Start/stop executions --- IMPLEMENTATION_COVERAGE.md | 38 +++++----- moto/athena/models.py | 42 ++++++++++- moto/athena/responses.py | 67 +++++++++++++++--- tests/test_athena/test_athena.py | 115 ++++++++++++++++++++++++++++++- 4 files changed, 229 insertions(+), 33 deletions(-) diff --git a/IMPLEMENTATION_COVERAGE.md b/IMPLEMENTATION_COVERAGE.md index f56385b25..1555da1c8 100644 --- a/IMPLEMENTATION_COVERAGE.md +++ b/IMPLEMENTATION_COVERAGE.md @@ -641,7 +641,7 @@ ## athena
-10% implemented +26% implemented - [ ] batch_get_named_query - [ ] batch_get_query_execution @@ -652,13 +652,13 @@ - [ ] get_named_query - [ ] get_query_execution - [ ] get_query_results -- [ ] get_work_group +- [X] get_work_group - [ ] list_named_queries - [ ] list_query_executions - [ ] list_tags_for_resource - [X] list_work_groups -- [ ] start_query_execution -- [ ] stop_query_execution +- [X] start_query_execution +- [X] stop_query_execution - [ ] tag_resource - [ ] untag_resource - [ ] update_work_group @@ -5287,26 +5287,26 @@ ## managedblockchain
-16% implemented +77% implemented -- [ ] create_member +- [X] create_member - [X] create_network - [ ] create_node -- [ ] create_proposal -- [ ] delete_member +- [X] create_proposal +- [X] delete_member - [ ] delete_node -- [ ] get_member +- [X] get_member - [X] get_network - [ ] get_node -- [ ] get_proposal -- [ ] list_invitations -- [ ] list_members +- [X] get_proposal +- [X] list_invitations +- [X] list_members - [X] list_networks - [ ] list_nodes -- [ ] list_proposal_votes -- [ ] list_proposals -- [ ] reject_invitation -- [ ] vote_on_proposal +- [X] list_proposal_votes +- [X] list_proposals +- [X] reject_invitation +- [X] vote_on_proposal
## marketplace-catalog @@ -7392,7 +7392,7 @@ ## ses
-18% implemented +21% implemented - [ ] clone_receipt_rule_set - [X] create_configuration_set @@ -7427,14 +7427,14 @@ - [ ] get_identity_verification_attributes - [X] get_send_quota - [X] get_send_statistics -- [ ] get_template +- [X] get_template - [ ] list_configuration_sets - [ ] list_custom_verification_email_templates - [X] list_identities - [ ] list_identity_policies - [ ] list_receipt_filters - [ ] list_receipt_rule_sets -- [ ] list_templates +- [X] list_templates - [X] list_verified_email_addresses - [ ] put_configuration_set_delivery_options - [ ] put_identity_policy diff --git a/moto/athena/models.py b/moto/athena/models.py index 6aeca0ffa..20d180d74 100644 --- a/moto/athena/models.py +++ b/moto/athena/models.py @@ -2,10 +2,9 @@ from __future__ import unicode_literals import time from boto3 import Session +from moto.core import BaseBackend, BaseModel, ACCOUNT_ID -from moto.core import BaseBackend, BaseModel - -from moto.core import ACCOUNT_ID +from uuid import uuid4 class TaggableResourceMixin(object): @@ -50,6 +49,18 @@ class WorkGroup(TaggableResourceMixin, BaseModel): self.configuration = configuration +class Execution(BaseModel): + + def __init__(self, query, context, config, workgroup): + self.id = str(uuid4()) + self.query = query + self.context = context + self.config = config + self.workgroup = workgroup + self.start_time = time.time() + self.status = "QUEUED" + + class AthenaBackend(BaseBackend): region_name = None @@ -57,6 +68,7 @@ class AthenaBackend(BaseBackend): if region_name is not None: self.region_name = region_name self.work_groups = {} + self.executions = {} def create_work_group(self, name, configuration, description, tags): if name in self.work_groups: @@ -76,6 +88,30 @@ class AthenaBackend(BaseBackend): for wg in self.work_groups.values() ] + def get_work_group(self, name): + if name not in self.work_groups: + return None + wg = self.work_groups[name] + return { + "Name": wg.name, + "State": wg.state, + "Configuration": wg.configuration, + "Description": wg.description, + "CreationTime": time.time() + } + + def start_query_execution(self, query, context, config, workgroup): + execution = Execution(query=query, context=context, config=config, workgroup=workgroup) + self.executions[execution.id] = execution + return execution.id + + def get_execution(self, exec_id): + return self.executions[exec_id] + + def stop_query_execution(self, exec_id): + execution = self.executions[exec_id] + execution.status = "CANCELLED" + athena_backends = {} for region in Session().get_available_regions("athena"): diff --git a/moto/athena/responses.py b/moto/athena/responses.py index 80cac5d62..c572cea0b 100644 --- a/moto/athena/responses.py +++ b/moto/athena/responses.py @@ -18,15 +18,7 @@ class AthenaResponse(BaseResponse): name, configuration, description, tags ) if not work_group: - return ( - json.dumps( - { - "__type": "InvalidRequestException", - "Message": "WorkGroup already exists", - } - ), - dict(status=400), - ) + return self.error("WorkGroup already exists", 400) return json.dumps( { "CreateWorkGroupResponse": { @@ -39,3 +31,60 @@ class AthenaResponse(BaseResponse): def list_work_groups(self): return json.dumps({"WorkGroups": self.athena_backend.list_work_groups()}) + + def get_work_group(self): + name = self._get_param("WorkGroup") + return json.dumps({"WorkGroup": self.athena_backend.get_work_group(name)}) + + def start_query_execution(self): + query = self._get_param("QueryString") + context = self._get_param("QueryExecutionContext") + config = self._get_param("ResultConfiguration") + workgroup = self._get_param("WorkGroup") + if workgroup and not self.athena_backend.get_work_group(workgroup): + return self.error("WorkGroup does not exist", 400) + id = self.athena_backend.start_query_execution(query=query, context=context, config=config, workgroup=workgroup) + return json.dumps({"QueryExecutionId": id}) + + def get_query_execution(self): + exec_id = self._get_param("QueryExecutionId") + execution = self.athena_backend.get_execution(exec_id) + result = { + 'QueryExecution': { + 'QueryExecutionId': exec_id, + 'Query': execution.query, + 'StatementType': 'DDL', + 'ResultConfiguration': execution.config, + 'QueryExecutionContext': execution.context, + 'Status': { + 'State': execution.status, + 'SubmissionDateTime': execution.start_time + }, + 'Statistics': { + 'EngineExecutionTimeInMillis': 0, + 'DataScannedInBytes': 0, + 'TotalExecutionTimeInMillis': 0, + 'QueryQueueTimeInMillis': 0, + 'QueryPlanningTimeInMillis': 0, + 'ServiceProcessingTimeInMillis': 0 + }, + 'WorkGroup': execution.workgroup + } + } + return json.dumps(result) + + def stop_query_execution(self): + exec_id = self._get_param("QueryExecutionId") + self.athena_backend.stop_query_execution(exec_id) + return json.dumps({}) + + def error(self, msg, status): + return ( + json.dumps( + { + "__type": "InvalidRequestException", + "Message": msg, + } + ), + dict(status=status), + ) diff --git a/tests/test_athena/test_athena.py b/tests/test_athena/test_athena.py index d36653910..597361b1d 100644 --- a/tests/test_athena/test_athena.py +++ b/tests/test_athena/test_athena.py @@ -1,8 +1,7 @@ from __future__ import unicode_literals -import datetime - from botocore.exceptions import ClientError +from nose.tools import assert_raises import boto3 import sure # noqa @@ -57,3 +56,115 @@ def test_create_work_group(): work_group["Name"].should.equal("athena_workgroup") work_group["Description"].should.equal("Test work group") work_group["State"].should.equal("ENABLED") + + +@mock_athena +def test_create_and_get_workgroup(): + client = boto3.client("athena", region_name="us-east-1") + + create_basic_workgroup(client=client, name="athena_workgroup") + + work_group = client.get_work_group(WorkGroup='athena_workgroup')['WorkGroup'] + del work_group["CreationTime"] # Were not testing creationtime atm + work_group.should.equal({ + 'Name': 'athena_workgroup', + 'State': 'ENABLED', + 'Configuration': { + 'ResultConfiguration': { + 'OutputLocation': 's3://bucket-name/prefix/' + } + }, + 'Description': 'Test work group' + }) + + +@mock_athena +def test_start_query_execution(): + client = boto3.client("athena", region_name="us-east-1") + + create_basic_workgroup(client=client, name="athena_workgroup") + response = client.start_query_execution(QueryString='query1', + QueryExecutionContext={'Database': 'string'}, + ResultConfiguration={'OutputLocation': 'string'}, + WorkGroup='athena_workgroup') + assert 'QueryExecutionId' in response + + sec_response = client.start_query_execution(QueryString='query2', + QueryExecutionContext={'Database': 'string'}, + ResultConfiguration={'OutputLocation': 'string'}) + assert 'QueryExecutionId' in sec_response + response["QueryExecutionId"].shouldnt.equal(sec_response["QueryExecutionId"]) + + +@mock_athena +def test_start_query_validate_workgroup(): + client = boto3.client("athena", region_name="us-east-1") + + with assert_raises(ClientError) as err: + client.start_query_execution(QueryString='query1', + QueryExecutionContext={'Database': 'string'}, + ResultConfiguration={'OutputLocation': 'string'}, + WorkGroup='unknown_workgroup') + err.exception.response["Error"]["Code"].should.equal("InvalidRequestException") + err.exception.response["Error"]["Message"].should.equal("WorkGroup does not exist") + + +@mock_athena +def test_get_query_execution(): + client = boto3.client("athena", region_name="us-east-1") + + query = "SELECT stuff" + location = "s3://bucket-name/prefix/" + database = "database" + # Start Query + exex_id = client.start_query_execution(QueryString=query, + QueryExecutionContext={'Database': database}, + ResultConfiguration={'OutputLocation': location})["QueryExecutionId"] + # + details = client.get_query_execution(QueryExecutionId=exex_id)["QueryExecution"] + # + details["QueryExecutionId"].should.equal(exex_id) + details["Query"].should.equal(query) + details["StatementType"].should.equal("DDL") + details["ResultConfiguration"]["OutputLocation"].should.equal(location) + details["QueryExecutionContext"]["Database"].should.equal(database) + details["Status"]["State"].should.equal("QUEUED") + details["Statistics"].should.equal({'EngineExecutionTimeInMillis': 0, + 'DataScannedInBytes': 0, + 'TotalExecutionTimeInMillis': 0, + 'QueryQueueTimeInMillis': 0, + 'QueryPlanningTimeInMillis': 0, + 'ServiceProcessingTimeInMillis': 0}) + assert "WorkGroup" not in details + + +@mock_athena +def test_stop_query_execution(): + client = boto3.client("athena", region_name="us-east-1") + + query = "SELECT stuff" + location = "s3://bucket-name/prefix/" + database = "database" + # Start Query + exex_id = client.start_query_execution(QueryString=query, + QueryExecutionContext={'Database': database}, + ResultConfiguration={'OutputLocation': location})["QueryExecutionId"] + # Stop Query + client.stop_query_execution(QueryExecutionId=exex_id) + # Verify status + details = client.get_query_execution(QueryExecutionId=exex_id)["QueryExecution"] + # + details["QueryExecutionId"].should.equal(exex_id) + details["Status"]["State"].should.equal("CANCELLED") + + +def create_basic_workgroup(client, name): + client.create_work_group( + Name=name, + Description="Test work group", + Configuration={ + "ResultConfiguration": { + "OutputLocation": "s3://bucket-name/prefix/", + } + } + ) From ffb521f86b2dc793e0c4a5bc953e1ae7aadb5195 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Sat, 16 May 2020 15:03:26 +0100 Subject: [PATCH 2/2] Linting --- moto/athena/models.py | 7 ++- moto/athena/responses.py | 45 +++++++-------- tests/test_athena/test_athena.py | 94 ++++++++++++++++++-------------- 3 files changed, 78 insertions(+), 68 deletions(-) diff --git a/moto/athena/models.py b/moto/athena/models.py index 20d180d74..c39c13817 100644 --- a/moto/athena/models.py +++ b/moto/athena/models.py @@ -50,7 +50,6 @@ class WorkGroup(TaggableResourceMixin, BaseModel): class Execution(BaseModel): - def __init__(self, query, context, config, workgroup): self.id = str(uuid4()) self.query = query @@ -97,11 +96,13 @@ class AthenaBackend(BaseBackend): "State": wg.state, "Configuration": wg.configuration, "Description": wg.description, - "CreationTime": time.time() + "CreationTime": time.time(), } def start_query_execution(self, query, context, config, workgroup): - execution = Execution(query=query, context=context, config=config, workgroup=workgroup) + execution = Execution( + query=query, context=context, config=config, workgroup=workgroup + ) self.executions[execution.id] = execution return execution.id diff --git a/moto/athena/responses.py b/moto/athena/responses.py index c572cea0b..b52e0beed 100644 --- a/moto/athena/responses.py +++ b/moto/athena/responses.py @@ -43,32 +43,34 @@ class AthenaResponse(BaseResponse): workgroup = self._get_param("WorkGroup") if workgroup and not self.athena_backend.get_work_group(workgroup): return self.error("WorkGroup does not exist", 400) - id = self.athena_backend.start_query_execution(query=query, context=context, config=config, workgroup=workgroup) + id = self.athena_backend.start_query_execution( + query=query, context=context, config=config, workgroup=workgroup + ) return json.dumps({"QueryExecutionId": id}) def get_query_execution(self): exec_id = self._get_param("QueryExecutionId") execution = self.athena_backend.get_execution(exec_id) result = { - 'QueryExecution': { - 'QueryExecutionId': exec_id, - 'Query': execution.query, - 'StatementType': 'DDL', - 'ResultConfiguration': execution.config, - 'QueryExecutionContext': execution.context, - 'Status': { - 'State': execution.status, - 'SubmissionDateTime': execution.start_time + "QueryExecution": { + "QueryExecutionId": exec_id, + "Query": execution.query, + "StatementType": "DDL", + "ResultConfiguration": execution.config, + "QueryExecutionContext": execution.context, + "Status": { + "State": execution.status, + "SubmissionDateTime": execution.start_time, }, - 'Statistics': { - 'EngineExecutionTimeInMillis': 0, - 'DataScannedInBytes': 0, - 'TotalExecutionTimeInMillis': 0, - 'QueryQueueTimeInMillis': 0, - 'QueryPlanningTimeInMillis': 0, - 'ServiceProcessingTimeInMillis': 0 + "Statistics": { + "EngineExecutionTimeInMillis": 0, + "DataScannedInBytes": 0, + "TotalExecutionTimeInMillis": 0, + "QueryQueueTimeInMillis": 0, + "QueryPlanningTimeInMillis": 0, + "ServiceProcessingTimeInMillis": 0, }, - 'WorkGroup': execution.workgroup + "WorkGroup": execution.workgroup, } } return json.dumps(result) @@ -80,11 +82,6 @@ class AthenaResponse(BaseResponse): def error(self, msg, status): return ( - json.dumps( - { - "__type": "InvalidRequestException", - "Message": msg, - } - ), + json.dumps({"__type": "InvalidRequestException", "Message": msg,}), dict(status=status), ) diff --git a/tests/test_athena/test_athena.py b/tests/test_athena/test_athena.py index 597361b1d..93ca436aa 100644 --- a/tests/test_athena/test_athena.py +++ b/tests/test_athena/test_athena.py @@ -64,18 +64,18 @@ def test_create_and_get_workgroup(): create_basic_workgroup(client=client, name="athena_workgroup") - work_group = client.get_work_group(WorkGroup='athena_workgroup')['WorkGroup'] - del work_group["CreationTime"] # Were not testing creationtime atm - work_group.should.equal({ - 'Name': 'athena_workgroup', - 'State': 'ENABLED', - 'Configuration': { - 'ResultConfiguration': { - 'OutputLocation': 's3://bucket-name/prefix/' - } - }, - 'Description': 'Test work group' - }) + work_group = client.get_work_group(WorkGroup="athena_workgroup")["WorkGroup"] + del work_group["CreationTime"] # Were not testing creationtime atm + work_group.should.equal( + { + "Name": "athena_workgroup", + "State": "ENABLED", + "Configuration": { + "ResultConfiguration": {"OutputLocation": "s3://bucket-name/prefix/"} + }, + "Description": "Test work group", + } + ) @mock_athena @@ -83,16 +83,20 @@ def test_start_query_execution(): client = boto3.client("athena", region_name="us-east-1") create_basic_workgroup(client=client, name="athena_workgroup") - response = client.start_query_execution(QueryString='query1', - QueryExecutionContext={'Database': 'string'}, - ResultConfiguration={'OutputLocation': 'string'}, - WorkGroup='athena_workgroup') - assert 'QueryExecutionId' in response + response = client.start_query_execution( + QueryString="query1", + QueryExecutionContext={"Database": "string"}, + ResultConfiguration={"OutputLocation": "string"}, + WorkGroup="athena_workgroup", + ) + assert "QueryExecutionId" in response - sec_response = client.start_query_execution(QueryString='query2', - QueryExecutionContext={'Database': 'string'}, - ResultConfiguration={'OutputLocation': 'string'}) - assert 'QueryExecutionId' in sec_response + sec_response = client.start_query_execution( + QueryString="query2", + QueryExecutionContext={"Database": "string"}, + ResultConfiguration={"OutputLocation": "string"}, + ) + assert "QueryExecutionId" in sec_response response["QueryExecutionId"].shouldnt.equal(sec_response["QueryExecutionId"]) @@ -101,10 +105,12 @@ def test_start_query_validate_workgroup(): client = boto3.client("athena", region_name="us-east-1") with assert_raises(ClientError) as err: - client.start_query_execution(QueryString='query1', - QueryExecutionContext={'Database': 'string'}, - ResultConfiguration={'OutputLocation': 'string'}, - WorkGroup='unknown_workgroup') + client.start_query_execution( + QueryString="query1", + QueryExecutionContext={"Database": "string"}, + ResultConfiguration={"OutputLocation": "string"}, + WorkGroup="unknown_workgroup", + ) err.exception.response["Error"]["Code"].should.equal("InvalidRequestException") err.exception.response["Error"]["Message"].should.equal("WorkGroup does not exist") @@ -117,9 +123,11 @@ def test_get_query_execution(): location = "s3://bucket-name/prefix/" database = "database" # Start Query - exex_id = client.start_query_execution(QueryString=query, - QueryExecutionContext={'Database': database}, - ResultConfiguration={'OutputLocation': location})["QueryExecutionId"] + exex_id = client.start_query_execution( + QueryString=query, + QueryExecutionContext={"Database": database}, + ResultConfiguration={"OutputLocation": location}, + )["QueryExecutionId"] # details = client.get_query_execution(QueryExecutionId=exex_id)["QueryExecution"] # @@ -129,12 +137,16 @@ def test_get_query_execution(): details["ResultConfiguration"]["OutputLocation"].should.equal(location) details["QueryExecutionContext"]["Database"].should.equal(database) details["Status"]["State"].should.equal("QUEUED") - details["Statistics"].should.equal({'EngineExecutionTimeInMillis': 0, - 'DataScannedInBytes': 0, - 'TotalExecutionTimeInMillis': 0, - 'QueryQueueTimeInMillis': 0, - 'QueryPlanningTimeInMillis': 0, - 'ServiceProcessingTimeInMillis': 0}) + details["Statistics"].should.equal( + { + "EngineExecutionTimeInMillis": 0, + "DataScannedInBytes": 0, + "TotalExecutionTimeInMillis": 0, + "QueryQueueTimeInMillis": 0, + "QueryPlanningTimeInMillis": 0, + "ServiceProcessingTimeInMillis": 0, + } + ) assert "WorkGroup" not in details @@ -146,9 +158,11 @@ def test_stop_query_execution(): location = "s3://bucket-name/prefix/" database = "database" # Start Query - exex_id = client.start_query_execution(QueryString=query, - QueryExecutionContext={'Database': database}, - ResultConfiguration={'OutputLocation': location})["QueryExecutionId"] + exex_id = client.start_query_execution( + QueryString=query, + QueryExecutionContext={"Database": database}, + ResultConfiguration={"OutputLocation": location}, + )["QueryExecutionId"] # Stop Query client.stop_query_execution(QueryExecutionId=exex_id) # Verify status @@ -163,8 +177,6 @@ def create_basic_workgroup(client, name): Name=name, Description="Test work group", Configuration={ - "ResultConfiguration": { - "OutputLocation": "s3://bucket-name/prefix/", - } - } + "ResultConfiguration": {"OutputLocation": "s3://bucket-name/prefix/",} + }, )