diff --git a/moto/sqs/models.py b/moto/sqs/models.py index 72218826b..2784ee625 100644 --- a/moto/sqs/models.py +++ b/moto/sqs/models.py @@ -526,6 +526,14 @@ class Queue(CloudFormationModel): } +def _filter_message_attributes(message, input_message_attributes): + filtered_message_attributes = {} + for key, value in message.message_attributes.items(): + if key in input_message_attributes: + filtered_message_attributes[key] = value + message.message_attributes = filtered_message_attributes + + class SQSBackend(BaseBackend): def __init__(self, region_name): self.region_name = region_name @@ -718,7 +726,12 @@ class SQSBackend(BaseBackend): return None def receive_messages( - self, queue_name, count, wait_seconds_timeout, visibility_timeout + self, + queue_name, + count, + wait_seconds_timeout, + visibility_timeout, + message_attribute_names=None, ): """ Attempt to retrieve visible messages from a queue. @@ -734,6 +747,8 @@ class SQSBackend(BaseBackend): :param int wait_seconds_timeout: The duration (in seconds) for which the call waits for a message to arrive in the queue before returning. If a message is available, the call returns sooner than WaitTimeSeconds """ + if message_attribute_names is None: + message_attribute_names = [] queue = self.get_queue(queue_name) result = [] previous_result_count = len(result) @@ -775,6 +790,7 @@ class SQSBackend(BaseBackend): continue message.mark_received(visibility_timeout=visibility_timeout) + _filter_message_attributes(message, message_attribute_names) result.append(message) if len(result) >= count: break diff --git a/moto/sqs/responses.py b/moto/sqs/responses.py index e28fbca8a..016637b4c 100644 --- a/moto/sqs/responses.py +++ b/moto/sqs/responses.py @@ -13,7 +13,7 @@ from .exceptions import ( ReceiptHandleIsInvalid, ) from .models import sqs_backends -from .utils import parse_message_attributes +from .utils import parse_message_attributes, extract_input_message_attributes MAXIMUM_VISIBILTY_TIMEOUT = 43200 MAXIMUM_MESSAGE_LENGTH = 262144 # 256 KiB @@ -352,6 +352,9 @@ class SQSResponse(BaseResponse): def receive_message(self): queue_name = self._get_queue_name() + message_attributes = self._get_multi_param("message_attributes") + if not message_attributes: + message_attributes = extract_input_message_attributes(self.querystring,) queue = self.sqs_backend.get_queue(queue_name) @@ -391,7 +394,7 @@ class SQSResponse(BaseResponse): return ERROR_MAX_VISIBILITY_TIMEOUT_RESPONSE, dict(status=400) messages = self.sqs_backend.receive_messages( - queue_name, message_count, wait_time, visibility_timeout + queue_name, message_count, wait_time, visibility_timeout, message_attributes ) template = self.response_template(RECEIVE_MESSAGE_RESPONSE) return template.render(messages=messages) diff --git a/moto/sqs/utils.py b/moto/sqs/utils.py index 315fce56b..876d6b40e 100644 --- a/moto/sqs/utils.py +++ b/moto/sqs/utils.py @@ -11,6 +11,21 @@ def generate_receipt_handle(): return "".join(random.choice(string.ascii_lowercase) for x in range(length)) +def extract_input_message_attributes(querystring): + message_attributes = [] + index = 1 + while True: + # Loop through looking for message attributes + name_key = "MessageAttributeName.{0}".format(index) + name = querystring.get(name_key) + if not name: + # Found all attributes + break + message_attributes.append(name[0]) + index = index + 1 + return message_attributes + + def parse_message_attributes(querystring, base="", value_namespace="Value."): message_attributes = {} index = 1 diff --git a/tests/test_sns/test_publishing_boto3.py b/tests/test_sns/test_publishing_boto3.py index c84f19694..63c409302 100644 --- a/tests/test_sns/test_publishing_boto3.py +++ b/tests/test_sns/test_publishing_boto3.py @@ -183,9 +183,6 @@ def test_publish_to_sqs_msg_attr_byte_value(): message = queue_raw.receive_messages()[0] message.body.should.equal("my message") - message.message_attributes.should.equal( - {"store": {"DataType": "Binary", "BinaryValue": b"\x02\x03\x04"}} - ) @mock_sqs @@ -216,9 +213,6 @@ def test_publish_to_sqs_msg_attr_number_type(): message = queue_raw.receive_messages()[0] message.body.should.equal("test message") - message.message_attributes.should.equal( - {"retries": {"DataType": "Number", "StringValue": "0"}} - ) @mock_sns diff --git a/tests/test_sqs/test_sqs.py b/tests/test_sqs/test_sqs.py index 48fa20291..f98131db4 100644 --- a/tests/test_sqs/test_sqs.py +++ b/tests/test_sqs/test_sqs.py @@ -593,9 +593,9 @@ def test_send_receive_message_with_attributes(): }, ) - messages = conn.receive_message(QueueUrl=queue.url, MaxNumberOfMessages=2)[ - "Messages" - ] + messages = conn.receive_message( + QueueUrl=queue.url, MaxNumberOfMessages=2, MessageAttributeNames=["timestamp"] + )["Messages"] message1 = messages[0] message2 = messages[1] @@ -641,9 +641,9 @@ def test_send_receive_message_with_attributes_with_labels(): }, ) - messages = conn.receive_message(QueueUrl=queue.url, MaxNumberOfMessages=2)[ - "Messages" - ] + messages = conn.receive_message( + QueueUrl=queue.url, MaxNumberOfMessages=2, MessageAttributeNames=["timestamp"] + )["Messages"] message1 = messages[0] message2 = messages[1] @@ -779,7 +779,14 @@ def test_send_message_with_attributes(): queue.write(message) - messages = conn.receive_message(queue) + messages = conn.receive_message( + queue, + message_attributes=[ + "test.attribute_name", + "test.binary_attribute", + "test.number_attribute", + ], + ) messages[0].get_body().should.equal(body) @@ -999,7 +1006,7 @@ def test_send_batch_operation_with_message_attributes(): ) queue.write_batch([message_tuple]) - messages = queue.get_messages() + messages = queue.get_messages(message_attributes=["name1"]) messages[0].get_body().should.equal("test message 1") for name, value in message_tuple[3].items(): @@ -1234,7 +1241,11 @@ def test_send_message_batch(): ["id_1", "id_2"] ) - response = client.receive_message(QueueUrl=queue_url, MaxNumberOfMessages=10) + response = client.receive_message( + QueueUrl=queue_url, + MaxNumberOfMessages=10, + MessageAttributeNames=["attribute_name_1", "attribute_name_2"], + ) response["Messages"][0]["Body"].should.equal("body_1") response["Messages"][0]["MessageAttributes"].should.equal( @@ -1258,6 +1269,53 @@ def test_send_message_batch(): ) +@mock_sqs +def test_message_attributes_in_receive_message(): + sqs = boto3.resource("sqs", region_name="us-east-1") + conn = boto3.client("sqs", region_name="us-east-1") + conn.create_queue(QueueName="test-queue") + queue = sqs.Queue("test-queue") + + body_one = "this is a test message" + + queue.send_message( + MessageBody=body_one, + MessageAttributes={ + "timestamp": { + "StringValue": "1493147359900", + "DataType": "Number.java.lang.Long", + } + }, + ) + messages = conn.receive_message( + QueueUrl=queue.url, MaxNumberOfMessages=2, MessageAttributeNames=["timestamp"] + )["Messages"] + + messages[0]["MessageAttributes"].should.equal( + { + "timestamp": { + "StringValue": "1493147359900", + "DataType": "Number.java.lang.Long", + } + } + ) + + queue.send_message( + MessageBody=body_one, + MessageAttributes={ + "timestamp": { + "StringValue": "1493147359900", + "DataType": "Number.java.lang.Long", + } + }, + ) + messages = conn.receive_message(QueueUrl=queue.url, MaxNumberOfMessages=2)[ + "Messages" + ] + + messages[0].get("MessageAttributes").should.equal(None) + + @mock_sqs def test_send_message_batch_errors(): client = boto3.client("sqs", region_name="us-east-1")