Athena: get_query_execution() should return exact OutputLocation file (#7422)
This commit is contained in:
		
							parent
							
								
									131b78c82c
								
							
						
					
					
						commit
						1c11fbc0a2
					
				| @ -86,7 +86,9 @@ class DataCatalog(TaggableResourceMixin, BaseModel): | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class Execution(BaseModel): | class Execution(BaseModel): | ||||||
|     def __init__(self, query: str, context: str, config: str, workgroup: WorkGroup): |     def __init__( | ||||||
|  |         self, query: str, context: str, config: Dict[str, Any], workgroup: WorkGroup | ||||||
|  |     ): | ||||||
|         self.id = str(mock_random.uuid4()) |         self.id = str(mock_random.uuid4()) | ||||||
|         self.query = query |         self.query = query | ||||||
|         self.context = context |         self.context = context | ||||||
| @ -95,6 +97,11 @@ class Execution(BaseModel): | |||||||
|         self.start_time = time.time() |         self.start_time = time.time() | ||||||
|         self.status = "SUCCEEDED" |         self.status = "SUCCEEDED" | ||||||
| 
 | 
 | ||||||
|  |         if "OutputLocation" in self.config: | ||||||
|  |             if not self.config["OutputLocation"].endswith("/"): | ||||||
|  |                 self.config["OutputLocation"] += "/" | ||||||
|  |             self.config["OutputLocation"] += f"{self.id}.csv" | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| class QueryResults(BaseModel): | class QueryResults(BaseModel): | ||||||
|     def __init__(self, rows: List[Dict[str, Any]], column_info: List[Dict[str, str]]): |     def __init__(self, rows: List[Dict[str, Any]], column_info: List[Dict[str, str]]): | ||||||
| @ -213,7 +220,7 @@ class AthenaBackend(BaseBackend): | |||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|     def start_query_execution( |     def start_query_execution( | ||||||
|         self, query: str, context: str, config: str, workgroup: WorkGroup |         self, query: str, context: str, config: Dict[str, Any], workgroup: WorkGroup | ||||||
|     ) -> str: |     ) -> str: | ||||||
|         execution = Execution( |         execution = Execution( | ||||||
|             query=query, context=context, config=config, workgroup=workgroup |             query=query, context=context, config=config, workgroup=workgroup | ||||||
|  | |||||||
| @ -123,11 +123,13 @@ def test_start_query_validate_workgroup(): | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @mock_aws | @mock_aws | ||||||
| def test_get_query_execution(): | @pytest.mark.parametrize( | ||||||
|  |     "location", ["s3://bucket-name/prefix/", "s3://bucket-name/prefix_wo_slash"] | ||||||
|  | ) | ||||||
|  | def test_get_query_execution(location): | ||||||
|     client = boto3.client("athena", region_name="us-east-1") |     client = boto3.client("athena", region_name="us-east-1") | ||||||
| 
 | 
 | ||||||
|     query = "SELECT stuff" |     query = "SELECT stuff" | ||||||
|     location = "s3://bucket-name/prefix/" |  | ||||||
|     database = "database" |     database = "database" | ||||||
|     # Start Query |     # Start Query | ||||||
|     exex_id = client.start_query_execution( |     exex_id = client.start_query_execution( | ||||||
| @ -141,7 +143,11 @@ def test_get_query_execution(): | |||||||
|     assert details["QueryExecutionId"] == exex_id |     assert details["QueryExecutionId"] == exex_id | ||||||
|     assert details["Query"] == query |     assert details["Query"] == query | ||||||
|     assert details["StatementType"] == "DML" |     assert details["StatementType"] == "DML" | ||||||
|     assert details["ResultConfiguration"]["OutputLocation"] == location |     result_config = details["ResultConfiguration"] | ||||||
|  |     if location.endswith("/"): | ||||||
|  |         assert result_config["OutputLocation"] == f"{location}{exex_id}.csv" | ||||||
|  |     else: | ||||||
|  |         assert result_config["OutputLocation"] == f"{location}/{exex_id}.csv" | ||||||
|     assert details["QueryExecutionContext"]["Database"] == database |     assert details["QueryExecutionContext"]["Database"] == database | ||||||
|     assert details["Status"]["State"] == "SUCCEEDED" |     assert details["Status"]["State"] == "SUCCEEDED" | ||||||
|     assert details["Statistics"] == { |     assert details["Statistics"] == { | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user