Athena: get_query_execution() should return exact OutputLocation file (#7422)
This commit is contained in:
parent
131b78c82c
commit
1c11fbc0a2
@ -86,7 +86,9 @@ class DataCatalog(TaggableResourceMixin, BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class Execution(BaseModel):
|
class Execution(BaseModel):
|
||||||
def __init__(self, query: str, context: str, config: str, workgroup: WorkGroup):
|
def __init__(
|
||||||
|
self, query: str, context: str, config: Dict[str, Any], workgroup: WorkGroup
|
||||||
|
):
|
||||||
self.id = str(mock_random.uuid4())
|
self.id = str(mock_random.uuid4())
|
||||||
self.query = query
|
self.query = query
|
||||||
self.context = context
|
self.context = context
|
||||||
@ -95,6 +97,11 @@ class Execution(BaseModel):
|
|||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
self.status = "SUCCEEDED"
|
self.status = "SUCCEEDED"
|
||||||
|
|
||||||
|
if "OutputLocation" in self.config:
|
||||||
|
if not self.config["OutputLocation"].endswith("/"):
|
||||||
|
self.config["OutputLocation"] += "/"
|
||||||
|
self.config["OutputLocation"] += f"{self.id}.csv"
|
||||||
|
|
||||||
|
|
||||||
class QueryResults(BaseModel):
|
class QueryResults(BaseModel):
|
||||||
def __init__(self, rows: List[Dict[str, Any]], column_info: List[Dict[str, str]]):
|
def __init__(self, rows: List[Dict[str, Any]], column_info: List[Dict[str, str]]):
|
||||||
@ -213,7 +220,7 @@ class AthenaBackend(BaseBackend):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def start_query_execution(
|
def start_query_execution(
|
||||||
self, query: str, context: str, config: str, workgroup: WorkGroup
|
self, query: str, context: str, config: Dict[str, Any], workgroup: WorkGroup
|
||||||
) -> str:
|
) -> str:
|
||||||
execution = Execution(
|
execution = Execution(
|
||||||
query=query, context=context, config=config, workgroup=workgroup
|
query=query, context=context, config=config, workgroup=workgroup
|
||||||
|
@ -123,11 +123,13 @@ def test_start_query_validate_workgroup():
|
|||||||
|
|
||||||
|
|
||||||
@mock_aws
|
@mock_aws
|
||||||
def test_get_query_execution():
|
@pytest.mark.parametrize(
|
||||||
|
"location", ["s3://bucket-name/prefix/", "s3://bucket-name/prefix_wo_slash"]
|
||||||
|
)
|
||||||
|
def test_get_query_execution(location):
|
||||||
client = boto3.client("athena", region_name="us-east-1")
|
client = boto3.client("athena", region_name="us-east-1")
|
||||||
|
|
||||||
query = "SELECT stuff"
|
query = "SELECT stuff"
|
||||||
location = "s3://bucket-name/prefix/"
|
|
||||||
database = "database"
|
database = "database"
|
||||||
# Start Query
|
# Start Query
|
||||||
exex_id = client.start_query_execution(
|
exex_id = client.start_query_execution(
|
||||||
@ -141,7 +143,11 @@ def test_get_query_execution():
|
|||||||
assert details["QueryExecutionId"] == exex_id
|
assert details["QueryExecutionId"] == exex_id
|
||||||
assert details["Query"] == query
|
assert details["Query"] == query
|
||||||
assert details["StatementType"] == "DML"
|
assert details["StatementType"] == "DML"
|
||||||
assert details["ResultConfiguration"]["OutputLocation"] == location
|
result_config = details["ResultConfiguration"]
|
||||||
|
if location.endswith("/"):
|
||||||
|
assert result_config["OutputLocation"] == f"{location}{exex_id}.csv"
|
||||||
|
else:
|
||||||
|
assert result_config["OutputLocation"] == f"{location}/{exex_id}.csv"
|
||||||
assert details["QueryExecutionContext"]["Database"] == database
|
assert details["QueryExecutionContext"]["Database"] == database
|
||||||
assert details["Status"]["State"] == "SUCCEEDED"
|
assert details["Status"]["State"] == "SUCCEEDED"
|
||||||
assert details["Statistics"] == {
|
assert details["Statistics"] == {
|
||||||
|
Loading…
Reference in New Issue
Block a user