Athena: get_query_execution() should return exact OutputLocation file (#7422)
This commit is contained in:
parent
131b78c82c
commit
1c11fbc0a2
@ -86,7 +86,9 @@ class DataCatalog(TaggableResourceMixin, BaseModel):
|
||||
|
||||
|
||||
class Execution(BaseModel):
|
||||
def __init__(self, query: str, context: str, config: str, workgroup: WorkGroup):
|
||||
def __init__(
|
||||
self, query: str, context: str, config: Dict[str, Any], workgroup: WorkGroup
|
||||
):
|
||||
self.id = str(mock_random.uuid4())
|
||||
self.query = query
|
||||
self.context = context
|
||||
@ -95,6 +97,11 @@ class Execution(BaseModel):
|
||||
self.start_time = time.time()
|
||||
self.status = "SUCCEEDED"
|
||||
|
||||
if "OutputLocation" in self.config:
|
||||
if not self.config["OutputLocation"].endswith("/"):
|
||||
self.config["OutputLocation"] += "/"
|
||||
self.config["OutputLocation"] += f"{self.id}.csv"
|
||||
|
||||
|
||||
class QueryResults(BaseModel):
|
||||
def __init__(self, rows: List[Dict[str, Any]], column_info: List[Dict[str, str]]):
|
||||
@ -213,7 +220,7 @@ class AthenaBackend(BaseBackend):
|
||||
}
|
||||
|
||||
def start_query_execution(
|
||||
self, query: str, context: str, config: str, workgroup: WorkGroup
|
||||
self, query: str, context: str, config: Dict[str, Any], workgroup: WorkGroup
|
||||
) -> str:
|
||||
execution = Execution(
|
||||
query=query, context=context, config=config, workgroup=workgroup
|
||||
|
@ -123,11 +123,13 @@ def test_start_query_validate_workgroup():
|
||||
|
||||
|
||||
@mock_aws
|
||||
def test_get_query_execution():
|
||||
@pytest.mark.parametrize(
|
||||
"location", ["s3://bucket-name/prefix/", "s3://bucket-name/prefix_wo_slash"]
|
||||
)
|
||||
def test_get_query_execution(location):
|
||||
client = boto3.client("athena", region_name="us-east-1")
|
||||
|
||||
query = "SELECT stuff"
|
||||
location = "s3://bucket-name/prefix/"
|
||||
database = "database"
|
||||
# Start Query
|
||||
exex_id = client.start_query_execution(
|
||||
@ -141,7 +143,11 @@ def test_get_query_execution():
|
||||
assert details["QueryExecutionId"] == exex_id
|
||||
assert details["Query"] == query
|
||||
assert details["StatementType"] == "DML"
|
||||
assert details["ResultConfiguration"]["OutputLocation"] == location
|
||||
result_config = details["ResultConfiguration"]
|
||||
if location.endswith("/"):
|
||||
assert result_config["OutputLocation"] == f"{location}{exex_id}.csv"
|
||||
else:
|
||||
assert result_config["OutputLocation"] == f"{location}/{exex_id}.csv"
|
||||
assert details["QueryExecutionContext"]["Database"] == database
|
||||
assert details["Status"]["State"] == "SUCCEEDED"
|
||||
assert details["Statistics"] == {
|
||||
|
Loading…
Reference in New Issue
Block a user