S3: select_object_content() now supports Compressed requests and CSV outputs (#7514)

This commit is contained in:
Bert Blommers 2024-03-25 05:34:06 -01:00 committed by GitHub
parent f14749b6b5
commit ade1001a69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 190 additions and 29 deletions

View File

@ -1,7 +1,9 @@
import base64 import base64
import bz2
import codecs import codecs
import copy import copy
import datetime import datetime
import gzip
import itertools import itertools
import json import json
import os import os
@ -12,6 +14,7 @@ import threading
import urllib.parse import urllib.parse
from bisect import insort from bisect import insort
from importlib import reload from importlib import reload
from io import BytesIO
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union
from moto.cloudwatch.models import MetricDatum from moto.cloudwatch.models import MetricDatum
@ -2858,6 +2861,7 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider):
key_name: str, key_name: str,
select_query: str, select_query: str,
input_details: Dict[str, Any], input_details: Dict[str, Any],
output_details: Dict[str, Any],
) -> List[bytes]: ) -> List[bytes]:
""" """
Highly experimental. Please raise an issue if you find any inconsistencies/bugs. Highly experimental. Please raise an issue if you find any inconsistencies/bugs.
@ -2870,22 +2874,51 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider):
""" """
self.get_bucket(bucket_name) self.get_bucket(bucket_name)
key = self.get_object(bucket_name, key_name) key = self.get_object(bucket_name, key_name)
query_input = key.value.decode("utf-8") # type: ignore if key is None:
raise MissingKey(key=key_name)
if input_details.get("CompressionType") == "GZIP":
with gzip.open(BytesIO(key.value), "rt") as f:
query_input = f.read()
elif input_details.get("CompressionType") == "BZIP2":
query_input = bz2.decompress(key.value).decode("utf-8")
else:
query_input = key.value.decode("utf-8")
if "CSV" in input_details: if "CSV" in input_details:
# input is in CSV - we need to convert it to JSON before parsing # input is in CSV - we need to convert it to JSON before parsing
from py_partiql_parser._internal.csv_converter import ( # noqa # pylint: disable=unused-import from py_partiql_parser import csv_to_json
csv_to_json,
)
use_headers = input_details["CSV"].get("FileHeaderInfo", "") == "USE" use_headers = (input_details.get("CSV") or {}).get(
"FileHeaderInfo", ""
) == "USE"
query_input = csv_to_json(query_input, use_headers) query_input = csv_to_json(query_input, use_headers)
query_result = parse_query(query_input, select_query) query_result = parse_query(query_input, select_query) # type: ignore
record_delimiter = "\n"
if "JSON" in output_details:
record_delimiter = (output_details.get("JSON") or {}).get(
"RecordDelimiter"
) or "\n"
elif "CSV" in output_details:
record_delimiter = (output_details.get("CSV") or {}).get(
"RecordDelimiter"
) or "\n"
if "CSV" in output_details:
field_delim = (output_details.get("CSV") or {}).get("FieldDelimiter") or ","
from py_partiql_parser import json_to_csv
query_result = json_to_csv(query_result, field_delim, record_delimiter)
return [query_result.encode("utf-8")] # type: ignore
else:
from py_partiql_parser import SelectEncoder from py_partiql_parser import SelectEncoder
return [ return [
json.dumps(x, indent=None, separators=(",", ":"), cls=SelectEncoder).encode( (
"utf-8" json.dumps(x, indent=None, separators=(",", ":"), cls=SelectEncoder)
) + record_delimiter
).encode("utf-8")
for x in query_result for x in query_result
] ]

View File

@ -2291,9 +2291,9 @@ class S3Response(BaseResponse):
input_details = request["InputSerialization"] input_details = request["InputSerialization"]
output_details = request["OutputSerialization"] output_details = request["OutputSerialization"]
results = self.backend.select_object_content( results = self.backend.select_object_content(
bucket_name, key_name, select_query, input_details bucket_name, key_name, select_query, input_details, output_details
) )
return 200, {}, serialize_select(results, output_details) return 200, {}, serialize_select(results)
else: else:
raise NotImplementedError( raise NotImplementedError(

View File

@ -49,11 +49,8 @@ def _create_end_message() -> bytes:
return _create_message(content_type=None, event_type=b"End", payload=b"") return _create_message(content_type=None, event_type=b"End", payload=b"")
def serialize_select(data_list: List[bytes], output_details: Dict[str, Any]) -> bytes: def serialize_select(data_list: List[bytes]) -> bytes:
delimiter = (
(output_details.get("JSON") or {}).get("RecordDelimiter") or "\n"
).encode("utf-8")
response = b"" response = b""
for data in data_list: for data in data_list:
response += _create_data_message(data + delimiter) response += _create_data_message(data)
return response + _create_stats_message() + _create_end_message() return response + _create_stats_message() + _create_end_message()

View File

@ -54,7 +54,7 @@ all =
openapi-spec-validator>=0.5.0 openapi-spec-validator>=0.5.0
pyparsing>=3.0.7 pyparsing>=3.0.7
jsondiff>=1.1.2 jsondiff>=1.1.2
py-partiql-parser==0.5.1 py-partiql-parser==0.5.2
aws-xray-sdk!=0.96,>=0.93 aws-xray-sdk!=0.96,>=0.93
setuptools setuptools
multipart multipart
@ -69,7 +69,7 @@ proxy =
openapi-spec-validator>=0.5.0 openapi-spec-validator>=0.5.0
pyparsing>=3.0.7 pyparsing>=3.0.7
jsondiff>=1.1.2 jsondiff>=1.1.2
py-partiql-parser==0.5.1 py-partiql-parser==0.5.2
aws-xray-sdk!=0.96,>=0.93 aws-xray-sdk!=0.96,>=0.93
setuptools setuptools
multipart multipart
@ -84,7 +84,7 @@ server =
openapi-spec-validator>=0.5.0 openapi-spec-validator>=0.5.0
pyparsing>=3.0.7 pyparsing>=3.0.7
jsondiff>=1.1.2 jsondiff>=1.1.2
py-partiql-parser==0.5.1 py-partiql-parser==0.5.2
aws-xray-sdk!=0.96,>=0.93 aws-xray-sdk!=0.96,>=0.93
setuptools setuptools
flask!=2.2.0,!=2.2.1 flask!=2.2.0,!=2.2.1
@ -119,7 +119,7 @@ cloudformation =
openapi-spec-validator>=0.5.0 openapi-spec-validator>=0.5.0
pyparsing>=3.0.7 pyparsing>=3.0.7
jsondiff>=1.1.2 jsondiff>=1.1.2
py-partiql-parser==0.5.1 py-partiql-parser==0.5.2
aws-xray-sdk!=0.96,>=0.93 aws-xray-sdk!=0.96,>=0.93
setuptools setuptools
cloudfront = cloudfront =
@ -141,10 +141,10 @@ dms =
ds = ds =
dynamodb = dynamodb =
docker>=3.0.0 docker>=3.0.0
py-partiql-parser==0.5.1 py-partiql-parser==0.5.2
dynamodbstreams = dynamodbstreams =
docker>=3.0.0 docker>=3.0.0
py-partiql-parser==0.5.1 py-partiql-parser==0.5.2
ebs = ebs =
ec2 = ec2 =
ec2instanceconnect = ec2instanceconnect =
@ -208,15 +208,15 @@ resourcegroupstaggingapi =
openapi-spec-validator>=0.5.0 openapi-spec-validator>=0.5.0
pyparsing>=3.0.7 pyparsing>=3.0.7
jsondiff>=1.1.2 jsondiff>=1.1.2
py-partiql-parser==0.5.1 py-partiql-parser==0.5.2
route53 = route53 =
route53resolver = route53resolver =
s3 = s3 =
PyYAML>=5.1 PyYAML>=5.1
py-partiql-parser==0.5.1 py-partiql-parser==0.5.2
s3crc32c = s3crc32c =
PyYAML>=5.1 PyYAML>=5.1
py-partiql-parser==0.5.1 py-partiql-parser==0.5.2
crc32c crc32c
s3control = s3control =
sagemaker = sagemaker =

View File

@ -1,7 +1,10 @@
import bz2
import gzip
import json import json
import boto3 import boto3
import pytest import pytest
from botocore.exceptions import ClientError
from . import s3_aws_verified from . import s3_aws_verified
@ -45,6 +48,24 @@ def create_test_files(bucket_name):
Key="nested.json", Key="nested.json",
Body=json.dumps(NESTED_JSON), Body=json.dumps(NESTED_JSON),
) )
client.put_object(
Bucket=bucket_name,
Key="json.gzip",
Body=gzip.compress(json.dumps(NESTED_JSON).encode("utf-8")),
)
client.put_object(
Bucket=bucket_name,
Key="json.bz2",
Body=bz2.compress(json.dumps(NESTED_JSON).encode("utf-8")),
)
client.put_object(
Bucket=bucket_name,
Key="csv.gzip",
Body=gzip.compress(SIMPLE_CSV.encode("utf-8")),
)
client.put_object(
Bucket=bucket_name, Key="csv.bz2", Body=bz2.compress(SIMPLE_CSV.encode("utf-8"))
)
@pytest.mark.aws_verified @pytest.mark.aws_verified
@ -226,3 +247,113 @@ def test_nested_json__select_all(bucket_name=None):
assert records[-1] == "," assert records[-1] == ","
assert json.loads(records[:-1]) == NESTED_JSON assert json.loads(records[:-1]) == NESTED_JSON
@pytest.mark.aws_verified
@s3_aws_verified
def test_gzipped_json(bucket_name=None):
client = boto3.client("s3", "us-east-1")
create_test_files(bucket_name)
content = client.select_object_content(
Bucket=bucket_name,
Key="json.gzip",
Expression="SELECT count(*) FROM S3Object",
ExpressionType="SQL",
InputSerialization={"JSON": {"Type": "DOCUMENT"}, "CompressionType": "GZIP"},
OutputSerialization={"JSON": {"RecordDelimiter": ","}},
)
result = list(content["Payload"])
assert {"Records": {"Payload": b'{"_1":1},'}} in result
@pytest.mark.aws_verified
@s3_aws_verified
def test_bzipped_json(bucket_name=None):
client = boto3.client("s3", "us-east-1")
create_test_files(bucket_name)
content = client.select_object_content(
Bucket=bucket_name,
Key="json.bz2",
Expression="SELECT count(*) FROM S3Object",
ExpressionType="SQL",
InputSerialization={"JSON": {"Type": "DOCUMENT"}, "CompressionType": "BZIP2"},
OutputSerialization={"JSON": {"RecordDelimiter": ","}},
)
result = list(content["Payload"])
assert {"Records": {"Payload": b'{"_1":1},'}} in result
@pytest.mark.aws_verified
@s3_aws_verified
def test_bzipped_csv_to_csv(bucket_name=None):
client = boto3.client("s3", "us-east-1")
create_test_files(bucket_name)
# Count Records
content = client.select_object_content(
Bucket=bucket_name,
Key="csv.bz2",
Expression="SELECT count(*) FROM S3Object",
ExpressionType="SQL",
InputSerialization={"CSV": {}, "CompressionType": "BZIP2"},
OutputSerialization={"CSV": {"RecordDelimiter": "_", "FieldDelimiter": ":"}},
)
result = list(content["Payload"])
assert {"Records": {"Payload": b"4_"}} in result
# Count Records
content = client.select_object_content(
Bucket=bucket_name,
Key="csv.bz2",
Expression="SELECT count(*) FROM S3Object",
ExpressionType="SQL",
InputSerialization={"CSV": {}, "CompressionType": "BZIP2"},
OutputSerialization={"CSV": {}},
)
result = list(content["Payload"])
assert {"Records": {"Payload": b"4\n"}} in result
# Mirror records
content = client.select_object_content(
Bucket=bucket_name,
Key="csv.bz2",
Expression="SELECT * FROM S3Object",
ExpressionType="SQL",
InputSerialization={"CSV": {}, "CompressionType": "BZIP2"},
OutputSerialization={"CSV": {}},
)
result = list(content["Payload"])
assert {"Records": {"Payload": b"a,b,c\ne,r,f\ny,u,i\nq,w,y\n"}} in result
# Mirror records, specifying output format
content = client.select_object_content(
Bucket=bucket_name,
Key="csv.bz2",
Expression="SELECT * FROM S3Object",
ExpressionType="SQL",
InputSerialization={"CSV": {}, "CompressionType": "BZIP2"},
OutputSerialization={"CSV": {"RecordDelimiter": "\n", "FieldDelimiter": ":"}},
)
result = list(content["Payload"])
assert {"Records": {"Payload": b"a:b:c\ne:r:f\ny:u:i\nq:w:y\n"}} in result
@pytest.mark.aws_verified
@s3_aws_verified
def test_select_unknown_key(bucket_name=None):
client = boto3.client("s3", "us-east-1")
with pytest.raises(ClientError) as exc:
client.select_object_content(
Bucket=bucket_name,
Key="unknown",
Expression="SELECT count(*) FROM S3Object",
ExpressionType="SQL",
InputSerialization={"CSV": {}, "CompressionType": "BZIP2"},
OutputSerialization={
"CSV": {"RecordDelimiter": "\n", "FieldDelimiter": ":"}
},
)
err = exc.value.response["Error"]
assert err["Code"] == "NoSuchKey"
assert err["Message"] == "The specified key does not exist."
assert err["Key"] == "unknown"