Implemented Athena create_named_query, get_named_query (#1524) (#3065)

* Implemented Athena create_named_query, get_named_query
This commit is contained in:
ktrueda 2020-06-12 01:27:29 +09:00 committed by GitHub
parent b88f166099
commit 5880d31f7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 92 additions and 0 deletions

View File

@ -60,6 +60,16 @@ class Execution(BaseModel):
self.status = "QUEUED" self.status = "QUEUED"
class NamedQuery(BaseModel):
def __init__(self, name, description, database, query_string, workgroup):
self.id = str(uuid4())
self.name = name
self.description = description
self.database = database
self.query_string = query_string
self.workgroup = workgroup
class AthenaBackend(BaseBackend): class AthenaBackend(BaseBackend):
region_name = None region_name = None
@ -68,6 +78,7 @@ class AthenaBackend(BaseBackend):
self.region_name = region_name self.region_name = region_name
self.work_groups = {} self.work_groups = {}
self.executions = {} self.executions = {}
self.named_queries = {}
def create_work_group(self, name, configuration, description, tags): def create_work_group(self, name, configuration, description, tags):
if name in self.work_groups: if name in self.work_groups:
@ -113,6 +124,20 @@ class AthenaBackend(BaseBackend):
execution = self.executions[exec_id] execution = self.executions[exec_id]
execution.status = "CANCELLED" execution.status = "CANCELLED"
def create_named_query(self, name, description, database, query_string, workgroup):
nq = NamedQuery(
name=name,
description=description,
database=database,
query_string=query_string,
workgroup=workgroup,
)
self.named_queries[nq.id] = nq
return nq.id
def get_named_query(self, query_id):
return self.named_queries[query_id] if query_id in self.named_queries else None
athena_backends = {} athena_backends = {}
for region in Session().get_available_regions("athena"): for region in Session().get_available_regions("athena"):

View File

@ -85,3 +85,32 @@ class AthenaResponse(BaseResponse):
json.dumps({"__type": "InvalidRequestException", "Message": msg,}), json.dumps({"__type": "InvalidRequestException", "Message": msg,}),
dict(status=status), dict(status=status),
) )
def create_named_query(self):
name = self._get_param("Name")
description = self._get_param("Description")
database = self._get_param("Database")
query_string = self._get_param("QueryString")
workgroup = self._get_param("WorkGroup")
if workgroup and not self.athena_backend.get_work_group(workgroup):
return self.error("WorkGroup does not exist", 400)
query_id = self.athena_backend.create_named_query(
name, description, database, query_string, workgroup
)
return json.dumps({"NamedQueryId": query_id})
def get_named_query(self):
query_id = self._get_param("NamedQueryId")
nq = self.athena_backend.get_named_query(query_id)
return json.dumps(
{
"NamedQuery": {
"Name": nq.name,
"Description": nq.description,
"Database": nq.database,
"QueryString": nq.query_string,
"NamedQueryId": nq.id,
"WorkGroup": nq.workgroup,
}
}
)

View File

@ -172,6 +172,44 @@ def test_stop_query_execution():
details["Status"]["State"].should.equal("CANCELLED") details["Status"]["State"].should.equal("CANCELLED")
@mock_athena
def test_create_named_query():
client = boto3.client("athena", region_name="us-east-1")
# craete named query
res = client.create_named_query(
Name="query-name", Database="target_db", QueryString="SELECT * FROM table1",
)
assert "NamedQueryId" in res
@mock_athena
def test_get_named_query():
client = boto3.client("athena", region_name="us-east-1")
query_name = "query-name"
database = "target_db"
query_string = "SELECT * FROM tbl1"
description = "description of this query"
# craete named query
res_create = client.create_named_query(
Name=query_name,
Database=database,
QueryString=query_string,
Description=description,
)
query_id = res_create["NamedQueryId"]
# get named query
res_get = client.get_named_query(NamedQueryId=query_id)["NamedQuery"]
res_get["Name"].should.equal(query_name)
res_get["Description"].should.equal(description)
res_get["Database"].should.equal(database)
res_get["QueryString"].should.equal(query_string)
res_get["NamedQueryId"].should.equal(query_id)
def create_basic_workgroup(client, name): def create_basic_workgroup(client, name):
client.create_work_group( client.create_work_group(
Name=name, Name=name,