S3: list_objects_v2() should have a hashed NextContinuationToken (#7187)

This commit is contained in:
Bert Blommers 2024-01-14 13:02:33 +00:00
parent 1f1e0caca3
commit 5aa3cc9d73
3 changed files with 91 additions and 46 deletions

View File

@ -1613,6 +1613,7 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider):
super().__init__(region_name, account_id)
self.buckets: Dict[str, FakeBucket] = {}
self.tagger = TaggingService()
self._pagination_tokens: Dict[str, str] = {}
def reset(self) -> None:
# For every key and multipart, Moto opens a TemporaryFile to write the value of those keys
@ -2442,8 +2443,13 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider):
return multipart.set_part(part_id, src_value)
def list_objects(
self, bucket: FakeBucket, prefix: Optional[str], delimiter: Optional[str]
) -> Tuple[Set[FakeKey], Set[str]]:
self,
bucket: FakeBucket,
prefix: Optional[str],
delimiter: Optional[str],
marker: Optional[str],
max_keys: int,
) -> Tuple[Set[FakeKey], Set[str], bool, Optional[str]]:
key_results = set()
folder_results = set()
if prefix:
@ -2474,16 +2480,70 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider):
folder_name for folder_name in sorted(folder_results, key=lambda key: key)
]
return key_results, folder_results
if marker:
limit = self._pagination_tokens.get(marker)
key_results = self._get_results_from_token(key_results, limit)
key_results, is_truncated, next_marker = self._truncate_result(
key_results, max_keys
)
return key_results, folder_results, is_truncated, next_marker
def list_objects_v2(
self, bucket: FakeBucket, prefix: Optional[str], delimiter: Optional[str]
) -> Set[Union[FakeKey, str]]:
result_keys, result_folders = self.list_objects(bucket, prefix, delimiter)
self,
bucket: FakeBucket,
prefix: Optional[str],
delimiter: Optional[str],
continuation_token: Optional[str],
start_after: Optional[str],
max_keys: int,
) -> Tuple[Set[Union[FakeKey, str]], bool, Optional[str]]:
result_keys, result_folders, _, _ = self.list_objects(
bucket, prefix, delimiter, marker=None, max_keys=1000
)
# sort the combination of folders and keys into lexicographical order
all_keys = result_keys + result_folders # type: ignore
all_keys.sort(key=self._get_name)
return all_keys
if continuation_token or start_after:
limit = (
self._pagination_tokens.get(continuation_token)
if continuation_token
else start_after
)
all_keys = self._get_results_from_token(all_keys, limit)
truncated_keys, is_truncated, next_continuation_token = self._truncate_result(
all_keys, max_keys
)
return truncated_keys, is_truncated, next_continuation_token
def _get_results_from_token(self, result_keys: Any, token: Any) -> Any:
continuation_index = 0
for key in result_keys:
if (key.name if isinstance(key, FakeKey) else key) > token:
break
continuation_index += 1
return result_keys[continuation_index:]
def _truncate_result(self, result_keys: Any, max_keys: int) -> Any:
if max_keys == 0:
result_keys = []
is_truncated = True
next_continuation_token = None
elif len(result_keys) > max_keys:
is_truncated = "true" # type: ignore
result_keys = result_keys[:max_keys]
item = result_keys[-1]
key_id = item.name if isinstance(item, FakeKey) else item
next_continuation_token = md5_hash(key_id.encode("utf-8")).hexdigest()
self._pagination_tokens[next_continuation_token] = key_id
else:
is_truncated = "false" # type: ignore
next_continuation_token = None
return result_keys, is_truncated, next_continuation_token
@staticmethod
def _get_name(key: Union[str, FakeKey]) -> str:

View File

@ -690,16 +690,19 @@ class S3Response(BaseResponse):
delimiter = querystring.get("delimiter", [None])[0]
max_keys = int(querystring.get("max-keys", [1000])[0])
marker = querystring.get("marker", [None])[0]
result_keys, result_folders = self.backend.list_objects(
bucket, prefix, delimiter
)
encoding_type = querystring.get("encoding-type", [None])[0]
if marker:
result_keys = self._get_results_from_token(result_keys, marker)
result_keys, is_truncated, next_marker = self._truncate_result(
result_keys, max_keys
(
result_keys,
result_folders,
is_truncated,
next_marker,
) = self.backend.list_objects(
bucket=bucket,
prefix=prefix,
delimiter=delimiter,
marker=marker,
max_keys=max_keys,
)
template = self.response_template(S3_BUCKET_GET_RESPONSE)
@ -746,20 +749,25 @@ class S3Response(BaseResponse):
if prefix and isinstance(prefix, bytes):
prefix = prefix.decode("utf-8")
delimiter = querystring.get("delimiter", [None])[0]
all_keys = self.backend.list_objects_v2(bucket, prefix, delimiter)
fetch_owner = querystring.get("fetch-owner", [False])[0]
max_keys = int(querystring.get("max-keys", [1000])[0])
start_after = querystring.get("start-after", [None])[0]
encoding_type = querystring.get("encoding-type", [None])[0]
if continuation_token or start_after:
limit = continuation_token or start_after
all_keys = self._get_results_from_token(all_keys, limit)
truncated_keys, is_truncated, next_continuation_token = self._truncate_result(
all_keys, max_keys
(
truncated_keys,
is_truncated,
next_continuation_token,
) = self.backend.list_objects_v2(
bucket=bucket,
prefix=prefix,
delimiter=delimiter,
continuation_token=continuation_token,
start_after=start_after,
max_keys=max_keys,
)
result_keys, result_folders = self._split_truncated_keys(truncated_keys)
key_count = len(result_keys) + len(result_folders)
@ -796,29 +804,6 @@ class S3Response(BaseResponse):
result_folders.append(key)
return result_keys, result_folders
def _get_results_from_token(self, result_keys: Any, token: Any) -> Any:
continuation_index = 0
for key in result_keys:
if (key.name if isinstance(key, FakeKey) else key) > token:
break
continuation_index += 1
return result_keys[continuation_index:]
def _truncate_result(self, result_keys: Any, max_keys: int) -> Any:
if max_keys == 0:
result_keys = []
is_truncated = True
next_continuation_token = None
elif len(result_keys) > max_keys:
is_truncated = "true" # type: ignore
result_keys = result_keys[:max_keys]
item = result_keys[-1]
next_continuation_token = item.name if isinstance(item, FakeKey) else item
else:
is_truncated = "false" # type: ignore
next_continuation_token = None
return result_keys, is_truncated, next_continuation_token
def _body_contains_location_constraint(self, body: bytes) -> bool:
if body:
try:

View File

@ -1441,7 +1441,7 @@ def test_list_objects_v2_truncate_combined_keys_and_folders():
assert len(resp["CommonPrefixes"]) == 1
assert resp["CommonPrefixes"][0]["Prefix"] == "1/"
last_tail = resp["NextContinuationToken"]
last_tail = resp["Contents"][-1]["Key"]
resp = s3_client.list_objects_v2(
Bucket="mybucket", MaxKeys=2, Prefix="", Delimiter="/", StartAfter=last_tail
)