Athena: get_query_execution() should return exact OutputLocation file (#7422)
This commit is contained in:
		
							parent
							
								
									131b78c82c
								
							
						
					
					
						commit
						1c11fbc0a2
					
				| @ -86,7 +86,9 @@ class DataCatalog(TaggableResourceMixin, BaseModel): | ||||
| 
 | ||||
| 
 | ||||
| class Execution(BaseModel): | ||||
|     def __init__(self, query: str, context: str, config: str, workgroup: WorkGroup): | ||||
|     def __init__( | ||||
|         self, query: str, context: str, config: Dict[str, Any], workgroup: WorkGroup | ||||
|     ): | ||||
|         self.id = str(mock_random.uuid4()) | ||||
|         self.query = query | ||||
|         self.context = context | ||||
| @ -95,6 +97,11 @@ class Execution(BaseModel): | ||||
|         self.start_time = time.time() | ||||
|         self.status = "SUCCEEDED" | ||||
| 
 | ||||
|         if "OutputLocation" in self.config: | ||||
|             if not self.config["OutputLocation"].endswith("/"): | ||||
|                 self.config["OutputLocation"] += "/" | ||||
|             self.config["OutputLocation"] += f"{self.id}.csv" | ||||
| 
 | ||||
| 
 | ||||
| class QueryResults(BaseModel): | ||||
|     def __init__(self, rows: List[Dict[str, Any]], column_info: List[Dict[str, str]]): | ||||
| @ -213,7 +220,7 @@ class AthenaBackend(BaseBackend): | ||||
|         } | ||||
| 
 | ||||
|     def start_query_execution( | ||||
|         self, query: str, context: str, config: str, workgroup: WorkGroup | ||||
|         self, query: str, context: str, config: Dict[str, Any], workgroup: WorkGroup | ||||
|     ) -> str: | ||||
|         execution = Execution( | ||||
|             query=query, context=context, config=config, workgroup=workgroup | ||||
|  | ||||
| @ -123,11 +123,13 @@ def test_start_query_validate_workgroup(): | ||||
| 
 | ||||
| 
 | ||||
| @mock_aws | ||||
| def test_get_query_execution(): | ||||
| @pytest.mark.parametrize( | ||||
|     "location", ["s3://bucket-name/prefix/", "s3://bucket-name/prefix_wo_slash"] | ||||
| ) | ||||
| def test_get_query_execution(location): | ||||
|     client = boto3.client("athena", region_name="us-east-1") | ||||
| 
 | ||||
|     query = "SELECT stuff" | ||||
|     location = "s3://bucket-name/prefix/" | ||||
|     database = "database" | ||||
|     # Start Query | ||||
|     exex_id = client.start_query_execution( | ||||
| @ -141,7 +143,11 @@ def test_get_query_execution(): | ||||
|     assert details["QueryExecutionId"] == exex_id | ||||
|     assert details["Query"] == query | ||||
|     assert details["StatementType"] == "DML" | ||||
|     assert details["ResultConfiguration"]["OutputLocation"] == location | ||||
|     result_config = details["ResultConfiguration"] | ||||
|     if location.endswith("/"): | ||||
|         assert result_config["OutputLocation"] == f"{location}{exex_id}.csv" | ||||
|     else: | ||||
|         assert result_config["OutputLocation"] == f"{location}/{exex_id}.csv" | ||||
|     assert details["QueryExecutionContext"]["Database"] == database | ||||
|     assert details["Status"]["State"] == "SUCCEEDED" | ||||
|     assert details["Statistics"] == { | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user