From ade1001a6928068d4bf3017f23b3ca10302d6066 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Mon, 25 Mar 2024 05:34:06 -0100 Subject: [PATCH] S3: select_object_content() now supports Compressed requests and CSV outputs (#7514) --- moto/s3/models.py | 59 +++++++++++--- moto/s3/responses.py | 4 +- moto/s3/select_object_content.py | 7 +- setup.cfg | 18 ++--- tests/test_s3/test_s3_select.py | 131 +++++++++++++++++++++++++++++++ 5 files changed, 190 insertions(+), 29 deletions(-) diff --git a/moto/s3/models.py b/moto/s3/models.py index bdfa9e6fe..af925ebf3 100644 --- a/moto/s3/models.py +++ b/moto/s3/models.py @@ -1,7 +1,9 @@ import base64 +import bz2 import codecs import copy import datetime +import gzip import itertools import json import os @@ -12,6 +14,7 @@ import threading import urllib.parse from bisect import insort from importlib import reload +from io import BytesIO from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union from moto.cloudwatch.models import MetricDatum @@ -2858,6 +2861,7 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): key_name: str, select_query: str, input_details: Dict[str, Any], + output_details: Dict[str, Any], ) -> List[bytes]: """ Highly experimental. Please raise an issue if you find any inconsistencies/bugs. @@ -2870,24 +2874,53 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): """ self.get_bucket(bucket_name) key = self.get_object(bucket_name, key_name) - query_input = key.value.decode("utf-8") # type: ignore + if key is None: + raise MissingKey(key=key_name) + if input_details.get("CompressionType") == "GZIP": + with gzip.open(BytesIO(key.value), "rt") as f: + query_input = f.read() + elif input_details.get("CompressionType") == "BZIP2": + query_input = bz2.decompress(key.value).decode("utf-8") + else: + query_input = key.value.decode("utf-8") if "CSV" in input_details: # input is in CSV - we need to convert it to JSON before parsing - from py_partiql_parser._internal.csv_converter import ( # noqa # pylint: disable=unused-import - csv_to_json, - ) + from py_partiql_parser import csv_to_json - use_headers = input_details["CSV"].get("FileHeaderInfo", "") == "USE" + use_headers = (input_details.get("CSV") or {}).get( + "FileHeaderInfo", "" + ) == "USE" query_input = csv_to_json(query_input, use_headers) - query_result = parse_query(query_input, select_query) - from py_partiql_parser import SelectEncoder + query_result = parse_query(query_input, select_query) # type: ignore - return [ - json.dumps(x, indent=None, separators=(",", ":"), cls=SelectEncoder).encode( - "utf-8" - ) - for x in query_result - ] + record_delimiter = "\n" + if "JSON" in output_details: + record_delimiter = (output_details.get("JSON") or {}).get( + "RecordDelimiter" + ) or "\n" + elif "CSV" in output_details: + record_delimiter = (output_details.get("CSV") or {}).get( + "RecordDelimiter" + ) or "\n" + + if "CSV" in output_details: + field_delim = (output_details.get("CSV") or {}).get("FieldDelimiter") or "," + + from py_partiql_parser import json_to_csv + + query_result = json_to_csv(query_result, field_delim, record_delimiter) + return [query_result.encode("utf-8")] # type: ignore + + else: + from py_partiql_parser import SelectEncoder + + return [ + ( + json.dumps(x, indent=None, separators=(",", ":"), cls=SelectEncoder) + + record_delimiter + ).encode("utf-8") + for x in query_result + ] def restore_object( self, bucket_name: str, key_name: str, days: Optional[str], type_: Optional[str] diff --git a/moto/s3/responses.py b/moto/s3/responses.py index 0b17c3915..ae6899b0c 100644 --- a/moto/s3/responses.py +++ b/moto/s3/responses.py @@ -2291,9 +2291,9 @@ class S3Response(BaseResponse): input_details = request["InputSerialization"] output_details = request["OutputSerialization"] results = self.backend.select_object_content( - bucket_name, key_name, select_query, input_details + bucket_name, key_name, select_query, input_details, output_details ) - return 200, {}, serialize_select(results, output_details) + return 200, {}, serialize_select(results) else: raise NotImplementedError( diff --git a/moto/s3/select_object_content.py b/moto/s3/select_object_content.py index e9d9985e8..db2ae2b84 100644 --- a/moto/s3/select_object_content.py +++ b/moto/s3/select_object_content.py @@ -49,11 +49,8 @@ def _create_end_message() -> bytes: return _create_message(content_type=None, event_type=b"End", payload=b"") -def serialize_select(data_list: List[bytes], output_details: Dict[str, Any]) -> bytes: - delimiter = ( - (output_details.get("JSON") or {}).get("RecordDelimiter") or "\n" - ).encode("utf-8") +def serialize_select(data_list: List[bytes]) -> bytes: response = b"" for data in data_list: - response += _create_data_message(data + delimiter) + response += _create_data_message(data) return response + _create_stats_message() + _create_end_message() diff --git a/setup.cfg b/setup.cfg index 2a5689a07..a7d287bc5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -54,7 +54,7 @@ all = openapi-spec-validator>=0.5.0 pyparsing>=3.0.7 jsondiff>=1.1.2 - py-partiql-parser==0.5.1 + py-partiql-parser==0.5.2 aws-xray-sdk!=0.96,>=0.93 setuptools multipart @@ -69,7 +69,7 @@ proxy = openapi-spec-validator>=0.5.0 pyparsing>=3.0.7 jsondiff>=1.1.2 - py-partiql-parser==0.5.1 + py-partiql-parser==0.5.2 aws-xray-sdk!=0.96,>=0.93 setuptools multipart @@ -84,7 +84,7 @@ server = openapi-spec-validator>=0.5.0 pyparsing>=3.0.7 jsondiff>=1.1.2 - py-partiql-parser==0.5.1 + py-partiql-parser==0.5.2 aws-xray-sdk!=0.96,>=0.93 setuptools flask!=2.2.0,!=2.2.1 @@ -119,7 +119,7 @@ cloudformation = openapi-spec-validator>=0.5.0 pyparsing>=3.0.7 jsondiff>=1.1.2 - py-partiql-parser==0.5.1 + py-partiql-parser==0.5.2 aws-xray-sdk!=0.96,>=0.93 setuptools cloudfront = @@ -141,10 +141,10 @@ dms = ds = dynamodb = docker>=3.0.0 - py-partiql-parser==0.5.1 + py-partiql-parser==0.5.2 dynamodbstreams = docker>=3.0.0 - py-partiql-parser==0.5.1 + py-partiql-parser==0.5.2 ebs = ec2 = ec2instanceconnect = @@ -208,15 +208,15 @@ resourcegroupstaggingapi = openapi-spec-validator>=0.5.0 pyparsing>=3.0.7 jsondiff>=1.1.2 - py-partiql-parser==0.5.1 + py-partiql-parser==0.5.2 route53 = route53resolver = s3 = PyYAML>=5.1 - py-partiql-parser==0.5.1 + py-partiql-parser==0.5.2 s3crc32c = PyYAML>=5.1 - py-partiql-parser==0.5.1 + py-partiql-parser==0.5.2 crc32c s3control = sagemaker = diff --git a/tests/test_s3/test_s3_select.py b/tests/test_s3/test_s3_select.py index 8d55ed546..d4a5a2f90 100644 --- a/tests/test_s3/test_s3_select.py +++ b/tests/test_s3/test_s3_select.py @@ -1,7 +1,10 @@ +import bz2 +import gzip import json import boto3 import pytest +from botocore.exceptions import ClientError from . import s3_aws_verified @@ -45,6 +48,24 @@ def create_test_files(bucket_name): Key="nested.json", Body=json.dumps(NESTED_JSON), ) + client.put_object( + Bucket=bucket_name, + Key="json.gzip", + Body=gzip.compress(json.dumps(NESTED_JSON).encode("utf-8")), + ) + client.put_object( + Bucket=bucket_name, + Key="json.bz2", + Body=bz2.compress(json.dumps(NESTED_JSON).encode("utf-8")), + ) + client.put_object( + Bucket=bucket_name, + Key="csv.gzip", + Body=gzip.compress(SIMPLE_CSV.encode("utf-8")), + ) + client.put_object( + Bucket=bucket_name, Key="csv.bz2", Body=bz2.compress(SIMPLE_CSV.encode("utf-8")) + ) @pytest.mark.aws_verified @@ -226,3 +247,113 @@ def test_nested_json__select_all(bucket_name=None): assert records[-1] == "," assert json.loads(records[:-1]) == NESTED_JSON + + +@pytest.mark.aws_verified +@s3_aws_verified +def test_gzipped_json(bucket_name=None): + client = boto3.client("s3", "us-east-1") + create_test_files(bucket_name) + content = client.select_object_content( + Bucket=bucket_name, + Key="json.gzip", + Expression="SELECT count(*) FROM S3Object", + ExpressionType="SQL", + InputSerialization={"JSON": {"Type": "DOCUMENT"}, "CompressionType": "GZIP"}, + OutputSerialization={"JSON": {"RecordDelimiter": ","}}, + ) + result = list(content["Payload"]) + assert {"Records": {"Payload": b'{"_1":1},'}} in result + + +@pytest.mark.aws_verified +@s3_aws_verified +def test_bzipped_json(bucket_name=None): + client = boto3.client("s3", "us-east-1") + create_test_files(bucket_name) + content = client.select_object_content( + Bucket=bucket_name, + Key="json.bz2", + Expression="SELECT count(*) FROM S3Object", + ExpressionType="SQL", + InputSerialization={"JSON": {"Type": "DOCUMENT"}, "CompressionType": "BZIP2"}, + OutputSerialization={"JSON": {"RecordDelimiter": ","}}, + ) + result = list(content["Payload"]) + assert {"Records": {"Payload": b'{"_1":1},'}} in result + + +@pytest.mark.aws_verified +@s3_aws_verified +def test_bzipped_csv_to_csv(bucket_name=None): + client = boto3.client("s3", "us-east-1") + create_test_files(bucket_name) + + # Count Records + content = client.select_object_content( + Bucket=bucket_name, + Key="csv.bz2", + Expression="SELECT count(*) FROM S3Object", + ExpressionType="SQL", + InputSerialization={"CSV": {}, "CompressionType": "BZIP2"}, + OutputSerialization={"CSV": {"RecordDelimiter": "_", "FieldDelimiter": ":"}}, + ) + result = list(content["Payload"]) + assert {"Records": {"Payload": b"4_"}} in result + + # Count Records + content = client.select_object_content( + Bucket=bucket_name, + Key="csv.bz2", + Expression="SELECT count(*) FROM S3Object", + ExpressionType="SQL", + InputSerialization={"CSV": {}, "CompressionType": "BZIP2"}, + OutputSerialization={"CSV": {}}, + ) + result = list(content["Payload"]) + assert {"Records": {"Payload": b"4\n"}} in result + + # Mirror records + content = client.select_object_content( + Bucket=bucket_name, + Key="csv.bz2", + Expression="SELECT * FROM S3Object", + ExpressionType="SQL", + InputSerialization={"CSV": {}, "CompressionType": "BZIP2"}, + OutputSerialization={"CSV": {}}, + ) + result = list(content["Payload"]) + assert {"Records": {"Payload": b"a,b,c\ne,r,f\ny,u,i\nq,w,y\n"}} in result + + # Mirror records, specifying output format + content = client.select_object_content( + Bucket=bucket_name, + Key="csv.bz2", + Expression="SELECT * FROM S3Object", + ExpressionType="SQL", + InputSerialization={"CSV": {}, "CompressionType": "BZIP2"}, + OutputSerialization={"CSV": {"RecordDelimiter": "\n", "FieldDelimiter": ":"}}, + ) + result = list(content["Payload"]) + assert {"Records": {"Payload": b"a:b:c\ne:r:f\ny:u:i\nq:w:y\n"}} in result + + +@pytest.mark.aws_verified +@s3_aws_verified +def test_select_unknown_key(bucket_name=None): + client = boto3.client("s3", "us-east-1") + with pytest.raises(ClientError) as exc: + client.select_object_content( + Bucket=bucket_name, + Key="unknown", + Expression="SELECT count(*) FROM S3Object", + ExpressionType="SQL", + InputSerialization={"CSV": {}, "CompressionType": "BZIP2"}, + OutputSerialization={ + "CSV": {"RecordDelimiter": "\n", "FieldDelimiter": ":"} + }, + ) + err = exc.value.response["Error"] + assert err["Code"] == "NoSuchKey" + assert err["Message"] == "The specified key does not exist." + assert err["Key"] == "unknown"