diff --git a/IMPLEMENTATION_COVERAGE.md b/IMPLEMENTATION_COVERAGE.md index 9454dc123..c9882c70f 100644 --- a/IMPLEMENTATION_COVERAGE.md +++ b/IMPLEMENTATION_COVERAGE.md @@ -5567,7 +5567,7 @@ ## s3
-63% implemented +64% implemented - [X] abort_multipart_upload - [X] complete_multipart_upload @@ -5662,7 +5662,7 @@ - [ ] put_object_tagging - [ ] put_public_access_block - [ ] restore_object -- [ ] select_object_content +- [X] select_object_content - [X] upload_part - [ ] upload_part_copy - [ ] write_get_object_response diff --git a/docs/docs/services/s3.rst b/docs/docs/services/s3.rst index c05e20bb1..a0a7920d5 100644 --- a/docs/docs/services/s3.rst +++ b/docs/docs/services/s3.rst @@ -144,7 +144,17 @@ s3 - [ ] put_object_tagging - [ ] put_public_access_block - [ ] restore_object -- [ ] select_object_content +- [X] select_object_content + + Highly experimental. Please raise an issue if you find any inconsistencies/bugs. + + Known missing features: + - Function aliases (count(*) as cnt) + - Most functions (only count() is supported) + - Result is always in JSON + - FieldDelimiters and RecordDelimiters are ignored + + - [X] upload_part - [ ] upload_part_copy - [ ] write_get_object_response diff --git a/moto/s3/models.py b/moto/s3/models.py index 31efe31ad..a7918f3d5 100644 --- a/moto/s3/models.py +++ b/moto/s3/models.py @@ -52,6 +52,7 @@ from moto.s3.exceptions import ( ) from .cloud_formation import cfn_to_api_encryption, is_replacement_update from . import notifications +from .select_object_content import parse_query, csv_to_json from .utils import clean_key_name, _VersionedKeyStore, undo_clean_key_name from .utils import ARCHIVE_STORAGE_CLASSES, STORAGE_CLASS from ..events.notifications import send_notification as events_send_notification @@ -2308,6 +2309,35 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): bucket = self.get_bucket(bucket_name) return bucket.notification_configuration + def select_object_content( + self, + bucket_name: str, + key_name: str, + select_query: str, + input_details: Dict[str, Any], + output_details: Dict[str, Any], # pylint: disable=unused-argument + ): + """ + Highly experimental. Please raise an issue if you find any inconsistencies/bugs. + + Known missing features: + - Function aliases (count(*) as cnt) + - Most functions (only count() is supported) + - Result is always in JSON + - FieldDelimiters and RecordDelimiters are ignored + """ + self.get_bucket(bucket_name) + key = self.get_object(bucket_name, key_name) + query_input = key.value.decode("utf-8") + if "CSV" in input_details: + # input is in CSV - we need to convert it to JSON before parsing + use_headers = input_details["CSV"].get("FileHeaderInfo", "") == "USE" + query_input = csv_to_json(query_input, use_headers) + return [ + json.dumps(x, indent=None, separators=(",", ":")).encode("utf-8") + for x in parse_query(query_input, select_query) + ] + s3_backends = BackendDict( S3Backend, service_name="s3", use_boto3_regions=False, additional_regions=["global"] diff --git a/moto/s3/responses.py b/moto/s3/responses.py index c9f93393b..62ba97692 100644 --- a/moto/s3/responses.py +++ b/moto/s3/responses.py @@ -54,6 +54,7 @@ from .exceptions import ( ) from .models import s3_backends, S3Backend from .models import get_canned_acl, FakeGrantee, FakeGrant, FakeAcl, FakeKey +from .select_object_content import serialize_select from .utils import ( bucket_name_from_url, metadata_from_headers, @@ -134,6 +135,7 @@ ACTION_MAP = { "uploads": "PutObject", "restore": "RestoreObject", "uploadId": "PutObject", + "select": "SelectObject", }, }, "CONTROL": { @@ -2120,6 +2122,15 @@ class S3Response(BaseResponse): r = 200 key.restore(int(days)) return r, {}, "" + elif "select" in query: + request = xmltodict.parse(body)["SelectObjectContentRequest"] + select_query = request["Expression"] + input_details = request["InputSerialization"] + output_details = request["OutputSerialization"] + results = self.backend.select_object_content( + bucket_name, key_name, select_query, input_details, output_details + ) + return 200, {}, serialize_select(results) else: raise NotImplementedError( diff --git a/moto/s3/select_object_content.py b/moto/s3/select_object_content.py new file mode 100644 index 000000000..9a8964941 --- /dev/null +++ b/moto/s3/select_object_content.py @@ -0,0 +1,56 @@ +import binascii +import struct +from typing import List +from py_partiql_parser import Parser +from py_partiql_parser._internal.csv_converter import ( # noqa # pylint: disable=unused-import + csv_to_json, +) + + +def parse_query(text_input, query): + return Parser(source_data={"s3object": text_input}).parse(query) + + +def _create_header(key: bytes, value: bytes): + return struct.pack("b", len(key)) + key + struct.pack("!bh", 7, len(value)) + value + + +def _create_message(content_type, event_type, payload): + headers = _create_header(b":message-type", b"event") + headers += _create_header(b":event-type", event_type) + if content_type is not None: + headers += _create_header(b":content-type", content_type) + + headers_length = struct.pack("!I", len(headers)) + total_length = struct.pack("!I", len(payload) + len(headers) + 16) + prelude = total_length + headers_length + + prelude_crc = struct.pack("!I", binascii.crc32(total_length + headers_length)) + message_crc = struct.pack( + "!I", binascii.crc32(prelude + prelude_crc + headers + payload) + ) + + return prelude + prelude_crc + headers + payload + message_crc + + +def _create_stats_message(): + stats = b"""242422""" + return _create_message(content_type=b"text/xml", event_type=b"Stats", payload=stats) + + +def _create_data_message(payload: bytes): + # https://docs.aws.amazon.com/AmazonS3/latest/API/RESTSelectObjectAppendix.html + return _create_message( + content_type=b"application/octet-stream", event_type=b"Records", payload=payload + ) + + +def _create_end_message(): + return _create_message(content_type=None, event_type=b"End", payload=b"") + + +def serialize_select(data_list: List[bytes]): + response = b"" + for data in data_list: + response += _create_data_message(data + b",") + return response + _create_stats_message() + _create_end_message() diff --git a/setup.cfg b/setup.cfg index 6f2cb44c0..0b75c56c0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -52,6 +52,7 @@ all = openapi-spec-validator>=0.2.8 pyparsing>=3.0.7 jsondiff>=1.1.2 + py-partiql-parser==0.1.0 aws-xray-sdk!=0.96,>=0.93 setuptools server = @@ -65,6 +66,7 @@ server = openapi-spec-validator>=0.2.8 pyparsing>=3.0.7 jsondiff>=1.1.2 + py-partiql-parser==0.1.0 aws-xray-sdk!=0.96,>=0.93 setuptools flask!=2.2.0,!=2.2.1 @@ -98,6 +100,7 @@ cloudformation = openapi-spec-validator>=0.2.8 pyparsing>=3.0.7 jsondiff>=1.1.2 + py-partiql-parser==0.1.0 aws-xray-sdk!=0.96,>=0.93 setuptools cloudfront = @@ -174,7 +177,9 @@ resourcegroups = resourcegroupstaggingapi = route53 = route53resolver = sshpubkeys>=3.1.0 -s3 = PyYAML>=5.1 +s3 = + PyYAML>=5.1 + py-partiql-parser==0.1.0 s3control = sagemaker = sdb = diff --git a/tests/test_s3/test_s3_select.py b/tests/test_s3/test_s3_select.py new file mode 100644 index 000000000..aea6cc9eb --- /dev/null +++ b/tests/test_s3/test_s3_select.py @@ -0,0 +1,114 @@ +import boto3 +import json +import pytest +from moto import mock_s3 +from unittest import TestCase +from uuid import uuid4 + + +SIMPLE_JSON = {"a1": "b1", "a2": "b2"} +SIMPLE_JSON2 = {"a1": "b2", "a3": "b3"} +SIMPLE_LIST = [SIMPLE_JSON, SIMPLE_JSON2] +SIMPLE_CSV = """a,b,c +e,r,f +y,u,i +q,w,y""" + + +@mock_s3 +class TestS3Select(TestCase): + def setUp(self) -> None: + self.client = boto3.client("s3", "us-east-1") + self.bucket_name = str(uuid4()) + self.client.create_bucket(Bucket=self.bucket_name) + self.client.put_object( + Bucket=self.bucket_name, Key="simple.json", Body=json.dumps(SIMPLE_JSON) + ) + self.client.put_object( + Bucket=self.bucket_name, Key="list.json", Body=json.dumps(SIMPLE_LIST) + ) + self.client.put_object( + Bucket=self.bucket_name, Key="simple_csv", Body=SIMPLE_CSV + ) + + def tearDown(self) -> None: + self.client.delete_object(Bucket=self.bucket_name, Key="list.json") + self.client.delete_object(Bucket=self.bucket_name, Key="simple.json") + self.client.delete_object(Bucket=self.bucket_name, Key="simple_csv") + self.client.delete_bucket(Bucket=self.bucket_name) + + def test_query_all(self): + x = self.client.select_object_content( + Bucket=self.bucket_name, + Key="simple.json", + Expression="SELECT * FROM S3Object", + ExpressionType="SQL", + InputSerialization={"JSON": {"Type": "DOCUMENT"}}, + OutputSerialization={"JSON": {"RecordDelimiter": ","}}, + ) + result = list(x["Payload"]) + result.should.contain({"Records": {"Payload": b'{"a1":"b1","a2":"b2"},'}}) + result.should.contain( + { + "Stats": { + "Details": { + "BytesScanned": 24, + "BytesProcessed": 24, + "BytesReturned": 22, + } + } + } + ) + result.should.contain({"End": {}}) + + def test_count_function(self): + x = self.client.select_object_content( + Bucket=self.bucket_name, + Key="simple.json", + Expression="SELECT count(*) FROM S3Object", + ExpressionType="SQL", + InputSerialization={"JSON": {"Type": "DOCUMENT"}}, + OutputSerialization={"JSON": {"RecordDelimiter": ","}}, + ) + result = list(x["Payload"]) + result.should.contain({"Records": {"Payload": b'{"_1":1},'}}) + + @pytest.mark.xfail(message="Not yet implement in our parser") + def test_count_as(self): + x = self.client.select_object_content( + Bucket=self.bucket_name, + Key="simple.json", + Expression="SELECT count(*) as cnt FROM S3Object", + ExpressionType="SQL", + InputSerialization={"JSON": {"Type": "DOCUMENT"}}, + OutputSerialization={"JSON": {"RecordDelimiter": ","}}, + ) + result = list(x["Payload"]) + result.should.contain({"Records": {"Payload": b'{"cnt":1},'}}) + + @pytest.mark.xfail(message="Not yet implement in our parser") + def test_count_list_as(self): + x = self.client.select_object_content( + Bucket=self.bucket_name, + Key="list.json", + Expression="SELECT count(*) as cnt FROM S3Object", + ExpressionType="SQL", + InputSerialization={"JSON": {"Type": "DOCUMENT"}}, + OutputSerialization={"JSON": {"RecordDelimiter": ","}}, + ) + result = list(x["Payload"]) + result.should.contain({"Records": {"Payload": b'{"cnt":1},'}}) + + def test_count_csv(self): + x = self.client.select_object_content( + Bucket=self.bucket_name, + Key="simple_csv", + Expression="SELECT count(*) FROM S3Object", + ExpressionType="SQL", + InputSerialization={ + "CSV": {"FileHeaderInfo": "USE", "FieldDelimiter": ","} + }, + OutputSerialization={"JSON": {"RecordDelimiter": ","}}, + ) + result = list(x["Payload"]) + result.should.contain({"Records": {"Payload": b'{"_1":3},'}})