SNS: Cross-account access for topics (#6330)

This commit is contained in:
Viren Nadkarni 2023-05-21 16:44:38 +05:30 committed by GitHub
parent b853593259
commit 8ca9c17e5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 58 additions and 52 deletions

View File

@ -4,7 +4,6 @@ import re
import json import json
import sys import sys
import warnings import warnings
from collections import namedtuple
from datetime import datetime from datetime import datetime
from enum import Enum, unique from enum import Enum, unique
from json import JSONDecodeError from json import JSONDecodeError
@ -27,6 +26,7 @@ from moto.events.exceptions import (
IllegalStatusException, IllegalStatusException,
) )
from moto.moto_api._internal import mock_random as random from moto.moto_api._internal import mock_random as random
from moto.utilities.arns import parse_arn
from moto.utilities.paginator import paginate from moto.utilities.paginator import paginate
from moto.utilities.tagging_service import TaggingService from moto.utilities.tagging_service import TaggingService
@ -37,10 +37,6 @@ UNDEFINED = object()
class Rule(CloudFormationModel): class Rule(CloudFormationModel):
Arn = namedtuple(
"Arn", ["account", "region", "service", "resource_type", "resource_id"]
)
def __init__( def __init__(
self, self,
name: str, name: str,
@ -123,13 +119,13 @@ class Rule(CloudFormationModel):
# - SQS Queue + FIFO Queue # - SQS Queue + FIFO Queue
# - Cross-region/account EventBus # - Cross-region/account EventBus
for target in self.targets: for target in self.targets:
arn = self._parse_arn(target["Arn"]) arn = parse_arn(target["Arn"])
if arn.service == "logs" and arn.resource_type == "log-group": if arn.service == "logs" and arn.resource_type == "log-group":
self._send_to_cw_log_group(arn.resource_id, event) self._send_to_cw_log_group(arn.resource_id, event)
elif arn.service == "events" and not arn.resource_type: elif arn.service == "events" and not arn.resource_type:
input_template = json.loads(target["InputTransformer"]["InputTemplate"]) input_template = json.loads(target["InputTransformer"]["InputTemplate"])
archive_arn = self._parse_arn(input_template["archive-arn"]) archive_arn = parse_arn(input_template["archive-arn"])
self._send_to_events_archive(archive_arn.resource_id, event) self._send_to_events_archive(archive_arn.resource_id, event)
elif arn.service == "sqs": elif arn.service == "sqs":
@ -149,33 +145,6 @@ class Rule(CloudFormationModel):
else: else:
raise NotImplementedError(f"Expr not defined for {type(self)}") raise NotImplementedError(f"Expr not defined for {type(self)}")
def _parse_arn(self, arn: str) -> Arn:
# http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html
# this method needs probably some more fine tuning,
# when also other targets are supported
_, _, service, region, account, resource = arn.split(":", 5)
if ":" in resource and "/" in resource:
if resource.index(":") < resource.index("/"):
resource_type, resource_id = resource.split(":", 1)
else:
resource_type, resource_id = resource.split("/", 1)
elif ":" in resource:
resource_type, resource_id = resource.split(":", 1)
elif "/" in resource:
resource_type, resource_id = resource.split("/", 1)
else:
resource_type = None
resource_id = resource
return self.Arn(
account=account,
region=region,
service=service,
resource_type=resource_type,
resource_id=resource_id,
)
def _send_to_cw_log_group(self, name: str, event: Dict[str, Any]) -> None: def _send_to_cw_log_group(self, name: str, event: Dict[str, Any]) -> None:
from moto.logs import logs_backends from moto.logs import logs_backends

View File

@ -34,7 +34,7 @@ from .utils import (
is_e164, is_e164,
FilterPolicyMatcher, FilterPolicyMatcher,
) )
from moto.utilities.arns import parse_arn
DEFAULT_PAGE_SIZE = 100 DEFAULT_PAGE_SIZE = 100
MAXIMUM_MESSAGE_LENGTH = 262144 # 256 KiB MAXIMUM_MESSAGE_LENGTH = 262144 # 256 KiB
@ -486,13 +486,15 @@ class SNSBackend(BaseBackend):
try: try:
topic = self.get_topic(arn) topic = self.get_topic(arn)
self.delete_topic_subscriptions(topic) self.delete_topic_subscriptions(topic)
self.topics.pop(arn) parsed_arn = parse_arn(arn)
sns_backends[parsed_arn.account][parsed_arn.region].topics.pop(arn, None)
except KeyError: except KeyError:
raise SNSNotFoundError(f"Topic with arn {arn} not found") raise SNSNotFoundError(f"Topic with arn {arn} not found")
def get_topic(self, arn: str) -> Topic: def get_topic(self, arn: str) -> Topic:
parsed_arn = parse_arn(arn)
try: try:
return self.topics[arn] return sns_backends[parsed_arn.account][parsed_arn.region].topics[arn]
except KeyError: except KeyError:
raise SNSNotFoundError(f"Topic with arn {arn} not found") raise SNSNotFoundError(f"Topic with arn {arn} not found")
@ -932,10 +934,8 @@ class SNSBackend(BaseBackend):
aws_account_ids: List[str], aws_account_ids: List[str],
action_names: List[str], action_names: List[str],
) -> None: ) -> None:
if topic_arn not in self.topics: topic = self.get_topic(topic_arn)
raise SNSNotFoundError("Topic does not exist") policy = topic._policy_json
policy = self.topics[topic_arn]._policy_json
statement = next( statement = next(
( (
statement statement
@ -964,18 +964,16 @@ class SNSBackend(BaseBackend):
"Resource": topic_arn, "Resource": topic_arn,
} }
self.topics[topic_arn]._policy_json["Statement"].append(statement) topic._policy_json["Statement"].append(statement)
def remove_permission(self, topic_arn: str, label: str) -> None: def remove_permission(self, topic_arn: str, label: str) -> None:
if topic_arn not in self.topics: topic = self.get_topic(topic_arn)
raise SNSNotFoundError("Topic does not exist") statements = topic._policy_json["Statement"]
statements = self.topics[topic_arn]._policy_json["Statement"]
statements = [ statements = [
statement for statement in statements if statement["Sid"] != label statement for statement in statements if statement["Sid"] != label
] ]
self.topics[topic_arn]._policy_json["Statement"] = statements topic._policy_json["Statement"] = statements
def list_tags_for_resource(self, resource_arn: str) -> Dict[str, str]: def list_tags_for_resource(self, resource_arn: str) -> Dict[str, str]:
if resource_arn not in self.topics: if resource_arn not in self.topics:

33
moto/utilities/arns.py Normal file
View File

@ -0,0 +1,33 @@
from collections import namedtuple
Arn = namedtuple(
"Arn", ["account", "region", "service", "resource_type", "resource_id"]
)
def parse_arn(arn: str) -> Arn:
# http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html
# this method needs probably some more fine tuning,
# when also other targets are supported
_, _, service, region, account, resource = arn.split(":", 5)
if ":" in resource and "/" in resource:
if resource.index(":") < resource.index("/"):
resource_type, resource_id = resource.split(":", 1)
else:
resource_type, resource_id = resource.split("/", 1)
elif ":" in resource:
resource_type, resource_id = resource.split(":", 1)
elif "/" in resource:
resource_type, resource_id = resource.split("/", 1)
else:
resource_type = None
resource_id = resource
return Arn(
account=account,
region=region,
service=service,
resource_type=resource_type,
resource_id=resource_id,
)

View File

@ -96,9 +96,9 @@ def test_create_topic_should_be_indempodent():
@mock_sns @mock_sns
def test_get_missing_topic(): def test_get_missing_topic():
conn = boto3.client("sns", region_name="us-east-1") conn = boto3.client("sns", region_name="us-east-1")
conn.get_topic_attributes.when.called_with(TopicArn="a-fake-arn").should.throw( conn.get_topic_attributes.when.called_with(
ClientError TopicArn="arn:aws:sns:us-east-1:424242424242:a-fake-arn"
) ).should.throw(ClientError)
@mock_sns @mock_sns
@ -360,7 +360,10 @@ def test_add_permission_errors():
Label="test-2", Label="test-2",
AWSAccountId=["999999999999"], AWSAccountId=["999999999999"],
ActionName=["AddPermission"], ActionName=["AddPermission"],
).should.throw(ClientError, "Topic does not exist") ).should.throw(
ClientError,
f"An error occurred (NotFound) when calling the AddPermission operation: Topic with arn {topic_arn + '-not-existing'} not found",
)
client.add_permission.when.called_with( client.add_permission.when.called_with(
TopicArn=topic_arn, TopicArn=topic_arn,
@ -383,7 +386,10 @@ def test_remove_permission_errors():
client.remove_permission.when.called_with( client.remove_permission.when.called_with(
TopicArn=topic_arn + "-not-existing", Label="test" TopicArn=topic_arn + "-not-existing", Label="test"
).should.throw(ClientError, "Topic does not exist") ).should.throw(
ClientError,
f"An error occurred (NotFound) when calling the RemovePermission operation: Topic with arn {topic_arn + '-not-existing'} not found",
)
@mock_sns @mock_sns