From 7bdea2688b95cae35a3375374935c038f2a3b825 Mon Sep 17 00:00:00 2001 From: Viren Nadkarni Date: Wed, 31 May 2023 15:29:16 +0530 Subject: [PATCH] S3: Cross-account access for buckets (#6333) --- moto/s3/models.py | 40 +++++++++++-- moto/s3/responses.py | 22 ++++--- tests/test_s3/test_s3.py | 45 +++++++++++++- tests/test_s3/test_s3_file_handles.py | 85 ++++++++++++++++----------- 4 files changed, 143 insertions(+), 49 deletions(-) diff --git a/moto/s3/models.py b/moto/s3/models.py index ae744e532..5a456f4ae 100644 --- a/moto/s3/models.py +++ b/moto/s3/models.py @@ -937,6 +937,7 @@ class FakeBucket(CloudFormationModel): self.default_lock_days: Optional[int] = 0 self.default_lock_years: Optional[int] = 0 self.ownership_rule: Optional[Dict[str, Any]] = None + s3_backends.bucket_accounts[name] = account_id @property def location(self) -> str: @@ -1494,6 +1495,7 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): key.dispose() for part in bucket.multiparts.values(): part.dispose() + s3_backends.bucket_accounts.pop(bucket.name, None) # # Second, go through the list of instances # It may contain FakeKeys created earlier, which are no longer tracked @@ -1614,7 +1616,7 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): return metrics def create_bucket(self, bucket_name: str, region_name: str) -> FakeBucket: - if bucket_name in self.buckets: + if bucket_name in s3_backends.bucket_accounts.keys(): raise BucketAlreadyExists(bucket=bucket_name) if not MIN_BUCKET_NAME_LENGTH <= len(bucket_name) <= MAX_BUCKET_NAME_LENGTH: raise InvalidBucketName() @@ -1646,10 +1648,14 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): return list(self.buckets.values()) def get_bucket(self, bucket_name: str) -> FakeBucket: - try: + if bucket_name in self.buckets: return self.buckets[bucket_name] - except KeyError: - raise MissingBucket(bucket=bucket_name) + + if bucket_name in s3_backends.bucket_accounts: + account_id = s3_backends.bucket_accounts[bucket_name] + return s3_backends[account_id]["global"].get_bucket(bucket_name) + + raise MissingBucket(bucket=bucket_name) def head_bucket(self, bucket_name: str) -> FakeBucket: return self.get_bucket(bucket_name) @@ -1660,6 +1666,7 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): # Can't delete a bucket with keys return None else: + s3_backends.bucket_accounts.pop(bucket_name, None) return self.buckets.pop(bucket_name) def put_bucket_versioning(self, bucket_name: str, status: str) -> None: @@ -1957,6 +1964,7 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): if not key_is_clean: key_name = clean_key_name(key_name) bucket = self.get_bucket(bucket_name) + key = None if bucket: @@ -2497,6 +2505,28 @@ class S3Backend(BaseBackend, CloudWatchMetricProvider): ] -s3_backends = BackendDict( +class S3BackendDict(BackendDict): + """ + Encapsulation class to hold S3 backends. + + This is specialised to include additional attributes to help multi-account support in S3 + but is otherwise identical to the superclass. + """ + + def __init__( + self, + backend: Any, + service_name: str, + use_boto3_regions: bool = True, + additional_regions: Optional[List[str]] = None, + ): + super().__init__(backend, service_name, use_boto3_regions, additional_regions) + + # Maps bucket names to account IDs. This is used to locate the exact S3Backend + # holding the bucket and to maintain the common bucket namespace. + self.bucket_accounts: dict[str, str] = {} + + +s3_backends = S3BackendDict( S3Backend, service_name="s3", use_boto3_regions=False, additional_regions=["global"] ) diff --git a/moto/s3/responses.py b/moto/s3/responses.py index 49b8be34b..c263039e7 100644 --- a/moto/s3/responses.py +++ b/moto/s3/responses.py @@ -909,15 +909,19 @@ class S3Response(BaseResponse): new_bucket = self.backend.create_bucket(bucket_name, region_name) except BucketAlreadyExists: new_bucket = self.backend.get_bucket(bucket_name) - if ( - new_bucket.region_name == DEFAULT_REGION_NAME - and region_name == DEFAULT_REGION_NAME - ): - # us-east-1 has different behavior - creating a bucket there is an idempotent operation - pass + if new_bucket.account_id == self.get_current_account(): + # special cases when the bucket belongs to self + if ( + new_bucket.region_name == DEFAULT_REGION_NAME + and region_name == DEFAULT_REGION_NAME + ): + # us-east-1 has different behavior - creating a bucket there is an idempotent operation + pass + else: + template = self.response_template(S3_DUPLICATE_BUCKET_ERROR) + return 409, {}, template.render(bucket_name=bucket_name) else: - template = self.response_template(S3_DUPLICATE_BUCKET_ERROR) - return 409, {}, template.render(bucket_name=bucket_name) + raise if "x-amz-acl" in request.headers: # TODO: Support the XML-based ACL format @@ -1519,7 +1523,7 @@ class S3Response(BaseResponse): acl = self._acl_from_headers(request.headers) if acl is None: - acl = self.backend.get_bucket(bucket_name).acl + acl = bucket.acl tagging = self._tagging_from_headers(request.headers) if "versionId" in query: diff --git a/tests/test_s3/test_s3.py b/tests/test_s3/test_s3.py index 3931c274c..73fd49e24 100644 --- a/tests/test_s3/test_s3.py +++ b/tests/test_s3/test_s3.py @@ -17,7 +17,7 @@ import requests from moto.moto_api import state_manager from moto.s3.responses import DEFAULT_REGION_NAME -from unittest import SkipTest +from unittest import SkipTest, mock import pytest import sure # noqa # pylint: disable=unused-import @@ -3377,3 +3377,46 @@ def test_checksum_response(algorithm): ChecksumAlgorithm=algorithm, ) assert f"Checksum{algorithm}" in response + + +@mock_s3 +def test_cross_account_region_access(): + if settings.TEST_SERVER_MODE: + raise SkipTest("Multi-accounts env config only works serverside") + + client1 = boto3.client("s3", region_name=DEFAULT_REGION_NAME) + client2 = boto3.client("s3", region_name=DEFAULT_REGION_NAME) + + account2 = "222222222222" + bucket_name = "cross-account-bucket" + key = "test-key" + + # Create a bucket in the default account + client1.create_bucket(Bucket=bucket_name) + client1.put_object(Bucket=bucket_name, Key=key, Body=b"data") + + with mock.patch.dict(os.environ, {"MOTO_ACCOUNT_ID": account2}): + # Ensure the bucket can be retrieved from another account + response = client2.list_objects(Bucket=bucket_name) + response.should.have.key("Contents").length_of(1) + response["Contents"][0]["Key"].should.equal(key) + + assert client2.get_object(Bucket=bucket_name, Key=key) + + assert client2.put_object(Bucket=bucket_name, Key=key, Body=b"kaytranada") + + # Ensure bucket namespace is shared across accounts + with pytest.raises(ClientError) as exc: + client2.create_bucket(Bucket=bucket_name) + exc.value.response["Error"]["Code"].should.equal("BucketAlreadyExists") + exc.value.response["Error"]["Message"].should.equal( + "The requested bucket name is not available. The bucket " + "namespace is shared by all users of the system. Please " + "select a different name and try again" + ) + + # Ensure bucket name can be reused if it is deleted + client1.delete_object(Bucket=bucket_name, Key=key) + client1.delete_bucket(Bucket=bucket_name) + with mock.patch.dict(os.environ, {"MOTO_ACCOUNT_ID": account2}): + assert client2.create_bucket(Bucket=bucket_name) diff --git a/tests/test_s3/test_s3_file_handles.py b/tests/test_s3/test_s3_file_handles.py index 70602145d..bd601b82b 100644 --- a/tests/test_s3/test_s3_file_handles.py +++ b/tests/test_s3/test_s3_file_handles.py @@ -5,10 +5,17 @@ import warnings from functools import wraps from moto import settings, mock_s3 from moto.dynamodb.models import DynamoDBBackend -from moto.s3 import models as s3model +from moto.s3 import models as s3model, s3_backends from moto.s3.responses import S3ResponseInstance from unittest import SkipTest, TestCase +from tests import DEFAULT_ACCOUNT_ID + + +TEST_BUCKET = "my-bucket" +TEST_BUCKET_VERSIONED = "versioned-bucket" +TEST_KEY = "my-key" + def verify_zero_warnings(f): @wraps(f) @@ -39,10 +46,20 @@ class TestS3FileHandleClosures(TestCase): def setUp(self) -> None: if settings.TEST_SERVER_MODE: raise SkipTest("No point in testing ServerMode, we're not using boto3") - self.s3 = s3model.S3Backend("us-west-1", "1234") - self.s3.create_bucket("my-bucket", "us-west-1") - self.s3.create_bucket("versioned-bucket", "us-west-1") - self.s3.put_object("my-bucket", "my-key", "x" * 10_000_000) + self.s3 = s3_backends[DEFAULT_ACCOUNT_ID]["global"] + self.s3.create_bucket(TEST_BUCKET, "us-west-1") + self.s3.create_bucket(TEST_BUCKET_VERSIONED, "us-west-1") + self.s3.put_object(TEST_BUCKET, TEST_KEY, "x" * 10_000_000) + + def tearDown(self) -> None: + for bucket_name in ( + TEST_BUCKET, + TEST_BUCKET_VERSIONED, + ): + keys = list(self.s3.get_bucket(bucket_name).keys.keys()) + for key in keys: + self.s3.delete_object(bucket_name, key) + self.s3.delete_bucket(bucket_name) @verify_zero_warnings def test_upload_large_file(self): @@ -52,28 +69,28 @@ class TestS3FileHandleClosures(TestCase): @verify_zero_warnings def test_delete_large_file(self): - self.s3.delete_object(bucket_name="my-bucket", key_name="my-key") + self.s3.delete_object(bucket_name=TEST_BUCKET, key_name=TEST_KEY) @verify_zero_warnings def test_overwriting_file(self): - self.s3.put_object("my-bucket", "my-key", "b" * 10_000_000) + self.s3.put_object(TEST_BUCKET, TEST_KEY, "b" * 10_000_000) @verify_zero_warnings def test_versioned_file(self): - self.s3.put_bucket_versioning("my-bucket", "Enabled") - self.s3.put_object("my-bucket", "my-key", "b" * 10_000_000) + self.s3.put_bucket_versioning(TEST_BUCKET, "Enabled") + self.s3.put_object(TEST_BUCKET, TEST_KEY, "b" * 10_000_000) @verify_zero_warnings def test_copy_object(self): - key = self.s3.get_object("my-bucket", "my-key") + key = self.s3.get_object(TEST_BUCKET, TEST_KEY) self.s3.copy_object( - src_key=key, dest_bucket_name="my-bucket", dest_key_name="key-2" + src_key=key, dest_bucket_name=TEST_BUCKET, dest_key_name="key-2" ) @verify_zero_warnings def test_part_upload(self): multipart_id = self.s3.create_multipart_upload( - bucket_name="my-bucket", + bucket_name=TEST_BUCKET, key_name="mp-key", metadata={}, storage_type="STANDARD", @@ -83,7 +100,7 @@ class TestS3FileHandleClosures(TestCase): kms_key_id=None, ) self.s3.upload_part( - bucket_name="my-bucket", + bucket_name=TEST_BUCKET, multipart_id=multipart_id, part_id=1, value="b" * 10_000_000, @@ -92,7 +109,7 @@ class TestS3FileHandleClosures(TestCase): @verify_zero_warnings def test_overwriting_part_upload(self): multipart_id = self.s3.create_multipart_upload( - bucket_name="my-bucket", + bucket_name=TEST_BUCKET, key_name="mp-key", metadata={}, storage_type="STANDARD", @@ -102,13 +119,13 @@ class TestS3FileHandleClosures(TestCase): kms_key_id=None, ) self.s3.upload_part( - bucket_name="my-bucket", + bucket_name=TEST_BUCKET, multipart_id=multipart_id, part_id=1, value="b" * 10_000_000, ) self.s3.upload_part( - bucket_name="my-bucket", + bucket_name=TEST_BUCKET, multipart_id=multipart_id, part_id=1, value="c" * 10_000_000, @@ -117,7 +134,7 @@ class TestS3FileHandleClosures(TestCase): @verify_zero_warnings def test_aborting_part_upload(self): multipart_id = self.s3.create_multipart_upload( - bucket_name="my-bucket", + bucket_name=TEST_BUCKET, key_name="mp-key", metadata={}, storage_type="STANDARD", @@ -127,19 +144,19 @@ class TestS3FileHandleClosures(TestCase): kms_key_id=None, ) self.s3.upload_part( - bucket_name="my-bucket", + bucket_name=TEST_BUCKET, multipart_id=multipart_id, part_id=1, value="b" * 10_000_000, ) self.s3.abort_multipart_upload( - bucket_name="my-bucket", multipart_id=multipart_id + bucket_name=TEST_BUCKET, multipart_id=multipart_id ) @verify_zero_warnings def test_completing_part_upload(self): multipart_id = self.s3.create_multipart_upload( - bucket_name="my-bucket", + bucket_name=TEST_BUCKET, key_name="mp-key", metadata={}, storage_type="STANDARD", @@ -149,7 +166,7 @@ class TestS3FileHandleClosures(TestCase): kms_key_id=None, ) etag = self.s3.upload_part( - bucket_name="my-bucket", + bucket_name=TEST_BUCKET, multipart_id=multipart_id, part_id=1, value="b" * 10_000_000, @@ -158,36 +175,36 @@ class TestS3FileHandleClosures(TestCase): mp_body = f"""{etag}1""" body = S3ResponseInstance._complete_multipart_body(mp_body) self.s3.complete_multipart_upload( - bucket_name="my-bucket", multipart_id=multipart_id, body=body + bucket_name=TEST_BUCKET, multipart_id=multipart_id, body=body ) @verify_zero_warnings def test_single_versioned_upload(self): - self.s3.put_object("versioned-bucket", "my-key", "x" * 10_000_000) + self.s3.put_object(TEST_BUCKET_VERSIONED, TEST_KEY, "x" * 10_000_000) @verify_zero_warnings def test_overwrite_versioned_upload(self): - self.s3.put_object("versioned-bucket", "my-key", "x" * 10_000_000) - self.s3.put_object("versioned-bucket", "my-key", "x" * 10_000_000) + self.s3.put_object(TEST_BUCKET_VERSIONED, TEST_KEY, "x" * 10_000_000) + self.s3.put_object(TEST_BUCKET_VERSIONED, TEST_KEY, "x" * 10_000_000) @verify_zero_warnings def test_multiple_versions_upload(self): - self.s3.put_object("versioned-bucket", "my-key", "x" * 10_000_000) - self.s3.put_object("versioned-bucket", "my-key", "y" * 10_000_000) - self.s3.put_object("versioned-bucket", "my-key", "z" * 10_000_000) + self.s3.put_object(TEST_BUCKET_VERSIONED, TEST_KEY, "x" * 10_000_000) + self.s3.put_object(TEST_BUCKET_VERSIONED, TEST_KEY, "y" * 10_000_000) + self.s3.put_object(TEST_BUCKET_VERSIONED, TEST_KEY, "z" * 10_000_000) @verify_zero_warnings def test_delete_versioned_upload(self): - self.s3.put_object("versioned-bucket", "my-key", "x" * 10_000_000) - self.s3.put_object("versioned-bucket", "my-key", "x" * 10_000_000) - self.s3.delete_object(bucket_name="my-bucket", key_name="my-key") + self.s3.put_object(TEST_BUCKET_VERSIONED, TEST_KEY, "x" * 10_000_000) + self.s3.put_object(TEST_BUCKET_VERSIONED, TEST_KEY, "x" * 10_000_000) + self.s3.delete_object(bucket_name=TEST_BUCKET, key_name=TEST_KEY) @verify_zero_warnings def test_delete_specific_version(self): - self.s3.put_object("versioned-bucket", "my-key", "x" * 10_000_000) - key = self.s3.put_object("versioned-bucket", "my-key", "y" * 10_000_000) + self.s3.put_object(TEST_BUCKET_VERSIONED, TEST_KEY, "x" * 10_000_000) + key = self.s3.put_object(TEST_BUCKET_VERSIONED, TEST_KEY, "y" * 10_000_000) self.s3.delete_object( - bucket_name="my-bucket", key_name="my-key", version_id=key._version_id + bucket_name=TEST_BUCKET, key_name=TEST_KEY, version_id=key._version_id ) @verify_zero_warnings