Merge pull request #27 from spulec/master

Merge upstream
This commit is contained in:
Bert Blommers 2020-02-01 09:53:15 +00:00 committed by GitHub
commit b971aee9d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 627 additions and 48 deletions

View File

@ -39,7 +39,7 @@ class InvalidResourcePathException(BadRequestException):
def __init__(self): def __init__(self):
super(InvalidResourcePathException, self).__init__( super(InvalidResourcePathException, self).__init__(
"BadRequestException", "BadRequestException",
"Resource's path part only allow a-zA-Z0-9._- and curly braces at the beginning and the end.", "Resource's path part only allow a-zA-Z0-9._- and curly braces at the beginning and the end and an optional plus sign before the closing brace.",
) )

View File

@ -556,7 +556,7 @@ class APIGatewayBackend(BaseBackend):
return resource return resource
def create_resource(self, function_id, parent_resource_id, path_part): def create_resource(self, function_id, parent_resource_id, path_part):
if not re.match("^\\{?[a-zA-Z0-9._-]+\\}?$", path_part): if not re.match("^\\{?[a-zA-Z0-9._-]+\\+?\\}?$", path_part):
raise InvalidResourcePathException() raise InvalidResourcePathException()
api = self.get_rest_api(function_id) api = self.get_rest_api(function_id)
child = api.add_child(path=path_part, parent_id=parent_resource_id) child = api.add_child(path=path_part, parent_id=parent_resource_id)

View File

@ -1,4 +1,5 @@
from botocore.client import ClientError from botocore.client import ClientError
from moto.core.exceptions import JsonRESTError
class LambdaClientError(ClientError): class LambdaClientError(ClientError):
@ -29,3 +30,12 @@ class InvalidRoleFormat(LambdaClientError):
role, InvalidRoleFormat.pattern role, InvalidRoleFormat.pattern
) )
super(InvalidRoleFormat, self).__init__("ValidationException", message) super(InvalidRoleFormat, self).__init__("ValidationException", message)
class PreconditionFailedException(JsonRESTError):
code = 412
def __init__(self, message):
super(PreconditionFailedException, self).__init__(
"PreconditionFailedException", message
)

View File

@ -25,6 +25,7 @@ import requests.adapters
from boto3 import Session from boto3 import Session
from moto.awslambda.policy import Policy
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.exceptions import RESTError from moto.core.exceptions import RESTError
from moto.iam.models import iam_backend from moto.iam.models import iam_backend
@ -47,7 +48,6 @@ from moto.core import ACCOUNT_ID
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try: try:
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
except ImportError: except ImportError:
@ -161,7 +161,8 @@ class LambdaFunction(BaseModel):
self.logs_backend = logs_backends[self.region] self.logs_backend = logs_backends[self.region]
self.environment_vars = spec.get("Environment", {}).get("Variables", {}) self.environment_vars = spec.get("Environment", {}).get("Variables", {})
self.docker_client = docker.from_env() self.docker_client = docker.from_env()
self.policy = "" self.policy = None
self.state = "Active"
# Unfortunately mocking replaces this method w/o fallback enabled, so we # Unfortunately mocking replaces this method w/o fallback enabled, so we
# need to replace it if we detect it's been mocked # need to replace it if we detect it's been mocked
@ -271,11 +272,11 @@ class LambdaFunction(BaseModel):
"MemorySize": self.memory_size, "MemorySize": self.memory_size,
"Role": self.role, "Role": self.role,
"Runtime": self.run_time, "Runtime": self.run_time,
"State": self.state,
"Timeout": self.timeout, "Timeout": self.timeout,
"Version": str(self.version), "Version": str(self.version),
"VpcConfig": self.vpc_config, "VpcConfig": self.vpc_config,
} }
if self.environment_vars: if self.environment_vars:
config["Environment"] = {"Variables": self.environment_vars} config["Environment"] = {"Variables": self.environment_vars}
@ -394,6 +395,7 @@ class LambdaFunction(BaseModel):
env_vars.update(self.environment_vars) env_vars.update(self.environment_vars)
container = output = exit_code = None container = output = exit_code = None
log_config = docker.types.LogConfig(type=docker.types.LogConfig.types.JSON)
with _DockerDataVolumeContext(self) as data_vol: with _DockerDataVolumeContext(self) as data_vol:
try: try:
run_kwargs = ( run_kwargs = (
@ -409,6 +411,7 @@ class LambdaFunction(BaseModel):
volumes=["{}:/var/task".format(data_vol.name)], volumes=["{}:/var/task".format(data_vol.name)],
environment=env_vars, environment=env_vars,
detach=True, detach=True,
log_config=log_config,
**run_kwargs **run_kwargs
) )
finally: finally:
@ -472,7 +475,7 @@ class LambdaFunction(BaseModel):
payload["result"] = response_headers["x-amz-log-result"] payload["result"] = response_headers["x-amz-log-result"]
result = res.encode("utf-8") result = res.encode("utf-8")
else: else:
result = json.dumps(payload) result = res
if errored: if errored:
response_headers["x-amz-function-error"] = "Handled" response_headers["x-amz-function-error"] = "Handled"
@ -701,7 +704,8 @@ class LambdaStorage(object):
"versions": [], "versions": [],
"alias": weakref.WeakValueDictionary(), "alias": weakref.WeakValueDictionary(),
} }
# instantiate a new policy for this version of the lambda
fn.policy = Policy(fn)
self._arns[fn.function_arn] = fn self._arns[fn.function_arn] = fn
def publish_function(self, name): def publish_function(self, name):
@ -1002,8 +1006,21 @@ class LambdaBackend(BaseBackend):
return True return True
return False return False
def add_policy(self, function_name, policy): def add_policy_statement(self, function_name, raw):
self.get_function(function_name).policy = policy fn = self.get_function(function_name)
fn.policy.add_statement(raw)
def del_policy_statement(self, function_name, sid, revision=""):
fn = self.get_function(function_name)
fn.policy.del_statement(sid, revision)
def get_policy(self, function_name):
fn = self.get_function(function_name)
return fn.policy.get_policy()
def get_policy_wire_format(self, function_name):
fn = self.get_function(function_name)
return fn.policy.wire_format()
def update_function_code(self, function_name, qualifier, body): def update_function_code(self, function_name, qualifier, body):
fn = self.get_function(function_name, qualifier) fn = self.get_function(function_name, qualifier)

134
moto/awslambda/policy.py Normal file
View File

@ -0,0 +1,134 @@
from __future__ import unicode_literals
import json
import uuid
from six import string_types
from moto.awslambda.exceptions import PreconditionFailedException
class Policy:
def __init__(self, parent):
self.revision = str(uuid.uuid4())
self.statements = []
self.parent = parent
def wire_format(self):
p = self.get_policy()
p["Policy"] = json.dumps(p["Policy"])
return json.dumps(p)
def get_policy(self):
return {
"Policy": {
"Version": "2012-10-17",
"Id": "default",
"Statement": self.statements,
},
"RevisionId": self.revision,
}
# adds the raw JSON statement to the policy
def add_statement(self, raw):
policy = json.loads(raw, object_hook=self.decode_policy)
if len(policy.revision) > 0 and self.revision != policy.revision:
raise PreconditionFailedException(
"The RevisionId provided does not match the latest RevisionId"
" for the Lambda function or alias. Call the GetFunction or the GetAlias API to retrieve"
" the latest RevisionId for your resource."
)
self.statements.append(policy.statements[0])
self.revision = str(uuid.uuid4())
# removes the statement that matches 'sid' from the policy
def del_statement(self, sid, revision=""):
if len(revision) > 0 and self.revision != revision:
raise PreconditionFailedException(
"The RevisionId provided does not match the latest RevisionId"
" for the Lambda function or alias. Call the GetFunction or the GetAlias API to retrieve"
" the latest RevisionId for your resource."
)
for statement in self.statements:
if "Sid" in statement and statement["Sid"] == sid:
self.statements.remove(statement)
# converts AddPermission request to PolicyStatement
# https://docs.aws.amazon.com/lambda/latest/dg/API_AddPermission.html
def decode_policy(self, obj):
# import pydevd
# pydevd.settrace("localhost", port=5678)
policy = Policy(self.parent)
policy.revision = obj.get("RevisionId", "")
# set some default values if these keys are not set
self.ensure_set(obj, "Effect", "Allow")
self.ensure_set(obj, "Resource", self.parent.function_arn + ":$LATEST")
self.ensure_set(obj, "StatementId", str(uuid.uuid4()))
# transform field names and values
self.transform_property(obj, "StatementId", "Sid", self.nop_formatter)
self.transform_property(obj, "Principal", "Principal", self.principal_formatter)
self.transform_property(
obj, "SourceArn", "SourceArn", self.source_arn_formatter
)
self.transform_property(
obj, "SourceAccount", "SourceAccount", self.source_account_formatter
)
# remove RevisionId and EventSourceToken if they are set
self.remove_if_set(obj, ["RevisionId", "EventSourceToken"])
# merge conditional statements into a single map under the Condition key
self.condition_merge(obj)
# append resulting statement to policy.statements
policy.statements.append(obj)
return policy
def nop_formatter(self, obj):
return obj
def ensure_set(self, obj, key, value):
if key not in obj:
obj[key] = value
def principal_formatter(self, obj):
if isinstance(obj, string_types):
if obj.endswith(".amazonaws.com"):
return {"Service": obj}
if obj.endswith(":root"):
return {"AWS": obj}
return obj
def source_account_formatter(self, obj):
return {"StringEquals": {"AWS:SourceAccount": obj}}
def source_arn_formatter(self, obj):
return {"ArnLike": {"AWS:SourceArn": obj}}
def transform_property(self, obj, old_name, new_name, formatter):
if old_name in obj:
obj[new_name] = formatter(obj[old_name])
if new_name != old_name:
del obj[old_name]
def remove_if_set(self, obj, keys):
for key in keys:
if key in obj:
del obj[key]
def condition_merge(self, obj):
if "SourceArn" in obj:
if "Condition" not in obj:
obj["Condition"] = {}
obj["Condition"].update(obj["SourceArn"])
del obj["SourceArn"]
if "SourceAccount" in obj:
if "Condition" not in obj:
obj["Condition"] = {}
obj["Condition"].update(obj["SourceAccount"])
del obj["SourceAccount"]

View File

@ -120,8 +120,12 @@ class LambdaResponse(BaseResponse):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "GET": if request.method == "GET":
return self._get_policy(request, full_url, headers) return self._get_policy(request, full_url, headers)
if request.method == "POST": elif request.method == "POST":
return self._add_policy(request, full_url, headers) return self._add_policy(request, full_url, headers)
elif request.method == "DELETE":
return self._del_policy(request, full_url, headers, self.querystring)
else:
raise ValueError("Cannot handle {0} request".format(request.method))
def configuration(self, request, full_url, headers): def configuration(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -141,9 +145,9 @@ class LambdaResponse(BaseResponse):
path = request.path if hasattr(request, "path") else path_url(request.url) path = request.path if hasattr(request, "path") else path_url(request.url)
function_name = path.split("/")[-2] function_name = path.split("/")[-2]
if self.lambda_backend.get_function(function_name): if self.lambda_backend.get_function(function_name):
policy = self.body statement = self.body
self.lambda_backend.add_policy(function_name, policy) self.lambda_backend.add_policy_statement(function_name, statement)
return 200, {}, json.dumps(dict(Statement=policy)) return 200, {}, json.dumps({"Statement": statement})
else: else:
return 404, {}, "{}" return 404, {}, "{}"
@ -151,14 +155,21 @@ class LambdaResponse(BaseResponse):
path = request.path if hasattr(request, "path") else path_url(request.url) path = request.path if hasattr(request, "path") else path_url(request.url)
function_name = path.split("/")[-2] function_name = path.split("/")[-2]
if self.lambda_backend.get_function(function_name): if self.lambda_backend.get_function(function_name):
lambda_function = self.lambda_backend.get_function(function_name) out = self.lambda_backend.get_policy_wire_format(function_name)
return ( return 200, {}, out
200, else:
{}, return 404, {}, "{}"
json.dumps(
dict(Policy='{"Statement":[' + lambda_function.policy + "]}") def _del_policy(self, request, full_url, headers, querystring):
), path = request.path if hasattr(request, "path") else path_url(request.url)
function_name = path.split("/")[-3]
statement_id = path.split("/")[-1].split("?")[0]
revision = querystring.get("RevisionId", "")
if self.lambda_backend.get_function(function_name):
self.lambda_backend.del_policy_statement(
function_name, statement_id, revision
) )
return 204, {}, "{}"
else: else:
return 404, {}, "{}" return 404, {}, "{}"

View File

@ -6,7 +6,7 @@ url_bases = ["https?://lambda.(.+).amazonaws.com"]
response = LambdaResponse() response = LambdaResponse()
url_paths = { url_paths = {
"{0}/(?P<api_version>[^/]+)/functions/?$": response.root, r"{0}/(?P<api_version>[^/]+)/functions/?$": response.root,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_:%-]+)/?$": response.function, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_:%-]+)/?$": response.function,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/versions/?$": response.versions, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/versions/?$": response.versions,
r"{0}/(?P<api_version>[^/]+)/event-source-mappings/?$": response.event_source_mappings, r"{0}/(?P<api_version>[^/]+)/event-source-mappings/?$": response.event_source_mappings,
@ -14,6 +14,7 @@ url_paths = {
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invocations/?$": response.invoke, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invocations/?$": response.invoke,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invoke-async/?$": response.invoke_async, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invoke-async/?$": response.invoke_async,
r"{0}/(?P<api_version>[^/]+)/tags/(?P<resource_arn>.+)": response.tag, r"{0}/(?P<api_version>[^/]+)/tags/(?P<resource_arn>.+)": response.tag,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/policy/(?P<statement_id>[\w_-]+)$": response.policy,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/policy/?$": response.policy, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/policy/?$": response.policy,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/configuration/?$": response.configuration, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/configuration/?$": response.configuration,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/code/?$": response.code, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/code/?$": response.code,

View File

@ -563,6 +563,10 @@ class IamResponse(BaseResponse):
def create_access_key(self): def create_access_key(self):
user_name = self._get_param("UserName") user_name = self._get_param("UserName")
if not user_name:
access_key_id = self.get_current_user()
access_key = iam_backend.get_access_key_last_used(access_key_id)
user_name = access_key["user_name"]
key = iam_backend.create_access_key(user_name) key = iam_backend.create_access_key(user_name)
template = self.response_template(CREATE_ACCESS_KEY_TEMPLATE) template = self.response_template(CREATE_ACCESS_KEY_TEMPLATE)
@ -572,6 +576,10 @@ class IamResponse(BaseResponse):
user_name = self._get_param("UserName") user_name = self._get_param("UserName")
access_key_id = self._get_param("AccessKeyId") access_key_id = self._get_param("AccessKeyId")
status = self._get_param("Status") status = self._get_param("Status")
if not user_name:
access_key = iam_backend.get_access_key_last_used(access_key_id)
user_name = access_key["user_name"]
iam_backend.update_access_key(user_name, access_key_id, status) iam_backend.update_access_key(user_name, access_key_id, status)
template = self.response_template(GENERIC_EMPTY_TEMPLATE) template = self.response_template(GENERIC_EMPTY_TEMPLATE)
return template.render(name="UpdateAccessKey") return template.render(name="UpdateAccessKey")
@ -587,6 +595,11 @@ class IamResponse(BaseResponse):
def list_access_keys(self): def list_access_keys(self):
user_name = self._get_param("UserName") user_name = self._get_param("UserName")
if not user_name:
access_key_id = self.get_current_user()
access_key = iam_backend.get_access_key_last_used(access_key_id)
user_name = access_key["user_name"]
keys = iam_backend.get_all_access_keys(user_name) keys = iam_backend.get_all_access_keys(user_name)
template = self.response_template(LIST_ACCESS_KEYS_TEMPLATE) template = self.response_template(LIST_ACCESS_KEYS_TEMPLATE)
return template.render(user_name=user_name, keys=keys) return template.render(user_name=user_name, keys=keys)
@ -594,6 +607,9 @@ class IamResponse(BaseResponse):
def delete_access_key(self): def delete_access_key(self):
user_name = self._get_param("UserName") user_name = self._get_param("UserName")
access_key_id = self._get_param("AccessKeyId") access_key_id = self._get_param("AccessKeyId")
if not user_name:
access_key = iam_backend.get_access_key_last_used(access_key_id)
user_name = access_key["user_name"]
iam_backend.delete_access_key(access_key_id, user_name) iam_backend.delete_access_key(access_key_id, user_name)
template = self.response_template(GENERIC_EMPTY_TEMPLATE) template = self.response_template(GENERIC_EMPTY_TEMPLATE)

View File

@ -99,3 +99,28 @@ class InvalidAttributeName(RESTError):
super(InvalidAttributeName, self).__init__( super(InvalidAttributeName, self).__init__(
"InvalidAttributeName", "Unknown Attribute {}.".format(attribute_name) "InvalidAttributeName", "Unknown Attribute {}.".format(attribute_name)
) )
class InvalidParameterValue(RESTError):
code = 400
def __init__(self, message):
super(InvalidParameterValue, self).__init__("InvalidParameterValue", message)
class MissingParameter(RESTError):
code = 400
def __init__(self):
super(MissingParameter, self).__init__(
"MissingParameter", "The request must contain the parameter Actions."
)
class OverLimit(RESTError):
code = 403
def __init__(self, count):
super(OverLimit, self).__init__(
"OverLimit", "{} Actions were found, maximum allowed is 7.".format(count)
)

View File

@ -30,6 +30,9 @@ from .exceptions import (
BatchEntryIdsNotDistinct, BatchEntryIdsNotDistinct,
TooManyEntriesInBatchRequest, TooManyEntriesInBatchRequest,
InvalidAttributeName, InvalidAttributeName,
InvalidParameterValue,
MissingParameter,
OverLimit,
) )
from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID
@ -183,6 +186,7 @@ class Queue(BaseModel):
"MaximumMessageSize", "MaximumMessageSize",
"MessageRetentionPeriod", "MessageRetentionPeriod",
"QueueArn", "QueueArn",
"Policy",
"RedrivePolicy", "RedrivePolicy",
"ReceiveMessageWaitTimeSeconds", "ReceiveMessageWaitTimeSeconds",
"VisibilityTimeout", "VisibilityTimeout",
@ -195,6 +199,8 @@ class Queue(BaseModel):
"DeleteMessage", "DeleteMessage",
"GetQueueAttributes", "GetQueueAttributes",
"GetQueueUrl", "GetQueueUrl",
"ListDeadLetterSourceQueues",
"PurgeQueue",
"ReceiveMessage", "ReceiveMessage",
"SendMessage", "SendMessage",
) )
@ -273,7 +279,7 @@ class Queue(BaseModel):
if key in bool_fields: if key in bool_fields:
value = value == "true" value = value == "true"
if key == "RedrivePolicy" and value is not None: if key in ["Policy", "RedrivePolicy"] and value is not None:
continue continue
setattr(self, camelcase_to_underscores(key), value) setattr(self, camelcase_to_underscores(key), value)
@ -281,6 +287,9 @@ class Queue(BaseModel):
if attributes.get("RedrivePolicy", None): if attributes.get("RedrivePolicy", None):
self._setup_dlq(attributes["RedrivePolicy"]) self._setup_dlq(attributes["RedrivePolicy"])
if attributes.get("Policy"):
self.policy = attributes["Policy"]
self.last_modified_timestamp = now self.last_modified_timestamp = now
def _setup_dlq(self, policy): def _setup_dlq(self, policy):
@ -472,6 +481,24 @@ class Queue(BaseModel):
return self.name return self.name
raise UnformattedGetAttTemplateException() raise UnformattedGetAttTemplateException()
@property
def policy(self):
if self._policy_json.get("Statement"):
return json.dumps(self._policy_json)
else:
return None
@policy.setter
def policy(self, policy):
if policy:
self._policy_json = json.loads(policy)
else:
self._policy_json = {
"Version": "2012-10-17",
"Id": "{}/SQSDefaultPolicy".format(self.queue_arn),
"Statement": [],
}
class SQSBackend(BaseBackend): class SQSBackend(BaseBackend):
def __init__(self, region_name): def __init__(self, region_name):
@ -802,25 +829,75 @@ class SQSBackend(BaseBackend):
def add_permission(self, queue_name, actions, account_ids, label): def add_permission(self, queue_name, actions, account_ids, label):
queue = self.get_queue(queue_name) queue = self.get_queue(queue_name)
if actions is None or len(actions) == 0: if not actions:
raise RESTError("InvalidParameterValue", "Need at least one Action") raise MissingParameter()
if account_ids is None or len(account_ids) == 0:
raise RESTError("InvalidParameterValue", "Need at least one Account ID")
if not all([item in Queue.ALLOWED_PERMISSIONS for item in actions]): if not account_ids:
raise RESTError("InvalidParameterValue", "Invalid permissions") raise InvalidParameterValue(
"Value [] for parameter PrincipalId is invalid. Reason: Unable to verify."
)
queue.permissions[label] = (account_ids, actions) count = len(actions)
if count > 7:
raise OverLimit(count)
invalid_action = next(
(action for action in actions if action not in Queue.ALLOWED_PERMISSIONS),
None,
)
if invalid_action:
raise InvalidParameterValue(
"Value SQS:{} for parameter ActionName is invalid. "
"Reason: Only the queue owner is allowed to invoke this action.".format(
invalid_action
)
)
policy = queue._policy_json
statement = next(
(
statement
for statement in policy["Statement"]
if statement["Sid"] == label
),
None,
)
if statement:
raise InvalidParameterValue(
"Value {} for parameter Label is invalid. "
"Reason: Already exists.".format(label)
)
principals = [
"arn:aws:iam::{}:root".format(account_id) for account_id in account_ids
]
actions = ["SQS:{}".format(action) for action in actions]
statement = {
"Sid": label,
"Effect": "Allow",
"Principal": {"AWS": principals[0] if len(principals) == 1 else principals},
"Action": actions[0] if len(actions) == 1 else actions,
"Resource": queue.queue_arn,
}
queue._policy_json["Statement"].append(statement)
def remove_permission(self, queue_name, label): def remove_permission(self, queue_name, label):
queue = self.get_queue(queue_name) queue = self.get_queue(queue_name)
if label not in queue.permissions: statements = queue._policy_json["Statement"]
raise RESTError( statements_new = [
"InvalidParameterValue", "Permission doesnt exist for the given label" statement for statement in statements if statement["Sid"] != label
]
if len(statements) == len(statements_new):
raise InvalidParameterValue(
"Value {} for parameter Label is invalid. "
"Reason: can't find label on existing policy.".format(label)
) )
del queue.permissions[label] queue._policy_json["Statement"] = statements_new
def tag_queue(self, queue_name, tags): def tag_queue(self, queue_name, tags):
queue = self.get_queue(queue_name) queue = self.get_queue(queue_name)

View File

@ -58,15 +58,15 @@ def test_create_resource__validate_name():
0 0
]["id"] ]["id"]
invalid_names = ["/users", "users/", "users/{user_id}", "us{er"] invalid_names = ["/users", "users/", "users/{user_id}", "us{er", "us+er"]
valid_names = ["users", "{user_id}", "user_09", "good-dog"] valid_names = ["users", "{user_id}", "{proxy+}", "user_09", "good-dog"]
# All invalid names should throw an exception # All invalid names should throw an exception
for name in invalid_names: for name in invalid_names:
with assert_raises(ClientError) as ex: with assert_raises(ClientError) as ex:
client.create_resource(restApiId=api_id, parentId=root_id, pathPart=name) client.create_resource(restApiId=api_id, parentId=root_id, pathPart=name)
ex.exception.response["Error"]["Code"].should.equal("BadRequestException") ex.exception.response["Error"]["Code"].should.equal("BadRequestException")
ex.exception.response["Error"]["Message"].should.equal( ex.exception.response["Error"]["Message"].should.equal(
"Resource's path part only allow a-zA-Z0-9._- and curly braces at the beginning and the end." "Resource's path part only allow a-zA-Z0-9._- and curly braces at the beginning and the end and an optional plus sign before the closing brace."
) )
# All valid names should go through # All valid names should go through
for name in valid_names: for name in valid_names:

View File

@ -148,7 +148,7 @@ def test_invoke_event_function():
FunctionName="testFunction", InvocationType="Event", Payload=json.dumps(in_data) FunctionName="testFunction", InvocationType="Event", Payload=json.dumps(in_data)
) )
success_result["StatusCode"].should.equal(202) success_result["StatusCode"].should.equal(202)
json.loads(success_result["Payload"].read().decode("utf-8")).should.equal({}) json.loads(success_result["Payload"].read().decode("utf-8")).should.equal(in_data)
if settings.TEST_SERVER_MODE: if settings.TEST_SERVER_MODE:
@ -305,6 +305,7 @@ def test_create_function_from_aws_bucket():
"VpcId": "vpc-123abc", "VpcId": "vpc-123abc",
}, },
"ResponseMetadata": {"HTTPStatusCode": 201}, "ResponseMetadata": {"HTTPStatusCode": 201},
"State": "Active",
} }
) )
@ -348,6 +349,7 @@ def test_create_function_from_zipfile():
"Version": "1", "Version": "1",
"VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []},
"ResponseMetadata": {"HTTPStatusCode": 201}, "ResponseMetadata": {"HTTPStatusCode": 201},
"State": "Active",
} }
) )
@ -612,6 +614,7 @@ def test_list_create_list_get_delete_list():
"Timeout": 3, "Timeout": 3,
"Version": "$LATEST", "Version": "$LATEST",
"VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []},
"State": "Active",
}, },
"ResponseMetadata": {"HTTPStatusCode": 200}, "ResponseMetadata": {"HTTPStatusCode": 200},
} }
@ -808,6 +811,7 @@ def test_get_function_created_with_zipfile():
"Timeout": 3, "Timeout": 3,
"Version": "$LATEST", "Version": "$LATEST",
"VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []},
"State": "Active",
} }
) )
@ -1417,6 +1421,7 @@ def test_update_function_zip():
"Timeout": 3, "Timeout": 3,
"Version": "2", "Version": "2",
"VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []},
"State": "Active",
} }
) )
@ -1479,6 +1484,7 @@ def test_update_function_s3():
"Timeout": 3, "Timeout": 3,
"Version": "2", "Version": "2",
"VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []},
"State": "Active",
} }
) )

View File

@ -0,0 +1,49 @@
from __future__ import unicode_literals
import json
import sure
from moto.awslambda.policy import Policy
class MockLambdaFunction:
def __init__(self, arn):
self.function_arn = arn
self.policy = None
def test_policy():
policy = Policy(MockLambdaFunction("arn"))
statement = {
"StatementId": "statement0",
"Action": "lambda:InvokeFunction",
"FunctionName": "function_name",
"Principal": "events.amazonaws.com",
"SourceArn": "arn:aws:events:us-east-1:111111111111:rule/rule_name",
"SourceAccount": "111111111111",
}
expected = {
"Action": "lambda:InvokeFunction",
"FunctionName": "function_name",
"Principal": {"Service": "events.amazonaws.com"},
"Effect": "Allow",
"Resource": "arn:$LATEST",
"Sid": "statement0",
"Condition": {
"ArnLike": {
"AWS:SourceArn": "arn:aws:events:us-east-1:111111111111:rule/rule_name",
},
"StringEquals": {"AWS:SourceAccount": "111111111111"},
},
}
policy.add_statement(json.dumps(statement))
expected.should.be.equal(policy.statements[0])
sid = statement.get("StatementId", None)
if sid == None:
raise "TestCase.statement does not contain StatementId"
policy.del_statement(sid)
[].should.be.equal(policy.statements)

View File

@ -785,7 +785,7 @@ def test_delete_login_profile():
conn.delete_login_profile("my-user") conn.delete_login_profile("my-user")
@mock_iam() @mock_iam
def test_create_access_key(): def test_create_access_key():
conn = boto3.client("iam", region_name="us-east-1") conn = boto3.client("iam", region_name="us-east-1")
with assert_raises(ClientError): with assert_raises(ClientError):
@ -798,6 +798,19 @@ def test_create_access_key():
access_key["AccessKeyId"].should.have.length_of(20) access_key["AccessKeyId"].should.have.length_of(20)
access_key["SecretAccessKey"].should.have.length_of(40) access_key["SecretAccessKey"].should.have.length_of(40)
assert access_key["AccessKeyId"].startswith("AKIA") assert access_key["AccessKeyId"].startswith("AKIA")
conn = boto3.client(
"iam",
region_name="us-east-1",
aws_access_key_id=access_key["AccessKeyId"],
aws_secret_access_key=access_key["SecretAccessKey"],
)
access_key = conn.create_access_key()["AccessKey"]
(
datetime.utcnow() - access_key["CreateDate"].replace(tzinfo=None)
).seconds.should.be.within(0, 10)
access_key["AccessKeyId"].should.have.length_of(20)
access_key["SecretAccessKey"].should.have.length_of(40)
assert access_key["AccessKeyId"].startswith("AKIA")
@mock_iam_deprecated() @mock_iam_deprecated()
@ -825,8 +838,35 @@ def test_get_all_access_keys():
) )
@mock_iam
def test_list_access_keys():
conn = boto3.client("iam", region_name="us-east-1")
conn.create_user(UserName="my-user")
response = conn.list_access_keys(UserName="my-user")
assert_equals(
response["AccessKeyMetadata"], [],
)
access_key = conn.create_access_key(UserName="my-user")["AccessKey"]
response = conn.list_access_keys(UserName="my-user")
assert_equals(
sorted(response["AccessKeyMetadata"][0].keys()),
sorted(["Status", "CreateDate", "UserName", "AccessKeyId"]),
)
conn = boto3.client(
"iam",
region_name="us-east-1",
aws_access_key_id=access_key["AccessKeyId"],
aws_secret_access_key=access_key["SecretAccessKey"],
)
response = conn.list_access_keys()
assert_equals(
sorted(response["AccessKeyMetadata"][0].keys()),
sorted(["Status", "CreateDate", "UserName", "AccessKeyId"]),
)
@mock_iam_deprecated() @mock_iam_deprecated()
def test_delete_access_key(): def test_delete_access_key_deprecated():
conn = boto.connect_iam() conn = boto.connect_iam()
conn.create_user("my-user") conn.create_user("my-user")
access_key_id = conn.create_access_key("my-user")["create_access_key_response"][ access_key_id = conn.create_access_key("my-user")["create_access_key_response"][
@ -835,6 +875,16 @@ def test_delete_access_key():
conn.delete_access_key(access_key_id, "my-user") conn.delete_access_key(access_key_id, "my-user")
@mock_iam
def test_delete_access_key():
conn = boto3.client("iam", region_name="us-east-1")
conn.create_user(UserName="my-user")
key = conn.create_access_key(UserName="my-user")["AccessKey"]
conn.delete_access_key(AccessKeyId=key["AccessKeyId"], UserName="my-user")
key = conn.create_access_key(UserName="my-user")["AccessKey"]
conn.delete_access_key(AccessKeyId=key["AccessKeyId"])
@mock_iam() @mock_iam()
def test_mfa_devices(): def test_mfa_devices():
# Test enable device # Test enable device
@ -1326,6 +1376,9 @@ def test_update_access_key():
) )
resp = client.list_access_keys(UserName=username) resp = client.list_access_keys(UserName=username)
resp["AccessKeyMetadata"][0]["Status"].should.equal("Inactive") resp["AccessKeyMetadata"][0]["Status"].should.equal("Inactive")
client.update_access_key(AccessKeyId=key["AccessKeyId"], Status="Active")
resp = client.list_access_keys(UserName=username)
resp["AccessKeyMetadata"][0]["Status"].should.equal("Active")
@mock_iam @mock_iam

View File

@ -132,6 +132,35 @@ def test_create_queue_with_tags():
) )
@mock_sqs
def test_create_queue_with_policy():
client = boto3.client("sqs", region_name="us-east-1")
response = client.create_queue(
QueueName="test-queue",
Attributes={
"Policy": json.dumps(
{
"Version": "2012-10-17",
"Id": "test",
"Statement": [{"Effect": "Allow", "Principal": "*", "Action": "*"}],
}
)
},
)
queue_url = response["QueueUrl"]
response = client.get_queue_attributes(
QueueUrl=queue_url, AttributeNames=["Policy"]
)
json.loads(response["Attributes"]["Policy"]).should.equal(
{
"Version": "2012-10-17",
"Id": "test",
"Statement": [{"Effect": "Allow", "Principal": "*", "Action": "*"}],
}
)
@mock_sqs @mock_sqs
def test_get_queue_url(): def test_get_queue_url():
client = boto3.client("sqs", region_name="us-east-1") client = boto3.client("sqs", region_name="us-east-1")
@ -1186,18 +1215,169 @@ def test_permissions():
Actions=["SendMessage"], Actions=["SendMessage"],
) )
with assert_raises(ClientError): response = client.get_queue_attributes(
client.add_permission( QueueUrl=queue_url, AttributeNames=["Policy"]
QueueUrl=queue_url, )
Label="account2", policy = json.loads(response["Attributes"]["Policy"])
AWSAccountIds=["222211111111"], policy["Version"].should.equal("2012-10-17")
Actions=["SomeRubbish"], policy["Id"].should.equal(
"arn:aws:sqs:us-east-1:123456789012:test-dlr-queue.fifo/SQSDefaultPolicy"
)
sorted(policy["Statement"], key=lambda x: x["Sid"]).should.equal(
[
{
"Sid": "account1",
"Effect": "Allow",
"Principal": {"AWS": "arn:aws:iam::111111111111:root"},
"Action": "SQS:*",
"Resource": "arn:aws:sqs:us-east-1:123456789012:test-dlr-queue.fifo",
},
{
"Sid": "account2",
"Effect": "Allow",
"Principal": {"AWS": "arn:aws:iam::222211111111:root"},
"Action": "SQS:SendMessage",
"Resource": "arn:aws:sqs:us-east-1:123456789012:test-dlr-queue.fifo",
},
]
) )
client.remove_permission(QueueUrl=queue_url, Label="account2") client.remove_permission(QueueUrl=queue_url, Label="account2")
with assert_raises(ClientError): response = client.get_queue_attributes(
client.remove_permission(QueueUrl=queue_url, Label="non_existent") QueueUrl=queue_url, AttributeNames=["Policy"]
)
json.loads(response["Attributes"]["Policy"]).should.equal(
{
"Version": "2012-10-17",
"Id": "arn:aws:sqs:us-east-1:123456789012:test-dlr-queue.fifo/SQSDefaultPolicy",
"Statement": [
{
"Sid": "account1",
"Effect": "Allow",
"Principal": {"AWS": "arn:aws:iam::111111111111:root"},
"Action": "SQS:*",
"Resource": "arn:aws:sqs:us-east-1:123456789012:test-dlr-queue.fifo",
},
],
}
)
@mock_sqs
def test_add_permission_errors():
client = boto3.client("sqs", region_name="us-east-1")
response = client.create_queue(QueueName="test-queue")
queue_url = response["QueueUrl"]
client.add_permission(
QueueUrl=queue_url,
Label="test",
AWSAccountIds=["111111111111"],
Actions=["ReceiveMessage"],
)
with assert_raises(ClientError) as e:
client.add_permission(
QueueUrl=queue_url,
Label="test",
AWSAccountIds=["111111111111"],
Actions=["ReceiveMessage", "SendMessage"],
)
ex = e.exception
ex.operation_name.should.equal("AddPermission")
ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400)
ex.response["Error"]["Code"].should.contain("InvalidParameterValue")
ex.response["Error"]["Message"].should.equal(
"Value test for parameter Label is invalid. " "Reason: Already exists."
)
with assert_raises(ClientError) as e:
client.add_permission(
QueueUrl=queue_url,
Label="test-2",
AWSAccountIds=["111111111111"],
Actions=["RemovePermission"],
)
ex = e.exception
ex.operation_name.should.equal("AddPermission")
ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400)
ex.response["Error"]["Code"].should.contain("InvalidParameterValue")
ex.response["Error"]["Message"].should.equal(
"Value SQS:RemovePermission for parameter ActionName is invalid. "
"Reason: Only the queue owner is allowed to invoke this action."
)
with assert_raises(ClientError) as e:
client.add_permission(
QueueUrl=queue_url,
Label="test-2",
AWSAccountIds=["111111111111"],
Actions=[],
)
ex = e.exception
ex.operation_name.should.equal("AddPermission")
ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400)
ex.response["Error"]["Code"].should.contain("MissingParameter")
ex.response["Error"]["Message"].should.equal(
"The request must contain the parameter Actions."
)
with assert_raises(ClientError) as e:
client.add_permission(
QueueUrl=queue_url,
Label="test-2",
AWSAccountIds=[],
Actions=["ReceiveMessage"],
)
ex = e.exception
ex.operation_name.should.equal("AddPermission")
ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400)
ex.response["Error"]["Code"].should.contain("InvalidParameterValue")
ex.response["Error"]["Message"].should.equal(
"Value [] for parameter PrincipalId is invalid. Reason: Unable to verify."
)
with assert_raises(ClientError) as e:
client.add_permission(
QueueUrl=queue_url,
Label="test-2",
AWSAccountIds=["111111111111"],
Actions=[
"ChangeMessageVisibility",
"DeleteMessage",
"GetQueueAttributes",
"GetQueueUrl",
"ListDeadLetterSourceQueues",
"PurgeQueue",
"ReceiveMessage",
"SendMessage",
],
)
ex = e.exception
ex.operation_name.should.equal("AddPermission")
ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(403)
ex.response["Error"]["Code"].should.contain("OverLimit")
ex.response["Error"]["Message"].should.equal(
"8 Actions were found, maximum allowed is 7."
)
@mock_sqs
def test_remove_permission_errors():
client = boto3.client("sqs", region_name="us-east-1")
response = client.create_queue(QueueName="test-queue")
queue_url = response["QueueUrl"]
with assert_raises(ClientError) as e:
client.remove_permission(QueueUrl=queue_url, Label="test")
ex = e.exception
ex.operation_name.should.equal("RemovePermission")
ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400)
ex.response["Error"]["Code"].should.contain("InvalidParameterValue")
ex.response["Error"]["Message"].should.equal(
"Value test for parameter Label is invalid. "
"Reason: can't find label on existing policy."
)
@mock_sqs @mock_sqs