improves support for AWS lambda policy management

This commit is contained in:
Brady 2020-01-23 12:46:24 -06:00
parent d596560971
commit 2a2ff32dec
7 changed files with 310 additions and 18 deletions

View File

@ -1,4 +1,5 @@
from botocore.client import ClientError from botocore.client import ClientError
from moto.core.exceptions import JsonRESTError
class LambdaClientError(ClientError): class LambdaClientError(ClientError):
@ -29,3 +30,12 @@ class InvalidRoleFormat(LambdaClientError):
role, InvalidRoleFormat.pattern role, InvalidRoleFormat.pattern
) )
super(InvalidRoleFormat, self).__init__("ValidationException", message) super(InvalidRoleFormat, self).__init__("ValidationException", message)
class PreconditionFailedException(JsonRESTError):
code = 412
def __init__(self, message):
super(PreconditionFailedException, self).__init__(
"PreconditionFailedException", message
)

View File

@ -25,6 +25,7 @@ import requests.adapters
from boto3 import Session from boto3 import Session
from moto.awslambda.policy import Policy
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.exceptions import RESTError from moto.core.exceptions import RESTError
from moto.iam.models import iam_backend from moto.iam.models import iam_backend
@ -47,7 +48,6 @@ from moto.core import ACCOUNT_ID
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try: try:
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
except ImportError: except ImportError:
@ -164,7 +164,8 @@ class LambdaFunction(BaseModel):
self.logs_backend = logs_backends[self.region] self.logs_backend = logs_backends[self.region]
self.environment_vars = spec.get("Environment", {}).get("Variables", {}) self.environment_vars = spec.get("Environment", {}).get("Variables", {})
self.docker_client = docker.from_env() self.docker_client = docker.from_env()
self.policy = "" self.policy = None
self.state = "Active"
# Unfortunately mocking replaces this method w/o fallback enabled, so we # Unfortunately mocking replaces this method w/o fallback enabled, so we
# need to replace it if we detect it's been mocked # need to replace it if we detect it's been mocked
@ -274,11 +275,11 @@ class LambdaFunction(BaseModel):
"MemorySize": self.memory_size, "MemorySize": self.memory_size,
"Role": self.role, "Role": self.role,
"Runtime": self.run_time, "Runtime": self.run_time,
"State": self.state,
"Timeout": self.timeout, "Timeout": self.timeout,
"Version": str(self.version), "Version": str(self.version),
"VpcConfig": self.vpc_config, "VpcConfig": self.vpc_config,
} }
if self.environment_vars: if self.environment_vars:
config["Environment"] = {"Variables": self.environment_vars} config["Environment"] = {"Variables": self.environment_vars}
@ -709,7 +710,8 @@ class LambdaStorage(object):
"versions": [], "versions": [],
"alias": weakref.WeakValueDictionary(), "alias": weakref.WeakValueDictionary(),
} }
# instantiate a new policy for this version of the lambda
fn.policy = Policy(fn)
self._arns[fn.function_arn] = fn self._arns[fn.function_arn] = fn
def publish_function(self, name): def publish_function(self, name):
@ -1010,8 +1012,21 @@ class LambdaBackend(BaseBackend):
return True return True
return False return False
def add_policy(self, function_name, policy): def add_policy_statement(self, function_name, raw):
self.get_function(function_name).policy = policy fn = self.get_function(function_name)
fn.policy.add_statement(raw)
def del_policy_statement(self, function_name, sid, revision=""):
fn = self.get_function(function_name)
fn.policy.del_statement(sid, revision)
def get_policy(self, function_name):
fn = self.get_function(function_name)
return fn.policy.get_policy()
def get_policy_wire_format(self, function_name):
fn = self.get_function(function_name)
return fn.policy.wire_format()
def update_function_code(self, function_name, qualifier, body): def update_function_code(self, function_name, qualifier, body):
fn = self.get_function(function_name, qualifier) fn = self.get_function(function_name, qualifier)

145
moto/awslambda/policy.py Normal file
View File

@ -0,0 +1,145 @@
from __future__ import unicode_literals
import json
import uuid
from six import string_types
from moto.awslambda.exceptions import PreconditionFailedException
class Policy:
def __init__(self, parent):
self.revision = str(uuid.uuid4())
self.statements = []
self.parent = parent
def __repr__(self):
return json.dumps(self.get_policy())
def wire_format(self):
return json.dumps(
{
"Policy": json.dumps(
{
"Version": "2012-10-17",
"Id": "default",
"Statement": self.statements,
}
),
"RevisionId": self.revision,
}
)
def get_policy(self):
return {
"Policy": {
"Version": "2012-10-17",
"Id": "default",
"Statement": self.statements,
},
"RevisionId": self.revision,
}
# adds the raw JSON statement to the policy
def add_statement(self, raw):
policy = json.loads(raw, object_hook=self.decode_policy)
if len(policy.revision) > 0 and self.revision != policy.revision:
raise PreconditionFailedException(
"The RevisionId provided does not match the latest RevisionId"
" for the Lambda function or alias. Call the GetFunction or the GetAlias API to retrieve"
" the latest RevisionId for your resource."
)
self.statements.append(policy.statements[0])
self.revision = str(uuid.uuid4())
# removes the statement that matches 'sid' from the policy
def del_statement(self, sid, revision=""):
if len(revision) > 0 and self.revision != revision:
raise PreconditionFailedException(
"The RevisionId provided does not match the latest RevisionId"
" for the Lambda function or alias. Call the GetFunction or the GetAlias API to retrieve"
" the latest RevisionId for your resource."
)
for statement in self.statements:
if "Sid" in statement and statement["Sid"] == sid:
self.statements.remove(statement)
# converts AddPermission request to PolicyStatement
# https://docs.aws.amazon.com/lambda/latest/dg/API_AddPermission.html
def decode_policy(self, obj):
# import pydevd
# pydevd.settrace("localhost", port=5678)
policy = Policy(self.parent)
policy.revision = obj.get("RevisionId", "")
# set some default values if these keys are not set
self.ensure_set(obj, "Effect", "Allow")
self.ensure_set(obj, "Resource", self.parent.function_arn + ":$LATEST")
self.ensure_set(obj, "StatementId", str(uuid.uuid4()))
# transform field names and values
self.transform_property(obj, "StatementId", "Sid", self.nop_formatter)
self.transform_property(obj, "Principal", "Principal", self.principal_formatter)
self.transform_property(
obj, "SourceArn", "SourceArn", self.source_arn_formatter
)
self.transform_property(
obj, "SourceAccount", "SourceAccount", self.source_account_formatter
)
# remove RevisionId and EventSourceToken if they are set
self.remove_if_set(obj, ["RevisionId", "EventSourceToken"])
# merge conditional statements into a single map under the Condition key
self.condition_merge(obj)
# append resulting statement to policy.statements
policy.statements.append(obj)
return policy
def nop_formatter(self, obj):
return obj
def ensure_set(self, obj, key, value):
if key not in obj:
obj[key] = value
def principal_formatter(self, obj):
if isinstance(obj, string_types):
if obj.endswith(".amazonaws.com"):
return {"Service": obj}
if obj.endswith(":root"):
return {"AWS": obj}
return obj
def source_account_formatter(self, obj):
return {"StringEquals": {"AWS:SourceAccount": obj}}
def source_arn_formatter(self, obj):
return {"ArnLike": {"AWS:SourceArn": obj}}
def transform_property(self, obj, old_name, new_name, formatter):
if old_name in obj:
obj[new_name] = formatter(obj[old_name])
if new_name != old_name:
del obj[old_name]
def remove_if_set(self, obj, keys):
for key in keys:
if key in obj:
del obj[key]
def condition_merge(self, obj):
if "SourceArn" in obj:
if "Condition" not in obj:
obj["Condition"] = {}
obj["Condition"].update(obj["SourceArn"])
del obj["SourceArn"]
if "SourceAccount" in obj:
if "Condition" not in obj:
obj["Condition"] = {}
obj["Condition"].update(obj["SourceAccount"])
del obj["SourceAccount"]

View File

@ -120,8 +120,12 @@ class LambdaResponse(BaseResponse):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "GET": if request.method == "GET":
return self._get_policy(request, full_url, headers) return self._get_policy(request, full_url, headers)
if request.method == "POST": elif request.method == "POST":
return self._add_policy(request, full_url, headers) return self._add_policy(request, full_url, headers)
elif request.method == "DELETE":
return self._del_policy(request, full_url, headers, self.querystring)
else:
raise ValueError("Cannot handle {0} request".format(request.method))
def configuration(self, request, full_url, headers): def configuration(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -141,9 +145,9 @@ class LambdaResponse(BaseResponse):
path = request.path if hasattr(request, "path") else path_url(request.url) path = request.path if hasattr(request, "path") else path_url(request.url)
function_name = path.split("/")[-2] function_name = path.split("/")[-2]
if self.lambda_backend.get_function(function_name): if self.lambda_backend.get_function(function_name):
policy = self.body statement = self.body
self.lambda_backend.add_policy(function_name, policy) self.lambda_backend.add_policy_statement(function_name, statement)
return 200, {}, json.dumps(dict(Statement=policy)) return 200, {}, json.dumps({"Statement": statement})
else: else:
return 404, {}, "{}" return 404, {}, "{}"
@ -151,14 +155,21 @@ class LambdaResponse(BaseResponse):
path = request.path if hasattr(request, "path") else path_url(request.url) path = request.path if hasattr(request, "path") else path_url(request.url)
function_name = path.split("/")[-2] function_name = path.split("/")[-2]
if self.lambda_backend.get_function(function_name): if self.lambda_backend.get_function(function_name):
lambda_function = self.lambda_backend.get_function(function_name) out = self.lambda_backend.get_policy_wire_format(function_name)
return ( return 200, {}, out
200, else:
{}, return 404, {}, "{}"
json.dumps(
dict(Policy='{"Statement":[' + lambda_function.policy + "]}") def _del_policy(self, request, full_url, headers, querystring):
), path = request.path if hasattr(request, "path") else path_url(request.url)
function_name = path.split("/")[-3]
statement_id = path.split("/")[-1].split("?")[0]
revision = querystring.get("RevisionId", "")
if self.lambda_backend.get_function(function_name):
self.lambda_backend.del_policy_statement(
function_name, statement_id, revision
) )
return 204, {}, "{}"
else: else:
return 404, {}, "{}" return 404, {}, "{}"

View File

@ -6,7 +6,7 @@ url_bases = ["https?://lambda.(.+).amazonaws.com"]
response = LambdaResponse() response = LambdaResponse()
url_paths = { url_paths = {
"{0}/(?P<api_version>[^/]+)/functions/?$": response.root, r"{0}/(?P<api_version>[^/]+)/functions/?$": response.root,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_:%-]+)/?$": response.function, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_:%-]+)/?$": response.function,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/versions/?$": response.versions, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/versions/?$": response.versions,
r"{0}/(?P<api_version>[^/]+)/event-source-mappings/?$": response.event_source_mappings, r"{0}/(?P<api_version>[^/]+)/event-source-mappings/?$": response.event_source_mappings,
@ -14,6 +14,7 @@ url_paths = {
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invocations/?$": response.invoke, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invocations/?$": response.invoke,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invoke-async/?$": response.invoke_async, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invoke-async/?$": response.invoke_async,
r"{0}/(?P<api_version>[^/]+)/tags/(?P<resource_arn>.+)": response.tag, r"{0}/(?P<api_version>[^/]+)/tags/(?P<resource_arn>.+)": response.tag,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/policy/(?P<statement_id>[\w_-]+)$": response.policy,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/policy/?$": response.policy, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/policy/?$": response.policy,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/configuration/?$": response.configuration, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/configuration/?$": response.configuration,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/code/?$": response.code, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/code/?$": response.code,

View File

@ -324,6 +324,7 @@ def test_create_function_from_aws_bucket():
"VpcId": "vpc-123abc", "VpcId": "vpc-123abc",
}, },
"ResponseMetadata": {"HTTPStatusCode": 201}, "ResponseMetadata": {"HTTPStatusCode": 201},
"State": "Active",
} }
) )
@ -367,6 +368,7 @@ def test_create_function_from_zipfile():
"Version": "1", "Version": "1",
"VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []},
"ResponseMetadata": {"HTTPStatusCode": 201}, "ResponseMetadata": {"HTTPStatusCode": 201},
"State": "Active",
} }
) )
@ -631,6 +633,7 @@ def test_list_create_list_get_delete_list():
"Timeout": 3, "Timeout": 3,
"Version": "$LATEST", "Version": "$LATEST",
"VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []},
"State": "Active",
}, },
"ResponseMetadata": {"HTTPStatusCode": 200}, "ResponseMetadata": {"HTTPStatusCode": 200},
} }
@ -827,6 +830,7 @@ def test_get_function_created_with_zipfile():
"Timeout": 3, "Timeout": 3,
"Version": "$LATEST", "Version": "$LATEST",
"VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []},
"State": "Active",
} }
) )
@ -1436,6 +1440,7 @@ def test_update_function_zip():
"Timeout": 3, "Timeout": 3,
"Version": "2", "Version": "2",
"VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []},
"State": "Active",
} }
) )
@ -1498,6 +1503,7 @@ def test_update_function_s3():
"Timeout": 3, "Timeout": 3,
"Version": "2", "Version": "2",
"VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []},
"State": "Active",
} }
) )

View File

@ -0,0 +1,104 @@
from __future__ import unicode_literals
import unittest
import json
from moto.awslambda.policy import Policy
class MockLambdaFunction:
def __init__(self, arn):
self.function_arn = arn
self.policy = None
class TC:
def __init__(self, lambda_arn, statement, expected):
self.statement = statement
self.expected = expected
self.fn = MockLambdaFunction(lambda_arn)
self.policy = Policy(self.fn)
def Run(self, parent):
self.policy.add_statement(json.dumps(self.statement))
parent.assertDictEqual(self.expected, self.policy.statements[0])
sid = self.statement.get("StatementId", None)
if sid == None:
raise "TestCase.statement does not contain StatementId"
self.policy.del_statement(sid)
parent.assertEqual([], self.policy.statements)
class TestPolicy(unittest.TestCase):
def test(self):
tt = [
TC(
# lambda_arn
"arn",
{ # statement
"StatementId": "statement0",
"Action": "lambda:InvokeFunction",
"FunctionName": "function_name",
"Principal": "events.amazonaws.com",
},
{ # expected
"Action": "lambda:InvokeFunction",
"FunctionName": "function_name",
"Principal": {"Service": "events.amazonaws.com"},
"Effect": "Allow",
"Resource": "arn:$LATEST",
"Sid": "statement0",
},
),
TC(
# lambda_arn
"arn",
{ # statement
"StatementId": "statement1",
"Action": "lambda:InvokeFunction",
"FunctionName": "function_name",
"Principal": "events.amazonaws.com",
"SourceArn": "arn:aws:events:us-east-1:111111111111:rule/rule_name",
},
{
"Action": "lambda:InvokeFunction",
"FunctionName": "function_name",
"Principal": {"Service": "events.amazonaws.com"},
"Effect": "Allow",
"Resource": "arn:$LATEST",
"Sid": "statement1",
"Condition": {
"ArnLike": {
"AWS:SourceArn": "arn:aws:events:us-east-1:111111111111:rule/rule_name"
}
},
},
),
TC(
# lambda_arn
"arn",
{ # statement
"StatementId": "statement2",
"Action": "lambda:InvokeFunction",
"FunctionName": "function_name",
"Principal": "events.amazonaws.com",
"SourceAccount": "111111111111",
},
{ # expected
"Action": "lambda:InvokeFunction",
"FunctionName": "function_name",
"Principal": {"Service": "events.amazonaws.com"},
"Effect": "Allow",
"Resource": "arn:$LATEST",
"Sid": "statement2",
"Condition": {
"StringEquals": {"AWS:SourceAccount": "111111111111"}
},
},
),
]
for tc in tt:
tc.Run(self)