AWSLambda - Policy improvements (#4949)

This commit is contained in:
Bert Blommers 2022-03-19 12:00:39 -01:00 committed by GitHub
parent 411ce71d3a
commit de990b07f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 218 additions and 131 deletions

View File

@ -1095,9 +1095,18 @@ class LambdaStorage(object):
return self._arns.get(arn, None) return self._arns.get(arn, None)
def get_function_by_name_or_arn(self, name_or_arn, qualifier=None): def get_function_by_name_or_arn(self, name_or_arn, qualifier=None):
return self.get_function_by_name(name_or_arn, qualifier) or self.get_arn( fn = self.get_function_by_name(name_or_arn, qualifier) or self.get_arn(
name_or_arn name_or_arn
) )
if fn is None:
if name_or_arn.startswith("arn:aws"):
arn = name_or_arn
else:
arn = make_function_arn(self.region_name, ACCOUNT_ID, name_or_arn)
if qualifier:
arn = f"{arn}:{qualifier}"
raise UnknownFunctionException(arn)
return fn
def put_function(self, fn): def put_function(self, fn):
""" """
@ -1127,12 +1136,6 @@ class LambdaStorage(object):
def publish_function(self, name_or_arn, description=""): def publish_function(self, name_or_arn, description=""):
function = self.get_function_by_name_or_arn(name_or_arn) function = self.get_function_by_name_or_arn(name_or_arn)
if not function:
if name_or_arn.startswith("arn:aws"):
arn = name_or_arn
else:
arn = make_function_arn(self.region_name, ACCOUNT_ID, name_or_arn)
raise UnknownFunctionException(arn)
name = function.function_name name = function.function_name
if name not in self._functions: if name not in self._functions:
return None return None
@ -1150,8 +1153,7 @@ class LambdaStorage(object):
return fn return fn
def del_function(self, name_or_arn, qualifier=None): def del_function(self, name_or_arn, qualifier=None):
function = self.get_function_by_name_or_arn(name_or_arn) function = self.get_function_by_name_or_arn(name_or_arn, qualifier)
if function:
name = function.function_name name = function.function_name
if not qualifier: if not qualifier:
# Something is still reffing this so delete all arns # Something is still reffing this so delete all arns
@ -1163,8 +1165,6 @@ class LambdaStorage(object):
del self._functions[name] del self._functions[name]
return True
elif qualifier == "$LATEST": elif qualifier == "$LATEST":
self._functions[name]["latest"] = None self._functions[name]["latest"] = None
@ -1175,8 +1175,6 @@ class LambdaStorage(object):
): ):
del self._functions[name] del self._functions[name]
return True
else: else:
fn = self.get_function_by_name(name, qualifier) fn = self.get_function_by_name(name, qualifier)
if fn: if fn:
@ -1189,10 +1187,6 @@ class LambdaStorage(object):
): ):
del self._functions[name] del self._functions[name]
return True
return False
def all(self): def all(self):
result = [] result = []
@ -1488,7 +1482,7 @@ class LambdaBackend(BaseBackend):
return self._lambdas.get_arn(function_arn) return self._lambdas.get_arn(function_arn)
def delete_function(self, function_name, qualifier=None): def delete_function(self, function_name, qualifier=None):
return self._lambdas.del_function(function_name, qualifier) self._lambdas.del_function(function_name, qualifier)
def list_functions(self, func_version=None): def list_functions(self, func_version=None):
if func_version == "ALL": if func_version == "ALL":
@ -1601,31 +1595,20 @@ class LambdaBackend(BaseBackend):
return func.invoke(json.dumps(event), {}, {}) return func.invoke(json.dumps(event), {}, {})
def list_tags(self, resource): def list_tags(self, resource):
return self.get_function_by_arn(resource).tags return self._lambdas.get_function_by_name_or_arn(resource).tags
def tag_resource(self, resource, tags): def tag_resource(self, resource, tags):
fn = self.get_function_by_arn(resource) fn = self._lambdas.get_function_by_name_or_arn(resource)
if not fn:
return False
fn.tags.update(tags) fn.tags.update(tags)
return True
def untag_resource(self, resource, tagKeys): def untag_resource(self, resource, tagKeys):
fn = self.get_function_by_arn(resource) fn = self._lambdas.get_function_by_name_or_arn(resource)
if fn:
for key in tagKeys: for key in tagKeys:
try: fn.tags.pop(key, None)
del fn.tags[key]
except KeyError:
pass
# Don't care
return True
return False
def add_permission(self, function_name, raw): def add_permission(self, function_name, qualifier, raw):
fn = self.get_function(function_name) fn = self.get_function(function_name, qualifier)
fn.policy.add_statement(raw) fn.policy.add_statement(raw, qualifier)
def remove_permission(self, function_name, sid, revision=""): def remove_permission(self, function_name, sid, revision=""):
fn = self.get_function(function_name) fn = self.get_function(function_name)

View File

@ -29,7 +29,7 @@ class Policy:
} }
# adds the raw JSON statement to the policy # adds the raw JSON statement to the policy
def add_statement(self, raw): def add_statement(self, raw, qualifier=None):
policy = json.loads(raw, object_hook=self.decode_policy) policy = json.loads(raw, object_hook=self.decode_policy)
if len(policy.revision) > 0 and self.revision != policy.revision: if len(policy.revision) > 0 and self.revision != policy.revision:
raise PreconditionFailedException( raise PreconditionFailedException(
@ -40,6 +40,10 @@ class Policy:
# Remove #LATEST from the Resource (Lambda ARN) # Remove #LATEST from the Resource (Lambda ARN)
if policy.statements[0].get("Resource", "").endswith("$LATEST"): if policy.statements[0].get("Resource", "").endswith("$LATEST"):
policy.statements[0]["Resource"] = policy.statements[0]["Resource"][0:-8] policy.statements[0]["Resource"] = policy.statements[0]["Resource"][0:-8]
if qualifier:
policy.statements[0]["Resource"] = (
policy.statements[0]["Resource"] + ":" + qualifier
)
self.statements.append(policy.statements[0]) self.statements.append(policy.statements[0])
self.revision = str(uuid.uuid4()) self.revision = str(uuid.uuid4())

View File

@ -193,12 +193,10 @@ class LambdaResponse(BaseResponse):
def _add_policy(self, request): def _add_policy(self, request):
path = request.path if hasattr(request, "path") else path_url(request.url) path = request.path if hasattr(request, "path") else path_url(request.url)
function_name = unquote(path.split("/")[-2]) function_name = unquote(path.split("/")[-2])
if self.lambda_backend.get_function(function_name): qualifier = self.querystring.get("Qualifier", [None])[0]
statement = self.body statement = self.body
self.lambda_backend.add_permission(function_name, statement) self.lambda_backend.add_permission(function_name, qualifier, statement)
return 200, {}, json.dumps({"Statement": statement}) return 200, {}, json.dumps({"Statement": statement})
else:
return 404, {}, "{}"
def _get_policy(self, request): def _get_policy(self, request):
path = request.path if hasattr(request, "path") else path_url(request.url) path = request.path if hasattr(request, "path") else path_url(request.url)
@ -257,12 +255,9 @@ class LambdaResponse(BaseResponse):
function_name = unquote(self.path.rsplit("/", 3)[-3]) function_name = unquote(self.path.rsplit("/", 3)[-3])
fn = self.lambda_backend.get_function(function_name, None) fn = self.lambda_backend.get_function(function_name, None)
if fn:
payload = fn.invoke(self.body, self.headers, response_headers) payload = fn.invoke(self.body, self.headers, response_headers)
response_headers["Content-Length"] = str(len(payload)) response_headers["Content-Length"] = str(len(payload))
return 202, response_headers, payload return 202, response_headers, payload
else:
return 404, response_headers, "{}"
def _list_functions(self): def _list_functions(self):
querystring = self.querystring querystring = self.querystring
@ -331,20 +326,15 @@ class LambdaResponse(BaseResponse):
description = self._get_param("Description") description = self._get_param("Description")
fn = self.lambda_backend.publish_function(function_name, description) fn = self.lambda_backend.publish_function(function_name, description)
if fn:
config = fn.get_configuration() config = fn.get_configuration()
return 201, {}, json.dumps(config) return 201, {}, json.dumps(config)
else:
return 404, {}, "{}"
def _delete_function(self): def _delete_function(self):
function_name = unquote(self.path.rsplit("/", 1)[-1]) function_name = unquote(self.path.rsplit("/", 1)[-1])
qualifier = self._get_param("Qualifier", None) qualifier = self._get_param("Qualifier", None)
if self.lambda_backend.delete_function(function_name, qualifier): self.lambda_backend.delete_function(function_name, qualifier)
return 204, {}, "" return 204, {}, ""
else:
return 404, {}, "{}"
@staticmethod @staticmethod
def _set_configuration_qualifier(configuration, qualifier): def _set_configuration_qualifier(configuration, qualifier):
@ -360,14 +350,11 @@ class LambdaResponse(BaseResponse):
fn = self.lambda_backend.get_function(function_name, qualifier) fn = self.lambda_backend.get_function(function_name, qualifier)
if fn:
code = fn.get_code() code = fn.get_code()
code["Configuration"] = self._set_configuration_qualifier( code["Configuration"] = self._set_configuration_qualifier(
code["Configuration"], qualifier code["Configuration"], qualifier
) )
return 200, {}, json.dumps(code) return 200, {}, json.dumps(code)
else:
return 404, {"x-amzn-ErrorType": "ResourceNotFoundException"}, "{}"
def _get_function_configuration(self): def _get_function_configuration(self):
function_name = unquote(self.path.rsplit("/", 2)[-2]) function_name = unquote(self.path.rsplit("/", 2)[-2])
@ -375,13 +362,10 @@ class LambdaResponse(BaseResponse):
fn = self.lambda_backend.get_function(function_name, qualifier) fn = self.lambda_backend.get_function(function_name, qualifier)
if fn:
configuration = self._set_configuration_qualifier( configuration = self._set_configuration_qualifier(
fn.get_configuration(), qualifier fn.get_configuration(), qualifier
) )
return 200, {}, json.dumps(configuration) return 200, {}, json.dumps(configuration)
else:
return 404, {"x-amzn-ErrorType": "ResourceNotFoundException"}, "{}"
def _get_aws_region(self, full_url): def _get_aws_region(self, full_url):
region = self.region_regex.search(full_url) region = self.region_regex.search(full_url)
@ -393,28 +377,21 @@ class LambdaResponse(BaseResponse):
def _list_tags(self): def _list_tags(self):
function_arn = unquote(self.path.rsplit("/", 1)[-1]) function_arn = unquote(self.path.rsplit("/", 1)[-1])
fn = self.lambda_backend.get_function_by_arn(function_arn) tags = self.lambda_backend.list_tags(function_arn)
if fn: return 200, {}, json.dumps({"Tags": tags})
return 200, {}, json.dumps({"Tags": fn.tags})
else:
return 404, {}, "{}"
def _tag_resource(self): def _tag_resource(self):
function_arn = unquote(self.path.rsplit("/", 1)[-1]) function_arn = unquote(self.path.rsplit("/", 1)[-1])
if self.lambda_backend.tag_resource(function_arn, self.json_body["Tags"]): self.lambda_backend.tag_resource(function_arn, self.json_body["Tags"])
return 200, {}, "{}" return 200, {}, "{}"
else:
return 404, {}, "{}"
def _untag_resource(self): def _untag_resource(self):
function_arn = unquote(self.path.rsplit("/", 1)[-1]) function_arn = unquote(self.path.rsplit("/", 1)[-1])
tag_keys = self.querystring["tagKeys"] tag_keys = self.querystring["tagKeys"]
if self.lambda_backend.untag_resource(function_arn, tag_keys): self.lambda_backend.untag_resource(function_arn, tag_keys)
return 204, {}, "{}" return 204, {}, "{}"
else:
return 404, {}, "{}"
def _put_configuration(self): def _put_configuration(self):
function_name = unquote(self.path.rsplit("/", 2)[-2]) function_name = unquote(self.path.rsplit("/", 2)[-2])

View File

@ -643,8 +643,9 @@ class Source(ConfigEmptyDictable):
# operations, only load it if needed. # operations, only load it if needed.
from moto.awslambda import lambda_backends from moto.awslambda import lambda_backends
lambda_func = lambda_backends[region].get_function(source_identifier) try:
if not lambda_func: lambda_backends[region].get_function(source_identifier)
except Exception:
raise InsufficientPermissionsException( raise InsufficientPermissionsException(
f"The AWS Lambda function {source_identifier} cannot be " f"The AWS Lambda function {source_identifier} cannot be "
f"invoked. Check the specified function ARN, and check the " f"invoked. Check the specified function ARN, and check the "

View File

@ -888,11 +888,10 @@ class LogsBackend(BaseBackend):
lambda_backends, lambda_backends,
) )
lambda_func = lambda_backends[self.region_name].get_function( try:
destination_arn lambda_backends[self.region_name].get_function(destination_arn)
)
# no specific permission check implemented # no specific permission check implemented
if not lambda_func: except Exception:
raise InvalidParameterException( raise InvalidParameterException(
"Could not execute the lambda function. Make sure you " "Could not execute the lambda function. Make sure you "
"have given CloudWatch Logs permission to execute your " "have given CloudWatch Logs permission to execute your "

View File

@ -537,8 +537,9 @@ class SecretsManagerBackend(BaseBackend):
request_headers = {} request_headers = {}
response_headers = {} response_headers = {}
try:
func = lambda_backend.get_function(secret.rotation_lambda_arn) func = lambda_backend.get_function(secret.rotation_lambda_arn)
if not func: except Exception:
msg = "Resource not found for ARN '{}'.".format( msg = "Resource not found for ARN '{}'.".format(
secret.rotation_lambda_arn secret.rotation_lambda_arn
) )

View File

@ -5,6 +5,7 @@ import pytest
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
from moto import mock_lambda, mock_s3 from moto import mock_lambda, mock_s3
from moto.core import ACCOUNT_ID
from uuid import uuid4 from uuid import uuid4
from .utilities import get_role_name, get_test_zip_file1 from .utilities import get_role_name, get_test_zip_file1
@ -27,10 +28,6 @@ def test_add_function_permission(key):
Role=(get_role_name()), Role=(get_role_name()),
Handler="lambda_function.handler", Handler="lambda_function.handler",
Code={"ZipFile": zip_content}, Code={"ZipFile": zip_content},
Description="test lambda function",
Timeout=3,
MemorySize=128,
Publish=True,
) )
name_or_arn = f[key] name_or_arn = f[key]
@ -40,9 +37,6 @@ def test_add_function_permission(key):
Action="lambda:InvokeFunction", Action="lambda:InvokeFunction",
Principal="432143214321", Principal="432143214321",
SourceArn="arn:aws:lambda:us-west-2:account-id:function:helloworld", SourceArn="arn:aws:lambda:us-west-2:account-id:function:helloworld",
SourceAccount="123412341234",
EventSourceToken="blah",
Qualifier="2",
) )
assert "Statement" in response assert "Statement" in response
res = json.loads(response["Statement"]) res = json.loads(response["Statement"])
@ -70,13 +64,10 @@ def test_get_function_policy(key):
conn.add_permission( conn.add_permission(
FunctionName=name_or_arn, FunctionName=name_or_arn,
StatementId="1", StatementId="2",
Action="lambda:InvokeFunction", Action="lambda:InvokeFunction",
Principal="432143214321", Principal="lambda.amazonaws.com",
SourceArn="arn:aws:lambda:us-west-2:account-id:function:helloworld", SourceArn=f"arn:aws:lambda:us-west-2:{ACCOUNT_ID}:function:helloworld",
SourceAccount="123412341234",
EventSourceToken="blah",
Qualifier="2",
) )
response = conn.get_policy(FunctionName=name_or_arn) response = conn.get_policy(FunctionName=name_or_arn)
@ -84,11 +75,125 @@ def test_get_function_policy(key):
assert "Policy" in response assert "Policy" in response
res = json.loads(response["Policy"]) res = json.loads(response["Policy"])
assert res["Statement"][0]["Action"] == "lambda:InvokeFunction" assert res["Statement"][0]["Action"] == "lambda:InvokeFunction"
assert res["Statement"][0]["Principal"] == {"Service": "lambda.amazonaws.com"}
assert (
res["Statement"][0]["Resource"]
== f"arn:aws:lambda:us-west-2:123456789012:function:{function_name}"
)
@mock_lambda
def test_get_policy_with_qualifier():
# assert that the resource within the statement ends with :qualifier
conn = boto3.client("lambda", _lambda_region)
zip_content = get_test_zip_file1()
function_name = str(uuid4())[0:6]
conn.create_function(
FunctionName=function_name,
Runtime="python3.7",
Role=get_role_name(),
Handler="lambda_function.handler",
Code={"ZipFile": zip_content},
Description="test lambda function",
Timeout=3,
MemorySize=128,
Publish=True,
)
zip_content_two = get_test_zip_file1()
conn.update_function_code(
FunctionName=function_name, ZipFile=zip_content_two, Publish=True
)
conn.add_permission(
FunctionName=function_name,
StatementId="1",
Action="lambda:InvokeFunction",
Principal="lambda.amazonaws.com",
SourceArn=f"arn:aws:lambda:us-west-2:{ACCOUNT_ID}:function:helloworld",
Qualifier="2",
)
response = conn.get_policy(FunctionName=function_name, Qualifier="2")
assert "Policy" in response
res = json.loads(response["Policy"])
assert res["Statement"][0]["Action"] == "lambda:InvokeFunction"
assert res["Statement"][0]["Principal"] == {"Service": "lambda.amazonaws.com"}
assert (
res["Statement"][0]["Resource"]
== f"arn:aws:lambda:us-west-2:123456789012:function:{function_name}:2"
)
@mock_lambda
def test_add_permission_with_unknown_qualifier():
# assert that the resource within the statement ends with :qualifier
conn = boto3.client("lambda", _lambda_region)
zip_content = get_test_zip_file1()
function_name = str(uuid4())[0:6]
conn.create_function(
FunctionName=function_name,
Runtime="python3.7",
Role=get_role_name(),
Handler="lambda_function.handler",
Code={"ZipFile": zip_content},
Description="test lambda function",
Timeout=3,
MemorySize=128,
Publish=True,
)
with pytest.raises(ClientError) as exc:
conn.add_permission(
FunctionName=function_name,
StatementId="2",
Action="lambda:InvokeFunction",
Principal="lambda.amazonaws.com",
SourceArn=f"arn:aws:lambda:us-west-2:{ACCOUNT_ID}:function:helloworld",
Qualifier="5",
)
err = exc.value.response["Error"]
err["Code"].should.equal("ResourceNotFoundException")
err["Message"].should.equal(
f"Function not found: arn:aws:lambda:us-west-2:{ACCOUNT_ID}:function:{function_name}:5"
)
@pytest.mark.parametrize("key", ["FunctionName", "FunctionArn"]) @pytest.mark.parametrize("key", ["FunctionName", "FunctionArn"])
@mock_lambda @mock_lambda
def test_remove_function_permission(key): def test_remove_function_permission(key):
conn = boto3.client("lambda", _lambda_region)
zip_content = get_test_zip_file1()
function_name = str(uuid4())[0:6]
f = conn.create_function(
FunctionName=function_name,
Runtime="python2.7",
Role=(get_role_name()),
Handler="lambda_function.handler",
Code={"ZipFile": zip_content},
)
name_or_arn = f[key]
conn.add_permission(
FunctionName=name_or_arn,
StatementId="1",
Action="lambda:InvokeFunction",
Principal="432143214321",
SourceArn="arn:aws:lambda:us-west-2:account-id:function:helloworld",
)
remove = conn.remove_permission(FunctionName=name_or_arn, StatementId="1")
remove["ResponseMetadata"]["HTTPStatusCode"].should.equal(204)
policy = conn.get_policy(FunctionName=name_or_arn)["Policy"]
policy = json.loads(policy)
policy["Statement"].should.equal([])
@pytest.mark.parametrize("key", ["FunctionName", "FunctionArn"])
@mock_lambda
def test_remove_function_permission__with_qualifier(key):
conn = boto3.client("lambda", _lambda_region) conn = boto3.client("lambda", _lambda_region)
zip_content = get_test_zip_file1() zip_content = get_test_zip_file1()
function_name = str(uuid4())[0:6] function_name = str(uuid4())[0:6]
@ -105,6 +210,12 @@ def test_remove_function_permission(key):
) )
name_or_arn = f[key] name_or_arn = f[key]
# Ensure Qualifier=2 exists
zip_content_two = get_test_zip_file1()
conn.update_function_code(
FunctionName=function_name, ZipFile=zip_content_two, Publish=True
)
conn.add_permission( conn.add_permission(
FunctionName=name_or_arn, FunctionName=name_or_arn,
StatementId="1", StatementId="1",
@ -134,4 +245,6 @@ def test_get_unknown_policy():
conn.get_policy(FunctionName="unknown") conn.get_policy(FunctionName="unknown")
err = exc.value.response["Error"] err = exc.value.response["Error"]
err["Code"].should.equal("ResourceNotFoundException") err["Code"].should.equal("ResourceNotFoundException")
err["Message"].should.equal("Function not found: unknown") err["Message"].should.equal(
"Function not found: arn:aws:lambda:us-west-2:123456789012:function:unknown"
)

View File

@ -365,9 +365,18 @@ def test_delete_subscription_filter_errors():
) )
@mock_lambda
@mock_logs @mock_logs
def test_put_subscription_filter_errors(): def test_put_subscription_filter_errors():
# given # given
client_lambda = boto3.client("lambda", "us-east-1")
function_arn = client_lambda.create_function(
FunctionName="test",
Runtime="python3.8",
Role=_get_role_name("us-east-1"),
Handler="lambda_function.lambda_handler",
Code={"ZipFile": _get_test_zip_file()},
)["FunctionArn"]
client = boto3.client("logs", "us-east-1") client = boto3.client("logs", "us-east-1")
log_group_name = "/test" log_group_name = "/test"
client.create_log_group(logGroupName=log_group_name) client.create_log_group(logGroupName=log_group_name)
@ -378,7 +387,7 @@ def test_put_subscription_filter_errors():
logGroupName="not-existing-log-group", logGroupName="not-existing-log-group",
filterName="test", filterName="test",
filterPattern="", filterPattern="",
destinationArn="arn:aws:lambda:us-east-1:123456789012:function:test", destinationArn=function_arn,
) )
# then # then