S3: select_object_content() now supports Compressed requests and CSV outputs (#7514)

This commit is contained in:
Bert Blommers 2024-03-25 05:34:06 -01:00 committed by GitHub
parent f14749b6b5
commit ade1001a69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 190 additions and 29 deletions

View File

@ -1,7 +1,9 @@
import base64
import bz2
import codecs
import copy
import datetime
import gzip
import itertools
import json
import os
@ -12,6 +14,7 @@ import threading
import urllib.parse
from bisect import insort
from importlib import reload
from io import BytesIO
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union
from moto.cloudwatch.models import MetricDatum
@ -2858,6 +2861,7 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider):
key_name: str,
select_query: str,
input_details: Dict[str, Any],
output_details: Dict[str, Any],
) -> List[bytes]:
"""
Highly experimental. Please raise an issue if you find any inconsistencies/bugs.
@ -2870,22 +2874,51 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider):
"""
self.get_bucket(bucket_name)
key = self.get_object(bucket_name, key_name)
query_input = key.value.decode("utf-8") # type: ignore
if key is None:
raise MissingKey(key=key_name)
if input_details.get("CompressionType") == "GZIP":
with gzip.open(BytesIO(key.value), "rt") as f:
query_input = f.read()
elif input_details.get("CompressionType") == "BZIP2":
query_input = bz2.decompress(key.value).decode("utf-8")
else:
query_input = key.value.decode("utf-8")
if "CSV" in input_details:
# input is in CSV - we need to convert it to JSON before parsing
from py_partiql_parser._internal.csv_converter import ( # noqa # pylint: disable=unused-import
csv_to_json,
)
from py_partiql_parser import csv_to_json
use_headers = input_details["CSV"].get("FileHeaderInfo", "") == "USE"
use_headers = (input_details.get("CSV") or {}).get(
"FileHeaderInfo", ""
) == "USE"
query_input = csv_to_json(query_input, use_headers)
query_result = parse_query(query_input, select_query)
query_result = parse_query(query_input, select_query) # type: ignore
record_delimiter = "\n"
if "JSON" in output_details:
record_delimiter = (output_details.get("JSON") or {}).get(
"RecordDelimiter"
) or "\n"
elif "CSV" in output_details:
record_delimiter = (output_details.get("CSV") or {}).get(
"RecordDelimiter"
) or "\n"
if "CSV" in output_details:
field_delim = (output_details.get("CSV") or {}).get("FieldDelimiter") or ","
from py_partiql_parser import json_to_csv
query_result = json_to_csv(query_result, field_delim, record_delimiter)
return [query_result.encode("utf-8")] # type: ignore
else:
from py_partiql_parser import SelectEncoder
return [
json.dumps(x, indent=None, separators=(",", ":"), cls=SelectEncoder).encode(
"utf-8"
)
(
json.dumps(x, indent=None, separators=(",", ":"), cls=SelectEncoder)
+ record_delimiter
).encode("utf-8")
for x in query_result
]

View File

@ -2291,9 +2291,9 @@ class S3Response(BaseResponse):
input_details = request["InputSerialization"]
output_details = request["OutputSerialization"]
results = self.backend.select_object_content(
bucket_name, key_name, select_query, input_details
bucket_name, key_name, select_query, input_details, output_details
)
return 200, {}, serialize_select(results, output_details)
return 200, {}, serialize_select(results)
else:
raise NotImplementedError(

View File

@ -49,11 +49,8 @@ def _create_end_message() -> bytes:
return _create_message(content_type=None, event_type=b"End", payload=b"")
def serialize_select(data_list: List[bytes], output_details: Dict[str, Any]) -> bytes:
delimiter = (
(output_details.get("JSON") or {}).get("RecordDelimiter") or "\n"
).encode("utf-8")
def serialize_select(data_list: List[bytes]) -> bytes:
response = b""
for data in data_list:
response += _create_data_message(data + delimiter)
response += _create_data_message(data)
return response + _create_stats_message() + _create_end_message()

View File

@ -54,7 +54,7 @@ all =
openapi-spec-validator>=0.5.0
pyparsing>=3.0.7
jsondiff>=1.1.2
py-partiql-parser==0.5.1
py-partiql-parser==0.5.2
aws-xray-sdk!=0.96,>=0.93
setuptools
multipart
@ -69,7 +69,7 @@ proxy =
openapi-spec-validator>=0.5.0
pyparsing>=3.0.7
jsondiff>=1.1.2
py-partiql-parser==0.5.1
py-partiql-parser==0.5.2
aws-xray-sdk!=0.96,>=0.93
setuptools
multipart
@ -84,7 +84,7 @@ server =
openapi-spec-validator>=0.5.0
pyparsing>=3.0.7
jsondiff>=1.1.2
py-partiql-parser==0.5.1
py-partiql-parser==0.5.2
aws-xray-sdk!=0.96,>=0.93
setuptools
flask!=2.2.0,!=2.2.1
@ -119,7 +119,7 @@ cloudformation =
openapi-spec-validator>=0.5.0
pyparsing>=3.0.7
jsondiff>=1.1.2
py-partiql-parser==0.5.1
py-partiql-parser==0.5.2
aws-xray-sdk!=0.96,>=0.93
setuptools
cloudfront =
@ -141,10 +141,10 @@ dms =
ds =
dynamodb =
docker>=3.0.0
py-partiql-parser==0.5.1
py-partiql-parser==0.5.2
dynamodbstreams =
docker>=3.0.0
py-partiql-parser==0.5.1
py-partiql-parser==0.5.2
ebs =
ec2 =
ec2instanceconnect =
@ -208,15 +208,15 @@ resourcegroupstaggingapi =
openapi-spec-validator>=0.5.0
pyparsing>=3.0.7
jsondiff>=1.1.2
py-partiql-parser==0.5.1
py-partiql-parser==0.5.2
route53 =
route53resolver =
s3 =
PyYAML>=5.1
py-partiql-parser==0.5.1
py-partiql-parser==0.5.2
s3crc32c =
PyYAML>=5.1
py-partiql-parser==0.5.1
py-partiql-parser==0.5.2
crc32c
s3control =
sagemaker =

View File

@ -1,7 +1,10 @@
import bz2
import gzip
import json
import boto3
import pytest
from botocore.exceptions import ClientError
from . import s3_aws_verified
@ -45,6 +48,24 @@ def create_test_files(bucket_name):
Key="nested.json",
Body=json.dumps(NESTED_JSON),
)
client.put_object(
Bucket=bucket_name,
Key="json.gzip",
Body=gzip.compress(json.dumps(NESTED_JSON).encode("utf-8")),
)
client.put_object(
Bucket=bucket_name,
Key="json.bz2",
Body=bz2.compress(json.dumps(NESTED_JSON).encode("utf-8")),
)
client.put_object(
Bucket=bucket_name,
Key="csv.gzip",
Body=gzip.compress(SIMPLE_CSV.encode("utf-8")),
)
client.put_object(
Bucket=bucket_name, Key="csv.bz2", Body=bz2.compress(SIMPLE_CSV.encode("utf-8"))
)
@pytest.mark.aws_verified
@ -226,3 +247,113 @@ def test_nested_json__select_all(bucket_name=None):
assert records[-1] == ","
assert json.loads(records[:-1]) == NESTED_JSON
@pytest.mark.aws_verified
@s3_aws_verified
def test_gzipped_json(bucket_name=None):
client = boto3.client("s3", "us-east-1")
create_test_files(bucket_name)
content = client.select_object_content(
Bucket=bucket_name,
Key="json.gzip",
Expression="SELECT count(*) FROM S3Object",
ExpressionType="SQL",
InputSerialization={"JSON": {"Type": "DOCUMENT"}, "CompressionType": "GZIP"},
OutputSerialization={"JSON": {"RecordDelimiter": ","}},
)
result = list(content["Payload"])
assert {"Records": {"Payload": b'{"_1":1},'}} in result
@pytest.mark.aws_verified
@s3_aws_verified
def test_bzipped_json(bucket_name=None):
client = boto3.client("s3", "us-east-1")
create_test_files(bucket_name)
content = client.select_object_content(
Bucket=bucket_name,
Key="json.bz2",
Expression="SELECT count(*) FROM S3Object",
ExpressionType="SQL",
InputSerialization={"JSON": {"Type": "DOCUMENT"}, "CompressionType": "BZIP2"},
OutputSerialization={"JSON": {"RecordDelimiter": ","}},
)
result = list(content["Payload"])
assert {"Records": {"Payload": b'{"_1":1},'}} in result
@pytest.mark.aws_verified
@s3_aws_verified
def test_bzipped_csv_to_csv(bucket_name=None):
client = boto3.client("s3", "us-east-1")
create_test_files(bucket_name)
# Count Records
content = client.select_object_content(
Bucket=bucket_name,
Key="csv.bz2",
Expression="SELECT count(*) FROM S3Object",
ExpressionType="SQL",
InputSerialization={"CSV": {}, "CompressionType": "BZIP2"},
OutputSerialization={"CSV": {"RecordDelimiter": "_", "FieldDelimiter": ":"}},
)
result = list(content["Payload"])
assert {"Records": {"Payload": b"4_"}} in result
# Count Records
content = client.select_object_content(
Bucket=bucket_name,
Key="csv.bz2",
Expression="SELECT count(*) FROM S3Object",
ExpressionType="SQL",
InputSerialization={"CSV": {}, "CompressionType": "BZIP2"},
OutputSerialization={"CSV": {}},
)
result = list(content["Payload"])
assert {"Records": {"Payload": b"4\n"}} in result
# Mirror records
content = client.select_object_content(
Bucket=bucket_name,
Key="csv.bz2",
Expression="SELECT * FROM S3Object",
ExpressionType="SQL",
InputSerialization={"CSV": {}, "CompressionType": "BZIP2"},
OutputSerialization={"CSV": {}},
)
result = list(content["Payload"])
assert {"Records": {"Payload": b"a,b,c\ne,r,f\ny,u,i\nq,w,y\n"}} in result
# Mirror records, specifying output format
content = client.select_object_content(
Bucket=bucket_name,
Key="csv.bz2",
Expression="SELECT * FROM S3Object",
ExpressionType="SQL",
InputSerialization={"CSV": {}, "CompressionType": "BZIP2"},
OutputSerialization={"CSV": {"RecordDelimiter": "\n", "FieldDelimiter": ":"}},
)
result = list(content["Payload"])
assert {"Records": {"Payload": b"a:b:c\ne:r:f\ny:u:i\nq:w:y\n"}} in result
@pytest.mark.aws_verified
@s3_aws_verified
def test_select_unknown_key(bucket_name=None):
client = boto3.client("s3", "us-east-1")
with pytest.raises(ClientError) as exc:
client.select_object_content(
Bucket=bucket_name,
Key="unknown",
Expression="SELECT count(*) FROM S3Object",
ExpressionType="SQL",
InputSerialization={"CSV": {}, "CompressionType": "BZIP2"},
OutputSerialization={
"CSV": {"RecordDelimiter": "\n", "FieldDelimiter": ":"}
},
)
err = exc.value.response["Error"]
assert err["Code"] == "NoSuchKey"
assert err["Message"] == "The specified key does not exist."
assert err["Key"] == "unknown"