diff --git a/moto/awslambda/exceptions.py b/moto/awslambda/exceptions.py index 1a82977c3..08d13dce5 100644 --- a/moto/awslambda/exceptions.py +++ b/moto/awslambda/exceptions.py @@ -1,4 +1,5 @@ from botocore.client import ClientError +from moto.core.exceptions import JsonRESTError class LambdaClientError(ClientError): @@ -29,3 +30,12 @@ class InvalidRoleFormat(LambdaClientError): role, InvalidRoleFormat.pattern ) super(InvalidRoleFormat, self).__init__("ValidationException", message) + + +class PreconditionFailedException(JsonRESTError): + code = 412 + + def __init__(self, message): + super(PreconditionFailedException, self).__init__( + "PreconditionFailedException", message + ) diff --git a/moto/awslambda/models.py b/moto/awslambda/models.py index 95a5c4ad5..ec327534a 100644 --- a/moto/awslambda/models.py +++ b/moto/awslambda/models.py @@ -25,6 +25,7 @@ import requests.adapters from boto3 import Session +from moto.awslambda.policy import Policy from moto.core import BaseBackend, BaseModel from moto.core.exceptions import RESTError from moto.iam.models import iam_backend @@ -47,7 +48,6 @@ from moto.core import ACCOUNT_ID logger = logging.getLogger(__name__) - try: from tempfile import TemporaryDirectory except ImportError: @@ -164,7 +164,8 @@ class LambdaFunction(BaseModel): self.logs_backend = logs_backends[self.region] self.environment_vars = spec.get("Environment", {}).get("Variables", {}) self.docker_client = docker.from_env() - self.policy = "" + self.policy = None + self.state = "Active" # Unfortunately mocking replaces this method w/o fallback enabled, so we # need to replace it if we detect it's been mocked @@ -274,11 +275,11 @@ class LambdaFunction(BaseModel): "MemorySize": self.memory_size, "Role": self.role, "Runtime": self.run_time, + "State": self.state, "Timeout": self.timeout, "Version": str(self.version), "VpcConfig": self.vpc_config, } - if self.environment_vars: config["Environment"] = {"Variables": self.environment_vars} @@ -709,7 +710,8 @@ class LambdaStorage(object): "versions": [], "alias": weakref.WeakValueDictionary(), } - + # instantiate a new policy for this version of the lambda + fn.policy = Policy(fn) self._arns[fn.function_arn] = fn def publish_function(self, name): @@ -1010,8 +1012,21 @@ class LambdaBackend(BaseBackend): return True return False - def add_policy(self, function_name, policy): - self.get_function(function_name).policy = policy + def add_policy_statement(self, function_name, raw): + fn = self.get_function(function_name) + fn.policy.add_statement(raw) + + def del_policy_statement(self, function_name, sid, revision=""): + fn = self.get_function(function_name) + fn.policy.del_statement(sid, revision) + + def get_policy(self, function_name): + fn = self.get_function(function_name) + return fn.policy.get_policy() + + def get_policy_wire_format(self, function_name): + fn = self.get_function(function_name) + return fn.policy.wire_format() def update_function_code(self, function_name, qualifier, body): fn = self.get_function(function_name, qualifier) diff --git a/moto/awslambda/policy.py b/moto/awslambda/policy.py new file mode 100644 index 000000000..66ec728f2 --- /dev/null +++ b/moto/awslambda/policy.py @@ -0,0 +1,145 @@ +from __future__ import unicode_literals + +import json +import uuid + +from six import string_types + +from moto.awslambda.exceptions import PreconditionFailedException + + +class Policy: + def __init__(self, parent): + self.revision = str(uuid.uuid4()) + self.statements = [] + self.parent = parent + + def __repr__(self): + return json.dumps(self.get_policy()) + + def wire_format(self): + return json.dumps( + { + "Policy": json.dumps( + { + "Version": "2012-10-17", + "Id": "default", + "Statement": self.statements, + } + ), + "RevisionId": self.revision, + } + ) + + def get_policy(self): + return { + "Policy": { + "Version": "2012-10-17", + "Id": "default", + "Statement": self.statements, + }, + "RevisionId": self.revision, + } + + # adds the raw JSON statement to the policy + def add_statement(self, raw): + policy = json.loads(raw, object_hook=self.decode_policy) + if len(policy.revision) > 0 and self.revision != policy.revision: + raise PreconditionFailedException( + "The RevisionId provided does not match the latest RevisionId" + " for the Lambda function or alias. Call the GetFunction or the GetAlias API to retrieve" + " the latest RevisionId for your resource." + ) + self.statements.append(policy.statements[0]) + self.revision = str(uuid.uuid4()) + + # removes the statement that matches 'sid' from the policy + def del_statement(self, sid, revision=""): + if len(revision) > 0 and self.revision != revision: + raise PreconditionFailedException( + "The RevisionId provided does not match the latest RevisionId" + " for the Lambda function or alias. Call the GetFunction or the GetAlias API to retrieve" + " the latest RevisionId for your resource." + ) + for statement in self.statements: + if "Sid" in statement and statement["Sid"] == sid: + self.statements.remove(statement) + + # converts AddPermission request to PolicyStatement + # https://docs.aws.amazon.com/lambda/latest/dg/API_AddPermission.html + def decode_policy(self, obj): + # import pydevd + # pydevd.settrace("localhost", port=5678) + policy = Policy(self.parent) + policy.revision = obj.get("RevisionId", "") + + # set some default values if these keys are not set + self.ensure_set(obj, "Effect", "Allow") + self.ensure_set(obj, "Resource", self.parent.function_arn + ":$LATEST") + self.ensure_set(obj, "StatementId", str(uuid.uuid4())) + + # transform field names and values + self.transform_property(obj, "StatementId", "Sid", self.nop_formatter) + self.transform_property(obj, "Principal", "Principal", self.principal_formatter) + self.transform_property( + obj, "SourceArn", "SourceArn", self.source_arn_formatter + ) + self.transform_property( + obj, "SourceAccount", "SourceAccount", self.source_account_formatter + ) + + # remove RevisionId and EventSourceToken if they are set + self.remove_if_set(obj, ["RevisionId", "EventSourceToken"]) + + # merge conditional statements into a single map under the Condition key + self.condition_merge(obj) + + # append resulting statement to policy.statements + policy.statements.append(obj) + + return policy + + def nop_formatter(self, obj): + return obj + + def ensure_set(self, obj, key, value): + if key not in obj: + obj[key] = value + + def principal_formatter(self, obj): + if isinstance(obj, string_types): + if obj.endswith(".amazonaws.com"): + return {"Service": obj} + if obj.endswith(":root"): + return {"AWS": obj} + return obj + + def source_account_formatter(self, obj): + return {"StringEquals": {"AWS:SourceAccount": obj}} + + def source_arn_formatter(self, obj): + return {"ArnLike": {"AWS:SourceArn": obj}} + + def transform_property(self, obj, old_name, new_name, formatter): + if old_name in obj: + obj[new_name] = formatter(obj[old_name]) + if new_name != old_name: + del obj[old_name] + + def remove_if_set(self, obj, keys): + for key in keys: + if key in obj: + del obj[key] + + def condition_merge(self, obj): + if "SourceArn" in obj: + if "Condition" not in obj: + obj["Condition"] = {} + obj["Condition"].update(obj["SourceArn"]) + del obj["SourceArn"] + + if "SourceAccount" in obj: + if "Condition" not in obj: + obj["Condition"] = {} + obj["Condition"].update(obj["SourceAccount"]) + del obj["SourceAccount"] diff --git a/moto/awslambda/responses.py b/moto/awslambda/responses.py index 46203c10d..e1713ce52 100644 --- a/moto/awslambda/responses.py +++ b/moto/awslambda/responses.py @@ -120,8 +120,12 @@ class LambdaResponse(BaseResponse): self.setup_class(request, full_url, headers) if request.method == "GET": return self._get_policy(request, full_url, headers) - if request.method == "POST": + elif request.method == "POST": return self._add_policy(request, full_url, headers) + elif request.method == "DELETE": + return self._del_policy(request, full_url, headers, self.querystring) + else: + raise ValueError("Cannot handle {0} request".format(request.method)) def configuration(self, request, full_url, headers): self.setup_class(request, full_url, headers) @@ -141,9 +145,9 @@ class LambdaResponse(BaseResponse): path = request.path if hasattr(request, "path") else path_url(request.url) function_name = path.split("/")[-2] if self.lambda_backend.get_function(function_name): - policy = self.body - self.lambda_backend.add_policy(function_name, policy) - return 200, {}, json.dumps(dict(Statement=policy)) + statement = self.body + self.lambda_backend.add_policy_statement(function_name, statement) + return 200, {}, json.dumps({"Statement": statement}) else: return 404, {}, "{}" @@ -151,14 +155,21 @@ class LambdaResponse(BaseResponse): path = request.path if hasattr(request, "path") else path_url(request.url) function_name = path.split("/")[-2] if self.lambda_backend.get_function(function_name): - lambda_function = self.lambda_backend.get_function(function_name) - return ( - 200, - {}, - json.dumps( - dict(Policy='{"Statement":[' + lambda_function.policy + "]}") - ), + out = self.lambda_backend.get_policy_wire_format(function_name) + return 200, {}, out + else: + return 404, {}, "{}" + + def _del_policy(self, request, full_url, headers, querystring): + path = request.path if hasattr(request, "path") else path_url(request.url) + function_name = path.split("/")[-3] + statement_id = path.split("/")[-1].split("?")[0] + revision = querystring.get("RevisionId", "") + if self.lambda_backend.get_function(function_name): + self.lambda_backend.del_policy_statement( + function_name, statement_id, revision ) + return 204, {}, "{}" else: return 404, {}, "{}" diff --git a/moto/awslambda/urls.py b/moto/awslambda/urls.py index da7346817..6c9b736a6 100644 --- a/moto/awslambda/urls.py +++ b/moto/awslambda/urls.py @@ -6,7 +6,7 @@ url_bases = ["https?://lambda.(.+).amazonaws.com"] response = LambdaResponse() url_paths = { - "{0}/(?P[^/]+)/functions/?$": response.root, + r"{0}/(?P[^/]+)/functions/?$": response.root, r"{0}/(?P[^/]+)/functions/(?P[\w_:%-]+)/?$": response.function, r"{0}/(?P[^/]+)/functions/(?P[\w_-]+)/versions/?$": response.versions, r"{0}/(?P[^/]+)/event-source-mappings/?$": response.event_source_mappings, @@ -14,6 +14,7 @@ url_paths = { r"{0}/(?P[^/]+)/functions/(?P[\w_-]+)/invocations/?$": response.invoke, r"{0}/(?P[^/]+)/functions/(?P[\w_-]+)/invoke-async/?$": response.invoke_async, r"{0}/(?P[^/]+)/tags/(?P.+)": response.tag, + r"{0}/(?P[^/]+)/functions/(?P[\w_-]+)/policy/(?P[\w_-]+)$": response.policy, r"{0}/(?P[^/]+)/functions/(?P[\w_-]+)/policy/?$": response.policy, r"{0}/(?P[^/]+)/functions/(?P[\w_-]+)/configuration/?$": response.configuration, r"{0}/(?P[^/]+)/functions/(?P[\w_-]+)/code/?$": response.code, diff --git a/tests/test_awslambda/test_lambda.py b/tests/test_awslambda/test_lambda.py index 6fd97e325..5f7eb0cfc 100644 --- a/tests/test_awslambda/test_lambda.py +++ b/tests/test_awslambda/test_lambda.py @@ -324,6 +324,7 @@ def test_create_function_from_aws_bucket(): "VpcId": "vpc-123abc", }, "ResponseMetadata": {"HTTPStatusCode": 201}, + "State": "Active", } ) @@ -367,6 +368,7 @@ def test_create_function_from_zipfile(): "Version": "1", "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, "ResponseMetadata": {"HTTPStatusCode": 201}, + "State": "Active", } ) @@ -631,6 +633,7 @@ def test_list_create_list_get_delete_list(): "Timeout": 3, "Version": "$LATEST", "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, + "State": "Active", }, "ResponseMetadata": {"HTTPStatusCode": 200}, } @@ -827,6 +830,7 @@ def test_get_function_created_with_zipfile(): "Timeout": 3, "Version": "$LATEST", "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, + "State": "Active", } ) @@ -1436,6 +1440,7 @@ def test_update_function_zip(): "Timeout": 3, "Version": "2", "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, + "State": "Active", } ) @@ -1498,6 +1503,7 @@ def test_update_function_s3(): "Timeout": 3, "Version": "2", "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, + "State": "Active", } ) diff --git a/tests/test_awslambda/test_policy.py b/tests/test_awslambda/test_policy.py new file mode 100644 index 000000000..2571926eb --- /dev/null +++ b/tests/test_awslambda/test_policy.py @@ -0,0 +1,104 @@ +from __future__ import unicode_literals + +import unittest +import json + +from moto.awslambda.policy import Policy + + +class MockLambdaFunction: + def __init__(self, arn): + self.function_arn = arn + self.policy = None + + +class TC: + def __init__(self, lambda_arn, statement, expected): + self.statement = statement + self.expected = expected + self.fn = MockLambdaFunction(lambda_arn) + self.policy = Policy(self.fn) + + def Run(self, parent): + self.policy.add_statement(json.dumps(self.statement)) + parent.assertDictEqual(self.expected, self.policy.statements[0]) + + sid = self.statement.get("StatementId", None) + if sid == None: + raise "TestCase.statement does not contain StatementId" + + self.policy.del_statement(sid) + parent.assertEqual([], self.policy.statements) + + +class TestPolicy(unittest.TestCase): + def test(self): + tt = [ + TC( + # lambda_arn + "arn", + { # statement + "StatementId": "statement0", + "Action": "lambda:InvokeFunction", + "FunctionName": "function_name", + "Principal": "events.amazonaws.com", + }, + { # expected + "Action": "lambda:InvokeFunction", + "FunctionName": "function_name", + "Principal": {"Service": "events.amazonaws.com"}, + "Effect": "Allow", + "Resource": "arn:$LATEST", + "Sid": "statement0", + }, + ), + TC( + # lambda_arn + "arn", + { # statement + "StatementId": "statement1", + "Action": "lambda:InvokeFunction", + "FunctionName": "function_name", + "Principal": "events.amazonaws.com", + "SourceArn": "arn:aws:events:us-east-1:111111111111:rule/rule_name", + }, + { + "Action": "lambda:InvokeFunction", + "FunctionName": "function_name", + "Principal": {"Service": "events.amazonaws.com"}, + "Effect": "Allow", + "Resource": "arn:$LATEST", + "Sid": "statement1", + "Condition": { + "ArnLike": { + "AWS:SourceArn": "arn:aws:events:us-east-1:111111111111:rule/rule_name" + } + }, + }, + ), + TC( + # lambda_arn + "arn", + { # statement + "StatementId": "statement2", + "Action": "lambda:InvokeFunction", + "FunctionName": "function_name", + "Principal": "events.amazonaws.com", + "SourceAccount": "111111111111", + }, + { # expected + "Action": "lambda:InvokeFunction", + "FunctionName": "function_name", + "Principal": {"Service": "events.amazonaws.com"}, + "Effect": "Allow", + "Resource": "arn:$LATEST", + "Sid": "statement2", + "Condition": { + "StringEquals": {"AWS:SourceAccount": "111111111111"} + }, + }, + ), + ] + + for tc in tt: + tc.Run(self)