diff --git a/moto/athena/models.py b/moto/athena/models.py index bdba60cc0..a7d5e4a4b 100644 --- a/moto/athena/models.py +++ b/moto/athena/models.py @@ -86,7 +86,9 @@ class DataCatalog(TaggableResourceMixin, BaseModel): class Execution(BaseModel): - def __init__(self, query: str, context: str, config: str, workgroup: WorkGroup): + def __init__( + self, query: str, context: str, config: Dict[str, Any], workgroup: WorkGroup + ): self.id = str(mock_random.uuid4()) self.query = query self.context = context @@ -95,6 +97,11 @@ class Execution(BaseModel): self.start_time = time.time() self.status = "SUCCEEDED" + if "OutputLocation" in self.config: + if not self.config["OutputLocation"].endswith("/"): + self.config["OutputLocation"] += "/" + self.config["OutputLocation"] += f"{self.id}.csv" + class QueryResults(BaseModel): def __init__(self, rows: List[Dict[str, Any]], column_info: List[Dict[str, str]]): @@ -213,7 +220,7 @@ class AthenaBackend(BaseBackend): } def start_query_execution( - self, query: str, context: str, config: str, workgroup: WorkGroup + self, query: str, context: str, config: Dict[str, Any], workgroup: WorkGroup ) -> str: execution = Execution( query=query, context=context, config=config, workgroup=workgroup diff --git a/tests/test_athena/test_athena.py b/tests/test_athena/test_athena.py index 0ac70daf1..0b6dad04a 100644 --- a/tests/test_athena/test_athena.py +++ b/tests/test_athena/test_athena.py @@ -123,11 +123,13 @@ def test_start_query_validate_workgroup(): @mock_aws -def test_get_query_execution(): +@pytest.mark.parametrize( + "location", ["s3://bucket-name/prefix/", "s3://bucket-name/prefix_wo_slash"] +) +def test_get_query_execution(location): client = boto3.client("athena", region_name="us-east-1") query = "SELECT stuff" - location = "s3://bucket-name/prefix/" database = "database" # Start Query exex_id = client.start_query_execution( @@ -141,7 +143,11 @@ def test_get_query_execution(): assert details["QueryExecutionId"] == exex_id assert details["Query"] == query assert details["StatementType"] == "DML" - assert details["ResultConfiguration"]["OutputLocation"] == location + result_config = details["ResultConfiguration"] + if location.endswith("/"): + assert result_config["OutputLocation"] == f"{location}{exex_id}.csv" + else: + assert result_config["OutputLocation"] == f"{location}/{exex_id}.csv" assert details["QueryExecutionContext"]["Database"] == database assert details["Status"]["State"] == "SUCCEEDED" assert details["Statistics"] == {