AWSLambda - Policy improvements (#4949)

This commit is contained in:
Bert Blommers 2022-03-19 12:00:39 -01:00 committed by GitHub
parent 411ce71d3a
commit de990b07f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 218 additions and 131 deletions

View File

@ -1095,9 +1095,18 @@ class LambdaStorage(object):
return self._arns.get(arn, None)
def get_function_by_name_or_arn(self, name_or_arn, qualifier=None):
return self.get_function_by_name(name_or_arn, qualifier) or self.get_arn(
fn = self.get_function_by_name(name_or_arn, qualifier) or self.get_arn(
name_or_arn
)
if fn is None:
if name_or_arn.startswith("arn:aws"):
arn = name_or_arn
else:
arn = make_function_arn(self.region_name, ACCOUNT_ID, name_or_arn)
if qualifier:
arn = f"{arn}:{qualifier}"
raise UnknownFunctionException(arn)
return fn
def put_function(self, fn):
"""
@ -1127,12 +1136,6 @@ class LambdaStorage(object):
def publish_function(self, name_or_arn, description=""):
function = self.get_function_by_name_or_arn(name_or_arn)
if not function:
if name_or_arn.startswith("arn:aws"):
arn = name_or_arn
else:
arn = make_function_arn(self.region_name, ACCOUNT_ID, name_or_arn)
raise UnknownFunctionException(arn)
name = function.function_name
if name not in self._functions:
return None
@ -1150,23 +1153,32 @@ class LambdaStorage(object):
return fn
def del_function(self, name_or_arn, qualifier=None):
function = self.get_function_by_name_or_arn(name_or_arn)
if function:
name = function.function_name
if not qualifier:
# Something is still reffing this so delete all arns
latest = self._functions[name]["latest"].function_arn
del self._arns[latest]
function = self.get_function_by_name_or_arn(name_or_arn, qualifier)
name = function.function_name
if not qualifier:
# Something is still reffing this so delete all arns
latest = self._functions[name]["latest"].function_arn
del self._arns[latest]
for fn in self._functions[name]["versions"]:
del self._arns[fn.function_arn]
for fn in self._functions[name]["versions"]:
del self._arns[fn.function_arn]
del self._functions[name]
elif qualifier == "$LATEST":
self._functions[name]["latest"] = None
# If theres no functions left
if (
not self._functions[name]["versions"]
and not self._functions[name]["latest"]
):
del self._functions[name]
return True
elif qualifier == "$LATEST":
self._functions[name]["latest"] = None
else:
fn = self.get_function_by_name(name, qualifier)
if fn:
self._functions[name]["versions"].remove(fn)
# If theres no functions left
if (
@ -1175,24 +1187,6 @@ class LambdaStorage(object):
):
del self._functions[name]
return True
else:
fn = self.get_function_by_name(name, qualifier)
if fn:
self._functions[name]["versions"].remove(fn)
# If theres no functions left
if (
not self._functions[name]["versions"]
and not self._functions[name]["latest"]
):
del self._functions[name]
return True
return False
def all(self):
result = []
@ -1488,7 +1482,7 @@ class LambdaBackend(BaseBackend):
return self._lambdas.get_arn(function_arn)
def delete_function(self, function_name, qualifier=None):
return self._lambdas.del_function(function_name, qualifier)
self._lambdas.del_function(function_name, qualifier)
def list_functions(self, func_version=None):
if func_version == "ALL":
@ -1601,31 +1595,20 @@ class LambdaBackend(BaseBackend):
return func.invoke(json.dumps(event), {}, {})
def list_tags(self, resource):
return self.get_function_by_arn(resource).tags
return self._lambdas.get_function_by_name_or_arn(resource).tags
def tag_resource(self, resource, tags):
fn = self.get_function_by_arn(resource)
if not fn:
return False
fn = self._lambdas.get_function_by_name_or_arn(resource)
fn.tags.update(tags)
return True
def untag_resource(self, resource, tagKeys):
fn = self.get_function_by_arn(resource)
if fn:
for key in tagKeys:
try:
del fn.tags[key]
except KeyError:
pass
# Don't care
return True
return False
fn = self._lambdas.get_function_by_name_or_arn(resource)
for key in tagKeys:
fn.tags.pop(key, None)
def add_permission(self, function_name, raw):
fn = self.get_function(function_name)
fn.policy.add_statement(raw)
def add_permission(self, function_name, qualifier, raw):
fn = self.get_function(function_name, qualifier)
fn.policy.add_statement(raw, qualifier)
def remove_permission(self, function_name, sid, revision=""):
fn = self.get_function(function_name)

View File

@ -29,7 +29,7 @@ class Policy:
}
# adds the raw JSON statement to the policy
def add_statement(self, raw):
def add_statement(self, raw, qualifier=None):
policy = json.loads(raw, object_hook=self.decode_policy)
if len(policy.revision) > 0 and self.revision != policy.revision:
raise PreconditionFailedException(
@ -40,6 +40,10 @@ class Policy:
# Remove #LATEST from the Resource (Lambda ARN)
if policy.statements[0].get("Resource", "").endswith("$LATEST"):
policy.statements[0]["Resource"] = policy.statements[0]["Resource"][0:-8]
if qualifier:
policy.statements[0]["Resource"] = (
policy.statements[0]["Resource"] + ":" + qualifier
)
self.statements.append(policy.statements[0])
self.revision = str(uuid.uuid4())

View File

@ -193,12 +193,10 @@ class LambdaResponse(BaseResponse):
def _add_policy(self, request):
path = request.path if hasattr(request, "path") else path_url(request.url)
function_name = unquote(path.split("/")[-2])
if self.lambda_backend.get_function(function_name):
statement = self.body
self.lambda_backend.add_permission(function_name, statement)
return 200, {}, json.dumps({"Statement": statement})
else:
return 404, {}, "{}"
qualifier = self.querystring.get("Qualifier", [None])[0]
statement = self.body
self.lambda_backend.add_permission(function_name, qualifier, statement)
return 200, {}, json.dumps({"Statement": statement})
def _get_policy(self, request):
path = request.path if hasattr(request, "path") else path_url(request.url)
@ -257,12 +255,9 @@ class LambdaResponse(BaseResponse):
function_name = unquote(self.path.rsplit("/", 3)[-3])
fn = self.lambda_backend.get_function(function_name, None)
if fn:
payload = fn.invoke(self.body, self.headers, response_headers)
response_headers["Content-Length"] = str(len(payload))
return 202, response_headers, payload
else:
return 404, response_headers, "{}"
payload = fn.invoke(self.body, self.headers, response_headers)
response_headers["Content-Length"] = str(len(payload))
return 202, response_headers, payload
def _list_functions(self):
querystring = self.querystring
@ -331,20 +326,15 @@ class LambdaResponse(BaseResponse):
description = self._get_param("Description")
fn = self.lambda_backend.publish_function(function_name, description)
if fn:
config = fn.get_configuration()
return 201, {}, json.dumps(config)
else:
return 404, {}, "{}"
config = fn.get_configuration()
return 201, {}, json.dumps(config)
def _delete_function(self):
function_name = unquote(self.path.rsplit("/", 1)[-1])
qualifier = self._get_param("Qualifier", None)
if self.lambda_backend.delete_function(function_name, qualifier):
return 204, {}, ""
else:
return 404, {}, "{}"
self.lambda_backend.delete_function(function_name, qualifier)
return 204, {}, ""
@staticmethod
def _set_configuration_qualifier(configuration, qualifier):
@ -360,14 +350,11 @@ class LambdaResponse(BaseResponse):
fn = self.lambda_backend.get_function(function_name, qualifier)
if fn:
code = fn.get_code()
code["Configuration"] = self._set_configuration_qualifier(
code["Configuration"], qualifier
)
return 200, {}, json.dumps(code)
else:
return 404, {"x-amzn-ErrorType": "ResourceNotFoundException"}, "{}"
code = fn.get_code()
code["Configuration"] = self._set_configuration_qualifier(
code["Configuration"], qualifier
)
return 200, {}, json.dumps(code)
def _get_function_configuration(self):
function_name = unquote(self.path.rsplit("/", 2)[-2])
@ -375,13 +362,10 @@ class LambdaResponse(BaseResponse):
fn = self.lambda_backend.get_function(function_name, qualifier)
if fn:
configuration = self._set_configuration_qualifier(
fn.get_configuration(), qualifier
)
return 200, {}, json.dumps(configuration)
else:
return 404, {"x-amzn-ErrorType": "ResourceNotFoundException"}, "{}"
configuration = self._set_configuration_qualifier(
fn.get_configuration(), qualifier
)
return 200, {}, json.dumps(configuration)
def _get_aws_region(self, full_url):
region = self.region_regex.search(full_url)
@ -393,28 +377,21 @@ class LambdaResponse(BaseResponse):
def _list_tags(self):
function_arn = unquote(self.path.rsplit("/", 1)[-1])
fn = self.lambda_backend.get_function_by_arn(function_arn)
if fn:
return 200, {}, json.dumps({"Tags": fn.tags})
else:
return 404, {}, "{}"
tags = self.lambda_backend.list_tags(function_arn)
return 200, {}, json.dumps({"Tags": tags})
def _tag_resource(self):
function_arn = unquote(self.path.rsplit("/", 1)[-1])
if self.lambda_backend.tag_resource(function_arn, self.json_body["Tags"]):
return 200, {}, "{}"
else:
return 404, {}, "{}"
self.lambda_backend.tag_resource(function_arn, self.json_body["Tags"])
return 200, {}, "{}"
def _untag_resource(self):
function_arn = unquote(self.path.rsplit("/", 1)[-1])
tag_keys = self.querystring["tagKeys"]
if self.lambda_backend.untag_resource(function_arn, tag_keys):
return 204, {}, "{}"
else:
return 404, {}, "{}"
self.lambda_backend.untag_resource(function_arn, tag_keys)
return 204, {}, "{}"
def _put_configuration(self):
function_name = unquote(self.path.rsplit("/", 2)[-2])

View File

@ -643,8 +643,9 @@ class Source(ConfigEmptyDictable):
# operations, only load it if needed.
from moto.awslambda import lambda_backends
lambda_func = lambda_backends[region].get_function(source_identifier)
if not lambda_func:
try:
lambda_backends[region].get_function(source_identifier)
except Exception:
raise InsufficientPermissionsException(
f"The AWS Lambda function {source_identifier} cannot be "
f"invoked. Check the specified function ARN, and check the "

View File

@ -888,11 +888,10 @@ class LogsBackend(BaseBackend):
lambda_backends,
)
lambda_func = lambda_backends[self.region_name].get_function(
destination_arn
)
try:
lambda_backends[self.region_name].get_function(destination_arn)
# no specific permission check implemented
if not lambda_func:
except Exception:
raise InvalidParameterException(
"Could not execute the lambda function. Make sure you "
"have given CloudWatch Logs permission to execute your "

View File

@ -537,8 +537,9 @@ class SecretsManagerBackend(BaseBackend):
request_headers = {}
response_headers = {}
func = lambda_backend.get_function(secret.rotation_lambda_arn)
if not func:
try:
func = lambda_backend.get_function(secret.rotation_lambda_arn)
except Exception:
msg = "Resource not found for ARN '{}'.".format(
secret.rotation_lambda_arn
)

View File

@ -5,6 +5,7 @@ import pytest
from botocore.exceptions import ClientError
from moto import mock_lambda, mock_s3
from moto.core import ACCOUNT_ID
from uuid import uuid4
from .utilities import get_role_name, get_test_zip_file1
@ -27,10 +28,6 @@ def test_add_function_permission(key):
Role=(get_role_name()),
Handler="lambda_function.handler",
Code={"ZipFile": zip_content},
Description="test lambda function",
Timeout=3,
MemorySize=128,
Publish=True,
)
name_or_arn = f[key]
@ -40,9 +37,6 @@ def test_add_function_permission(key):
Action="lambda:InvokeFunction",
Principal="432143214321",
SourceArn="arn:aws:lambda:us-west-2:account-id:function:helloworld",
SourceAccount="123412341234",
EventSourceToken="blah",
Qualifier="2",
)
assert "Statement" in response
res = json.loads(response["Statement"])
@ -70,13 +64,10 @@ def test_get_function_policy(key):
conn.add_permission(
FunctionName=name_or_arn,
StatementId="1",
StatementId="2",
Action="lambda:InvokeFunction",
Principal="432143214321",
SourceArn="arn:aws:lambda:us-west-2:account-id:function:helloworld",
SourceAccount="123412341234",
EventSourceToken="blah",
Qualifier="2",
Principal="lambda.amazonaws.com",
SourceArn=f"arn:aws:lambda:us-west-2:{ACCOUNT_ID}:function:helloworld",
)
response = conn.get_policy(FunctionName=name_or_arn)
@ -84,11 +75,125 @@ def test_get_function_policy(key):
assert "Policy" in response
res = json.loads(response["Policy"])
assert res["Statement"][0]["Action"] == "lambda:InvokeFunction"
assert res["Statement"][0]["Principal"] == {"Service": "lambda.amazonaws.com"}
assert (
res["Statement"][0]["Resource"]
== f"arn:aws:lambda:us-west-2:123456789012:function:{function_name}"
)
@mock_lambda
def test_get_policy_with_qualifier():
# assert that the resource within the statement ends with :qualifier
conn = boto3.client("lambda", _lambda_region)
zip_content = get_test_zip_file1()
function_name = str(uuid4())[0:6]
conn.create_function(
FunctionName=function_name,
Runtime="python3.7",
Role=get_role_name(),
Handler="lambda_function.handler",
Code={"ZipFile": zip_content},
Description="test lambda function",
Timeout=3,
MemorySize=128,
Publish=True,
)
zip_content_two = get_test_zip_file1()
conn.update_function_code(
FunctionName=function_name, ZipFile=zip_content_two, Publish=True
)
conn.add_permission(
FunctionName=function_name,
StatementId="1",
Action="lambda:InvokeFunction",
Principal="lambda.amazonaws.com",
SourceArn=f"arn:aws:lambda:us-west-2:{ACCOUNT_ID}:function:helloworld",
Qualifier="2",
)
response = conn.get_policy(FunctionName=function_name, Qualifier="2")
assert "Policy" in response
res = json.loads(response["Policy"])
assert res["Statement"][0]["Action"] == "lambda:InvokeFunction"
assert res["Statement"][0]["Principal"] == {"Service": "lambda.amazonaws.com"}
assert (
res["Statement"][0]["Resource"]
== f"arn:aws:lambda:us-west-2:123456789012:function:{function_name}:2"
)
@mock_lambda
def test_add_permission_with_unknown_qualifier():
# assert that the resource within the statement ends with :qualifier
conn = boto3.client("lambda", _lambda_region)
zip_content = get_test_zip_file1()
function_name = str(uuid4())[0:6]
conn.create_function(
FunctionName=function_name,
Runtime="python3.7",
Role=get_role_name(),
Handler="lambda_function.handler",
Code={"ZipFile": zip_content},
Description="test lambda function",
Timeout=3,
MemorySize=128,
Publish=True,
)
with pytest.raises(ClientError) as exc:
conn.add_permission(
FunctionName=function_name,
StatementId="2",
Action="lambda:InvokeFunction",
Principal="lambda.amazonaws.com",
SourceArn=f"arn:aws:lambda:us-west-2:{ACCOUNT_ID}:function:helloworld",
Qualifier="5",
)
err = exc.value.response["Error"]
err["Code"].should.equal("ResourceNotFoundException")
err["Message"].should.equal(
f"Function not found: arn:aws:lambda:us-west-2:{ACCOUNT_ID}:function:{function_name}:5"
)
@pytest.mark.parametrize("key", ["FunctionName", "FunctionArn"])
@mock_lambda
def test_remove_function_permission(key):
conn = boto3.client("lambda", _lambda_region)
zip_content = get_test_zip_file1()
function_name = str(uuid4())[0:6]
f = conn.create_function(
FunctionName=function_name,
Runtime="python2.7",
Role=(get_role_name()),
Handler="lambda_function.handler",
Code={"ZipFile": zip_content},
)
name_or_arn = f[key]
conn.add_permission(
FunctionName=name_or_arn,
StatementId="1",
Action="lambda:InvokeFunction",
Principal="432143214321",
SourceArn="arn:aws:lambda:us-west-2:account-id:function:helloworld",
)
remove = conn.remove_permission(FunctionName=name_or_arn, StatementId="1")
remove["ResponseMetadata"]["HTTPStatusCode"].should.equal(204)
policy = conn.get_policy(FunctionName=name_or_arn)["Policy"]
policy = json.loads(policy)
policy["Statement"].should.equal([])
@pytest.mark.parametrize("key", ["FunctionName", "FunctionArn"])
@mock_lambda
def test_remove_function_permission__with_qualifier(key):
conn = boto3.client("lambda", _lambda_region)
zip_content = get_test_zip_file1()
function_name = str(uuid4())[0:6]
@ -105,6 +210,12 @@ def test_remove_function_permission(key):
)
name_or_arn = f[key]
# Ensure Qualifier=2 exists
zip_content_two = get_test_zip_file1()
conn.update_function_code(
FunctionName=function_name, ZipFile=zip_content_two, Publish=True
)
conn.add_permission(
FunctionName=name_or_arn,
StatementId="1",
@ -134,4 +245,6 @@ def test_get_unknown_policy():
conn.get_policy(FunctionName="unknown")
err = exc.value.response["Error"]
err["Code"].should.equal("ResourceNotFoundException")
err["Message"].should.equal("Function not found: unknown")
err["Message"].should.equal(
"Function not found: arn:aws:lambda:us-west-2:123456789012:function:unknown"
)

View File

@ -365,9 +365,18 @@ def test_delete_subscription_filter_errors():
)
@mock_lambda
@mock_logs
def test_put_subscription_filter_errors():
# given
client_lambda = boto3.client("lambda", "us-east-1")
function_arn = client_lambda.create_function(
FunctionName="test",
Runtime="python3.8",
Role=_get_role_name("us-east-1"),
Handler="lambda_function.lambda_handler",
Code={"ZipFile": _get_test_zip_file()},
)["FunctionArn"]
client = boto3.client("logs", "us-east-1")
log_group_name = "/test"
client.create_log_group(logGroupName=log_group_name)
@ -378,7 +387,7 @@ def test_put_subscription_filter_errors():
logGroupName="not-existing-log-group",
filterName="test",
filterPattern="",
destinationArn="arn:aws:lambda:us-east-1:123456789012:function:test",
destinationArn=function_arn,
)
# then