From 5880d31f7e746388a64ff42b7f90077a0d666a82 Mon Sep 17 00:00:00 2001 From: ktrueda Date: Fri, 12 Jun 2020 01:27:29 +0900 Subject: [PATCH] Implemented Athena create_named_query, get_named_query (#1524) (#3065) * Implemented Athena create_named_query, get_named_query --- moto/athena/models.py | 25 +++++++++++++++++++++ moto/athena/responses.py | 29 ++++++++++++++++++++++++ tests/test_athena/test_athena.py | 38 ++++++++++++++++++++++++++++++++ 3 files changed, 92 insertions(+) diff --git a/moto/athena/models.py b/moto/athena/models.py index c39c13817..24ad73ab9 100644 --- a/moto/athena/models.py +++ b/moto/athena/models.py @@ -60,6 +60,16 @@ class Execution(BaseModel): self.status = "QUEUED" +class NamedQuery(BaseModel): + def __init__(self, name, description, database, query_string, workgroup): + self.id = str(uuid4()) + self.name = name + self.description = description + self.database = database + self.query_string = query_string + self.workgroup = workgroup + + class AthenaBackend(BaseBackend): region_name = None @@ -68,6 +78,7 @@ class AthenaBackend(BaseBackend): self.region_name = region_name self.work_groups = {} self.executions = {} + self.named_queries = {} def create_work_group(self, name, configuration, description, tags): if name in self.work_groups: @@ -113,6 +124,20 @@ class AthenaBackend(BaseBackend): execution = self.executions[exec_id] execution.status = "CANCELLED" + def create_named_query(self, name, description, database, query_string, workgroup): + nq = NamedQuery( + name=name, + description=description, + database=database, + query_string=query_string, + workgroup=workgroup, + ) + self.named_queries[nq.id] = nq + return nq.id + + def get_named_query(self, query_id): + return self.named_queries[query_id] if query_id in self.named_queries else None + athena_backends = {} for region in Session().get_available_regions("athena"): diff --git a/moto/athena/responses.py b/moto/athena/responses.py index b52e0beed..b5e6d6a95 100644 --- a/moto/athena/responses.py +++ b/moto/athena/responses.py @@ -85,3 +85,32 @@ class AthenaResponse(BaseResponse): json.dumps({"__type": "InvalidRequestException", "Message": msg,}), dict(status=status), ) + + def create_named_query(self): + name = self._get_param("Name") + description = self._get_param("Description") + database = self._get_param("Database") + query_string = self._get_param("QueryString") + workgroup = self._get_param("WorkGroup") + if workgroup and not self.athena_backend.get_work_group(workgroup): + return self.error("WorkGroup does not exist", 400) + query_id = self.athena_backend.create_named_query( + name, description, database, query_string, workgroup + ) + return json.dumps({"NamedQueryId": query_id}) + + def get_named_query(self): + query_id = self._get_param("NamedQueryId") + nq = self.athena_backend.get_named_query(query_id) + return json.dumps( + { + "NamedQuery": { + "Name": nq.name, + "Description": nq.description, + "Database": nq.database, + "QueryString": nq.query_string, + "NamedQueryId": nq.id, + "WorkGroup": nq.workgroup, + } + } + ) diff --git a/tests/test_athena/test_athena.py b/tests/test_athena/test_athena.py index 93ca436aa..805a653e3 100644 --- a/tests/test_athena/test_athena.py +++ b/tests/test_athena/test_athena.py @@ -172,6 +172,44 @@ def test_stop_query_execution(): details["Status"]["State"].should.equal("CANCELLED") +@mock_athena +def test_create_named_query(): + client = boto3.client("athena", region_name="us-east-1") + + # craete named query + res = client.create_named_query( + Name="query-name", Database="target_db", QueryString="SELECT * FROM table1", + ) + + assert "NamedQueryId" in res + + +@mock_athena +def test_get_named_query(): + client = boto3.client("athena", region_name="us-east-1") + query_name = "query-name" + database = "target_db" + query_string = "SELECT * FROM tbl1" + description = "description of this query" + + # craete named query + res_create = client.create_named_query( + Name=query_name, + Database=database, + QueryString=query_string, + Description=description, + ) + query_id = res_create["NamedQueryId"] + + # get named query + res_get = client.get_named_query(NamedQueryId=query_id)["NamedQuery"] + res_get["Name"].should.equal(query_name) + res_get["Description"].should.equal(description) + res_get["Database"].should.equal(database) + res_get["QueryString"].should.equal(query_string) + res_get["NamedQueryId"].should.equal(query_id) + + def create_basic_workgroup(client, name): client.create_work_group( Name=name,