diff --git a/moto/sqs/exceptions.py b/moto/sqs/exceptions.py index 01123d777..77d7b9fb2 100644 --- a/moto/sqs/exceptions.py +++ b/moto/sqs/exceptions.py @@ -99,3 +99,28 @@ class InvalidAttributeName(RESTError): super(InvalidAttributeName, self).__init__( "InvalidAttributeName", "Unknown Attribute {}.".format(attribute_name) ) + + +class InvalidParameterValue(RESTError): + code = 400 + + def __init__(self, message): + super(InvalidParameterValue, self).__init__("InvalidParameterValue", message) + + +class MissingParameter(RESTError): + code = 400 + + def __init__(self): + super(MissingParameter, self).__init__( + "MissingParameter", "The request must contain the parameter Actions." + ) + + +class OverLimit(RESTError): + code = 403 + + def __init__(self, count): + super(OverLimit, self).__init__( + "OverLimit", "{} Actions were found, maximum allowed is 7.".format(count) + ) diff --git a/moto/sqs/models.py b/moto/sqs/models.py index 8b8263e3c..8fbe90108 100644 --- a/moto/sqs/models.py +++ b/moto/sqs/models.py @@ -30,6 +30,9 @@ from .exceptions import ( BatchEntryIdsNotDistinct, TooManyEntriesInBatchRequest, InvalidAttributeName, + InvalidParameterValue, + MissingParameter, + OverLimit, ) from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID @@ -183,6 +186,7 @@ class Queue(BaseModel): "MaximumMessageSize", "MessageRetentionPeriod", "QueueArn", + "Policy", "RedrivePolicy", "ReceiveMessageWaitTimeSeconds", "VisibilityTimeout", @@ -195,6 +199,8 @@ class Queue(BaseModel): "DeleteMessage", "GetQueueAttributes", "GetQueueUrl", + "ListDeadLetterSourceQueues", + "PurgeQueue", "ReceiveMessage", "SendMessage", ) @@ -273,7 +279,7 @@ class Queue(BaseModel): if key in bool_fields: value = value == "true" - if key == "RedrivePolicy" and value is not None: + if key in ["Policy", "RedrivePolicy"] and value is not None: continue setattr(self, camelcase_to_underscores(key), value) @@ -281,6 +287,9 @@ class Queue(BaseModel): if attributes.get("RedrivePolicy", None): self._setup_dlq(attributes["RedrivePolicy"]) + if attributes.get("Policy"): + self.policy = attributes["Policy"] + self.last_modified_timestamp = now def _setup_dlq(self, policy): @@ -472,6 +481,24 @@ class Queue(BaseModel): return self.name raise UnformattedGetAttTemplateException() + @property + def policy(self): + if self._policy_json.get("Statement"): + return json.dumps(self._policy_json) + else: + return None + + @policy.setter + def policy(self, policy): + if policy: + self._policy_json = json.loads(policy) + else: + self._policy_json = { + "Version": "2012-10-17", + "Id": "{}/SQSDefaultPolicy".format(self.queue_arn), + "Statement": [], + } + class SQSBackend(BaseBackend): def __init__(self, region_name): @@ -802,25 +829,75 @@ class SQSBackend(BaseBackend): def add_permission(self, queue_name, actions, account_ids, label): queue = self.get_queue(queue_name) - if actions is None or len(actions) == 0: - raise RESTError("InvalidParameterValue", "Need at least one Action") - if account_ids is None or len(account_ids) == 0: - raise RESTError("InvalidParameterValue", "Need at least one Account ID") + if not actions: + raise MissingParameter() - if not all([item in Queue.ALLOWED_PERMISSIONS for item in actions]): - raise RESTError("InvalidParameterValue", "Invalid permissions") + if not account_ids: + raise InvalidParameterValue( + "Value [] for parameter PrincipalId is invalid. Reason: Unable to verify." + ) - queue.permissions[label] = (account_ids, actions) + count = len(actions) + if count > 7: + raise OverLimit(count) + + invalid_action = next( + (action for action in actions if action not in Queue.ALLOWED_PERMISSIONS), + None, + ) + if invalid_action: + raise InvalidParameterValue( + "Value SQS:{} for parameter ActionName is invalid. " + "Reason: Only the queue owner is allowed to invoke this action.".format( + invalid_action + ) + ) + + policy = queue._policy_json + statement = next( + ( + statement + for statement in policy["Statement"] + if statement["Sid"] == label + ), + None, + ) + if statement: + raise InvalidParameterValue( + "Value {} for parameter Label is invalid. " + "Reason: Already exists.".format(label) + ) + + principals = [ + "arn:aws:iam::{}:root".format(account_id) for account_id in account_ids + ] + actions = ["SQS:{}".format(action) for action in actions] + + statement = { + "Sid": label, + "Effect": "Allow", + "Principal": {"AWS": principals[0] if len(principals) == 1 else principals}, + "Action": actions[0] if len(actions) == 1 else actions, + "Resource": queue.queue_arn, + } + + queue._policy_json["Statement"].append(statement) def remove_permission(self, queue_name, label): queue = self.get_queue(queue_name) - if label not in queue.permissions: - raise RESTError( - "InvalidParameterValue", "Permission doesnt exist for the given label" + statements = queue._policy_json["Statement"] + statements_new = [ + statement for statement in statements if statement["Sid"] != label + ] + + if len(statements) == len(statements_new): + raise InvalidParameterValue( + "Value {} for parameter Label is invalid. " + "Reason: can't find label on existing policy.".format(label) ) - del queue.permissions[label] + queue._policy_json["Statement"] = statements_new def tag_queue(self, queue_name, tags): queue = self.get_queue(queue_name) diff --git a/tests/test_sqs/test_sqs.py b/tests/test_sqs/test_sqs.py index 1eb511db0..93d388117 100644 --- a/tests/test_sqs/test_sqs.py +++ b/tests/test_sqs/test_sqs.py @@ -132,6 +132,35 @@ def test_create_queue_with_tags(): ) +@mock_sqs +def test_create_queue_with_policy(): + client = boto3.client("sqs", region_name="us-east-1") + response = client.create_queue( + QueueName="test-queue", + Attributes={ + "Policy": json.dumps( + { + "Version": "2012-10-17", + "Id": "test", + "Statement": [{"Effect": "Allow", "Principal": "*", "Action": "*"}], + } + ) + }, + ) + queue_url = response["QueueUrl"] + + response = client.get_queue_attributes( + QueueUrl=queue_url, AttributeNames=["Policy"] + ) + json.loads(response["Attributes"]["Policy"]).should.equal( + { + "Version": "2012-10-17", + "Id": "test", + "Statement": [{"Effect": "Allow", "Principal": "*", "Action": "*"}], + } + ) + + @mock_sqs def test_get_queue_url(): client = boto3.client("sqs", region_name="us-east-1") @@ -1186,18 +1215,169 @@ def test_permissions(): Actions=["SendMessage"], ) - with assert_raises(ClientError): - client.add_permission( - QueueUrl=queue_url, - Label="account2", - AWSAccountIds=["222211111111"], - Actions=["SomeRubbish"], - ) + response = client.get_queue_attributes( + QueueUrl=queue_url, AttributeNames=["Policy"] + ) + policy = json.loads(response["Attributes"]["Policy"]) + policy["Version"].should.equal("2012-10-17") + policy["Id"].should.equal( + "arn:aws:sqs:us-east-1:123456789012:test-dlr-queue.fifo/SQSDefaultPolicy" + ) + sorted(policy["Statement"], key=lambda x: x["Sid"]).should.equal( + [ + { + "Sid": "account1", + "Effect": "Allow", + "Principal": {"AWS": "arn:aws:iam::111111111111:root"}, + "Action": "SQS:*", + "Resource": "arn:aws:sqs:us-east-1:123456789012:test-dlr-queue.fifo", + }, + { + "Sid": "account2", + "Effect": "Allow", + "Principal": {"AWS": "arn:aws:iam::222211111111:root"}, + "Action": "SQS:SendMessage", + "Resource": "arn:aws:sqs:us-east-1:123456789012:test-dlr-queue.fifo", + }, + ] + ) client.remove_permission(QueueUrl=queue_url, Label="account2") - with assert_raises(ClientError): - client.remove_permission(QueueUrl=queue_url, Label="non_existent") + response = client.get_queue_attributes( + QueueUrl=queue_url, AttributeNames=["Policy"] + ) + json.loads(response["Attributes"]["Policy"]).should.equal( + { + "Version": "2012-10-17", + "Id": "arn:aws:sqs:us-east-1:123456789012:test-dlr-queue.fifo/SQSDefaultPolicy", + "Statement": [ + { + "Sid": "account1", + "Effect": "Allow", + "Principal": {"AWS": "arn:aws:iam::111111111111:root"}, + "Action": "SQS:*", + "Resource": "arn:aws:sqs:us-east-1:123456789012:test-dlr-queue.fifo", + }, + ], + } + ) + + +@mock_sqs +def test_add_permission_errors(): + client = boto3.client("sqs", region_name="us-east-1") + response = client.create_queue(QueueName="test-queue") + queue_url = response["QueueUrl"] + client.add_permission( + QueueUrl=queue_url, + Label="test", + AWSAccountIds=["111111111111"], + Actions=["ReceiveMessage"], + ) + + with assert_raises(ClientError) as e: + client.add_permission( + QueueUrl=queue_url, + Label="test", + AWSAccountIds=["111111111111"], + Actions=["ReceiveMessage", "SendMessage"], + ) + ex = e.exception + ex.operation_name.should.equal("AddPermission") + ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.response["Error"]["Code"].should.contain("InvalidParameterValue") + ex.response["Error"]["Message"].should.equal( + "Value test for parameter Label is invalid. " "Reason: Already exists." + ) + + with assert_raises(ClientError) as e: + client.add_permission( + QueueUrl=queue_url, + Label="test-2", + AWSAccountIds=["111111111111"], + Actions=["RemovePermission"], + ) + ex = e.exception + ex.operation_name.should.equal("AddPermission") + ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.response["Error"]["Code"].should.contain("InvalidParameterValue") + ex.response["Error"]["Message"].should.equal( + "Value SQS:RemovePermission for parameter ActionName is invalid. " + "Reason: Only the queue owner is allowed to invoke this action." + ) + + with assert_raises(ClientError) as e: + client.add_permission( + QueueUrl=queue_url, + Label="test-2", + AWSAccountIds=["111111111111"], + Actions=[], + ) + ex = e.exception + ex.operation_name.should.equal("AddPermission") + ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.response["Error"]["Code"].should.contain("MissingParameter") + ex.response["Error"]["Message"].should.equal( + "The request must contain the parameter Actions." + ) + + with assert_raises(ClientError) as e: + client.add_permission( + QueueUrl=queue_url, + Label="test-2", + AWSAccountIds=[], + Actions=["ReceiveMessage"], + ) + ex = e.exception + ex.operation_name.should.equal("AddPermission") + ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.response["Error"]["Code"].should.contain("InvalidParameterValue") + ex.response["Error"]["Message"].should.equal( + "Value [] for parameter PrincipalId is invalid. Reason: Unable to verify." + ) + + with assert_raises(ClientError) as e: + client.add_permission( + QueueUrl=queue_url, + Label="test-2", + AWSAccountIds=["111111111111"], + Actions=[ + "ChangeMessageVisibility", + "DeleteMessage", + "GetQueueAttributes", + "GetQueueUrl", + "ListDeadLetterSourceQueues", + "PurgeQueue", + "ReceiveMessage", + "SendMessage", + ], + ) + ex = e.exception + ex.operation_name.should.equal("AddPermission") + ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(403) + ex.response["Error"]["Code"].should.contain("OverLimit") + ex.response["Error"]["Message"].should.equal( + "8 Actions were found, maximum allowed is 7." + ) + + +@mock_sqs +def test_remove_permission_errors(): + client = boto3.client("sqs", region_name="us-east-1") + response = client.create_queue(QueueName="test-queue") + queue_url = response["QueueUrl"] + + with assert_raises(ClientError) as e: + client.remove_permission(QueueUrl=queue_url, Label="test") + ex = e.exception + ex.operation_name.should.equal("RemovePermission") + ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.response["Error"]["Code"].should.contain("InvalidParameterValue") + ex.response["Error"]["Message"].should.equal( + "Value test for parameter Label is invalid. " + "Reason: can't find label on existing policy." + ) @mock_sqs