diff --git a/moto/s3/models.py b/moto/s3/models.py index a65f278b3..3837e7e59 100644 --- a/moto/s3/models.py +++ b/moto/s3/models.py @@ -1613,6 +1613,7 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): super().__init__(region_name, account_id) self.buckets: Dict[str, FakeBucket] = {} self.tagger = TaggingService() + self._pagination_tokens: Dict[str, str] = {} def reset(self) -> None: # For every key and multipart, Moto opens a TemporaryFile to write the value of those keys @@ -2442,8 +2443,13 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): return multipart.set_part(part_id, src_value) def list_objects( - self, bucket: FakeBucket, prefix: Optional[str], delimiter: Optional[str] - ) -> Tuple[Set[FakeKey], Set[str]]: + self, + bucket: FakeBucket, + prefix: Optional[str], + delimiter: Optional[str], + marker: Optional[str], + max_keys: int, + ) -> Tuple[Set[FakeKey], Set[str], bool, Optional[str]]: key_results = set() folder_results = set() if prefix: @@ -2474,16 +2480,70 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): folder_name for folder_name in sorted(folder_results, key=lambda key: key) ] - return key_results, folder_results + if marker: + limit = self._pagination_tokens.get(marker) + key_results = self._get_results_from_token(key_results, limit) + + key_results, is_truncated, next_marker = self._truncate_result( + key_results, max_keys + ) + + return key_results, folder_results, is_truncated, next_marker def list_objects_v2( - self, bucket: FakeBucket, prefix: Optional[str], delimiter: Optional[str] - ) -> Set[Union[FakeKey, str]]: - result_keys, result_folders = self.list_objects(bucket, prefix, delimiter) + self, + bucket: FakeBucket, + prefix: Optional[str], + delimiter: Optional[str], + continuation_token: Optional[str], + start_after: Optional[str], + max_keys: int, + ) -> Tuple[Set[Union[FakeKey, str]], bool, Optional[str]]: + result_keys, result_folders, _, _ = self.list_objects( + bucket, prefix, delimiter, marker=None, max_keys=1000 + ) # sort the combination of folders and keys into lexicographical order all_keys = result_keys + result_folders # type: ignore all_keys.sort(key=self._get_name) - return all_keys + + if continuation_token or start_after: + limit = ( + self._pagination_tokens.get(continuation_token) + if continuation_token + else start_after + ) + all_keys = self._get_results_from_token(all_keys, limit) + + truncated_keys, is_truncated, next_continuation_token = self._truncate_result( + all_keys, max_keys + ) + + return truncated_keys, is_truncated, next_continuation_token + + def _get_results_from_token(self, result_keys: Any, token: Any) -> Any: + continuation_index = 0 + for key in result_keys: + if (key.name if isinstance(key, FakeKey) else key) > token: + break + continuation_index += 1 + return result_keys[continuation_index:] + + def _truncate_result(self, result_keys: Any, max_keys: int) -> Any: + if max_keys == 0: + result_keys = [] + is_truncated = True + next_continuation_token = None + elif len(result_keys) > max_keys: + is_truncated = "true" # type: ignore + result_keys = result_keys[:max_keys] + item = result_keys[-1] + key_id = item.name if isinstance(item, FakeKey) else item + next_continuation_token = md5_hash(key_id.encode("utf-8")).hexdigest() + self._pagination_tokens[next_continuation_token] = key_id + else: + is_truncated = "false" # type: ignore + next_continuation_token = None + return result_keys, is_truncated, next_continuation_token @staticmethod def _get_name(key: Union[str, FakeKey]) -> str: diff --git a/moto/s3/responses.py b/moto/s3/responses.py index 1b189521f..bacea5426 100644 --- a/moto/s3/responses.py +++ b/moto/s3/responses.py @@ -690,16 +690,19 @@ class S3Response(BaseResponse): delimiter = querystring.get("delimiter", [None])[0] max_keys = int(querystring.get("max-keys", [1000])[0]) marker = querystring.get("marker", [None])[0] - result_keys, result_folders = self.backend.list_objects( - bucket, prefix, delimiter - ) encoding_type = querystring.get("encoding-type", [None])[0] - if marker: - result_keys = self._get_results_from_token(result_keys, marker) - - result_keys, is_truncated, next_marker = self._truncate_result( - result_keys, max_keys + ( + result_keys, + result_folders, + is_truncated, + next_marker, + ) = self.backend.list_objects( + bucket=bucket, + prefix=prefix, + delimiter=delimiter, + marker=marker, + max_keys=max_keys, ) template = self.response_template(S3_BUCKET_GET_RESPONSE) @@ -746,20 +749,25 @@ class S3Response(BaseResponse): if prefix and isinstance(prefix, bytes): prefix = prefix.decode("utf-8") delimiter = querystring.get("delimiter", [None])[0] - all_keys = self.backend.list_objects_v2(bucket, prefix, delimiter) fetch_owner = querystring.get("fetch-owner", [False])[0] max_keys = int(querystring.get("max-keys", [1000])[0]) start_after = querystring.get("start-after", [None])[0] encoding_type = querystring.get("encoding-type", [None])[0] - if continuation_token or start_after: - limit = continuation_token or start_after - all_keys = self._get_results_from_token(all_keys, limit) - - truncated_keys, is_truncated, next_continuation_token = self._truncate_result( - all_keys, max_keys + ( + truncated_keys, + is_truncated, + next_continuation_token, + ) = self.backend.list_objects_v2( + bucket=bucket, + prefix=prefix, + delimiter=delimiter, + continuation_token=continuation_token, + start_after=start_after, + max_keys=max_keys, ) + result_keys, result_folders = self._split_truncated_keys(truncated_keys) key_count = len(result_keys) + len(result_folders) @@ -796,29 +804,6 @@ class S3Response(BaseResponse): result_folders.append(key) return result_keys, result_folders - def _get_results_from_token(self, result_keys: Any, token: Any) -> Any: - continuation_index = 0 - for key in result_keys: - if (key.name if isinstance(key, FakeKey) else key) > token: - break - continuation_index += 1 - return result_keys[continuation_index:] - - def _truncate_result(self, result_keys: Any, max_keys: int) -> Any: - if max_keys == 0: - result_keys = [] - is_truncated = True - next_continuation_token = None - elif len(result_keys) > max_keys: - is_truncated = "true" # type: ignore - result_keys = result_keys[:max_keys] - item = result_keys[-1] - next_continuation_token = item.name if isinstance(item, FakeKey) else item - else: - is_truncated = "false" # type: ignore - next_continuation_token = None - return result_keys, is_truncated, next_continuation_token - def _body_contains_location_constraint(self, body: bytes) -> bool: if body: try: diff --git a/tests/test_s3/test_s3.py b/tests/test_s3/test_s3.py index 550dd720f..bf8510863 100644 --- a/tests/test_s3/test_s3.py +++ b/tests/test_s3/test_s3.py @@ -1441,7 +1441,7 @@ def test_list_objects_v2_truncate_combined_keys_and_folders(): assert len(resp["CommonPrefixes"]) == 1 assert resp["CommonPrefixes"][0]["Prefix"] == "1/" - last_tail = resp["NextContinuationToken"] + last_tail = resp["Contents"][-1]["Key"] resp = s3_client.list_objects_v2( Bucket="mybucket", MaxKeys=2, Prefix="", Delimiter="/", StartAfter=last_tail )