* Implemented Athena create_named_query, get_named_query
This commit is contained in:
parent
b88f166099
commit
5880d31f7e
@ -60,6 +60,16 @@ class Execution(BaseModel):
|
||||
self.status = "QUEUED"
|
||||
|
||||
|
||||
class NamedQuery(BaseModel):
|
||||
def __init__(self, name, description, database, query_string, workgroup):
|
||||
self.id = str(uuid4())
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.database = database
|
||||
self.query_string = query_string
|
||||
self.workgroup = workgroup
|
||||
|
||||
|
||||
class AthenaBackend(BaseBackend):
|
||||
region_name = None
|
||||
|
||||
@ -68,6 +78,7 @@ class AthenaBackend(BaseBackend):
|
||||
self.region_name = region_name
|
||||
self.work_groups = {}
|
||||
self.executions = {}
|
||||
self.named_queries = {}
|
||||
|
||||
def create_work_group(self, name, configuration, description, tags):
|
||||
if name in self.work_groups:
|
||||
@ -113,6 +124,20 @@ class AthenaBackend(BaseBackend):
|
||||
execution = self.executions[exec_id]
|
||||
execution.status = "CANCELLED"
|
||||
|
||||
def create_named_query(self, name, description, database, query_string, workgroup):
|
||||
nq = NamedQuery(
|
||||
name=name,
|
||||
description=description,
|
||||
database=database,
|
||||
query_string=query_string,
|
||||
workgroup=workgroup,
|
||||
)
|
||||
self.named_queries[nq.id] = nq
|
||||
return nq.id
|
||||
|
||||
def get_named_query(self, query_id):
|
||||
return self.named_queries[query_id] if query_id in self.named_queries else None
|
||||
|
||||
|
||||
athena_backends = {}
|
||||
for region in Session().get_available_regions("athena"):
|
||||
|
@ -85,3 +85,32 @@ class AthenaResponse(BaseResponse):
|
||||
json.dumps({"__type": "InvalidRequestException", "Message": msg,}),
|
||||
dict(status=status),
|
||||
)
|
||||
|
||||
def create_named_query(self):
|
||||
name = self._get_param("Name")
|
||||
description = self._get_param("Description")
|
||||
database = self._get_param("Database")
|
||||
query_string = self._get_param("QueryString")
|
||||
workgroup = self._get_param("WorkGroup")
|
||||
if workgroup and not self.athena_backend.get_work_group(workgroup):
|
||||
return self.error("WorkGroup does not exist", 400)
|
||||
query_id = self.athena_backend.create_named_query(
|
||||
name, description, database, query_string, workgroup
|
||||
)
|
||||
return json.dumps({"NamedQueryId": query_id})
|
||||
|
||||
def get_named_query(self):
|
||||
query_id = self._get_param("NamedQueryId")
|
||||
nq = self.athena_backend.get_named_query(query_id)
|
||||
return json.dumps(
|
||||
{
|
||||
"NamedQuery": {
|
||||
"Name": nq.name,
|
||||
"Description": nq.description,
|
||||
"Database": nq.database,
|
||||
"QueryString": nq.query_string,
|
||||
"NamedQueryId": nq.id,
|
||||
"WorkGroup": nq.workgroup,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
@ -172,6 +172,44 @@ def test_stop_query_execution():
|
||||
details["Status"]["State"].should.equal("CANCELLED")
|
||||
|
||||
|
||||
@mock_athena
|
||||
def test_create_named_query():
|
||||
client = boto3.client("athena", region_name="us-east-1")
|
||||
|
||||
# craete named query
|
||||
res = client.create_named_query(
|
||||
Name="query-name", Database="target_db", QueryString="SELECT * FROM table1",
|
||||
)
|
||||
|
||||
assert "NamedQueryId" in res
|
||||
|
||||
|
||||
@mock_athena
|
||||
def test_get_named_query():
|
||||
client = boto3.client("athena", region_name="us-east-1")
|
||||
query_name = "query-name"
|
||||
database = "target_db"
|
||||
query_string = "SELECT * FROM tbl1"
|
||||
description = "description of this query"
|
||||
|
||||
# craete named query
|
||||
res_create = client.create_named_query(
|
||||
Name=query_name,
|
||||
Database=database,
|
||||
QueryString=query_string,
|
||||
Description=description,
|
||||
)
|
||||
query_id = res_create["NamedQueryId"]
|
||||
|
||||
# get named query
|
||||
res_get = client.get_named_query(NamedQueryId=query_id)["NamedQuery"]
|
||||
res_get["Name"].should.equal(query_name)
|
||||
res_get["Description"].should.equal(description)
|
||||
res_get["Database"].should.equal(database)
|
||||
res_get["QueryString"].should.equal(query_string)
|
||||
res_get["NamedQueryId"].should.equal(query_id)
|
||||
|
||||
|
||||
def create_basic_workgroup(client, name):
|
||||
client.create_work_group(
|
||||
Name=name,
|
||||
|
Loading…
Reference in New Issue
Block a user