Lambda+SQS: Allow FIFO queues to be used as event sources (#5998)

This commit is contained in:
Bert Blommers 2023-03-01 14:03:20 -01:00 committed by GitHub
parent a59f921036
commit c5a91e6cc6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 94 additions and 84 deletions

View File

@ -1643,11 +1643,6 @@ class LambdaBackend(BaseBackend):
raise RESTError(
"ResourceConflictException", "The resource already exists."
)
if queue.fifo_queue:
raise RESTError(
"InvalidParameterValueException", f"{queue.queue_arn} is FIFO"
)
else:
spec.update({"FunctionArn": func.function_arn})
esm = EventSourceMapping(spec)
self._event_source_mappings[esm.uuid] = esm
@ -1656,6 +1651,7 @@ class LambdaBackend(BaseBackend):
queue.lambda_event_source_mappings[esm.function_arn] = esm
return esm
ddbstream_backend = dynamodbstreams_backends[self.account_id][self.region_name]
ddb_backend = dynamodb_backends[self.account_id][self.region_name]
for stream in json.loads(ddbstream_backend.list_streams())["Streams"]:
@ -1792,6 +1788,14 @@ class LambdaBackend(BaseBackend):
}
]
}
if queue_arn.endswith(".fifo"):
# Messages from FIFO queue have additional attributes
event["Records"][0]["attributes"].update(
{
"MessageGroupId": message.group_id,
"MessageDeduplicationId": message.deduplication_id,
}
)
request_headers: Dict[str, Any] = {}
response_headers: Dict[str, Any] = {}

View File

@ -1,66 +0,0 @@
import json
import time
import uuid
import boto3
import sure # noqa # pylint: disable=unused-import
from moto import mock_sqs, mock_lambda, mock_logs
from tests.test_awslambda.test_lambda import get_test_zip_file1, get_role_name
@mock_logs
@mock_lambda
@mock_sqs
def test_invoke_function_from_sqs_exception():
logs_conn = boto3.client("logs", region_name="us-east-1")
sqs = boto3.resource("sqs", region_name="us-east-1")
queue = sqs.create_queue(QueueName="test-sqs-queue1")
conn = boto3.client("lambda", region_name="us-east-1")
func = conn.create_function(
FunctionName="testFunction",
Runtime="python2.7",
Role=get_role_name(),
Handler="lambda_function.lambda_handler",
Code={"ZipFile": get_test_zip_file1()},
Description="test lambda function",
Timeout=3,
MemorySize=128,
Publish=True,
)
response = conn.create_event_source_mapping(
EventSourceArn=queue.attributes["QueueArn"], FunctionName=func["FunctionArn"]
)
assert response["EventSourceArn"] == queue.attributes["QueueArn"]
assert response["State"] == "Enabled"
entries = [
{
"Id": "1",
"MessageBody": json.dumps({"uuid": str(uuid.uuid4()), "test": "test"}),
}
]
queue.send_messages(Entries=entries)
start = time.time()
while (time.time() - start) < 30:
result = logs_conn.describe_log_streams(logGroupName="/aws/lambda/testFunction")
log_streams = result.get("logStreams")
if not log_streams:
time.sleep(1)
continue
assert len(log_streams) >= 1
result = logs_conn.get_log_events(
logGroupName="/aws/lambda/testFunction",
logStreamName=log_streams[0]["logStreamName"],
)
for event in result.get("events"):
if "custom log event" in event["message"]:
return
time.sleep(1)
assert False, "Test Failed"

View File

@ -5,19 +5,22 @@ import uuid
from moto import mock_lambda, mock_sqs, mock_logs
from tests.test_awslambda.test_lambda import get_test_zip_file1, get_role_name
from tests.test_awslambda.utilities import get_test_zip_file_print_event
@mock_logs
@mock_lambda
@mock_sqs
def test_invoke_function_from_sqs_exception():
def test_invoke_function_from_sqs_queue():
logs_conn = boto3.client("logs", region_name="us-east-1")
sqs = boto3.resource("sqs", region_name="us-east-1")
queue = sqs.create_queue(QueueName="test-sqs-queue1")
queue_name = str(uuid.uuid4())[0:6]
queue = sqs.create_queue(QueueName=queue_name)
fn_name = str(uuid.uuid4())[0:6]
conn = boto3.client("lambda", region_name="us-east-1")
func = conn.create_function(
FunctionName="testFunction",
FunctionName=fn_name,
Runtime="python2.7",
Role=get_role_name(),
Handler="lambda_function.lambda_handler",
@ -46,20 +49,89 @@ def test_invoke_function_from_sqs_exception():
start = time.time()
while (time.time() - start) < 30:
result = logs_conn.describe_log_streams(logGroupName="/aws/lambda/testFunction")
result = logs_conn.describe_log_streams(logGroupName=f"/aws/lambda/{fn_name}")
log_streams = result.get("logStreams")
if not log_streams:
time.sleep(1)
time.sleep(0.5)
continue
assert len(log_streams) >= 1
result = logs_conn.get_log_events(
logGroupName="/aws/lambda/testFunction",
logGroupName=f"/aws/lambda/{fn_name}",
logStreamName=log_streams[0]["logStreamName"],
)
for event in result.get("events"):
if "custom log event" in event["message"]:
return
time.sleep(1)
time.sleep(0.5)
assert False, "Test Failed"
@mock_logs
@mock_lambda
@mock_sqs
def test_invoke_function_from_sqs_fifo_queue():
"""
Create a FIFO Queue
Create a Lambda Function
Send a message to the queue
Verify the Lambda has been invoked with the correct message body
(By checking the resulting log files)
"""
logs_conn = boto3.client("logs", region_name="us-east-1")
sqs = boto3.resource("sqs", region_name="us-east-1")
queue_name = str(uuid.uuid4())[0:6] + ".fifo"
queue = sqs.create_queue(QueueName=queue_name, Attributes={"FifoQueue": "true"})
fn_name = str(uuid.uuid4())[0:6]
conn = boto3.client("lambda", region_name="us-east-1")
func = conn.create_function(
FunctionName=fn_name,
Runtime="python3.7",
Role=get_role_name(),
Handler="lambda_function.lambda_handler",
Code={"ZipFile": get_test_zip_file_print_event()},
Timeout=3,
MemorySize=128,
Publish=True,
)
response = conn.create_event_source_mapping(
EventSourceArn=queue.attributes["QueueArn"], FunctionName=func["FunctionArn"]
)
assert response["EventSourceArn"] == queue.attributes["QueueArn"]
assert response["State"] == "Enabled"
entries = [{"Id": "1", "MessageBody": "some body", "MessageGroupId": "mg1"}]
queue.send_messages(Entries=entries)
start = time.time()
while (time.time() - start) < 30:
result = logs_conn.describe_log_streams(logGroupName=f"/aws/lambda/{fn_name}")
log_streams = result.get("logStreams")
if not log_streams:
time.sleep(0.5)
continue
assert len(log_streams) >= 1
result = logs_conn.get_log_events(
logGroupName=f"/aws/lambda/{fn_name}",
logStreamName=log_streams[0]["logStreamName"],
)
for event in result.get("events"):
try:
body = json.loads(event.get("message"))
atts = body["Records"][0]["attributes"]
atts.should.have.key("MessageGroupId").equals("mg1")
atts.should.have.key("MessageDeduplicationId")
return
except: # noqa: E722 Do not use bare except
pass
time.sleep(0.5)
assert False, "Test Failed"