S3: select_object_content() now supports Compressed requests and CSV outputs (#7514)
This commit is contained in:
parent
f14749b6b5
commit
ade1001a69
@ -1,7 +1,9 @@
|
||||
import base64
|
||||
import bz2
|
||||
import codecs
|
||||
import copy
|
||||
import datetime
|
||||
import gzip
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
@ -12,6 +14,7 @@ import threading
|
||||
import urllib.parse
|
||||
from bisect import insort
|
||||
from importlib import reload
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union
|
||||
|
||||
from moto.cloudwatch.models import MetricDatum
|
||||
@ -2858,6 +2861,7 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider):
|
||||
key_name: str,
|
||||
select_query: str,
|
||||
input_details: Dict[str, Any],
|
||||
output_details: Dict[str, Any],
|
||||
) -> List[bytes]:
|
||||
"""
|
||||
Highly experimental. Please raise an issue if you find any inconsistencies/bugs.
|
||||
@ -2870,24 +2874,53 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider):
|
||||
"""
|
||||
self.get_bucket(bucket_name)
|
||||
key = self.get_object(bucket_name, key_name)
|
||||
query_input = key.value.decode("utf-8") # type: ignore
|
||||
if key is None:
|
||||
raise MissingKey(key=key_name)
|
||||
if input_details.get("CompressionType") == "GZIP":
|
||||
with gzip.open(BytesIO(key.value), "rt") as f:
|
||||
query_input = f.read()
|
||||
elif input_details.get("CompressionType") == "BZIP2":
|
||||
query_input = bz2.decompress(key.value).decode("utf-8")
|
||||
else:
|
||||
query_input = key.value.decode("utf-8")
|
||||
if "CSV" in input_details:
|
||||
# input is in CSV - we need to convert it to JSON before parsing
|
||||
from py_partiql_parser._internal.csv_converter import ( # noqa # pylint: disable=unused-import
|
||||
csv_to_json,
|
||||
)
|
||||
from py_partiql_parser import csv_to_json
|
||||
|
||||
use_headers = input_details["CSV"].get("FileHeaderInfo", "") == "USE"
|
||||
use_headers = (input_details.get("CSV") or {}).get(
|
||||
"FileHeaderInfo", ""
|
||||
) == "USE"
|
||||
query_input = csv_to_json(query_input, use_headers)
|
||||
query_result = parse_query(query_input, select_query)
|
||||
from py_partiql_parser import SelectEncoder
|
||||
query_result = parse_query(query_input, select_query) # type: ignore
|
||||
|
||||
return [
|
||||
json.dumps(x, indent=None, separators=(",", ":"), cls=SelectEncoder).encode(
|
||||
"utf-8"
|
||||
)
|
||||
for x in query_result
|
||||
]
|
||||
record_delimiter = "\n"
|
||||
if "JSON" in output_details:
|
||||
record_delimiter = (output_details.get("JSON") or {}).get(
|
||||
"RecordDelimiter"
|
||||
) or "\n"
|
||||
elif "CSV" in output_details:
|
||||
record_delimiter = (output_details.get("CSV") or {}).get(
|
||||
"RecordDelimiter"
|
||||
) or "\n"
|
||||
|
||||
if "CSV" in output_details:
|
||||
field_delim = (output_details.get("CSV") or {}).get("FieldDelimiter") or ","
|
||||
|
||||
from py_partiql_parser import json_to_csv
|
||||
|
||||
query_result = json_to_csv(query_result, field_delim, record_delimiter)
|
||||
return [query_result.encode("utf-8")] # type: ignore
|
||||
|
||||
else:
|
||||
from py_partiql_parser import SelectEncoder
|
||||
|
||||
return [
|
||||
(
|
||||
json.dumps(x, indent=None, separators=(",", ":"), cls=SelectEncoder)
|
||||
+ record_delimiter
|
||||
).encode("utf-8")
|
||||
for x in query_result
|
||||
]
|
||||
|
||||
def restore_object(
|
||||
self, bucket_name: str, key_name: str, days: Optional[str], type_: Optional[str]
|
||||
|
@ -2291,9 +2291,9 @@ class S3Response(BaseResponse):
|
||||
input_details = request["InputSerialization"]
|
||||
output_details = request["OutputSerialization"]
|
||||
results = self.backend.select_object_content(
|
||||
bucket_name, key_name, select_query, input_details
|
||||
bucket_name, key_name, select_query, input_details, output_details
|
||||
)
|
||||
return 200, {}, serialize_select(results, output_details)
|
||||
return 200, {}, serialize_select(results)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
|
@ -49,11 +49,8 @@ def _create_end_message() -> bytes:
|
||||
return _create_message(content_type=None, event_type=b"End", payload=b"")
|
||||
|
||||
|
||||
def serialize_select(data_list: List[bytes], output_details: Dict[str, Any]) -> bytes:
|
||||
delimiter = (
|
||||
(output_details.get("JSON") or {}).get("RecordDelimiter") or "\n"
|
||||
).encode("utf-8")
|
||||
def serialize_select(data_list: List[bytes]) -> bytes:
|
||||
response = b""
|
||||
for data in data_list:
|
||||
response += _create_data_message(data + delimiter)
|
||||
response += _create_data_message(data)
|
||||
return response + _create_stats_message() + _create_end_message()
|
||||
|
18
setup.cfg
18
setup.cfg
@ -54,7 +54,7 @@ all =
|
||||
openapi-spec-validator>=0.5.0
|
||||
pyparsing>=3.0.7
|
||||
jsondiff>=1.1.2
|
||||
py-partiql-parser==0.5.1
|
||||
py-partiql-parser==0.5.2
|
||||
aws-xray-sdk!=0.96,>=0.93
|
||||
setuptools
|
||||
multipart
|
||||
@ -69,7 +69,7 @@ proxy =
|
||||
openapi-spec-validator>=0.5.0
|
||||
pyparsing>=3.0.7
|
||||
jsondiff>=1.1.2
|
||||
py-partiql-parser==0.5.1
|
||||
py-partiql-parser==0.5.2
|
||||
aws-xray-sdk!=0.96,>=0.93
|
||||
setuptools
|
||||
multipart
|
||||
@ -84,7 +84,7 @@ server =
|
||||
openapi-spec-validator>=0.5.0
|
||||
pyparsing>=3.0.7
|
||||
jsondiff>=1.1.2
|
||||
py-partiql-parser==0.5.1
|
||||
py-partiql-parser==0.5.2
|
||||
aws-xray-sdk!=0.96,>=0.93
|
||||
setuptools
|
||||
flask!=2.2.0,!=2.2.1
|
||||
@ -119,7 +119,7 @@ cloudformation =
|
||||
openapi-spec-validator>=0.5.0
|
||||
pyparsing>=3.0.7
|
||||
jsondiff>=1.1.2
|
||||
py-partiql-parser==0.5.1
|
||||
py-partiql-parser==0.5.2
|
||||
aws-xray-sdk!=0.96,>=0.93
|
||||
setuptools
|
||||
cloudfront =
|
||||
@ -141,10 +141,10 @@ dms =
|
||||
ds =
|
||||
dynamodb =
|
||||
docker>=3.0.0
|
||||
py-partiql-parser==0.5.1
|
||||
py-partiql-parser==0.5.2
|
||||
dynamodbstreams =
|
||||
docker>=3.0.0
|
||||
py-partiql-parser==0.5.1
|
||||
py-partiql-parser==0.5.2
|
||||
ebs =
|
||||
ec2 =
|
||||
ec2instanceconnect =
|
||||
@ -208,15 +208,15 @@ resourcegroupstaggingapi =
|
||||
openapi-spec-validator>=0.5.0
|
||||
pyparsing>=3.0.7
|
||||
jsondiff>=1.1.2
|
||||
py-partiql-parser==0.5.1
|
||||
py-partiql-parser==0.5.2
|
||||
route53 =
|
||||
route53resolver =
|
||||
s3 =
|
||||
PyYAML>=5.1
|
||||
py-partiql-parser==0.5.1
|
||||
py-partiql-parser==0.5.2
|
||||
s3crc32c =
|
||||
PyYAML>=5.1
|
||||
py-partiql-parser==0.5.1
|
||||
py-partiql-parser==0.5.2
|
||||
crc32c
|
||||
s3control =
|
||||
sagemaker =
|
||||
|
@ -1,7 +1,10 @@
|
||||
import bz2
|
||||
import gzip
|
||||
import json
|
||||
|
||||
import boto3
|
||||
import pytest
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from . import s3_aws_verified
|
||||
|
||||
@ -45,6 +48,24 @@ def create_test_files(bucket_name):
|
||||
Key="nested.json",
|
||||
Body=json.dumps(NESTED_JSON),
|
||||
)
|
||||
client.put_object(
|
||||
Bucket=bucket_name,
|
||||
Key="json.gzip",
|
||||
Body=gzip.compress(json.dumps(NESTED_JSON).encode("utf-8")),
|
||||
)
|
||||
client.put_object(
|
||||
Bucket=bucket_name,
|
||||
Key="json.bz2",
|
||||
Body=bz2.compress(json.dumps(NESTED_JSON).encode("utf-8")),
|
||||
)
|
||||
client.put_object(
|
||||
Bucket=bucket_name,
|
||||
Key="csv.gzip",
|
||||
Body=gzip.compress(SIMPLE_CSV.encode("utf-8")),
|
||||
)
|
||||
client.put_object(
|
||||
Bucket=bucket_name, Key="csv.bz2", Body=bz2.compress(SIMPLE_CSV.encode("utf-8"))
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.aws_verified
|
||||
@ -226,3 +247,113 @@ def test_nested_json__select_all(bucket_name=None):
|
||||
assert records[-1] == ","
|
||||
|
||||
assert json.loads(records[:-1]) == NESTED_JSON
|
||||
|
||||
|
||||
@pytest.mark.aws_verified
|
||||
@s3_aws_verified
|
||||
def test_gzipped_json(bucket_name=None):
|
||||
client = boto3.client("s3", "us-east-1")
|
||||
create_test_files(bucket_name)
|
||||
content = client.select_object_content(
|
||||
Bucket=bucket_name,
|
||||
Key="json.gzip",
|
||||
Expression="SELECT count(*) FROM S3Object",
|
||||
ExpressionType="SQL",
|
||||
InputSerialization={"JSON": {"Type": "DOCUMENT"}, "CompressionType": "GZIP"},
|
||||
OutputSerialization={"JSON": {"RecordDelimiter": ","}},
|
||||
)
|
||||
result = list(content["Payload"])
|
||||
assert {"Records": {"Payload": b'{"_1":1},'}} in result
|
||||
|
||||
|
||||
@pytest.mark.aws_verified
|
||||
@s3_aws_verified
|
||||
def test_bzipped_json(bucket_name=None):
|
||||
client = boto3.client("s3", "us-east-1")
|
||||
create_test_files(bucket_name)
|
||||
content = client.select_object_content(
|
||||
Bucket=bucket_name,
|
||||
Key="json.bz2",
|
||||
Expression="SELECT count(*) FROM S3Object",
|
||||
ExpressionType="SQL",
|
||||
InputSerialization={"JSON": {"Type": "DOCUMENT"}, "CompressionType": "BZIP2"},
|
||||
OutputSerialization={"JSON": {"RecordDelimiter": ","}},
|
||||
)
|
||||
result = list(content["Payload"])
|
||||
assert {"Records": {"Payload": b'{"_1":1},'}} in result
|
||||
|
||||
|
||||
@pytest.mark.aws_verified
|
||||
@s3_aws_verified
|
||||
def test_bzipped_csv_to_csv(bucket_name=None):
|
||||
client = boto3.client("s3", "us-east-1")
|
||||
create_test_files(bucket_name)
|
||||
|
||||
# Count Records
|
||||
content = client.select_object_content(
|
||||
Bucket=bucket_name,
|
||||
Key="csv.bz2",
|
||||
Expression="SELECT count(*) FROM S3Object",
|
||||
ExpressionType="SQL",
|
||||
InputSerialization={"CSV": {}, "CompressionType": "BZIP2"},
|
||||
OutputSerialization={"CSV": {"RecordDelimiter": "_", "FieldDelimiter": ":"}},
|
||||
)
|
||||
result = list(content["Payload"])
|
||||
assert {"Records": {"Payload": b"4_"}} in result
|
||||
|
||||
# Count Records
|
||||
content = client.select_object_content(
|
||||
Bucket=bucket_name,
|
||||
Key="csv.bz2",
|
||||
Expression="SELECT count(*) FROM S3Object",
|
||||
ExpressionType="SQL",
|
||||
InputSerialization={"CSV": {}, "CompressionType": "BZIP2"},
|
||||
OutputSerialization={"CSV": {}},
|
||||
)
|
||||
result = list(content["Payload"])
|
||||
assert {"Records": {"Payload": b"4\n"}} in result
|
||||
|
||||
# Mirror records
|
||||
content = client.select_object_content(
|
||||
Bucket=bucket_name,
|
||||
Key="csv.bz2",
|
||||
Expression="SELECT * FROM S3Object",
|
||||
ExpressionType="SQL",
|
||||
InputSerialization={"CSV": {}, "CompressionType": "BZIP2"},
|
||||
OutputSerialization={"CSV": {}},
|
||||
)
|
||||
result = list(content["Payload"])
|
||||
assert {"Records": {"Payload": b"a,b,c\ne,r,f\ny,u,i\nq,w,y\n"}} in result
|
||||
|
||||
# Mirror records, specifying output format
|
||||
content = client.select_object_content(
|
||||
Bucket=bucket_name,
|
||||
Key="csv.bz2",
|
||||
Expression="SELECT * FROM S3Object",
|
||||
ExpressionType="SQL",
|
||||
InputSerialization={"CSV": {}, "CompressionType": "BZIP2"},
|
||||
OutputSerialization={"CSV": {"RecordDelimiter": "\n", "FieldDelimiter": ":"}},
|
||||
)
|
||||
result = list(content["Payload"])
|
||||
assert {"Records": {"Payload": b"a:b:c\ne:r:f\ny:u:i\nq:w:y\n"}} in result
|
||||
|
||||
|
||||
@pytest.mark.aws_verified
|
||||
@s3_aws_verified
|
||||
def test_select_unknown_key(bucket_name=None):
|
||||
client = boto3.client("s3", "us-east-1")
|
||||
with pytest.raises(ClientError) as exc:
|
||||
client.select_object_content(
|
||||
Bucket=bucket_name,
|
||||
Key="unknown",
|
||||
Expression="SELECT count(*) FROM S3Object",
|
||||
ExpressionType="SQL",
|
||||
InputSerialization={"CSV": {}, "CompressionType": "BZIP2"},
|
||||
OutputSerialization={
|
||||
"CSV": {"RecordDelimiter": "\n", "FieldDelimiter": ":"}
|
||||
},
|
||||
)
|
||||
err = exc.value.response["Error"]
|
||||
assert err["Code"] == "NoSuchKey"
|
||||
assert err["Message"] == "The specified key does not exist."
|
||||
assert err["Key"] == "unknown"
|
||||
|
Loading…
Reference in New Issue
Block a user