improves support for AWS lambda policy management
This commit is contained in:
parent
d596560971
commit
2a2ff32dec
@ -1,4 +1,5 @@
|
||||
from botocore.client import ClientError
|
||||
from moto.core.exceptions import JsonRESTError
|
||||
|
||||
|
||||
class LambdaClientError(ClientError):
|
||||
@ -29,3 +30,12 @@ class InvalidRoleFormat(LambdaClientError):
|
||||
role, InvalidRoleFormat.pattern
|
||||
)
|
||||
super(InvalidRoleFormat, self).__init__("ValidationException", message)
|
||||
|
||||
|
||||
class PreconditionFailedException(JsonRESTError):
|
||||
code = 412
|
||||
|
||||
def __init__(self, message):
|
||||
super(PreconditionFailedException, self).__init__(
|
||||
"PreconditionFailedException", message
|
||||
)
|
||||
|
@ -25,6 +25,7 @@ import requests.adapters
|
||||
|
||||
from boto3 import Session
|
||||
|
||||
from moto.awslambda.policy import Policy
|
||||
from moto.core import BaseBackend, BaseModel
|
||||
from moto.core.exceptions import RESTError
|
||||
from moto.iam.models import iam_backend
|
||||
@ -47,7 +48,6 @@ from moto.core import ACCOUNT_ID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
try:
|
||||
from tempfile import TemporaryDirectory
|
||||
except ImportError:
|
||||
@ -164,7 +164,8 @@ class LambdaFunction(BaseModel):
|
||||
self.logs_backend = logs_backends[self.region]
|
||||
self.environment_vars = spec.get("Environment", {}).get("Variables", {})
|
||||
self.docker_client = docker.from_env()
|
||||
self.policy = ""
|
||||
self.policy = None
|
||||
self.state = "Active"
|
||||
|
||||
# Unfortunately mocking replaces this method w/o fallback enabled, so we
|
||||
# need to replace it if we detect it's been mocked
|
||||
@ -274,11 +275,11 @@ class LambdaFunction(BaseModel):
|
||||
"MemorySize": self.memory_size,
|
||||
"Role": self.role,
|
||||
"Runtime": self.run_time,
|
||||
"State": self.state,
|
||||
"Timeout": self.timeout,
|
||||
"Version": str(self.version),
|
||||
"VpcConfig": self.vpc_config,
|
||||
}
|
||||
|
||||
if self.environment_vars:
|
||||
config["Environment"] = {"Variables": self.environment_vars}
|
||||
|
||||
@ -709,7 +710,8 @@ class LambdaStorage(object):
|
||||
"versions": [],
|
||||
"alias": weakref.WeakValueDictionary(),
|
||||
}
|
||||
|
||||
# instantiate a new policy for this version of the lambda
|
||||
fn.policy = Policy(fn)
|
||||
self._arns[fn.function_arn] = fn
|
||||
|
||||
def publish_function(self, name):
|
||||
@ -1010,8 +1012,21 @@ class LambdaBackend(BaseBackend):
|
||||
return True
|
||||
return False
|
||||
|
||||
def add_policy(self, function_name, policy):
|
||||
self.get_function(function_name).policy = policy
|
||||
def add_policy_statement(self, function_name, raw):
|
||||
fn = self.get_function(function_name)
|
||||
fn.policy.add_statement(raw)
|
||||
|
||||
def del_policy_statement(self, function_name, sid, revision=""):
|
||||
fn = self.get_function(function_name)
|
||||
fn.policy.del_statement(sid, revision)
|
||||
|
||||
def get_policy(self, function_name):
|
||||
fn = self.get_function(function_name)
|
||||
return fn.policy.get_policy()
|
||||
|
||||
def get_policy_wire_format(self, function_name):
|
||||
fn = self.get_function(function_name)
|
||||
return fn.policy.wire_format()
|
||||
|
||||
def update_function_code(self, function_name, qualifier, body):
|
||||
fn = self.get_function(function_name, qualifier)
|
||||
|
145
moto/awslambda/policy.py
Normal file
145
moto/awslambda/policy.py
Normal file
@ -0,0 +1,145 @@
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from six import string_types
|
||||
|
||||
from moto.awslambda.exceptions import PreconditionFailedException
|
||||
|
||||
|
||||
class Policy:
|
||||
def __init__(self, parent):
|
||||
self.revision = str(uuid.uuid4())
|
||||
self.statements = []
|
||||
self.parent = parent
|
||||
|
||||
def __repr__(self):
|
||||
return json.dumps(self.get_policy())
|
||||
|
||||
def wire_format(self):
|
||||
return json.dumps(
|
||||
{
|
||||
"Policy": json.dumps(
|
||||
{
|
||||
"Version": "2012-10-17",
|
||||
"Id": "default",
|
||||
"Statement": self.statements,
|
||||
}
|
||||
),
|
||||
"RevisionId": self.revision,
|
||||
}
|
||||
)
|
||||
|
||||
def get_policy(self):
|
||||
return {
|
||||
"Policy": {
|
||||
"Version": "2012-10-17",
|
||||
"Id": "default",
|
||||
"Statement": self.statements,
|
||||
},
|
||||
"RevisionId": self.revision,
|
||||
}
|
||||
|
||||
# adds the raw JSON statement to the policy
|
||||
def add_statement(self, raw):
|
||||
policy = json.loads(raw, object_hook=self.decode_policy)
|
||||
if len(policy.revision) > 0 and self.revision != policy.revision:
|
||||
raise PreconditionFailedException(
|
||||
"The RevisionId provided does not match the latest RevisionId"
|
||||
" for the Lambda function or alias. Call the GetFunction or the GetAlias API to retrieve"
|
||||
" the latest RevisionId for your resource."
|
||||
)
|
||||
self.statements.append(policy.statements[0])
|
||||
self.revision = str(uuid.uuid4())
|
||||
|
||||
# removes the statement that matches 'sid' from the policy
|
||||
def del_statement(self, sid, revision=""):
|
||||
if len(revision) > 0 and self.revision != revision:
|
||||
raise PreconditionFailedException(
|
||||
"The RevisionId provided does not match the latest RevisionId"
|
||||
" for the Lambda function or alias. Call the GetFunction or the GetAlias API to retrieve"
|
||||
" the latest RevisionId for your resource."
|
||||
)
|
||||
for statement in self.statements:
|
||||
if "Sid" in statement and statement["Sid"] == sid:
|
||||
self.statements.remove(statement)
|
||||
|
||||
# converts AddPermission request to PolicyStatement
|
||||
# https://docs.aws.amazon.com/lambda/latest/dg/API_AddPermission.html
|
||||
def decode_policy(self, obj):
|
||||
# import pydevd
|
||||
# pydevd.settrace("localhost", port=5678)
|
||||
policy = Policy(self.parent)
|
||||
policy.revision = obj.get("RevisionId", "")
|
||||
|
||||
# set some default values if these keys are not set
|
||||
self.ensure_set(obj, "Effect", "Allow")
|
||||
self.ensure_set(obj, "Resource", self.parent.function_arn + ":$LATEST")
|
||||
self.ensure_set(obj, "StatementId", str(uuid.uuid4()))
|
||||
|
||||
# transform field names and values
|
||||
self.transform_property(obj, "StatementId", "Sid", self.nop_formatter)
|
||||
self.transform_property(obj, "Principal", "Principal", self.principal_formatter)
|
||||
self.transform_property(
|
||||
obj, "SourceArn", "SourceArn", self.source_arn_formatter
|
||||
)
|
||||
self.transform_property(
|
||||
obj, "SourceAccount", "SourceAccount", self.source_account_formatter
|
||||
)
|
||||
|
||||
# remove RevisionId and EventSourceToken if they are set
|
||||
self.remove_if_set(obj, ["RevisionId", "EventSourceToken"])
|
||||
|
||||
# merge conditional statements into a single map under the Condition key
|
||||
self.condition_merge(obj)
|
||||
|
||||
# append resulting statement to policy.statements
|
||||
policy.statements.append(obj)
|
||||
|
||||
return policy
|
||||
|
||||
def nop_formatter(self, obj):
|
||||
return obj
|
||||
|
||||
def ensure_set(self, obj, key, value):
|
||||
if key not in obj:
|
||||
obj[key] = value
|
||||
|
||||
def principal_formatter(self, obj):
|
||||
if isinstance(obj, string_types):
|
||||
if obj.endswith(".amazonaws.com"):
|
||||
return {"Service": obj}
|
||||
if obj.endswith(":root"):
|
||||
return {"AWS": obj}
|
||||
return obj
|
||||
|
||||
def source_account_formatter(self, obj):
|
||||
return {"StringEquals": {"AWS:SourceAccount": obj}}
|
||||
|
||||
def source_arn_formatter(self, obj):
|
||||
return {"ArnLike": {"AWS:SourceArn": obj}}
|
||||
|
||||
def transform_property(self, obj, old_name, new_name, formatter):
|
||||
if old_name in obj:
|
||||
obj[new_name] = formatter(obj[old_name])
|
||||
if new_name != old_name:
|
||||
del obj[old_name]
|
||||
|
||||
def remove_if_set(self, obj, keys):
|
||||
for key in keys:
|
||||
if key in obj:
|
||||
del obj[key]
|
||||
|
||||
def condition_merge(self, obj):
|
||||
if "SourceArn" in obj:
|
||||
if "Condition" not in obj:
|
||||
obj["Condition"] = {}
|
||||
obj["Condition"].update(obj["SourceArn"])
|
||||
del obj["SourceArn"]
|
||||
|
||||
if "SourceAccount" in obj:
|
||||
if "Condition" not in obj:
|
||||
obj["Condition"] = {}
|
||||
obj["Condition"].update(obj["SourceAccount"])
|
||||
del obj["SourceAccount"]
|
@ -120,8 +120,12 @@ class LambdaResponse(BaseResponse):
|
||||
self.setup_class(request, full_url, headers)
|
||||
if request.method == "GET":
|
||||
return self._get_policy(request, full_url, headers)
|
||||
if request.method == "POST":
|
||||
elif request.method == "POST":
|
||||
return self._add_policy(request, full_url, headers)
|
||||
elif request.method == "DELETE":
|
||||
return self._del_policy(request, full_url, headers, self.querystring)
|
||||
else:
|
||||
raise ValueError("Cannot handle {0} request".format(request.method))
|
||||
|
||||
def configuration(self, request, full_url, headers):
|
||||
self.setup_class(request, full_url, headers)
|
||||
@ -141,9 +145,9 @@ class LambdaResponse(BaseResponse):
|
||||
path = request.path if hasattr(request, "path") else path_url(request.url)
|
||||
function_name = path.split("/")[-2]
|
||||
if self.lambda_backend.get_function(function_name):
|
||||
policy = self.body
|
||||
self.lambda_backend.add_policy(function_name, policy)
|
||||
return 200, {}, json.dumps(dict(Statement=policy))
|
||||
statement = self.body
|
||||
self.lambda_backend.add_policy_statement(function_name, statement)
|
||||
return 200, {}, json.dumps({"Statement": statement})
|
||||
else:
|
||||
return 404, {}, "{}"
|
||||
|
||||
@ -151,14 +155,21 @@ class LambdaResponse(BaseResponse):
|
||||
path = request.path if hasattr(request, "path") else path_url(request.url)
|
||||
function_name = path.split("/")[-2]
|
||||
if self.lambda_backend.get_function(function_name):
|
||||
lambda_function = self.lambda_backend.get_function(function_name)
|
||||
return (
|
||||
200,
|
||||
{},
|
||||
json.dumps(
|
||||
dict(Policy='{"Statement":[' + lambda_function.policy + "]}")
|
||||
),
|
||||
out = self.lambda_backend.get_policy_wire_format(function_name)
|
||||
return 200, {}, out
|
||||
else:
|
||||
return 404, {}, "{}"
|
||||
|
||||
def _del_policy(self, request, full_url, headers, querystring):
|
||||
path = request.path if hasattr(request, "path") else path_url(request.url)
|
||||
function_name = path.split("/")[-3]
|
||||
statement_id = path.split("/")[-1].split("?")[0]
|
||||
revision = querystring.get("RevisionId", "")
|
||||
if self.lambda_backend.get_function(function_name):
|
||||
self.lambda_backend.del_policy_statement(
|
||||
function_name, statement_id, revision
|
||||
)
|
||||
return 204, {}, "{}"
|
||||
else:
|
||||
return 404, {}, "{}"
|
||||
|
||||
|
@ -6,7 +6,7 @@ url_bases = ["https?://lambda.(.+).amazonaws.com"]
|
||||
response = LambdaResponse()
|
||||
|
||||
url_paths = {
|
||||
"{0}/(?P<api_version>[^/]+)/functions/?$": response.root,
|
||||
r"{0}/(?P<api_version>[^/]+)/functions/?$": response.root,
|
||||
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_:%-]+)/?$": response.function,
|
||||
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/versions/?$": response.versions,
|
||||
r"{0}/(?P<api_version>[^/]+)/event-source-mappings/?$": response.event_source_mappings,
|
||||
@ -14,6 +14,7 @@ url_paths = {
|
||||
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invocations/?$": response.invoke,
|
||||
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invoke-async/?$": response.invoke_async,
|
||||
r"{0}/(?P<api_version>[^/]+)/tags/(?P<resource_arn>.+)": response.tag,
|
||||
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/policy/(?P<statement_id>[\w_-]+)$": response.policy,
|
||||
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/policy/?$": response.policy,
|
||||
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/configuration/?$": response.configuration,
|
||||
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/code/?$": response.code,
|
||||
|
@ -324,6 +324,7 @@ def test_create_function_from_aws_bucket():
|
||||
"VpcId": "vpc-123abc",
|
||||
},
|
||||
"ResponseMetadata": {"HTTPStatusCode": 201},
|
||||
"State": "Active",
|
||||
}
|
||||
)
|
||||
|
||||
@ -367,6 +368,7 @@ def test_create_function_from_zipfile():
|
||||
"Version": "1",
|
||||
"VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []},
|
||||
"ResponseMetadata": {"HTTPStatusCode": 201},
|
||||
"State": "Active",
|
||||
}
|
||||
)
|
||||
|
||||
@ -631,6 +633,7 @@ def test_list_create_list_get_delete_list():
|
||||
"Timeout": 3,
|
||||
"Version": "$LATEST",
|
||||
"VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []},
|
||||
"State": "Active",
|
||||
},
|
||||
"ResponseMetadata": {"HTTPStatusCode": 200},
|
||||
}
|
||||
@ -827,6 +830,7 @@ def test_get_function_created_with_zipfile():
|
||||
"Timeout": 3,
|
||||
"Version": "$LATEST",
|
||||
"VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []},
|
||||
"State": "Active",
|
||||
}
|
||||
)
|
||||
|
||||
@ -1436,6 +1440,7 @@ def test_update_function_zip():
|
||||
"Timeout": 3,
|
||||
"Version": "2",
|
||||
"VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []},
|
||||
"State": "Active",
|
||||
}
|
||||
)
|
||||
|
||||
@ -1498,6 +1503,7 @@ def test_update_function_s3():
|
||||
"Timeout": 3,
|
||||
"Version": "2",
|
||||
"VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []},
|
||||
"State": "Active",
|
||||
}
|
||||
)
|
||||
|
||||
|
104
tests/test_awslambda/test_policy.py
Normal file
104
tests/test_awslambda/test_policy.py
Normal file
@ -0,0 +1,104 @@
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import unittest
|
||||
import json
|
||||
|
||||
from moto.awslambda.policy import Policy
|
||||
|
||||
|
||||
class MockLambdaFunction:
|
||||
def __init__(self, arn):
|
||||
self.function_arn = arn
|
||||
self.policy = None
|
||||
|
||||
|
||||
class TC:
|
||||
def __init__(self, lambda_arn, statement, expected):
|
||||
self.statement = statement
|
||||
self.expected = expected
|
||||
self.fn = MockLambdaFunction(lambda_arn)
|
||||
self.policy = Policy(self.fn)
|
||||
|
||||
def Run(self, parent):
|
||||
self.policy.add_statement(json.dumps(self.statement))
|
||||
parent.assertDictEqual(self.expected, self.policy.statements[0])
|
||||
|
||||
sid = self.statement.get("StatementId", None)
|
||||
if sid == None:
|
||||
raise "TestCase.statement does not contain StatementId"
|
||||
|
||||
self.policy.del_statement(sid)
|
||||
parent.assertEqual([], self.policy.statements)
|
||||
|
||||
|
||||
class TestPolicy(unittest.TestCase):
|
||||
def test(self):
|
||||
tt = [
|
||||
TC(
|
||||
# lambda_arn
|
||||
"arn",
|
||||
{ # statement
|
||||
"StatementId": "statement0",
|
||||
"Action": "lambda:InvokeFunction",
|
||||
"FunctionName": "function_name",
|
||||
"Principal": "events.amazonaws.com",
|
||||
},
|
||||
{ # expected
|
||||
"Action": "lambda:InvokeFunction",
|
||||
"FunctionName": "function_name",
|
||||
"Principal": {"Service": "events.amazonaws.com"},
|
||||
"Effect": "Allow",
|
||||
"Resource": "arn:$LATEST",
|
||||
"Sid": "statement0",
|
||||
},
|
||||
),
|
||||
TC(
|
||||
# lambda_arn
|
||||
"arn",
|
||||
{ # statement
|
||||
"StatementId": "statement1",
|
||||
"Action": "lambda:InvokeFunction",
|
||||
"FunctionName": "function_name",
|
||||
"Principal": "events.amazonaws.com",
|
||||
"SourceArn": "arn:aws:events:us-east-1:111111111111:rule/rule_name",
|
||||
},
|
||||
{
|
||||
"Action": "lambda:InvokeFunction",
|
||||
"FunctionName": "function_name",
|
||||
"Principal": {"Service": "events.amazonaws.com"},
|
||||
"Effect": "Allow",
|
||||
"Resource": "arn:$LATEST",
|
||||
"Sid": "statement1",
|
||||
"Condition": {
|
||||
"ArnLike": {
|
||||
"AWS:SourceArn": "arn:aws:events:us-east-1:111111111111:rule/rule_name"
|
||||
}
|
||||
},
|
||||
},
|
||||
),
|
||||
TC(
|
||||
# lambda_arn
|
||||
"arn",
|
||||
{ # statement
|
||||
"StatementId": "statement2",
|
||||
"Action": "lambda:InvokeFunction",
|
||||
"FunctionName": "function_name",
|
||||
"Principal": "events.amazonaws.com",
|
||||
"SourceAccount": "111111111111",
|
||||
},
|
||||
{ # expected
|
||||
"Action": "lambda:InvokeFunction",
|
||||
"FunctionName": "function_name",
|
||||
"Principal": {"Service": "events.amazonaws.com"},
|
||||
"Effect": "Allow",
|
||||
"Resource": "arn:$LATEST",
|
||||
"Sid": "statement2",
|
||||
"Condition": {
|
||||
"StringEquals": {"AWS:SourceAccount": "111111111111"}
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
for tc in tt:
|
||||
tc.Run(self)
|
Loading…
Reference in New Issue
Block a user