Athena: get_query_execution() should return exact OutputLocation file (#7422)

This commit is contained in:
Bert Blommers 2024-03-04 21:17:17 +00:00 committed by GitHub
parent 131b78c82c
commit 1c11fbc0a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 5 deletions

View File

@ -86,7 +86,9 @@ class DataCatalog(TaggableResourceMixin, BaseModel):
class Execution(BaseModel): class Execution(BaseModel):
def __init__(self, query: str, context: str, config: str, workgroup: WorkGroup): def __init__(
self, query: str, context: str, config: Dict[str, Any], workgroup: WorkGroup
):
self.id = str(mock_random.uuid4()) self.id = str(mock_random.uuid4())
self.query = query self.query = query
self.context = context self.context = context
@ -95,6 +97,11 @@ class Execution(BaseModel):
self.start_time = time.time() self.start_time = time.time()
self.status = "SUCCEEDED" self.status = "SUCCEEDED"
if "OutputLocation" in self.config:
if not self.config["OutputLocation"].endswith("/"):
self.config["OutputLocation"] += "/"
self.config["OutputLocation"] += f"{self.id}.csv"
class QueryResults(BaseModel): class QueryResults(BaseModel):
def __init__(self, rows: List[Dict[str, Any]], column_info: List[Dict[str, str]]): def __init__(self, rows: List[Dict[str, Any]], column_info: List[Dict[str, str]]):
@ -213,7 +220,7 @@ class AthenaBackend(BaseBackend):
} }
def start_query_execution( def start_query_execution(
self, query: str, context: str, config: str, workgroup: WorkGroup self, query: str, context: str, config: Dict[str, Any], workgroup: WorkGroup
) -> str: ) -> str:
execution = Execution( execution = Execution(
query=query, context=context, config=config, workgroup=workgroup query=query, context=context, config=config, workgroup=workgroup

View File

@ -123,11 +123,13 @@ def test_start_query_validate_workgroup():
@mock_aws @mock_aws
def test_get_query_execution(): @pytest.mark.parametrize(
"location", ["s3://bucket-name/prefix/", "s3://bucket-name/prefix_wo_slash"]
)
def test_get_query_execution(location):
client = boto3.client("athena", region_name="us-east-1") client = boto3.client("athena", region_name="us-east-1")
query = "SELECT stuff" query = "SELECT stuff"
location = "s3://bucket-name/prefix/"
database = "database" database = "database"
# Start Query # Start Query
exex_id = client.start_query_execution( exex_id = client.start_query_execution(
@ -141,7 +143,11 @@ def test_get_query_execution():
assert details["QueryExecutionId"] == exex_id assert details["QueryExecutionId"] == exex_id
assert details["Query"] == query assert details["Query"] == query
assert details["StatementType"] == "DML" assert details["StatementType"] == "DML"
assert details["ResultConfiguration"]["OutputLocation"] == location result_config = details["ResultConfiguration"]
if location.endswith("/"):
assert result_config["OutputLocation"] == f"{location}{exex_id}.csv"
else:
assert result_config["OutputLocation"] == f"{location}/{exex_id}.csv"
assert details["QueryExecutionContext"]["Database"] == database assert details["QueryExecutionContext"]["Database"] == database
assert details["Status"]["State"] == "SUCCEEDED" assert details["Status"]["State"] == "SUCCEEDED"
assert details["Statistics"] == { assert details["Statistics"] == {