SQS: Support JSON format (#6331)

This commit is contained in:
Bert Blommers 2023-05-24 21:19:40 +00:00 committed by GitHub
parent a28dd1abc8
commit 42b03d36ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 456 additions and 123 deletions

39
.github/workflows/test_sqs.yml vendored Normal file
View File

@ -0,0 +1,39 @@
# Run separate test cases to verify SQS works with multiple botocore versions (/multiple response-types, QUERY and JSON)
#
name: "SQS Tests"
on:
pull_request:
types: [ labeled ]
jobs:
test:
if: ${{ github.event.label.name == 'service-sqs' }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: [ "3.11" ]
botocore-version: ["1.29.126", "1.29.127", "1.29.128"]
steps:
- name: Checkout repository
uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Update pip
run: |
python -m pip install --upgrade pip
- name: Install project dependencies
run: |
pip install -r requirements-dev.txt
pip install botocore==${{ matrix.botocore-version }}
- name: Run tests
run: |
pytest -sv tests/test_sqs

View File

@ -84,6 +84,11 @@ class RESTError(HTTPException):
) -> str: ) -> str:
return self.description # type: ignore[return-value] return self.description # type: ignore[return-value]
def to_json(self) -> "JsonRESTError":
err = JsonRESTError(error_type=self.error_type, message=self.message)
err.code = self.code
return err
class DryRunClientError(RESTError): class DryRunClientError(RESTError):
code = 412 code = 412

View File

@ -162,6 +162,10 @@ class Message(BaseModel):
def body(self) -> str: def body(self) -> str:
return escape(self._body).replace('"', """).replace("\r", "
") return escape(self._body).replace('"', """).replace("\r", "
")
@property
def original_body(self) -> str:
return self._body
def mark_sent(self, delay_seconds: Optional[int] = None) -> None: def mark_sent(self, delay_seconds: Optional[int] = None) -> None:
self.sent_timestamp = int(unix_time_millis()) # type: ignore self.sent_timestamp = int(unix_time_millis()) # type: ignore
if delay_seconds: if delay_seconds:
@ -1030,7 +1034,7 @@ class SQSBackend(BaseBackend):
errors.append( errors.append(
{ {
"Id": receipt_and_id["msg_user_id"], "Id": receipt_and_id["msg_user_id"],
"SenderFault": "true", "SenderFault": True,
"Code": "ReceiptHandleIsInvalid", "Code": "ReceiptHandleIsInvalid",
"Message": f'The input receipt handle "{receipt_and_id["receipt_handle"]}" is not a valid receipt handle.', "Message": f'The input receipt handle "{receipt_and_id["receipt_handle"]}" is not a valid receipt handle.',
} }

View File

@ -1,11 +1,18 @@
import json
import re import re
from typing import Any, Dict, Optional, Tuple, Union from functools import wraps
from typing import Any, Callable, Dict, Optional, Union
from moto.core.common_types import TYPE_RESPONSE from moto.core.common_types import TYPE_RESPONSE
from moto.core.exceptions import RESTError from moto.core.exceptions import JsonRESTError
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.core.utils import underscores_to_camelcase, camelcase_to_pascal from moto.core.utils import (
underscores_to_camelcase,
camelcase_to_pascal,
camelcase_to_underscores,
)
from moto.utilities.aws_headers import amz_crc32, amzn_request_id from moto.utilities.aws_headers import amz_crc32, amzn_request_id
from moto.utilities.constants import JSON_TYPES
from urllib.parse import urlparse from urllib.parse import urlparse
from .constants import ( from .constants import (
@ -14,12 +21,36 @@ from .constants import (
MAXIMUM_VISIBILITY_TIMEOUT, MAXIMUM_VISIBILITY_TIMEOUT,
) )
from .exceptions import ( from .exceptions import (
RESTError,
EmptyBatchRequest, EmptyBatchRequest,
InvalidAttributeName, InvalidAttributeName,
BatchEntryIdsNotDistinct, BatchEntryIdsNotDistinct,
) )
from .models import sqs_backends, SQSBackend from .models import sqs_backends, SQSBackend
from .utils import parse_message_attributes, extract_input_message_attributes from .utils import (
parse_message_attributes,
extract_input_message_attributes,
validate_message_attributes,
)
def jsonify_error(
method: Callable[["SQSResponse"], Union[str, TYPE_RESPONSE]]
) -> Callable[["SQSResponse"], Union[str, TYPE_RESPONSE]]:
"""
The decorator to convert an RESTError to JSON, if necessary
"""
@wraps(method)
def f(self: "SQSResponse") -> Union[str, TYPE_RESPONSE]:
try:
return method(self)
except RESTError as e:
if self.is_json():
raise e.to_json()
raise e
return f
class SQSResponse(BaseResponse): class SQSResponse(BaseResponse):
@ -29,27 +60,52 @@ class SQSResponse(BaseResponse):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__(service_name="sqs") super().__init__(service_name="sqs")
def is_json(self) -> bool:
"""
botocore 1.29.127 changed the wire-format to SQS
This means three things:
- The Content-Type is set to JSON
- The input-parameters are in different formats
- The output is in a different format
The change has been reverted for now, but it will be re-introduced later:
https://github.com/boto/botocore/pull/2931
"""
return self.headers.get("Content-Type") in JSON_TYPES
@property @property
def sqs_backend(self) -> SQSBackend: def sqs_backend(self) -> SQSBackend:
return sqs_backends[self.current_account][self.region] return sqs_backends[self.current_account][self.region]
@property @property
def attribute(self) -> Any: # type: ignore[misc] def attribute(self) -> Any: # type: ignore[misc]
if not hasattr(self, "_attribute"): try:
self._attribute = self._get_map_prefix( assert self.is_json()
"Attribute", key_end=".Name", value_end=".Value" return json.loads(self.body).get("Attributes", {})
) except: # noqa: E722 Do not use bare except
return self._attribute if not hasattr(self, "_attribute"):
self._attribute = self._get_map_prefix(
"Attribute", key_end=".Name", value_end=".Value"
)
return self._attribute
@property @property
def tags(self) -> Dict[str, str]: def tags(self) -> Dict[str, str]:
if not hasattr(self, "_tags"): if not hasattr(self, "_tags"):
self._tags = self._get_map_prefix("Tag", key_end=".Key", value_end=".Value") if self.is_json():
self._tags = self._get_param("tags")
else:
self._tags = self._get_map_prefix(
"Tag", key_end=".Key", value_end=".Value"
)
return self._tags return self._tags
def _get_queue_name(self) -> str: def _get_queue_name(self) -> str:
try: try:
queue_url = self.querystring.get("QueueUrl")[0] # type: ignore if self.is_json():
queue_url = self._get_param("QueueUrl")
else:
queue_url = self.querystring.get("QueueUrl")[0] # type: ignore
if queue_url.startswith("http://") or queue_url.startswith("https://"): if queue_url.startswith("http://") or queue_url.startswith("https://"):
return queue_url.split("/")[-1] return queue_url.split("/")[-1]
else: else:
@ -67,7 +123,10 @@ class SQSResponse(BaseResponse):
if timeout is not None: if timeout is not None:
visibility_timeout = int(timeout) visibility_timeout = int(timeout)
else: else:
visibility_timeout = int(self.querystring.get("VisibilityTimeout")[0]) # type: ignore if self.is_json():
visibility_timeout = self._get_param("VisibilityTimeout")
else:
visibility_timeout = int(self.querystring.get("VisibilityTimeout")[0]) # type: ignore
if visibility_timeout > MAXIMUM_VISIBILITY_TIMEOUT: if visibility_timeout > MAXIMUM_VISIBILITY_TIMEOUT:
raise ValueError raise ValueError
@ -85,45 +144,66 @@ class SQSResponse(BaseResponse):
return 404, headers, response return 404, headers, response
return status_code, headers, body return status_code, headers, body
def _error( def _error(self, code: str, message: str, status: int = 400) -> TYPE_RESPONSE:
self, code: str, message: str, status: int = 400 if self.is_json():
) -> Tuple[str, Dict[str, int]]: err = JsonRESTError(error_type=code, message=message)
err.code = status
raise err
template = self.response_template(ERROR_TEMPLATE) template = self.response_template(ERROR_TEMPLATE)
return template.render(code=code, message=message), dict(status=status) return status, {"status": status}, template.render(code=code, message=message)
@jsonify_error
def create_queue(self) -> str: def create_queue(self) -> str:
request_url = urlparse(self.uri) request_url = urlparse(self.uri)
queue_name = self._get_param("QueueName") queue_name = self._get_param("QueueName")
queue = self.sqs_backend.create_queue(queue_name, self.tags, **self.attribute) queue = self.sqs_backend.create_queue(queue_name, self.tags, **self.attribute)
if self.is_json():
return json.dumps({"QueueUrl": queue.url(request_url)})
template = self.response_template(CREATE_QUEUE_RESPONSE) template = self.response_template(CREATE_QUEUE_RESPONSE)
return template.render(queue_url=queue.url(request_url)) return template.render(queue_url=queue.url(request_url))
@jsonify_error
def get_queue_url(self) -> str: def get_queue_url(self) -> str:
request_url = urlparse(self.uri) request_url = urlparse(self.uri)
queue_name = self._get_param("QueueName") queue_name = self._get_param("QueueName")
queue = self.sqs_backend.get_queue_url(queue_name) queue = self.sqs_backend.get_queue_url(queue_name)
if self.is_json():
return json.dumps({"QueueUrl": queue.url(request_url)})
template = self.response_template(GET_QUEUE_URL_RESPONSE) template = self.response_template(GET_QUEUE_URL_RESPONSE)
return template.render(queue_url=queue.url(request_url)) return template.render(queue_url=queue.url(request_url))
@jsonify_error
def list_queues(self) -> str: def list_queues(self) -> str:
request_url = urlparse(self.uri) request_url = urlparse(self.uri)
queue_name_prefix = self._get_param("QueueNamePrefix") queue_name_prefix = self._get_param("QueueNamePrefix")
queues = self.sqs_backend.list_queues(queue_name_prefix) queues = self.sqs_backend.list_queues(queue_name_prefix)
if self.is_json():
if queues:
return json.dumps(
{"QueueUrls": [queue.url(request_url) for queue in queues]}
)
else:
return "{}"
template = self.response_template(LIST_QUEUES_RESPONSE) template = self.response_template(LIST_QUEUES_RESPONSE)
return template.render(queues=queues, request_url=request_url) return template.render(queues=queues, request_url=request_url)
def change_message_visibility(self) -> Union[str, Tuple[str, Dict[str, int]]]: @jsonify_error
def change_message_visibility(self) -> Union[str, TYPE_RESPONSE]:
queue_name = self._get_queue_name() queue_name = self._get_queue_name()
receipt_handle = self._get_param("ReceiptHandle") receipt_handle = self._get_param("ReceiptHandle")
try: try:
visibility_timeout = self._get_validated_visibility_timeout() visibility_timeout = self._get_validated_visibility_timeout()
except ValueError: except ValueError:
return ERROR_MAX_VISIBILITY_TIMEOUT_RESPONSE, dict(status=400) return 400, {}, ERROR_MAX_VISIBILITY_TIMEOUT_RESPONSE
self.sqs_backend.change_message_visibility( self.sqs_backend.change_message_visibility(
queue_name=queue_name, queue_name=queue_name,
@ -131,36 +211,70 @@ class SQSResponse(BaseResponse):
visibility_timeout=visibility_timeout, visibility_timeout=visibility_timeout,
) )
if self.is_json():
return "{}"
template = self.response_template(CHANGE_MESSAGE_VISIBILITY_RESPONSE) template = self.response_template(CHANGE_MESSAGE_VISIBILITY_RESPONSE)
return template.render() return template.render()
@jsonify_error
def change_message_visibility_batch(self) -> str: def change_message_visibility_batch(self) -> str:
queue_name = self._get_queue_name() queue_name = self._get_queue_name()
entries = self._get_list_prefix("ChangeMessageVisibilityBatchRequestEntry") if self.is_json():
entries = [
{camelcase_to_underscores(key): value for key, value in entr.items()}
for entr in self._get_param("Entries")
]
else:
entries = self._get_list_prefix("ChangeMessageVisibilityBatchRequestEntry")
success, error = self.sqs_backend.change_message_visibility_batch( success, error = self.sqs_backend.change_message_visibility_batch(
queue_name, entries queue_name, entries
) )
if self.is_json():
return json.dumps(
{"Successful": [{"Id": _id} for _id in success], "Failed": error}
)
template = self.response_template(CHANGE_MESSAGE_VISIBILITY_BATCH_RESPONSE) template = self.response_template(CHANGE_MESSAGE_VISIBILITY_BATCH_RESPONSE)
return template.render(success=success, errors=error) return template.render(success=success, errors=error)
@jsonify_error
def get_queue_attributes(self) -> str: def get_queue_attributes(self) -> str:
queue_name = self._get_queue_name() queue_name = self._get_queue_name()
if self.querystring.get("AttributeNames"): if not self.is_json() and self.querystring.get("AttributeNames"):
raise InvalidAttributeName("") raise InvalidAttributeName("")
# if connecting to AWS via boto, then 'AttributeName' is just a normal parameter if self.is_json():
attribute_names = self._get_multi_param( attribute_names = self._get_param("AttributeNames")
"AttributeName" if attribute_names == [] or (attribute_names and "" in attribute_names):
) or self.querystring.get("AttributeName") raise InvalidAttributeName("")
else:
# if connecting to AWS via boto, then 'AttributeName' is just a normal parameter
attribute_names = self._get_multi_param(
"AttributeName"
) or self.querystring.get("AttributeName")
attributes = self.sqs_backend.get_queue_attributes(queue_name, attribute_names) # type: ignore attributes = self.sqs_backend.get_queue_attributes(queue_name, attribute_names) # type: ignore
if self.is_json():
if len(attributes) == 0:
return "{}"
return json.dumps(
{
"Attributes": {
key: str(value)
for key, value in attributes.items()
if value is not None
}
}
)
template = self.response_template(GET_QUEUE_ATTRIBUTES_RESPONSE) template = self.response_template(GET_QUEUE_ATTRIBUTES_RESPONSE)
return template.render(attributes=attributes) return template.render(attributes=attributes)
@jsonify_error
def set_queue_attributes(self) -> str: def set_queue_attributes(self) -> str:
# TODO validate self.get_param('QueueUrl') # TODO validate self.get_param('QueueUrl')
attribute = self.attribute attribute = self.attribute
@ -177,6 +291,7 @@ class SQSResponse(BaseResponse):
return SET_QUEUE_ATTRIBUTE_RESPONSE return SET_QUEUE_ATTRIBUTE_RESPONSE
@jsonify_error
def delete_queue(self) -> str: def delete_queue(self) -> str:
# TODO validate self.get_param('QueueUrl') # TODO validate self.get_param('QueueUrl')
queue_name = self._get_queue_name() queue_name = self._get_queue_name()
@ -186,19 +301,32 @@ class SQSResponse(BaseResponse):
template = self.response_template(DELETE_QUEUE_RESPONSE) template = self.response_template(DELETE_QUEUE_RESPONSE)
return template.render() return template.render()
def send_message(self) -> Union[str, Tuple[str, Dict[str, int]]]: @jsonify_error
def send_message(self) -> Union[str, TYPE_RESPONSE]:
message = self._get_param("MessageBody") message = self._get_param("MessageBody")
delay_seconds = int(self._get_param("DelaySeconds", 0)) delay_seconds = int(self._get_param("DelaySeconds", 0))
message_group_id = self._get_param("MessageGroupId") message_group_id = self._get_param("MessageGroupId")
message_dedupe_id = self._get_param("MessageDeduplicationId") message_dedupe_id = self._get_param("MessageDeduplicationId")
if len(message) > MAXIMUM_MESSAGE_LENGTH: if len(message) > MAXIMUM_MESSAGE_LENGTH:
return ERROR_TOO_LONG_RESPONSE, dict(status=400) return self._error(
"InvalidParameterValue",
message="One or more parameters are invalid. Reason: Message must be shorter than 262144 bytes.",
)
message_attributes = parse_message_attributes(self.querystring) if self.is_json():
system_message_attributes = parse_message_attributes( message_attributes = self._get_param("MessageAttributes")
self.querystring, key="MessageSystemAttribute" self.normalize_json_msg_attributes(message_attributes)
) else:
message_attributes = parse_message_attributes(self.querystring)
if self.is_json():
system_message_attributes = self._get_param("MessageSystemAttributes")
self.normalize_json_msg_attributes(system_message_attributes)
else:
system_message_attributes = parse_message_attributes(
self.querystring, key="MessageSystemAttribute"
)
queue_name = self._get_queue_name() queue_name = self._get_queue_name()
@ -215,9 +343,30 @@ class SQSResponse(BaseResponse):
except RESTError as err: except RESTError as err:
return self._error(err.error_type, err.message) return self._error(err.error_type, err.message)
if self.is_json():
resp = {
"MD5OfMessageBody": message.body_md5,
"MessageId": message.id,
}
if len(message.message_attributes) > 0:
resp["MD5OfMessageAttributes"] = message.attribute_md5
return json.dumps(resp)
template = self.response_template(SEND_MESSAGE_RESPONSE) template = self.response_template(SEND_MESSAGE_RESPONSE)
return template.render(message=message, message_attributes=message_attributes) return template.render(message=message, message_attributes=message_attributes)
def normalize_json_msg_attributes(self, message_attributes: Dict[str, Any]) -> None:
for key, value in (message_attributes or {}).items():
if "BinaryValue" in value:
message_attributes[key]["binary_value"] = value.pop("BinaryValue")
if "StringValue" in value:
message_attributes[key]["string_value"] = value.pop("StringValue")
if "DataType" in value:
message_attributes[key]["data_type"] = value.pop("DataType")
validate_message_attributes(message_attributes)
@jsonify_error
def send_message_batch(self) -> str: def send_message_batch(self) -> str:
""" """
The querystring comes like this The querystring comes like this
@ -234,39 +383,33 @@ class SQSResponse(BaseResponse):
self.sqs_backend.get_queue(queue_name) self.sqs_backend.get_queue(queue_name)
if self.querystring.get("Entries"): if not self.is_json() and self.querystring.get("Entries"):
raise EmptyBatchRequest() raise EmptyBatchRequest()
if self.is_json():
entries = {} entries = {
for key, value in self.querystring.items(): str(idx): entry for idx, entry in enumerate(self._get_param("Entries"))
match = re.match(r"^SendMessageBatchRequestEntry\.(\d+)\.Id", key) }
if match: else:
index = match.group(1) entries = {
str(idx): entry
message_attributes = parse_message_attributes( for idx, entry in enumerate(
self.querystring, self._get_multi_param("SendMessageBatchRequestEntry")
base=f"SendMessageBatchRequestEntry.{index}.",
) )
}
for entry in entries.values():
if "MessageAttribute" in entry:
entry["MessageAttributes"] = {
val["Name"]: val["Value"]
for val in entry.pop("MessageAttribute")
}
entries[index] = { for entry in entries.values():
"Id": value[0], if "MessageAttributes" in entry:
"MessageBody": self.querystring.get( # type: ignore self.normalize_json_msg_attributes(entry["MessageAttributes"])
f"SendMessageBatchRequestEntry.{index}.MessageBody" else:
)[0], entry["MessageAttributes"] = {}
"DelaySeconds": self.querystring.get( if "DelaySeconds" not in entry:
f"SendMessageBatchRequestEntry.{index}.DelaySeconds", entry["DelaySeconds"] = None
[None],
)[0],
"MessageAttributes": message_attributes,
"MessageGroupId": self.querystring.get(
f"SendMessageBatchRequestEntry.{index}.MessageGroupId",
[None],
)[0],
"MessageDeduplicationId": self.querystring.get(
f"SendMessageBatchRequestEntry.{index}.MessageDeduplicationId",
[None],
)[0],
}
if entries == {}: if entries == {}:
raise EmptyBatchRequest() raise EmptyBatchRequest()
@ -286,16 +429,38 @@ class SQSResponse(BaseResponse):
} }
) )
if self.is_json():
resp: Dict[str, Any] = {"Successful": [], "Failed": errors}
for msg in messages:
msg_dict = {
"Id": msg.user_id, # type: ignore
"MessageId": msg.id,
"MD5OfMessageBody": msg.body_md5,
}
if len(msg.message_attributes) > 0:
msg_dict["MD5OfMessageAttributes"] = msg.attribute_md5
resp["Successful"].append(msg_dict)
return json.dumps(resp)
template = self.response_template(SEND_MESSAGE_BATCH_RESPONSE) template = self.response_template(SEND_MESSAGE_BATCH_RESPONSE)
return template.render(messages=messages, errors=errors) return template.render(messages=messages, errors=errors)
@jsonify_error
def delete_message(self) -> str: def delete_message(self) -> str:
queue_name = self._get_queue_name() queue_name = self._get_queue_name()
receipt_handle = self.querystring.get("ReceiptHandle")[0] # type: ignore if self.is_json():
receipt_handle = self._get_param("ReceiptHandle")
else:
receipt_handle = self.querystring.get("ReceiptHandle")[0] # type: ignore
self.sqs_backend.delete_message(queue_name, receipt_handle) self.sqs_backend.delete_message(queue_name, receipt_handle)
if self.is_json():
return "{}"
template = self.response_template(DELETE_MESSAGE_RESPONSE) template = self.response_template(DELETE_MESSAGE_RESPONSE)
return template.render() return template.render()
@jsonify_error
def delete_message_batch(self) -> str: def delete_message_batch(self) -> str:
""" """
The querystring comes like this The querystring comes like this
@ -308,21 +473,17 @@ class SQSResponse(BaseResponse):
""" """
queue_name = self._get_queue_name() queue_name = self._get_queue_name()
receipts = [] if self.is_json():
receipts = self._get_param("Entries")
else:
receipts = self._get_multi_param("DeleteMessageBatchRequestEntry")
for index in range(1, 11): for r in receipts:
# Loop through looking for messages for key in list(r.keys()):
receipt_key = f"DeleteMessageBatchRequestEntry.{index}.ReceiptHandle" if key == "Id":
receipt_handle = self.querystring.get(receipt_key) r["msg_user_id"] = r.pop(key)
if not receipt_handle: else:
# Found all messages r[camelcase_to_underscores(key)] = r.pop(key)
break
message_user_id_key = f"DeleteMessageBatchRequestEntry.{index}.Id"
message_user_id = self.querystring.get(message_user_id_key)[0] # type: ignore
receipts.append(
{"receipt_handle": receipt_handle[0], "msg_user_id": message_user_id}
)
receipt_seen = set() receipt_seen = set()
for receipt_and_id in receipts: for receipt_and_id in receipts:
@ -333,27 +494,45 @@ class SQSResponse(BaseResponse):
success, errors = self.sqs_backend.delete_message_batch(queue_name, receipts) success, errors = self.sqs_backend.delete_message_batch(queue_name, receipts)
if self.is_json():
return json.dumps(
{"Successful": [{"Id": _id} for _id in success], "Failed": errors}
)
template = self.response_template(DELETE_MESSAGE_BATCH_RESPONSE) template = self.response_template(DELETE_MESSAGE_BATCH_RESPONSE)
return template.render(success=success, errors=errors) return template.render(success=success, errors=errors)
@jsonify_error
def purge_queue(self) -> str: def purge_queue(self) -> str:
queue_name = self._get_queue_name() queue_name = self._get_queue_name()
self.sqs_backend.purge_queue(queue_name) self.sqs_backend.purge_queue(queue_name)
template = self.response_template(PURGE_QUEUE_RESPONSE) template = self.response_template(PURGE_QUEUE_RESPONSE)
return template.render() return template.render()
def receive_message(self) -> Union[str, Tuple[str, Dict[str, int]]]: @jsonify_error
def receive_message(self) -> Union[str, TYPE_RESPONSE]:
queue_name = self._get_queue_name() queue_name = self._get_queue_name()
message_attributes = self._get_multi_param("message_attributes") if self.is_json():
message_attributes = self._get_param("MessageAttributeNames")
else:
message_attributes = self._get_multi_param("message_attributes")
if not message_attributes: if not message_attributes:
message_attributes = extract_input_message_attributes(self.querystring) message_attributes = extract_input_message_attributes(self.querystring)
attribute_names = self._get_multi_param("AttributeName") if self.is_json():
attribute_names = self._get_param("AttributeNames", [])
else:
attribute_names = self._get_multi_param("AttributeName")
queue = self.sqs_backend.get_queue(queue_name) queue = self.sqs_backend.get_queue(queue_name)
try: try:
message_count = int(self.querystring.get("MaxNumberOfMessages")[0]) # type: ignore if self.is_json():
message_count = self._get_param(
"MaxNumberOfMessages", DEFAULT_RECEIVED_MESSAGES
)
else:
message_count = int(self.querystring.get("MaxNumberOfMessages")[0]) # type: ignore
except TypeError: except TypeError:
message_count = DEFAULT_RECEIVED_MESSAGES message_count = DEFAULT_RECEIVED_MESSAGES
@ -367,7 +546,10 @@ class SQSResponse(BaseResponse):
) )
try: try:
wait_time = int(self.querystring.get("WaitTimeSeconds")[0]) # type: ignore if self.is_json():
wait_time = int(self._get_param("WaitTimeSeconds"))
else:
wait_time = int(self.querystring.get("WaitTimeSeconds")[0]) # type: ignore
except TypeError: except TypeError:
wait_time = int(queue.receive_message_wait_time_seconds) # type: ignore wait_time = int(queue.receive_message_wait_time_seconds) # type: ignore
@ -385,7 +567,7 @@ class SQSResponse(BaseResponse):
except TypeError: except TypeError:
visibility_timeout = queue.visibility_timeout # type: ignore visibility_timeout = queue.visibility_timeout # type: ignore
except ValueError: except ValueError:
return ERROR_MAX_VISIBILITY_TIMEOUT_RESPONSE, dict(status=400) return 400, {}, ERROR_MAX_VISIBILITY_TIMEOUT_RESPONSE
messages = self.sqs_backend.receive_message( messages = self.sqs_backend.receive_message(
queue_name, message_count, wait_time, visibility_timeout, message_attributes queue_name, message_count, wait_time, visibility_timeout, message_attributes
@ -406,22 +588,98 @@ class SQSResponse(BaseResponse):
if any(x in ["All", pascalcase_name] for x in attribute_names): if any(x in ["All", pascalcase_name] for x in attribute_names):
attributes[attribute] = True attributes[attribute] = True
if self.is_json():
msgs = []
for message in messages:
msg: Dict[str, Any] = {
"MessageId": message.id,
"ReceiptHandle": message.receipt_handle,
"MD5OfBody": message.body_md5,
"Body": message.original_body,
"Attributes": {},
"MessageAttributes": {},
}
if len(message.message_attributes) > 0:
msg["MD5OfMessageAttributes"] = message.attribute_md5
if attributes["sender_id"]:
msg["Attributes"]["SenderId"] = message.sender_id
if attributes["sent_timestamp"]:
msg["Attributes"]["SentTimestamp"] = str(message.sent_timestamp)
if attributes["approximate_receive_count"]:
msg["Attributes"]["ApproximateReceiveCount"] = str(
message.approximate_receive_count
)
if attributes["approximate_first_receive_timestamp"]:
msg["Attributes"]["ApproximateFirstReceiveTimestamp"] = str(
message.approximate_first_receive_timestamp
)
if attributes["message_deduplication_id"]:
msg["Attributes"][
"MessageDeduplicationId"
] = message.deduplication_id
if attributes["message_group_id"] and message.group_id is not None:
msg["Attributes"]["MessageGroupId"] = message.group_id
if message.system_attributes and message.system_attributes.get(
"AWSTraceHeader"
):
msg["Attributes"]["AWSTraceHeader"] = message.system_attributes[
"AWSTraceHeader"
].get("string_value")
if (
attributes["sequence_number"]
and message.sequence_number is not None
):
msg["Attributes"]["SequenceNumber"] = message.sequence_number
for name, value in message.message_attributes.items():
msg["MessageAttributes"][name] = {"DataType": value["data_type"]}
if "Binary" in value["data_type"]:
msg["MessageAttributes"][name]["BinaryValue"] = value[
"binary_value"
]
else:
msg["MessageAttributes"][name]["StringValue"] = value[
"string_value"
]
if len(msg["Attributes"]) == 0:
msg.pop("Attributes")
if len(msg["MessageAttributes"]) == 0:
msg.pop("MessageAttributes")
msgs.append(msg)
return json.dumps({"Messages": msgs} if msgs else {})
template = self.response_template(RECEIVE_MESSAGE_RESPONSE) template = self.response_template(RECEIVE_MESSAGE_RESPONSE)
return template.render(messages=messages, attributes=attributes) return template.render(messages=messages, attributes=attributes)
@jsonify_error
def list_dead_letter_source_queues(self) -> str: def list_dead_letter_source_queues(self) -> str:
request_url = urlparse(self.uri) request_url = urlparse(self.uri)
queue_name = self._get_queue_name() queue_name = self._get_queue_name()
source_queue_urls = self.sqs_backend.list_dead_letter_source_queues(queue_name) queues = self.sqs_backend.list_dead_letter_source_queues(queue_name)
if self.is_json():
return json.dumps(
{"queueUrls": [queue.url(request_url) for queue in queues]}
)
template = self.response_template(LIST_DEAD_LETTER_SOURCE_QUEUES_RESPONSE) template = self.response_template(LIST_DEAD_LETTER_SOURCE_QUEUES_RESPONSE)
return template.render(queues=source_queue_urls, request_url=request_url) return template.render(queues=queues, request_url=request_url)
@jsonify_error
def add_permission(self) -> str: def add_permission(self) -> str:
queue_name = self._get_queue_name() queue_name = self._get_queue_name()
actions = self._get_multi_param("ActionName") actions = (
account_ids = self._get_multi_param("AWSAccountId") self._get_param("Actions")
if self.is_json()
else self._get_multi_param("ActionName")
)
account_ids = (
self._get_param("AWSAccountIds")
if self.is_json()
else self._get_multi_param("AWSAccountId")
)
label = self._get_param("Label") label = self._get_param("Label")
self.sqs_backend.add_permission(queue_name, actions, account_ids, label) self.sqs_backend.add_permission(queue_name, actions, account_ids, label)
@ -429,6 +687,7 @@ class SQSResponse(BaseResponse):
template = self.response_template(ADD_PERMISSION_RESPONSE) template = self.response_template(ADD_PERMISSION_RESPONSE)
return template.render() return template.render()
@jsonify_error
def remove_permission(self) -> str: def remove_permission(self) -> str:
queue_name = self._get_queue_name() queue_name = self._get_queue_name()
label = self._get_param("Label") label = self._get_param("Label")
@ -438,21 +697,36 @@ class SQSResponse(BaseResponse):
template = self.response_template(REMOVE_PERMISSION_RESPONSE) template = self.response_template(REMOVE_PERMISSION_RESPONSE)
return template.render() return template.render()
@jsonify_error
def tag_queue(self) -> str: def tag_queue(self) -> str:
queue_name = self._get_queue_name() queue_name = self._get_queue_name()
tags = self._get_map_prefix("Tag", key_end=".Key", value_end=".Value") if self.is_json():
tags = self._get_param("Tags")
else:
tags = self._get_map_prefix("Tag", key_end=".Key", value_end=".Value")
self.sqs_backend.tag_queue(queue_name, tags) self.sqs_backend.tag_queue(queue_name, tags)
if self.is_json():
return "{}"
template = self.response_template(TAG_QUEUE_RESPONSE) template = self.response_template(TAG_QUEUE_RESPONSE)
return template.render() return template.render()
@jsonify_error
def untag_queue(self) -> str: def untag_queue(self) -> str:
queue_name = self._get_queue_name() queue_name = self._get_queue_name()
tag_keys = self._get_multi_param("TagKey") tag_keys = (
self._get_param("TagKeys")
if self.is_json()
else self._get_multi_param("TagKey")
)
self.sqs_backend.untag_queue(queue_name, tag_keys) self.sqs_backend.untag_queue(queue_name, tag_keys)
if self.is_json():
return "{}"
template = self.response_template(UNTAG_QUEUE_RESPONSE) template = self.response_template(UNTAG_QUEUE_RESPONSE)
return template.render() return template.render()
@ -461,6 +735,9 @@ class SQSResponse(BaseResponse):
queue = self.sqs_backend.list_queue_tags(queue_name) queue = self.sqs_backend.list_queue_tags(queue_name)
if self.is_json():
return json.dumps({"Tags": queue.tags})
template = self.response_template(LIST_QUEUE_TAGS_RESPONSE) template = self.response_template(LIST_QUEUE_TAGS_RESPONSE)
return template.render(tags=queue.tags) return template.render(tags=queue.tags)
@ -665,7 +942,7 @@ DELETE_MESSAGE_BATCH_RESPONSE = """<DeleteMessageBatchResponse>
<Id>{{ error_dict['Id'] }}</Id> <Id>{{ error_dict['Id'] }}</Id>
<Code>{{ error_dict['Code'] }}</Code> <Code>{{ error_dict['Code'] }}</Code>
<Message>{{ error_dict['Message'] }}</Message> <Message>{{ error_dict['Message'] }}</Message>
<SenderFault>{{ error_dict['SenderFault'] }}</SenderFault> <SenderFault>{{ 'true' if error_dict['SenderFault'] else 'false' }}</SenderFault>
</BatchResultErrorEntry> </BatchResultErrorEntry>
{% endfor %} {% endfor %}
</DeleteMessageBatchResult> </DeleteMessageBatchResult>
@ -756,16 +1033,6 @@ LIST_QUEUE_TAGS_RESPONSE = """<ListQueueTagsResponse>
</ResponseMetadata> </ResponseMetadata>
</ListQueueTagsResponse>""" </ListQueueTagsResponse>"""
ERROR_TOO_LONG_RESPONSE = """<ErrorResponse xmlns="http://queue.amazonaws.com/doc/2012-11-05/">
<Error>
<Type>Sender</Type>
<Code>InvalidParameterValue</Code>
<Message>One or more parameters are invalid. Reason: Message must be shorter than 262144 bytes.</Message>
<Detail/>
</Error>
<RequestId>6fde8d1e-52cd-4581-8cd9-c512f4c64223</RequestId>
</ErrorResponse>"""
ERROR_MAX_VISIBILITY_TIMEOUT_RESPONSE = ( ERROR_MAX_VISIBILITY_TIMEOUT_RESPONSE = (
f"Invalid request, maximum visibility timeout is {MAXIMUM_VISIBILITY_TIMEOUT}" f"Invalid request, maximum visibility timeout is {MAXIMUM_VISIBILITY_TIMEOUT}"
) )

View File

@ -43,38 +43,51 @@ def parse_message_attributes(
break break
data_type_key = base + f"{key}.{index}.{value_namespace}DataType" data_type_key = base + f"{key}.{index}.{value_namespace}DataType"
data_type = querystring.get(data_type_key) data_type = querystring.get(data_type_key, [None])[0]
data_type_parts = (data_type or "").split(".")[0]
type_prefix = "String"
if data_type_parts == "Binary":
type_prefix = "Binary"
value_key = base + f"{key}.{index}.{value_namespace}{type_prefix}Value"
value = querystring.get(value_key, [None])[0]
message_attributes[name[0]] = {
"data_type": data_type,
type_prefix.lower() + "_value": value,
}
index += 1
validate_message_attributes(message_attributes)
return message_attributes
def validate_message_attributes(message_attributes: Dict[str, Any]) -> None:
for name, value in (message_attributes or {}).items():
data_type = value["data_type"]
if not data_type: if not data_type:
raise MessageAttributesInvalid( raise MessageAttributesInvalid(
f"The message attribute '{name[0]}' must contain non-empty message attribute value." f"The message attribute '{name}' must contain non-empty message attribute value."
) )
data_type_parts = data_type[0].split(".") data_type_parts = data_type.split(".")[0]
if data_type_parts[0] not in [ if data_type_parts not in [
"String", "String",
"Binary", "Binary",
"Number", "Number",
]: ]:
raise MessageAttributesInvalid( raise MessageAttributesInvalid(
f"The message attribute '{name[0]}' has an invalid message attribute type, the set of supported type prefixes is Binary, Number, and String." f"The message attribute '{name}' has an invalid message attribute type, the set of supported type prefixes is Binary, Number, and String."
) )
type_prefix = "String" possible_value_fields = ["string_value", "binary_value"]
if data_type_parts[0] == "Binary": for field in possible_value_fields:
type_prefix = "Binary" if field in value and value[field] is None:
raise MessageAttributesInvalid(
value_key = base + f"{key}.{index}.{value_namespace}{type_prefix}Value" f"The message attribute '{name}' must contain non-empty message attribute value for message attribute type '{data_type}'."
value = querystring.get(value_key) )
if not value:
raise MessageAttributesInvalid(
f"The message attribute '{name[0]}' must contain non-empty message attribute value for message attribute type '{data_type[0]}'."
)
message_attributes[name[0]] = {
"data_type": data_type[0],
type_prefix.lower() + "_value": value[0],
}
index += 1
return message_attributes

View File

@ -0,0 +1,5 @@
APPLICATION_AMZ_JSON_1_0 = "application/x-amz-json-1.0"
APPLICATION_AMZ_JSON_1_1 = "application/x-amz-json-1.1"
APPLICATION_JSON = "application/json"
JSON_TYPES = [APPLICATION_JSON, APPLICATION_AMZ_JSON_1_0, APPLICATION_AMZ_JSON_1_1]