merging from master

This commit is contained in:
Bryan Alexander 2020-02-18 10:47:05 -06:00
commit 445f474534
125 changed files with 7407 additions and 3848 deletions

2
.gitignore vendored
View File

@ -20,3 +20,5 @@ env/
.vscode/ .vscode/
tests/file.tmp tests/file.tmp
.eggs/ .eggs/
.mypy_cache/
*.tmp

View File

@ -26,11 +26,12 @@ install:
fi fi
docker run --rm -t --name motoserver -e TEST_SERVER_MODE=true -e AWS_SECRET_ACCESS_KEY=server_secret -e AWS_ACCESS_KEY_ID=server_key -v `pwd`:/moto -p 5000:5000 -v /var/run/docker.sock:/var/run/docker.sock python:${PYTHON_DOCKER_TAG} /moto/travis_moto_server.sh & docker run --rm -t --name motoserver -e TEST_SERVER_MODE=true -e AWS_SECRET_ACCESS_KEY=server_secret -e AWS_ACCESS_KEY_ID=server_key -v `pwd`:/moto -p 5000:5000 -v /var/run/docker.sock:/var/run/docker.sock python:${PYTHON_DOCKER_TAG} /moto/travis_moto_server.sh &
fi fi
travis_retry pip install -r requirements-dev.txt
travis_retry pip install boto==2.45.0 travis_retry pip install boto==2.45.0
travis_retry pip install boto3 travis_retry pip install boto3
travis_retry pip install dist/moto*.gz travis_retry pip install dist/moto*.gz
travis_retry pip install coveralls==1.1 travis_retry pip install coveralls==1.1
travis_retry pip install -r requirements-dev.txt travis_retry pip install coverage==4.5.4
if [ "$TEST_SERVER_MODE" = "true" ]; then if [ "$TEST_SERVER_MODE" = "true" ]; then
python wait_for.py python wait_for.py

View File

@ -381,7 +381,7 @@ def test_something(s3):
from some.package.that.does.something.with.s3 import some_func # <-- Local import for unit test from some.package.that.does.something.with.s3 import some_func # <-- Local import for unit test
# ^^ Importing here ensures that the mock has been established. # ^^ Importing here ensures that the mock has been established.
sume_func() # The mock has been established from the "s3" pytest fixture, so this function that uses some_func() # The mock has been established from the "s3" pytest fixture, so this function that uses
# a package-level S3 client will properly use the mock and not reach out to AWS. # a package-level S3 client will properly use the mock and not reach out to AWS.
``` ```
@ -450,6 +450,16 @@ boto3.resource(
) )
``` ```
### Caveats
The standalone server has some caveats with some services. The following services
require that you update your hosts file for your code to work properly:
1. `s3-control`
For the above services, this is required because the hostname is in the form of `AWS_ACCOUNT_ID.localhost`.
As a result, you need to add that entry to your host file for your tests to function properly.
## Install ## Install

View File

@ -56,9 +56,10 @@ author = 'Steve Pulec'
# built documents. # built documents.
# #
# The short X.Y version. # The short X.Y version.
version = '0.4.10' import moto
version = moto.__version__
# The full version, including alpha/beta/rc tags. # The full version, including alpha/beta/rc tags.
release = '0.4.10' release = moto.__version__
# The language for content autogenerated by Sphinx. Refer to documentation # The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages. # for a list of supported languages.

View File

@ -24,8 +24,7 @@ For example, we have the following code we want to test:
.. sourcecode:: python .. sourcecode:: python
import boto import boto3
from boto.s3.key import Key
class MyModel(object): class MyModel(object):
def __init__(self, name, value): def __init__(self, name, value):
@ -33,11 +32,8 @@ For example, we have the following code we want to test:
self.value = value self.value = value
def save(self): def save(self):
conn = boto.connect_s3() s3 = boto3.client('s3', region_name='us-east-1')
bucket = conn.get_bucket('mybucket') s3.put_object(Bucket='mybucket', Key=self.name, Body=self.value)
k = Key(bucket)
k.key = self.name
k.set_contents_from_string(self.value)
There are several ways to do this, but you should keep in mind that Moto creates a full, blank environment. There are several ways to do this, but you should keep in mind that Moto creates a full, blank environment.
@ -48,20 +44,23 @@ With a decorator wrapping, all the calls to S3 are automatically mocked out.
.. sourcecode:: python .. sourcecode:: python
import boto import boto3
from moto import mock_s3 from moto import mock_s3
from mymodule import MyModel from mymodule import MyModel
@mock_s3 @mock_s3
def test_my_model_save(): def test_my_model_save():
conn = boto.connect_s3() conn = boto3.resource('s3', region_name='us-east-1')
# We need to create the bucket since this is all in Moto's 'virtual' AWS account # We need to create the bucket since this is all in Moto's 'virtual' AWS account
conn.create_bucket('mybucket') conn.create_bucket(Bucket='mybucket')
model_instance = MyModel('steve', 'is awesome') model_instance = MyModel('steve', 'is awesome')
model_instance.save() model_instance.save()
assert conn.get_bucket('mybucket').get_key('steve').get_contents_as_string() == 'is awesome' body = conn.Object('mybucket', 'steve').get()[
'Body'].read().decode("utf-8")
assert body == 'is awesome'
Context manager Context manager
~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~
@ -72,13 +71,16 @@ Same as the Decorator, every call inside the ``with`` statement is mocked out.
def test_my_model_save(): def test_my_model_save():
with mock_s3(): with mock_s3():
conn = boto.connect_s3() conn = boto3.resource('s3', region_name='us-east-1')
conn.create_bucket('mybucket') conn.create_bucket(Bucket='mybucket')
model_instance = MyModel('steve', 'is awesome') model_instance = MyModel('steve', 'is awesome')
model_instance.save() model_instance.save()
assert conn.get_bucket('mybucket').get_key('steve').get_contents_as_string() == 'is awesome' body = conn.Object('mybucket', 'steve').get()[
'Body'].read().decode("utf-8")
assert body == 'is awesome'
Raw Raw
~~~ ~~~
@ -91,13 +93,16 @@ You can also start and stop the mocking manually.
mock = mock_s3() mock = mock_s3()
mock.start() mock.start()
conn = boto.connect_s3() conn = boto3.resource('s3', region_name='us-east-1')
conn.create_bucket('mybucket') conn.create_bucket(Bucket='mybucket')
model_instance = MyModel('steve', 'is awesome') model_instance = MyModel('steve', 'is awesome')
model_instance.save() model_instance.save()
assert conn.get_bucket('mybucket').get_key('steve').get_contents_as_string() == 'is awesome' body = conn.Object('mybucket', 'steve').get()[
'Body'].read().decode("utf-8")
assert body == 'is awesome'
mock.stop() mock.stop()

View File

@ -76,7 +76,7 @@ Currently implemented Services:
+---------------------------+-----------------------+------------------------------------+ +---------------------------+-----------------------+------------------------------------+
| Logs | @mock_logs | basic endpoints done | | Logs | @mock_logs | basic endpoints done |
+---------------------------+-----------------------+------------------------------------+ +---------------------------+-----------------------+------------------------------------+
| Organizations | @mock_organizations | some core edpoints done | | Organizations | @mock_organizations | some core endpoints done |
+---------------------------+-----------------------+------------------------------------+ +---------------------------+-----------------------+------------------------------------+
| Polly | @mock_polly | all endpoints done | | Polly | @mock_polly | all endpoints done |
+---------------------------+-----------------------+------------------------------------+ +---------------------------+-----------------------+------------------------------------+

View File

@ -39,7 +39,7 @@ class InvalidResourcePathException(BadRequestException):
def __init__(self): def __init__(self):
super(InvalidResourcePathException, self).__init__( super(InvalidResourcePathException, self).__init__(
"BadRequestException", "BadRequestException",
"Resource's path part only allow a-zA-Z0-9._- and curly braces at the beginning and the end.", "Resource's path part only allow a-zA-Z0-9._- and curly braces at the beginning and the end and an optional plus sign before the closing brace.",
) )

View File

@ -83,14 +83,14 @@ class MethodResponse(BaseModel, dict):
class Method(BaseModel, dict): class Method(BaseModel, dict):
def __init__(self, method_type, authorization_type): def __init__(self, method_type, authorization_type, **kwargs):
super(Method, self).__init__() super(Method, self).__init__()
self.update( self.update(
dict( dict(
httpMethod=method_type, httpMethod=method_type,
authorizationType=authorization_type, authorizationType=authorization_type,
authorizerId=None, authorizerId=None,
apiKeyRequired=None, apiKeyRequired=kwargs.get("api_key_required") or False,
requestParameters=None, requestParameters=None,
requestModels=None, requestModels=None,
methodIntegration=None, methodIntegration=None,
@ -117,14 +117,15 @@ class Resource(BaseModel):
self.api_id = api_id self.api_id = api_id
self.path_part = path_part self.path_part = path_part
self.parent_id = parent_id self.parent_id = parent_id
self.resource_methods = {"GET": {}} self.resource_methods = {}
def to_dict(self): def to_dict(self):
response = { response = {
"path": self.get_path(), "path": self.get_path(),
"id": self.id, "id": self.id,
"resourceMethods": self.resource_methods,
} }
if self.resource_methods:
response["resourceMethods"] = self.resource_methods
if self.parent_id: if self.parent_id:
response["parentId"] = self.parent_id response["parentId"] = self.parent_id
response["pathPart"] = self.path_part response["pathPart"] = self.path_part
@ -158,8 +159,12 @@ class Resource(BaseModel):
) )
return response.status_code, response.text return response.status_code, response.text
def add_method(self, method_type, authorization_type): def add_method(self, method_type, authorization_type, api_key_required):
method = Method(method_type=method_type, authorization_type=authorization_type) method = Method(
method_type=method_type,
authorization_type=authorization_type,
api_key_required=api_key_required,
)
self.resource_methods[method_type] = method self.resource_methods[method_type] = method
return method return method
@ -394,12 +399,17 @@ class UsagePlanKey(BaseModel, dict):
class RestAPI(BaseModel): class RestAPI(BaseModel):
def __init__(self, id, region_name, name, description): def __init__(self, id, region_name, name, description, **kwargs):
self.id = id self.id = id
self.region_name = region_name self.region_name = region_name
self.name = name self.name = name
self.description = description self.description = description
self.create_date = int(time.time()) self.create_date = int(time.time())
self.api_key_source = kwargs.get("api_key_source") or "HEADER"
self.endpoint_configuration = kwargs.get("endpoint_configuration") or {
"types": ["EDGE"]
}
self.tags = kwargs.get("tags") or {}
self.deployments = {} self.deployments = {}
self.stages = {} self.stages = {}
@ -416,6 +426,9 @@ class RestAPI(BaseModel):
"name": self.name, "name": self.name,
"description": self.description, "description": self.description,
"createdDate": int(time.time()), "createdDate": int(time.time()),
"apiKeySource": self.api_key_source,
"endpointConfiguration": self.endpoint_configuration,
"tags": self.tags,
} }
def add_child(self, path, parent_id=None): def add_child(self, path, parent_id=None):
@ -529,9 +542,24 @@ class APIGatewayBackend(BaseBackend):
self.__dict__ = {} self.__dict__ = {}
self.__init__(region_name) self.__init__(region_name)
def create_rest_api(self, name, description): def create_rest_api(
self,
name,
description,
api_key_source=None,
endpoint_configuration=None,
tags=None,
):
api_id = create_id() api_id = create_id()
rest_api = RestAPI(api_id, self.region_name, name, description) rest_api = RestAPI(
api_id,
self.region_name,
name,
description,
api_key_source=api_key_source,
endpoint_configuration=endpoint_configuration,
tags=tags,
)
self.apis[api_id] = rest_api self.apis[api_id] = rest_api
return rest_api return rest_api
@ -556,7 +584,7 @@ class APIGatewayBackend(BaseBackend):
return resource return resource
def create_resource(self, function_id, parent_resource_id, path_part): def create_resource(self, function_id, parent_resource_id, path_part):
if not re.match("^\\{?[a-zA-Z0-9._-]+\\}?$", path_part): if not re.match("^\\{?[a-zA-Z0-9._-]+\\+?\\}?$", path_part):
raise InvalidResourcePathException() raise InvalidResourcePathException()
api = self.get_rest_api(function_id) api = self.get_rest_api(function_id)
child = api.add_child(path=path_part, parent_id=parent_resource_id) child = api.add_child(path=path_part, parent_id=parent_resource_id)
@ -571,9 +599,18 @@ class APIGatewayBackend(BaseBackend):
resource = self.get_resource(function_id, resource_id) resource = self.get_resource(function_id, resource_id)
return resource.get_method(method_type) return resource.get_method(method_type)
def create_method(self, function_id, resource_id, method_type, authorization_type): def create_method(
self,
function_id,
resource_id,
method_type,
authorization_type,
api_key_required=None,
):
resource = self.get_resource(function_id, resource_id) resource = self.get_resource(function_id, resource_id)
method = resource.add_method(method_type, authorization_type) method = resource.add_method(
method_type, authorization_type, api_key_required=api_key_required
)
return method return method
def get_stage(self, function_id, stage_name): def get_stage(self, function_id, stage_name):

View File

@ -12,6 +12,9 @@ from .exceptions import (
ApiKeyAlreadyExists, ApiKeyAlreadyExists,
) )
API_KEY_SOURCES = ["AUTHORIZER", "HEADER"]
ENDPOINT_CONFIGURATION_TYPES = ["PRIVATE", "EDGE", "REGIONAL"]
class APIGatewayResponse(BaseResponse): class APIGatewayResponse(BaseResponse):
def error(self, type_, message, status=400): def error(self, type_, message, status=400):
@ -45,7 +48,45 @@ class APIGatewayResponse(BaseResponse):
elif self.method == "POST": elif self.method == "POST":
name = self._get_param("name") name = self._get_param("name")
description = self._get_param("description") description = self._get_param("description")
rest_api = self.backend.create_rest_api(name, description) api_key_source = self._get_param("apiKeySource")
endpoint_configuration = self._get_param("endpointConfiguration")
tags = self._get_param("tags")
# Param validation
if api_key_source and api_key_source not in API_KEY_SOURCES:
return self.error(
"ValidationException",
(
"1 validation error detected: "
"Value '{api_key_source}' at 'createRestApiInput.apiKeySource' failed "
"to satisfy constraint: Member must satisfy enum value set: "
"[AUTHORIZER, HEADER]"
).format(api_key_source=api_key_source),
)
if endpoint_configuration and "types" in endpoint_configuration:
invalid_types = list(
set(endpoint_configuration["types"])
- set(ENDPOINT_CONFIGURATION_TYPES)
)
if invalid_types:
return self.error(
"ValidationException",
(
"1 validation error detected: Value '{endpoint_type}' "
"at 'createRestApiInput.endpointConfiguration.types' failed "
"to satisfy constraint: Member must satisfy enum value set: "
"[PRIVATE, EDGE, REGIONAL]"
).format(endpoint_type=invalid_types[0]),
)
rest_api = self.backend.create_rest_api(
name,
description,
api_key_source=api_key_source,
endpoint_configuration=endpoint_configuration,
tags=tags,
)
return 200, {}, json.dumps(rest_api.to_dict()) return 200, {}, json.dumps(rest_api.to_dict())
def restapis_individual(self, request, full_url, headers): def restapis_individual(self, request, full_url, headers):
@ -104,8 +145,13 @@ class APIGatewayResponse(BaseResponse):
return 200, {}, json.dumps(method) return 200, {}, json.dumps(method)
elif self.method == "PUT": elif self.method == "PUT":
authorization_type = self._get_param("authorizationType") authorization_type = self._get_param("authorizationType")
api_key_required = self._get_param("apiKeyRequired")
method = self.backend.create_method( method = self.backend.create_method(
function_id, resource_id, method_type, authorization_type function_id,
resource_id,
method_type,
authorization_type,
api_key_required,
) )
return 200, {}, json.dumps(method) return 200, {}, json.dumps(method)

View File

@ -1,4 +1,5 @@
from botocore.client import ClientError from botocore.client import ClientError
from moto.core.exceptions import JsonRESTError
class LambdaClientError(ClientError): class LambdaClientError(ClientError):
@ -29,3 +30,12 @@ class InvalidRoleFormat(LambdaClientError):
role, InvalidRoleFormat.pattern role, InvalidRoleFormat.pattern
) )
super(InvalidRoleFormat, self).__init__("ValidationException", message) super(InvalidRoleFormat, self).__init__("ValidationException", message)
class PreconditionFailedException(JsonRESTError):
code = 412
def __init__(self, message):
super(PreconditionFailedException, self).__init__(
"PreconditionFailedException", message
)

View File

@ -25,6 +25,7 @@ import requests.adapters
from boto3 import Session from boto3 import Session
from moto.awslambda.policy import Policy
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.exceptions import RESTError from moto.core.exceptions import RESTError
from moto.iam.models import iam_backend from moto.iam.models import iam_backend
@ -47,15 +48,11 @@ from moto.core import ACCOUNT_ID
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try: try:
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
except ImportError: except ImportError:
from backports.tempfile import TemporaryDirectory from backports.tempfile import TemporaryDirectory
# The lambci container is returning a special escape character for the "RequestID" fields. Unicode 033:
# _stderr_regex = re.compile(r"START|END|REPORT RequestId: .*")
_stderr_regex = re.compile(r"\033\[\d+.*")
_orig_adapter_send = requests.adapters.HTTPAdapter.send _orig_adapter_send = requests.adapters.HTTPAdapter.send
docker_3 = docker.__version__[0] >= "3" docker_3 = docker.__version__[0] >= "3"
@ -164,7 +161,8 @@ class LambdaFunction(BaseModel):
self.logs_backend = logs_backends[self.region] self.logs_backend = logs_backends[self.region]
self.environment_vars = spec.get("Environment", {}).get("Variables", {}) self.environment_vars = spec.get("Environment", {}).get("Variables", {})
self.docker_client = docker.from_env() self.docker_client = docker.from_env()
self.policy = "" self.policy = None
self.state = "Active"
# Unfortunately mocking replaces this method w/o fallback enabled, so we # Unfortunately mocking replaces this method w/o fallback enabled, so we
# need to replace it if we detect it's been mocked # need to replace it if we detect it's been mocked
@ -274,11 +272,11 @@ class LambdaFunction(BaseModel):
"MemorySize": self.memory_size, "MemorySize": self.memory_size,
"Role": self.role, "Role": self.role,
"Runtime": self.run_time, "Runtime": self.run_time,
"State": self.state,
"Timeout": self.timeout, "Timeout": self.timeout,
"Version": str(self.version), "Version": str(self.version),
"VpcConfig": self.vpc_config, "VpcConfig": self.vpc_config,
} }
if self.environment_vars: if self.environment_vars:
config["Environment"] = {"Variables": self.environment_vars} config["Environment"] = {"Variables": self.environment_vars}
@ -385,7 +383,7 @@ class LambdaFunction(BaseModel):
try: try:
# TODO: I believe we can keep the container running and feed events as needed # TODO: I believe we can keep the container running and feed events as needed
# also need to hook it up to the other services so it can make kws/s3 etc calls # also need to hook it up to the other services so it can make kws/s3 etc calls
# Should get invoke_id /RequestId from invovation # Should get invoke_id /RequestId from invocation
env_vars = { env_vars = {
"AWS_LAMBDA_FUNCTION_TIMEOUT": self.timeout, "AWS_LAMBDA_FUNCTION_TIMEOUT": self.timeout,
"AWS_LAMBDA_FUNCTION_NAME": self.function_name, "AWS_LAMBDA_FUNCTION_NAME": self.function_name,
@ -397,6 +395,7 @@ class LambdaFunction(BaseModel):
env_vars.update(self.environment_vars) env_vars.update(self.environment_vars)
container = output = exit_code = None container = output = exit_code = None
log_config = docker.types.LogConfig(type=docker.types.LogConfig.types.JSON)
with _DockerDataVolumeContext(self) as data_vol: with _DockerDataVolumeContext(self) as data_vol:
try: try:
run_kwargs = ( run_kwargs = (
@ -412,6 +411,7 @@ class LambdaFunction(BaseModel):
volumes=["{}:/var/task".format(data_vol.name)], volumes=["{}:/var/task".format(data_vol.name)],
environment=env_vars, environment=env_vars,
detach=True, detach=True,
log_config=log_config,
**run_kwargs **run_kwargs
) )
finally: finally:
@ -453,14 +453,9 @@ class LambdaFunction(BaseModel):
if exit_code != 0: if exit_code != 0:
raise Exception("lambda invoke failed output: {}".format(output)) raise Exception("lambda invoke failed output: {}".format(output))
# strip out RequestId lines (TODO: This will return an additional '\n' in the response) # We only care about the response from the lambda
output = os.linesep.join( # Which is the last line of the output, according to https://github.com/lambci/docker-lambda/issues/25
[ output = output.splitlines()[-1]
line
for line in self.convert(output).splitlines()
if not _stderr_regex.match(line)
]
)
return output, False return output, False
except BaseException as e: except BaseException as e:
traceback.print_exc() traceback.print_exc()
@ -480,7 +475,7 @@ class LambdaFunction(BaseModel):
payload["result"] = response_headers["x-amz-log-result"] payload["result"] = response_headers["x-amz-log-result"]
result = res.encode("utf-8") result = res.encode("utf-8")
else: else:
result = json.dumps(payload) result = res
if errored: if errored:
response_headers["x-amz-function-error"] = "Handled" response_headers["x-amz-function-error"] = "Handled"
@ -709,7 +704,8 @@ class LambdaStorage(object):
"versions": [], "versions": [],
"alias": weakref.WeakValueDictionary(), "alias": weakref.WeakValueDictionary(),
} }
# instantiate a new policy for this version of the lambda
fn.policy = Policy(fn)
self._arns[fn.function_arn] = fn self._arns[fn.function_arn] = fn
def publish_function(self, name): def publish_function(self, name):
@ -1010,8 +1006,21 @@ class LambdaBackend(BaseBackend):
return True return True
return False return False
def add_policy(self, function_name, policy): def add_policy_statement(self, function_name, raw):
self.get_function(function_name).policy = policy fn = self.get_function(function_name)
fn.policy.add_statement(raw)
def del_policy_statement(self, function_name, sid, revision=""):
fn = self.get_function(function_name)
fn.policy.del_statement(sid, revision)
def get_policy(self, function_name):
fn = self.get_function(function_name)
return fn.policy.get_policy()
def get_policy_wire_format(self, function_name):
fn = self.get_function(function_name)
return fn.policy.wire_format()
def update_function_code(self, function_name, qualifier, body): def update_function_code(self, function_name, qualifier, body):
fn = self.get_function(function_name, qualifier) fn = self.get_function(function_name, qualifier)

134
moto/awslambda/policy.py Normal file
View File

@ -0,0 +1,134 @@
from __future__ import unicode_literals
import json
import uuid
from six import string_types
from moto.awslambda.exceptions import PreconditionFailedException
class Policy:
def __init__(self, parent):
self.revision = str(uuid.uuid4())
self.statements = []
self.parent = parent
def wire_format(self):
p = self.get_policy()
p["Policy"] = json.dumps(p["Policy"])
return json.dumps(p)
def get_policy(self):
return {
"Policy": {
"Version": "2012-10-17",
"Id": "default",
"Statement": self.statements,
},
"RevisionId": self.revision,
}
# adds the raw JSON statement to the policy
def add_statement(self, raw):
policy = json.loads(raw, object_hook=self.decode_policy)
if len(policy.revision) > 0 and self.revision != policy.revision:
raise PreconditionFailedException(
"The RevisionId provided does not match the latest RevisionId"
" for the Lambda function or alias. Call the GetFunction or the GetAlias API to retrieve"
" the latest RevisionId for your resource."
)
self.statements.append(policy.statements[0])
self.revision = str(uuid.uuid4())
# removes the statement that matches 'sid' from the policy
def del_statement(self, sid, revision=""):
if len(revision) > 0 and self.revision != revision:
raise PreconditionFailedException(
"The RevisionId provided does not match the latest RevisionId"
" for the Lambda function or alias. Call the GetFunction or the GetAlias API to retrieve"
" the latest RevisionId for your resource."
)
for statement in self.statements:
if "Sid" in statement and statement["Sid"] == sid:
self.statements.remove(statement)
# converts AddPermission request to PolicyStatement
# https://docs.aws.amazon.com/lambda/latest/dg/API_AddPermission.html
def decode_policy(self, obj):
# import pydevd
# pydevd.settrace("localhost", port=5678)
policy = Policy(self.parent)
policy.revision = obj.get("RevisionId", "")
# set some default values if these keys are not set
self.ensure_set(obj, "Effect", "Allow")
self.ensure_set(obj, "Resource", self.parent.function_arn + ":$LATEST")
self.ensure_set(obj, "StatementId", str(uuid.uuid4()))
# transform field names and values
self.transform_property(obj, "StatementId", "Sid", self.nop_formatter)
self.transform_property(obj, "Principal", "Principal", self.principal_formatter)
self.transform_property(
obj, "SourceArn", "SourceArn", self.source_arn_formatter
)
self.transform_property(
obj, "SourceAccount", "SourceAccount", self.source_account_formatter
)
# remove RevisionId and EventSourceToken if they are set
self.remove_if_set(obj, ["RevisionId", "EventSourceToken"])
# merge conditional statements into a single map under the Condition key
self.condition_merge(obj)
# append resulting statement to policy.statements
policy.statements.append(obj)
return policy
def nop_formatter(self, obj):
return obj
def ensure_set(self, obj, key, value):
if key not in obj:
obj[key] = value
def principal_formatter(self, obj):
if isinstance(obj, string_types):
if obj.endswith(".amazonaws.com"):
return {"Service": obj}
if obj.endswith(":root"):
return {"AWS": obj}
return obj
def source_account_formatter(self, obj):
return {"StringEquals": {"AWS:SourceAccount": obj}}
def source_arn_formatter(self, obj):
return {"ArnLike": {"AWS:SourceArn": obj}}
def transform_property(self, obj, old_name, new_name, formatter):
if old_name in obj:
obj[new_name] = formatter(obj[old_name])
if new_name != old_name:
del obj[old_name]
def remove_if_set(self, obj, keys):
for key in keys:
if key in obj:
del obj[key]
def condition_merge(self, obj):
if "SourceArn" in obj:
if "Condition" not in obj:
obj["Condition"] = {}
obj["Condition"].update(obj["SourceArn"])
del obj["SourceArn"]
if "SourceAccount" in obj:
if "Condition" not in obj:
obj["Condition"] = {}
obj["Condition"].update(obj["SourceAccount"])
del obj["SourceAccount"]

View File

@ -120,8 +120,12 @@ class LambdaResponse(BaseResponse):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "GET": if request.method == "GET":
return self._get_policy(request, full_url, headers) return self._get_policy(request, full_url, headers)
if request.method == "POST": elif request.method == "POST":
return self._add_policy(request, full_url, headers) return self._add_policy(request, full_url, headers)
elif request.method == "DELETE":
return self._del_policy(request, full_url, headers, self.querystring)
else:
raise ValueError("Cannot handle {0} request".format(request.method))
def configuration(self, request, full_url, headers): def configuration(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -141,9 +145,9 @@ class LambdaResponse(BaseResponse):
path = request.path if hasattr(request, "path") else path_url(request.url) path = request.path if hasattr(request, "path") else path_url(request.url)
function_name = path.split("/")[-2] function_name = path.split("/")[-2]
if self.lambda_backend.get_function(function_name): if self.lambda_backend.get_function(function_name):
policy = self.body statement = self.body
self.lambda_backend.add_policy(function_name, policy) self.lambda_backend.add_policy_statement(function_name, statement)
return 200, {}, json.dumps(dict(Statement=policy)) return 200, {}, json.dumps({"Statement": statement})
else: else:
return 404, {}, "{}" return 404, {}, "{}"
@ -151,28 +155,42 @@ class LambdaResponse(BaseResponse):
path = request.path if hasattr(request, "path") else path_url(request.url) path = request.path if hasattr(request, "path") else path_url(request.url)
function_name = path.split("/")[-2] function_name = path.split("/")[-2]
if self.lambda_backend.get_function(function_name): if self.lambda_backend.get_function(function_name):
lambda_function = self.lambda_backend.get_function(function_name) out = self.lambda_backend.get_policy_wire_format(function_name)
return ( return 200, {}, out
200, else:
{}, return 404, {}, "{}"
json.dumps(
dict(Policy='{"Statement":[' + lambda_function.policy + "]}") def _del_policy(self, request, full_url, headers, querystring):
), path = request.path if hasattr(request, "path") else path_url(request.url)
function_name = path.split("/")[-3]
statement_id = path.split("/")[-1].split("?")[0]
revision = querystring.get("RevisionId", "")
if self.lambda_backend.get_function(function_name):
self.lambda_backend.del_policy_statement(
function_name, statement_id, revision
) )
return 204, {}, "{}"
else: else:
return 404, {}, "{}" return 404, {}, "{}"
def _invoke(self, request, full_url): def _invoke(self, request, full_url):
response_headers = {} response_headers = {}
function_name = self.path.rsplit("/", 2)[-2] # URL Decode in case it's a ARN:
function_name = unquote(self.path.rsplit("/", 2)[-2])
qualifier = self._get_param("qualifier") qualifier = self._get_param("qualifier")
response_header, payload = self.lambda_backend.invoke( response_header, payload = self.lambda_backend.invoke(
function_name, qualifier, self.body, self.headers, response_headers function_name, qualifier, self.body, self.headers, response_headers
) )
if payload: if payload:
return 202, response_headers, payload if request.headers["X-Amz-Invocation-Type"] == "Event":
status_code = 202
elif request.headers["X-Amz-Invocation-Type"] == "DryRun":
status_code = 204
else:
status_code = 200
return status_code, response_headers, payload
else: else:
return 404, response_headers, "{}" return 404, response_headers, "{}"
@ -283,7 +301,7 @@ class LambdaResponse(BaseResponse):
code["Configuration"]["FunctionArn"] += ":$LATEST" code["Configuration"]["FunctionArn"] += ":$LATEST"
return 200, {}, json.dumps(code) return 200, {}, json.dumps(code)
else: else:
return 404, {}, "{}" return 404, {"x-amzn-ErrorType": "ResourceNotFoundException"}, "{}"
def _get_aws_region(self, full_url): def _get_aws_region(self, full_url):
region = self.region_regex.search(full_url) region = self.region_regex.search(full_url)

View File

@ -6,14 +6,16 @@ url_bases = ["https?://lambda.(.+).amazonaws.com"]
response = LambdaResponse() response = LambdaResponse()
url_paths = { url_paths = {
"{0}/(?P<api_version>[^/]+)/functions/?$": response.root, r"{0}/(?P<api_version>[^/]+)/functions/?$": response.root,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_:%-]+)/?$": response.function, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_:%-]+)/?$": response.function,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/versions/?$": response.versions, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/versions/?$": response.versions,
r"{0}/(?P<api_version>[^/]+)/event-source-mappings/?$": response.event_source_mappings, r"{0}/(?P<api_version>[^/]+)/event-source-mappings/?$": response.event_source_mappings,
r"{0}/(?P<api_version>[^/]+)/event-source-mappings/(?P<UUID>[\w_-]+)/?$": response.event_source_mapping, r"{0}/(?P<api_version>[^/]+)/event-source-mappings/(?P<UUID>[\w_-]+)/?$": response.event_source_mapping,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invocations/?$": response.invoke, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invocations/?$": response.invoke,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<resource_arn>.+)/invocations/?$": response.invoke,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invoke-async/?$": response.invoke_async, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invoke-async/?$": response.invoke_async,
r"{0}/(?P<api_version>[^/]+)/tags/(?P<resource_arn>.+)": response.tag, r"{0}/(?P<api_version>[^/]+)/tags/(?P<resource_arn>.+)": response.tag,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/policy/(?P<statement_id>[\w_-]+)$": response.policy,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/policy/?$": response.policy, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/policy/?$": response.policy,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/configuration/?$": response.configuration, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/configuration/?$": response.configuration,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/code/?$": response.code, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/code/?$": response.code,

View File

@ -677,6 +677,8 @@ class CloudFormationBackend(BaseBackend):
def list_stack_resources(self, stack_name_or_id): def list_stack_resources(self, stack_name_or_id):
stack = self.get_stack(stack_name_or_id) stack = self.get_stack(stack_name_or_id)
if stack is None:
return None
return stack.stack_resources return stack.stack_resources
def delete_stack(self, name_or_stack_id): def delete_stack(self, name_or_stack_id):

View File

@ -229,6 +229,9 @@ class CloudFormationResponse(BaseResponse):
stack_name_or_id = self._get_param("StackName") stack_name_or_id = self._get_param("StackName")
resources = self.cloudformation_backend.list_stack_resources(stack_name_or_id) resources = self.cloudformation_backend.list_stack_resources(stack_name_or_id)
if resources is None:
raise ValidationError(stack_name_or_id)
template = self.response_template(LIST_STACKS_RESOURCES_RESPONSE) template = self.response_template(LIST_STACKS_RESOURCES_RESPONSE)
return template.render(resources=resources) return template.render(resources=resources)

View File

@ -14,6 +14,7 @@ from jose import jws
from moto.compat import OrderedDict from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID
from .exceptions import ( from .exceptions import (
GroupExistsException, GroupExistsException,
NotAuthorizedError, NotAuthorizedError,
@ -69,6 +70,9 @@ class CognitoIdpUserPool(BaseModel):
def __init__(self, region, name, extended_config): def __init__(self, region, name, extended_config):
self.region = region self.region = region
self.id = "{}_{}".format(self.region, str(uuid.uuid4().hex)) self.id = "{}_{}".format(self.region, str(uuid.uuid4().hex))
self.arn = "arn:aws:cognito-idp:{}:{}:userpool/{}".format(
self.region, DEFAULT_ACCOUNT_ID, self.id
)
self.name = name self.name = name
self.status = None self.status = None
self.extended_config = extended_config or {} self.extended_config = extended_config or {}
@ -91,6 +95,7 @@ class CognitoIdpUserPool(BaseModel):
def _base_json(self): def _base_json(self):
return { return {
"Id": self.id, "Id": self.id,
"Arn": self.arn,
"Name": self.name, "Name": self.name,
"Status": self.status, "Status": self.status,
"CreationDate": time.mktime(self.creation_date.timetuple()), "CreationDate": time.mktime(self.creation_date.timetuple()),
@ -108,7 +113,9 @@ class CognitoIdpUserPool(BaseModel):
return user_pool_json return user_pool_json
def create_jwt(self, client_id, username, expires_in=60 * 60, extra_data={}): def create_jwt(
self, client_id, username, token_use, expires_in=60 * 60, extra_data={}
):
now = int(time.time()) now = int(time.time())
payload = { payload = {
"iss": "https://cognito-idp.{}.amazonaws.com/{}".format( "iss": "https://cognito-idp.{}.amazonaws.com/{}".format(
@ -116,7 +123,7 @@ class CognitoIdpUserPool(BaseModel):
), ),
"sub": self.users[username].id, "sub": self.users[username].id,
"aud": client_id, "aud": client_id,
"token_use": "id", "token_use": token_use,
"auth_time": now, "auth_time": now,
"exp": now + expires_in, "exp": now + expires_in,
} }
@ -125,7 +132,10 @@ class CognitoIdpUserPool(BaseModel):
return jws.sign(payload, self.json_web_key, algorithm="RS256"), expires_in return jws.sign(payload, self.json_web_key, algorithm="RS256"), expires_in
def create_id_token(self, client_id, username): def create_id_token(self, client_id, username):
id_token, expires_in = self.create_jwt(client_id, username) extra_data = self.get_user_extra_data_by_client_id(client_id, username)
id_token, expires_in = self.create_jwt(
client_id, username, "id", extra_data=extra_data
)
self.id_tokens[id_token] = (client_id, username) self.id_tokens[id_token] = (client_id, username)
return id_token, expires_in return id_token, expires_in
@ -135,10 +145,7 @@ class CognitoIdpUserPool(BaseModel):
return refresh_token return refresh_token
def create_access_token(self, client_id, username): def create_access_token(self, client_id, username):
extra_data = self.get_user_extra_data_by_client_id(client_id, username) access_token, expires_in = self.create_jwt(client_id, username, "access")
access_token, expires_in = self.create_jwt(
client_id, username, extra_data=extra_data
)
self.access_tokens[access_token] = (client_id, username) self.access_tokens[access_token] = (client_id, username)
return access_token, expires_in return access_token, expires_in
@ -562,12 +569,17 @@ class CognitoIdpBackend(BaseBackend):
user.groups.discard(group) user.groups.discard(group)
# User # User
def admin_create_user(self, user_pool_id, username, temporary_password, attributes): def admin_create_user(
self, user_pool_id, username, message_action, temporary_password, attributes
):
user_pool = self.user_pools.get(user_pool_id) user_pool = self.user_pools.get(user_pool_id)
if not user_pool: if not user_pool:
raise ResourceNotFoundError(user_pool_id) raise ResourceNotFoundError(user_pool_id)
if username in user_pool.users: if message_action and message_action == "RESEND":
if username not in user_pool.users:
raise UserNotFoundError(username)
elif username in user_pool.users:
raise UsernameExistsException(username) raise UsernameExistsException(username)
user = CognitoIdpUser( user = CognitoIdpUser(

View File

@ -259,10 +259,12 @@ class CognitoIdpResponse(BaseResponse):
def admin_create_user(self): def admin_create_user(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
username = self._get_param("Username") username = self._get_param("Username")
message_action = self._get_param("MessageAction")
temporary_password = self._get_param("TemporaryPassword") temporary_password = self._get_param("TemporaryPassword")
user = cognitoidp_backends[self.region].admin_create_user( user = cognitoidp_backends[self.region].admin_create_user(
user_pool_id, user_pool_id,
username, username,
message_action,
temporary_password, temporary_password,
self._get_param("UserAttributes", []), self._get_param("UserAttributes", []),
) )
@ -279,9 +281,18 @@ class CognitoIdpResponse(BaseResponse):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
limit = self._get_param("Limit") limit = self._get_param("Limit")
token = self._get_param("PaginationToken") token = self._get_param("PaginationToken")
filt = self._get_param("Filter")
users, token = cognitoidp_backends[self.region].list_users( users, token = cognitoidp_backends[self.region].list_users(
user_pool_id, limit=limit, pagination_token=token user_pool_id, limit=limit, pagination_token=token
) )
if filt:
name, value = filt.replace('"', "").split("=")
users = [
user
for user in users
for attribute in user.attributes
if attribute["Name"] == name and attribute["Value"] == value
]
response = {"Users": [user.to_json(extended=True) for user in users]} response = {"Users": [user.to_json(extended=True) for user in users]}
if token: if token:
response["PaginationToken"] = str(token) response["PaginationToken"] = str(token)

View File

@ -43,7 +43,7 @@ from moto.config.exceptions import (
) )
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.s3.config import s3_config_query from moto.s3.config import s3_account_public_access_block_query, s3_config_query
from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID
@ -58,7 +58,10 @@ POP_STRINGS = [
DEFAULT_PAGE_SIZE = 100 DEFAULT_PAGE_SIZE = 100
# Map the Config resource type to a backend: # Map the Config resource type to a backend:
RESOURCE_MAP = {"AWS::S3::Bucket": s3_config_query} RESOURCE_MAP = {
"AWS::S3::Bucket": s3_config_query,
"AWS::S3::AccountPublicAccessBlock": s3_account_public_access_block_query,
}
def datetime2int(date): def datetime2int(date):
@ -867,16 +870,17 @@ class ConfigBackend(BaseBackend):
backend_region=backend_query_region, backend_region=backend_query_region,
) )
result = { resource_identifiers = []
"resourceIdentifiers": [ for identifier in identifiers:
{ item = {"resourceType": identifier["type"], "resourceId": identifier["id"]}
"resourceType": identifier["type"],
"resourceId": identifier["id"], # Some resource types lack names:
"resourceName": identifier["name"], if identifier.get("name"):
} item["resourceName"] = identifier["name"]
for identifier in identifiers
] resource_identifiers.append(item)
}
result = {"resourceIdentifiers": resource_identifiers}
if new_token: if new_token:
result["nextToken"] = new_token result["nextToken"] = new_token
@ -927,19 +931,22 @@ class ConfigBackend(BaseBackend):
resource_region=resource_region, resource_region=resource_region,
) )
result = { resource_identifiers = []
"ResourceIdentifiers": [ for identifier in identifiers:
{ item = {
"SourceAccountId": DEFAULT_ACCOUNT_ID, "SourceAccountId": DEFAULT_ACCOUNT_ID,
"SourceRegion": identifier["region"], "SourceRegion": identifier["region"],
"ResourceType": identifier["type"], "ResourceType": identifier["type"],
"ResourceId": identifier["id"], "ResourceId": identifier["id"],
"ResourceName": identifier["name"],
}
for identifier in identifiers
]
} }
if identifier.get("name"):
item["ResourceName"] = identifier["name"]
resource_identifiers.append(item)
result = {"ResourceIdentifiers": resource_identifiers}
if new_token: if new_token:
result["NextToken"] = new_token result["NextToken"] = new_token

View File

@ -606,12 +606,13 @@ class ConfigQueryModel(object):
As such, the proper way to implement is to first obtain a full list of results from all the region backends, and then filter As such, the proper way to implement is to first obtain a full list of results from all the region backends, and then filter
from there. It may be valuable to make this a concatenation of the region and resource name. from there. It may be valuable to make this a concatenation of the region and resource name.
:param resource_region: :param resource_ids: A list of resource IDs
:param resource_ids: :param resource_name: The individual name of a resource
:param resource_name: :param limit: How many per page
:param limit: :param next_token: The item that will page on
:param next_token:
:param backend_region: The region for the backend to pull results from. Set to `None` if this is an aggregated query. :param backend_region: The region for the backend to pull results from. Set to `None` if this is an aggregated query.
:param resource_region: The region for where the resources reside to pull results from. Set to `None` if this is a
non-aggregated query.
:return: This should return a list of Dicts that have the following fields: :return: This should return a list of Dicts that have the following fields:
[ [
{ {

View File

@ -977,10 +977,8 @@ class OpLessThan(Op):
lhs = self.lhs.expr(item) lhs = self.lhs.expr(item)
rhs = self.rhs.expr(item) rhs = self.rhs.expr(item)
# In python3 None is not a valid comparator when using < or > so must be handled specially # In python3 None is not a valid comparator when using < or > so must be handled specially
if lhs and rhs: if lhs is not None and rhs is not None:
return lhs < rhs return lhs < rhs
elif lhs is None and rhs:
return True
else: else:
return False return False
@ -992,10 +990,8 @@ class OpGreaterThan(Op):
lhs = self.lhs.expr(item) lhs = self.lhs.expr(item)
rhs = self.rhs.expr(item) rhs = self.rhs.expr(item)
# In python3 None is not a valid comparator when using < or > so must be handled specially # In python3 None is not a valid comparator when using < or > so must be handled specially
if lhs and rhs: if lhs is not None and rhs is not None:
return lhs > rhs return lhs > rhs
elif lhs and rhs is None:
return True
else: else:
return False return False
@ -1025,10 +1021,8 @@ class OpLessThanOrEqual(Op):
lhs = self.lhs.expr(item) lhs = self.lhs.expr(item)
rhs = self.rhs.expr(item) rhs = self.rhs.expr(item)
# In python3 None is not a valid comparator when using < or > so must be handled specially # In python3 None is not a valid comparator when using < or > so must be handled specially
if lhs and rhs: if lhs is not None and rhs is not None:
return lhs <= rhs return lhs <= rhs
elif lhs is None and rhs or lhs is None and rhs is None:
return True
else: else:
return False return False
@ -1040,10 +1034,8 @@ class OpGreaterThanOrEqual(Op):
lhs = self.lhs.expr(item) lhs = self.lhs.expr(item)
rhs = self.rhs.expr(item) rhs = self.rhs.expr(item)
# In python3 None is not a valid comparator when using < or > so must be handled specially # In python3 None is not a valid comparator when using < or > so must be handled specially
if lhs and rhs: if lhs is not None and rhs is not None:
return lhs >= rhs return lhs >= rhs
elif lhs and rhs is None or lhs is None and rhs is None:
return True
else: else:
return False return False

View File

@ -448,7 +448,12 @@ class Item(BaseModel):
if list_append_re: if list_append_re:
new_value = expression_attribute_values[list_append_re.group(2).strip()] new_value = expression_attribute_values[list_append_re.group(2).strip()]
old_list_key = list_append_re.group(1) old_list_key = list_append_re.group(1)
# Get the existing value # old_key could be a function itself (if_not_exists)
if old_list_key.startswith("if_not_exists"):
old_list = DynamoType(
expression_attribute_values[self._get_default(old_list_key)]
)
else:
old_list = self.attrs[old_list_key.split(".")[0]] old_list = self.attrs[old_list_key.split(".")[0]]
if "." in old_list_key: if "." in old_list_key:
# Value is nested inside a map - find the appropriate child attr # Value is nested inside a map - find the appropriate child attr
@ -457,7 +462,7 @@ class Item(BaseModel):
) )
if not old_list.is_list(): if not old_list.is_list():
raise ParamValidationError raise ParamValidationError
old_list.value.extend(new_value["L"]) old_list.value.extend([DynamoType(v) for v in new_value["L"]])
value = old_list value = old_list
return value return value

View File

@ -508,6 +508,13 @@ class DynamoHandler(BaseResponse):
# 'KeyConditions': {u'forum_name': {u'ComparisonOperator': u'EQ', u'AttributeValueList': [{u'S': u'the-key'}]}} # 'KeyConditions': {u'forum_name': {u'ComparisonOperator': u'EQ', u'AttributeValueList': [{u'S': u'the-key'}]}}
key_conditions = self.body.get("KeyConditions") key_conditions = self.body.get("KeyConditions")
query_filters = self.body.get("QueryFilter") query_filters = self.body.get("QueryFilter")
if not (key_conditions or query_filters):
return self.error(
"com.amazonaws.dynamodb.v20111205#ValidationException",
"Either KeyConditions or QueryFilter should be present",
)
if key_conditions: if key_conditions:
( (
hash_key_name, hash_key_name,

View File

@ -27,6 +27,7 @@ from moto.core.utils import (
iso_8601_datetime_with_milliseconds, iso_8601_datetime_with_milliseconds,
camelcase_to_underscores, camelcase_to_underscores,
) )
from moto.iam.models import ACCOUNT_ID
from .exceptions import ( from .exceptions import (
CidrLimitExceeded, CidrLimitExceeded,
DependencyViolationError, DependencyViolationError,
@ -139,18 +140,23 @@ from .utils import (
rsa_public_key_fingerprint, rsa_public_key_fingerprint,
) )
INSTANCE_TYPES = json.load(
open(resource_filename(__name__, "resources/instance_types.json"), "r") def _load_resource(filename):
) with open(filename, "r") as f:
AMIS = json.load( return json.load(f)
open(
os.environ.get("MOTO_AMIS_PATH")
or resource_filename(__name__, "resources/amis.json"), INSTANCE_TYPES = _load_resource(
"r", resource_filename(__name__, "resources/instance_types.json")
)
) )
OWNER_ID = "111122223333" AMIS = _load_resource(
os.environ.get("MOTO_AMIS_PATH")
or resource_filename(__name__, "resources/amis.json"),
)
OWNER_ID = ACCOUNT_ID
def utc_date_and_time(): def utc_date_and_time():
@ -1336,7 +1342,7 @@ class AmiBackend(object):
source_ami=None, source_ami=None,
name=name, name=name,
description=description, description=description,
owner_id=context.get_current_user() if context else OWNER_ID, owner_id=OWNER_ID,
) )
self.amis[ami_id] = ami self.amis[ami_id] = ami
return ami return ami
@ -1387,14 +1393,7 @@ class AmiBackend(object):
# Limit by owner ids # Limit by owner ids
if owners: if owners:
# support filtering by Owners=['self'] # support filtering by Owners=['self']
owners = list( owners = list(map(lambda o: OWNER_ID if o == "self" else o, owners,))
map(
lambda o: context.get_current_user()
if context and o == "self"
else o,
owners,
)
)
images = [ami for ami in images if ami.owner_id in owners] images = [ami for ami in images if ami.owner_id in owners]
# Generic filters # Generic filters

View File

@ -104,7 +104,7 @@ class SecurityGroups(BaseResponse):
if self.is_not_dryrun("GrantSecurityGroupIngress"): if self.is_not_dryrun("GrantSecurityGroupIngress"):
for args in self._process_rules_from_querystring(): for args in self._process_rules_from_querystring():
self.ec2_backend.authorize_security_group_ingress(*args) self.ec2_backend.authorize_security_group_ingress(*args)
return AUTHORIZE_SECURITY_GROUP_INGRESS_REPONSE return AUTHORIZE_SECURITY_GROUP_INGRESS_RESPONSE
def create_security_group(self): def create_security_group(self):
name = self._get_param("GroupName") name = self._get_param("GroupName")
@ -158,7 +158,7 @@ class SecurityGroups(BaseResponse):
if self.is_not_dryrun("RevokeSecurityGroupIngress"): if self.is_not_dryrun("RevokeSecurityGroupIngress"):
for args in self._process_rules_from_querystring(): for args in self._process_rules_from_querystring():
self.ec2_backend.revoke_security_group_ingress(*args) self.ec2_backend.revoke_security_group_ingress(*args)
return REVOKE_SECURITY_GROUP_INGRESS_REPONSE return REVOKE_SECURITY_GROUP_INGRESS_RESPONSE
CREATE_SECURITY_GROUP_RESPONSE = """<CreateSecurityGroupResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/"> CREATE_SECURITY_GROUP_RESPONSE = """<CreateSecurityGroupResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
@ -265,12 +265,12 @@ DESCRIBE_SECURITY_GROUPS_RESPONSE = (
</DescribeSecurityGroupsResponse>""" </DescribeSecurityGroupsResponse>"""
) )
AUTHORIZE_SECURITY_GROUP_INGRESS_REPONSE = """<AuthorizeSecurityGroupIngressResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/"> AUTHORIZE_SECURITY_GROUP_INGRESS_RESPONSE = """<AuthorizeSecurityGroupIngressResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId> <requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId>
<return>true</return> <return>true</return>
</AuthorizeSecurityGroupIngressResponse>""" </AuthorizeSecurityGroupIngressResponse>"""
REVOKE_SECURITY_GROUP_INGRESS_REPONSE = """<RevokeSecurityGroupIngressResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/"> REVOKE_SECURITY_GROUP_INGRESS_RESPONSE = """<RevokeSecurityGroupIngressResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId> <requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId>
<return>true</return> <return>true</return>
</RevokeSecurityGroupIngressResponse>""" </RevokeSecurityGroupIngressResponse>"""

View File

@ -118,6 +118,7 @@ class TaskDefinition(BaseObject):
revision, revision,
container_definitions, container_definitions,
region_name, region_name,
network_mode=None,
volumes=None, volumes=None,
tags=None, tags=None,
): ):
@ -132,6 +133,10 @@ class TaskDefinition(BaseObject):
self.volumes = [] self.volumes = []
else: else:
self.volumes = volumes self.volumes = volumes
if network_mode is None:
self.network_mode = "bridge"
else:
self.network_mode = network_mode
@property @property
def response_object(self): def response_object(self):
@ -553,7 +558,7 @@ class EC2ContainerServiceBackend(BaseBackend):
raise Exception("{0} is not a cluster".format(cluster_name)) raise Exception("{0} is not a cluster".format(cluster_name))
def register_task_definition( def register_task_definition(
self, family, container_definitions, volumes, tags=None self, family, container_definitions, volumes=None, network_mode=None, tags=None
): ):
if family in self.task_definitions: if family in self.task_definitions:
last_id = self._get_last_task_definition_revision_id(family) last_id = self._get_last_task_definition_revision_id(family)
@ -562,7 +567,13 @@ class EC2ContainerServiceBackend(BaseBackend):
self.task_definitions[family] = {} self.task_definitions[family] = {}
revision = 1 revision = 1
task_definition = TaskDefinition( task_definition = TaskDefinition(
family, revision, container_definitions, self.region_name, volumes, tags family,
revision,
container_definitions,
self.region_name,
volumes=volumes,
network_mode=network_mode,
tags=tags,
) )
self.task_definitions[family][revision] = task_definition self.task_definitions[family][revision] = task_definition

View File

@ -62,8 +62,13 @@ class EC2ContainerServiceResponse(BaseResponse):
container_definitions = self._get_param("containerDefinitions") container_definitions = self._get_param("containerDefinitions")
volumes = self._get_param("volumes") volumes = self._get_param("volumes")
tags = self._get_param("tags") tags = self._get_param("tags")
network_mode = self._get_param("networkMode")
task_definition = self.ecs_backend.register_task_definition( task_definition = self.ecs_backend.register_task_definition(
family, container_definitions, volumes, tags family,
container_definitions,
volumes=volumes,
network_mode=network_mode,
tags=tags,
) )
return json.dumps({"taskDefinition": task_definition.response_object}) return json.dumps({"taskDefinition": task_definition.response_object})

View File

@ -143,6 +143,9 @@ class EventsBackend(BaseBackend):
def delete_rule(self, name): def delete_rule(self, name):
self.rules_order.pop(self.rules_order.index(name)) self.rules_order.pop(self.rules_order.index(name))
arn = self.rules.get(name).arn
if self.tagger.has_tags(arn):
self.tagger.delete_all_tags_for_resource(arn)
return self.rules.pop(name) is not None return self.rules.pop(name) is not None
def describe_rule(self, name): def describe_rule(self, name):
@ -371,8 +374,16 @@ class EventsBackend(BaseBackend):
"ResourceNotFoundException", "An entity that you specified does not exist." "ResourceNotFoundException", "An entity that you specified does not exist."
) )
def list_tags_for_resource(self, arn):
name = arn.split("/")[-1]
if name in self.rules:
return self.tagger.list_tags_for_resource(self.rules[name].arn)
raise JsonRESTError(
"ResourceNotFoundException", "An entity that you specified does not exist."
)
def tag_resource(self, arn, tags): def tag_resource(self, arn, tags):
name = arn.split('/')[-1] name = arn.split("/")[-1]
if name in self.rules: if name in self.rules:
self.tagger.tag_resource(self.rules[name].arn, tags) self.tagger.tag_resource(self.rules[name].arn, tags)
return {} return {}
@ -381,7 +392,7 @@ class EventsBackend(BaseBackend):
) )
def untag_resource(self, arn, tag_names): def untag_resource(self, arn, tag_names):
name = arn.split('/')[-1] name = arn.split("/")[-1]
if name in self.rules: if name in self.rules:
self.tagger.untag_resource_using_names(self.rules[name].arn, tag_names) self.tagger.untag_resource_using_names(self.rules[name].arn, tag_names)
return {} return {}
@ -389,6 +400,7 @@ class EventsBackend(BaseBackend):
"ResourceNotFoundException", "An entity that you specified does not exist." "ResourceNotFoundException", "An entity that you specified does not exist."
) )
events_backends = {} events_backends = {}
for region in Session().get_available_regions("events"): for region in Session().get_available_regions("events"):
events_backends[region] = EventsBackend(region) events_backends[region] = EventsBackend(region)

View File

@ -563,6 +563,10 @@ class IamResponse(BaseResponse):
def create_access_key(self): def create_access_key(self):
user_name = self._get_param("UserName") user_name = self._get_param("UserName")
if not user_name:
access_key_id = self.get_current_user()
access_key = iam_backend.get_access_key_last_used(access_key_id)
user_name = access_key["user_name"]
key = iam_backend.create_access_key(user_name) key = iam_backend.create_access_key(user_name)
template = self.response_template(CREATE_ACCESS_KEY_TEMPLATE) template = self.response_template(CREATE_ACCESS_KEY_TEMPLATE)
@ -572,6 +576,10 @@ class IamResponse(BaseResponse):
user_name = self._get_param("UserName") user_name = self._get_param("UserName")
access_key_id = self._get_param("AccessKeyId") access_key_id = self._get_param("AccessKeyId")
status = self._get_param("Status") status = self._get_param("Status")
if not user_name:
access_key = iam_backend.get_access_key_last_used(access_key_id)
user_name = access_key["user_name"]
iam_backend.update_access_key(user_name, access_key_id, status) iam_backend.update_access_key(user_name, access_key_id, status)
template = self.response_template(GENERIC_EMPTY_TEMPLATE) template = self.response_template(GENERIC_EMPTY_TEMPLATE)
return template.render(name="UpdateAccessKey") return template.render(name="UpdateAccessKey")
@ -587,6 +595,11 @@ class IamResponse(BaseResponse):
def list_access_keys(self): def list_access_keys(self):
user_name = self._get_param("UserName") user_name = self._get_param("UserName")
if not user_name:
access_key_id = self.get_current_user()
access_key = iam_backend.get_access_key_last_used(access_key_id)
user_name = access_key["user_name"]
keys = iam_backend.get_all_access_keys(user_name) keys = iam_backend.get_all_access_keys(user_name)
template = self.response_template(LIST_ACCESS_KEYS_TEMPLATE) template = self.response_template(LIST_ACCESS_KEYS_TEMPLATE)
return template.render(user_name=user_name, keys=keys) return template.render(user_name=user_name, keys=keys)
@ -594,6 +607,9 @@ class IamResponse(BaseResponse):
def delete_access_key(self): def delete_access_key(self):
user_name = self._get_param("UserName") user_name = self._get_param("UserName")
access_key_id = self._get_param("AccessKeyId") access_key_id = self._get_param("AccessKeyId")
if not user_name:
access_key = iam_backend.get_access_key_last_used(access_key_id)
user_name = access_key["user_name"]
iam_backend.delete_access_key(access_key_id, user_name) iam_backend.delete_access_key(access_key_id, user_name)
template = self.response_template(GENERIC_EMPTY_TEMPLATE) template = self.response_template(GENERIC_EMPTY_TEMPLATE)

View File

@ -22,6 +22,15 @@ class InvalidRequestException(IoTClientError):
) )
class InvalidStateTransitionException(IoTClientError):
def __init__(self, msg=None):
self.code = 409
super(InvalidStateTransitionException, self).__init__(
"InvalidStateTransitionException",
msg or "An attempt was made to change to an invalid state.",
)
class VersionConflictException(IoTClientError): class VersionConflictException(IoTClientError):
def __init__(self, name): def __init__(self, name):
self.code = 409 self.code = 409

View File

@ -17,6 +17,7 @@ from .exceptions import (
DeleteConflictException, DeleteConflictException,
ResourceNotFoundException, ResourceNotFoundException,
InvalidRequestException, InvalidRequestException,
InvalidStateTransitionException,
VersionConflictException, VersionConflictException,
) )
@ -29,7 +30,7 @@ class FakeThing(BaseModel):
self.attributes = attributes self.attributes = attributes
self.arn = "arn:aws:iot:%s:1:thing/%s" % (self.region_name, thing_name) self.arn = "arn:aws:iot:%s:1:thing/%s" % (self.region_name, thing_name)
self.version = 1 self.version = 1
# TODO: we need to handle 'version'? # TODO: we need to handle "version"?
# for iot-data # for iot-data
self.thing_shadow = None self.thing_shadow = None
@ -174,18 +175,19 @@ class FakeCertificate(BaseModel):
class FakePolicy(BaseModel): class FakePolicy(BaseModel):
def __init__(self, name, document, region_name): def __init__(self, name, document, region_name, default_version_id="1"):
self.name = name self.name = name
self.document = document self.document = document
self.arn = "arn:aws:iot:%s:1:policy/%s" % (region_name, name) self.arn = "arn:aws:iot:%s:1:policy/%s" % (region_name, name)
self.version = "1" # TODO: handle version self.default_version_id = default_version_id
self.versions = [FakePolicyVersion(self.name, document, True, region_name)]
def to_get_dict(self): def to_get_dict(self):
return { return {
"policyName": self.name, "policyName": self.name,
"policyArn": self.arn, "policyArn": self.arn,
"policyDocument": self.document, "policyDocument": self.document,
"defaultVersionId": self.version, "defaultVersionId": self.default_version_id,
} }
def to_dict_at_creation(self): def to_dict_at_creation(self):
@ -193,13 +195,52 @@ class FakePolicy(BaseModel):
"policyName": self.name, "policyName": self.name,
"policyArn": self.arn, "policyArn": self.arn,
"policyDocument": self.document, "policyDocument": self.document,
"policyVersionId": self.version, "policyVersionId": self.default_version_id,
} }
def to_dict(self): def to_dict(self):
return {"policyName": self.name, "policyArn": self.arn} return {"policyName": self.name, "policyArn": self.arn}
class FakePolicyVersion(object):
def __init__(self, policy_name, document, is_default, region_name):
self.name = policy_name
self.arn = "arn:aws:iot:%s:1:policy/%s" % (region_name, policy_name)
self.document = document or {}
self.is_default = is_default
self.version_id = "1"
self.create_datetime = time.mktime(datetime(2015, 1, 1).timetuple())
self.last_modified_datetime = time.mktime(datetime(2015, 1, 2).timetuple())
def to_get_dict(self):
return {
"policyName": self.name,
"policyArn": self.arn,
"policyDocument": self.document,
"policyVersionId": self.version_id,
"isDefaultVersion": self.is_default,
"creationDate": self.create_datetime,
"lastModifiedDate": self.last_modified_datetime,
"generationId": self.version_id,
}
def to_dict_at_creation(self):
return {
"policyArn": self.arn,
"policyDocument": self.document,
"policyVersionId": self.version_id,
"isDefaultVersion": self.is_default,
}
def to_dict(self):
return {
"versionId": self.version_id,
"isDefaultVersion": self.is_default,
"createDate": self.create_datetime,
}
class FakeJob(BaseModel): class FakeJob(BaseModel):
JOB_ID_REGEX_PATTERN = "[a-zA-Z0-9_-]" JOB_ID_REGEX_PATTERN = "[a-zA-Z0-9_-]"
JOB_ID_REGEX = re.compile(JOB_ID_REGEX_PATTERN) JOB_ID_REGEX = re.compile(JOB_ID_REGEX_PATTERN)
@ -226,12 +267,14 @@ class FakeJob(BaseModel):
self.targets = targets self.targets = targets
self.document_source = document_source self.document_source = document_source
self.document = document self.document = document
self.force = False
self.description = description self.description = description
self.presigned_url_config = presigned_url_config self.presigned_url_config = presigned_url_config
self.target_selection = target_selection self.target_selection = target_selection
self.job_executions_rollout_config = job_executions_rollout_config self.job_executions_rollout_config = job_executions_rollout_config
self.status = None # IN_PROGRESS | CANCELED | COMPLETED self.status = "QUEUED" # IN_PROGRESS | CANCELED | COMPLETED
self.comment = None self.comment = None
self.reason_code = None
self.created_at = time.mktime(datetime(2015, 1, 1).timetuple()) self.created_at = time.mktime(datetime(2015, 1, 1).timetuple())
self.last_updated_at = time.mktime(datetime(2015, 1, 1).timetuple()) self.last_updated_at = time.mktime(datetime(2015, 1, 1).timetuple())
self.completed_at = None self.completed_at = None
@ -258,9 +301,11 @@ class FakeJob(BaseModel):
"jobExecutionsRolloutConfig": self.job_executions_rollout_config, "jobExecutionsRolloutConfig": self.job_executions_rollout_config,
"status": self.status, "status": self.status,
"comment": self.comment, "comment": self.comment,
"forceCanceled": self.force,
"reasonCode": self.reason_code,
"createdAt": self.created_at, "createdAt": self.created_at,
"lastUpdatedAt": self.last_updated_at, "lastUpdatedAt": self.last_updated_at,
"completedAt": self.completedAt, "completedAt": self.completed_at,
"jobProcessDetails": self.job_process_details, "jobProcessDetails": self.job_process_details,
"documentParameters": self.document_parameters, "documentParameters": self.document_parameters,
"document": self.document, "document": self.document,
@ -275,12 +320,67 @@ class FakeJob(BaseModel):
return regex_match and length_match return regex_match and length_match
class FakeJobExecution(BaseModel):
def __init__(
self,
job_id,
thing_arn,
status="QUEUED",
force_canceled=False,
status_details_map={},
):
self.job_id = job_id
self.status = status # IN_PROGRESS | CANCELED | COMPLETED
self.force_canceled = force_canceled
self.status_details_map = status_details_map
self.thing_arn = thing_arn
self.queued_at = time.mktime(datetime(2015, 1, 1).timetuple())
self.started_at = time.mktime(datetime(2015, 1, 1).timetuple())
self.last_updated_at = time.mktime(datetime(2015, 1, 1).timetuple())
self.execution_number = 123
self.version_number = 123
self.approximate_seconds_before_time_out = 123
def to_get_dict(self):
obj = {
"jobId": self.job_id,
"status": self.status,
"forceCanceled": self.force_canceled,
"statusDetails": {"detailsMap": self.status_details_map},
"thingArn": self.thing_arn,
"queuedAt": self.queued_at,
"startedAt": self.started_at,
"lastUpdatedAt": self.last_updated_at,
"executionNumber": self.execution_number,
"versionNumber": self.version_number,
"approximateSecondsBeforeTimedOut": self.approximate_seconds_before_time_out,
}
return obj
def to_dict(self):
obj = {
"jobId": self.job_id,
"thingArn": self.thing_arn,
"jobExecutionSummary": {
"status": self.status,
"queuedAt": self.queued_at,
"startedAt": self.started_at,
"lastUpdatedAt": self.last_updated_at,
"executionNumber": self.execution_number,
},
}
return obj
class IoTBackend(BaseBackend): class IoTBackend(BaseBackend):
def __init__(self, region_name=None): def __init__(self, region_name=None):
super(IoTBackend, self).__init__() super(IoTBackend, self).__init__()
self.region_name = region_name self.region_name = region_name
self.things = OrderedDict() self.things = OrderedDict()
self.jobs = OrderedDict() self.jobs = OrderedDict()
self.job_executions = OrderedDict()
self.thing_types = OrderedDict() self.thing_types = OrderedDict()
self.thing_groups = OrderedDict() self.thing_groups = OrderedDict()
self.certificates = OrderedDict() self.certificates = OrderedDict()
@ -535,6 +635,28 @@ class IoTBackend(BaseBackend):
self.policies[policy.name] = policy self.policies[policy.name] = policy
return policy return policy
def attach_policy(self, policy_name, target):
principal = self._get_principal(target)
policy = self.get_policy(policy_name)
k = (target, policy_name)
if k in self.principal_policies:
return
self.principal_policies[k] = (principal, policy)
def detach_policy(self, policy_name, target):
# this may raises ResourceNotFoundException
self._get_principal(target)
self.get_policy(policy_name)
k = (target, policy_name)
if k not in self.principal_policies:
raise ResourceNotFoundException()
del self.principal_policies[k]
def list_attached_policies(self, target):
policies = [v[1] for k, v in self.principal_policies.items() if k[0] == target]
return policies
def list_policies(self): def list_policies(self):
policies = self.policies.values() policies = self.policies.values()
return policies return policies
@ -559,6 +681,60 @@ class IoTBackend(BaseBackend):
policy = self.get_policy(policy_name) policy = self.get_policy(policy_name)
del self.policies[policy.name] del self.policies[policy.name]
def create_policy_version(self, policy_name, policy_document, set_as_default):
policy = self.get_policy(policy_name)
if not policy:
raise ResourceNotFoundException()
version = FakePolicyVersion(
policy_name, policy_document, set_as_default, self.region_name
)
policy.versions.append(version)
version.version_id = "{0}".format(len(policy.versions))
if set_as_default:
self.set_default_policy_version(policy_name, version.version_id)
return version
def set_default_policy_version(self, policy_name, version_id):
policy = self.get_policy(policy_name)
if not policy:
raise ResourceNotFoundException()
for version in policy.versions:
if version.version_id == version_id:
version.is_default = True
policy.default_version_id = version.version_id
policy.document = version.document
else:
version.is_default = False
def get_policy_version(self, policy_name, version_id):
policy = self.get_policy(policy_name)
if not policy:
raise ResourceNotFoundException()
for version in policy.versions:
if version.version_id == version_id:
return version
raise ResourceNotFoundException()
def list_policy_versions(self, policy_name):
policy = self.get_policy(policy_name)
if not policy:
raise ResourceNotFoundException()
return policy.versions
def delete_policy_version(self, policy_name, version_id):
policy = self.get_policy(policy_name)
if not policy:
raise ResourceNotFoundException()
if version_id == policy.default_version_id:
raise InvalidRequestException(
"Cannot delete the default version of a policy"
)
for i, v in enumerate(policy.versions):
if v.version_id == version_id:
del policy.versions[i]
return
raise ResourceNotFoundException()
def _get_principal(self, principal_arn): def _get_principal(self, principal_arn):
""" """
raise ResourceNotFoundException raise ResourceNotFoundException
@ -574,14 +750,6 @@ class IoTBackend(BaseBackend):
pass pass
raise ResourceNotFoundException() raise ResourceNotFoundException()
def attach_policy(self, policy_name, target):
principal = self._get_principal(target)
policy = self.get_policy(policy_name)
k = (target, policy_name)
if k in self.principal_policies:
return
self.principal_policies[k] = (principal, policy)
def attach_principal_policy(self, policy_name, principal_arn): def attach_principal_policy(self, policy_name, principal_arn):
principal = self._get_principal(principal_arn) principal = self._get_principal(principal_arn)
policy = self.get_policy(policy_name) policy = self.get_policy(policy_name)
@ -590,15 +758,6 @@ class IoTBackend(BaseBackend):
return return
self.principal_policies[k] = (principal, policy) self.principal_policies[k] = (principal, policy)
def detach_policy(self, policy_name, target):
# this may raises ResourceNotFoundException
self._get_principal(target)
self.get_policy(policy_name)
k = (target, policy_name)
if k not in self.principal_policies:
raise ResourceNotFoundException()
del self.principal_policies[k]
def detach_principal_policy(self, policy_name, principal_arn): def detach_principal_policy(self, policy_name, principal_arn):
# this may raises ResourceNotFoundException # this may raises ResourceNotFoundException
self._get_principal(principal_arn) self._get_principal(principal_arn)
@ -819,11 +978,187 @@ class IoTBackend(BaseBackend):
self.region_name, self.region_name,
) )
self.jobs[job_id] = job self.jobs[job_id] = job
for thing_arn in targets:
thing_name = thing_arn.split(":")[-1].split("/")[-1]
job_execution = FakeJobExecution(job_id, thing_arn)
self.job_executions[(job_id, thing_name)] = job_execution
return job.job_arn, job_id, description return job.job_arn, job_id, description
def describe_job(self, job_id): def describe_job(self, job_id):
jobs = [_ for _ in self.jobs.values() if _.job_id == job_id]
if len(jobs) == 0:
raise ResourceNotFoundException()
return jobs[0]
def delete_job(self, job_id, force):
job = self.jobs[job_id]
if job.status == "IN_PROGRESS" and force:
del self.jobs[job_id]
elif job.status != "IN_PROGRESS":
del self.jobs[job_id]
else:
raise InvalidStateTransitionException()
def cancel_job(self, job_id, reason_code, comment, force):
job = self.jobs[job_id]
job.reason_code = reason_code if reason_code is not None else job.reason_code
job.comment = comment if comment is not None else job.comment
job.force = force if force is not None and force != job.force else job.force
job.status = "CANCELED"
if job.status == "IN_PROGRESS" and force:
self.jobs[job_id] = job
elif job.status != "IN_PROGRESS":
self.jobs[job_id] = job
else:
raise InvalidStateTransitionException()
return job
def get_job_document(self, job_id):
return self.jobs[job_id] return self.jobs[job_id]
def list_jobs(
self,
status,
target_selection,
max_results,
token,
thing_group_name,
thing_group_id,
):
# TODO: implement filters
all_jobs = [_.to_dict() for _ in self.jobs.values()]
filtered_jobs = all_jobs
if token is None:
jobs = filtered_jobs[0:max_results]
next_token = str(max_results) if len(filtered_jobs) > max_results else None
else:
token = int(token)
jobs = filtered_jobs[token : token + max_results]
next_token = (
str(token + max_results)
if len(filtered_jobs) > token + max_results
else None
)
return jobs, next_token
def describe_job_execution(self, job_id, thing_name, execution_number):
try:
job_execution = self.job_executions[(job_id, thing_name)]
except KeyError:
raise ResourceNotFoundException()
if job_execution is None or (
execution_number is not None
and job_execution.execution_number != execution_number
):
raise ResourceNotFoundException()
return job_execution
def cancel_job_execution(
self, job_id, thing_name, force, expected_version, status_details
):
job_execution = self.job_executions[(job_id, thing_name)]
if job_execution is None:
raise ResourceNotFoundException()
job_execution.force_canceled = (
force if force is not None else job_execution.force_canceled
)
# TODO: implement expected_version and status_details (at most 10 can be specified)
if job_execution.status == "IN_PROGRESS" and force:
job_execution.status = "CANCELED"
self.job_executions[(job_id, thing_name)] = job_execution
elif job_execution.status != "IN_PROGRESS":
job_execution.status = "CANCELED"
self.job_executions[(job_id, thing_name)] = job_execution
else:
raise InvalidStateTransitionException()
def delete_job_execution(self, job_id, thing_name, execution_number, force):
job_execution = self.job_executions[(job_id, thing_name)]
if job_execution.execution_number != execution_number:
raise ResourceNotFoundException()
if job_execution.status == "IN_PROGRESS" and force:
del self.job_executions[(job_id, thing_name)]
elif job_execution.status != "IN_PROGRESS":
del self.job_executions[(job_id, thing_name)]
else:
raise InvalidStateTransitionException()
def list_job_executions_for_job(self, job_id, status, max_results, next_token):
job_executions = [
self.job_executions[je].to_dict()
for je in self.job_executions
if je[0] == job_id
]
if status is not None:
job_executions = list(
filter(
lambda elem: status in elem["status"] and elem["status"] == status,
job_executions,
)
)
token = next_token
if token is None:
job_executions = job_executions[0:max_results]
next_token = str(max_results) if len(job_executions) > max_results else None
else:
token = int(token)
job_executions = job_executions[token : token + max_results]
next_token = (
str(token + max_results)
if len(job_executions) > token + max_results
else None
)
return job_executions, next_token
def list_job_executions_for_thing(
self, thing_name, status, max_results, next_token
):
job_executions = [
self.job_executions[je].to_dict()
for je in self.job_executions
if je[1] == thing_name
]
if status is not None:
job_executions = list(
filter(
lambda elem: status in elem["status"] and elem["status"] == status,
job_executions,
)
)
token = next_token
if token is None:
job_executions = job_executions[0:max_results]
next_token = str(max_results) if len(job_executions) > max_results else None
else:
token = int(token)
job_executions = job_executions[token : token + max_results]
next_token = (
str(token + max_results)
if len(job_executions) > token + max_results
else None
)
return job_executions, next_token
iot_backends = {} iot_backends = {}
for region in Session().get_available_regions("iot"): for region in Session().get_available_regions("iot"):

View File

@ -1,6 +1,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import json import json
from six.moves.urllib.parse import unquote
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import iot_backends from .models import iot_backends
@ -141,6 +142,8 @@ class IoTResponse(BaseResponse):
createdAt=job.created_at, createdAt=job.created_at,
description=job.description, description=job.description,
documentParameters=job.document_parameters, documentParameters=job.document_parameters,
forceCanceled=job.force,
reasonCode=job.reason_code,
jobArn=job.job_arn, jobArn=job.job_arn,
jobExecutionsRolloutConfig=job.job_executions_rollout_config, jobExecutionsRolloutConfig=job.job_executions_rollout_config,
jobId=job.job_id, jobId=job.job_id,
@ -154,6 +157,127 @@ class IoTResponse(BaseResponse):
) )
) )
def delete_job(self):
job_id = self._get_param("jobId")
force = self._get_bool_param("force")
self.iot_backend.delete_job(job_id=job_id, force=force)
return json.dumps(dict())
def cancel_job(self):
job_id = self._get_param("jobId")
reason_code = self._get_param("reasonCode")
comment = self._get_param("comment")
force = self._get_bool_param("force")
job = self.iot_backend.cancel_job(
job_id=job_id, reason_code=reason_code, comment=comment, force=force
)
return json.dumps(job.to_dict())
def get_job_document(self):
job = self.iot_backend.get_job_document(job_id=self._get_param("jobId"))
if job.document is not None:
return json.dumps({"document": job.document})
else:
# job.document_source is not None:
# TODO: needs to be implemented to get document_source's content from S3
return json.dumps({"document": ""})
def list_jobs(self):
status = (self._get_param("status"),)
target_selection = (self._get_param("targetSelection"),)
max_results = self._get_int_param(
"maxResults", 50
) # not the default, but makes testing easier
previous_next_token = self._get_param("nextToken")
thing_group_name = (self._get_param("thingGroupName"),)
thing_group_id = self._get_param("thingGroupId")
jobs, next_token = self.iot_backend.list_jobs(
status=status,
target_selection=target_selection,
max_results=max_results,
token=previous_next_token,
thing_group_name=thing_group_name,
thing_group_id=thing_group_id,
)
return json.dumps(dict(jobs=jobs, nextToken=next_token))
def describe_job_execution(self):
job_id = self._get_param("jobId")
thing_name = self._get_param("thingName")
execution_number = self._get_int_param("executionNumber")
job_execution = self.iot_backend.describe_job_execution(
job_id=job_id, thing_name=thing_name, execution_number=execution_number
)
return json.dumps(dict(execution=job_execution.to_get_dict()))
def cancel_job_execution(self):
job_id = self._get_param("jobId")
thing_name = self._get_param("thingName")
force = self._get_bool_param("force")
expected_version = self._get_int_param("expectedVersion")
status_details = self._get_param("statusDetails")
self.iot_backend.cancel_job_execution(
job_id=job_id,
thing_name=thing_name,
force=force,
expected_version=expected_version,
status_details=status_details,
)
return json.dumps(dict())
def delete_job_execution(self):
job_id = self._get_param("jobId")
thing_name = self._get_param("thingName")
execution_number = self._get_int_param("executionNumber")
force = self._get_bool_param("force")
self.iot_backend.delete_job_execution(
job_id=job_id,
thing_name=thing_name,
execution_number=execution_number,
force=force,
)
return json.dumps(dict())
def list_job_executions_for_job(self):
job_id = self._get_param("jobId")
status = self._get_param("status")
max_results = self._get_int_param(
"maxResults", 50
) # not the default, but makes testing easier
next_token = self._get_param("nextToken")
job_executions, next_token = self.iot_backend.list_job_executions_for_job(
job_id=job_id, status=status, max_results=max_results, next_token=next_token
)
return json.dumps(dict(executionSummaries=job_executions, nextToken=next_token))
def list_job_executions_for_thing(self):
thing_name = self._get_param("thingName")
status = self._get_param("status")
max_results = self._get_int_param(
"maxResults", 50
) # not the default, but makes testing easier
next_token = self._get_param("nextToken")
job_executions, next_token = self.iot_backend.list_job_executions_for_thing(
thing_name=thing_name,
status=status,
max_results=max_results,
next_token=next_token,
)
return json.dumps(dict(executionSummaries=job_executions, nextToken=next_token))
def create_keys_and_certificate(self): def create_keys_and_certificate(self):
set_as_active = self._get_bool_param("setAsActive") set_as_active = self._get_bool_param("setAsActive")
cert, key_pair = self.iot_backend.create_keys_and_certificate( cert, key_pair = self.iot_backend.create_keys_and_certificate(
@ -241,12 +365,61 @@ class IoTResponse(BaseResponse):
self.iot_backend.delete_policy(policy_name=policy_name) self.iot_backend.delete_policy(policy_name=policy_name)
return json.dumps(dict()) return json.dumps(dict())
def create_policy_version(self):
policy_name = self._get_param("policyName")
policy_document = self._get_param("policyDocument")
set_as_default = self._get_bool_param("setAsDefault")
policy_version = self.iot_backend.create_policy_version(
policy_name, policy_document, set_as_default
)
return json.dumps(dict(policy_version.to_dict_at_creation()))
def set_default_policy_version(self):
policy_name = self._get_param("policyName")
version_id = self._get_param("policyVersionId")
self.iot_backend.set_default_policy_version(policy_name, version_id)
return json.dumps(dict())
def get_policy_version(self):
policy_name = self._get_param("policyName")
version_id = self._get_param("policyVersionId")
policy_version = self.iot_backend.get_policy_version(policy_name, version_id)
return json.dumps(dict(policy_version.to_get_dict()))
def list_policy_versions(self):
policy_name = self._get_param("policyName")
policiy_versions = self.iot_backend.list_policy_versions(
policy_name=policy_name
)
return json.dumps(dict(policyVersions=[_.to_dict() for _ in policiy_versions]))
def delete_policy_version(self):
policy_name = self._get_param("policyName")
version_id = self._get_param("policyVersionId")
self.iot_backend.delete_policy_version(policy_name, version_id)
return json.dumps(dict())
def attach_policy(self): def attach_policy(self):
policy_name = self._get_param("policyName") policy_name = self._get_param("policyName")
target = self._get_param("target") target = self._get_param("target")
self.iot_backend.attach_policy(policy_name=policy_name, target=target) self.iot_backend.attach_policy(policy_name=policy_name, target=target)
return json.dumps(dict()) return json.dumps(dict())
def list_attached_policies(self):
principal = unquote(self._get_param("target"))
# marker = self._get_param("marker")
# page_size = self._get_int_param("pageSize")
policies = self.iot_backend.list_attached_policies(target=principal)
# TODO: implement pagination in the future
next_marker = None
return json.dumps(
dict(policies=[_.to_dict() for _ in policies], nextMarker=next_marker)
)
def attach_principal_policy(self): def attach_principal_policy(self):
policy_name = self._get_param("policyName") policy_name = self._get_param("policyName")
principal = self.headers.get("x-amzn-iot-principal") principal = self.headers.get("x-amzn-iot-principal")

View File

@ -7,25 +7,42 @@ from datetime import datetime, timedelta
from boto3 import Session from boto3 import Session
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
<<<<<<< HEAD
from moto.core.exceptions import JsonRESTError from moto.core.exceptions import JsonRESTError
from moto.core.utils import iso_8601_datetime_without_milliseconds from moto.core.utils import iso_8601_datetime_without_milliseconds
from moto.utilities.tagging_service import TaggingService from moto.utilities.tagging_service import TaggingService
=======
from moto.core.utils import unix_time
from moto.iam.models import ACCOUNT_ID
>>>>>>> 100dbd529f174f18d579a1dcc066d55409f2e38f
from .utils import decrypt, encrypt, generate_key_id, generate_master_key from .utils import decrypt, encrypt, generate_key_id, generate_master_key
class Key(BaseModel): class Key(BaseModel):
<<<<<<< HEAD
def __init__(self, policy, key_usage, description, region): def __init__(self, policy, key_usage, description, region):
=======
def __init__(
self, policy, key_usage, customer_master_key_spec, description, tags, region
):
>>>>>>> 100dbd529f174f18d579a1dcc066d55409f2e38f
self.id = generate_key_id() self.id = generate_key_id()
self.creation_date = unix_time()
self.policy = policy self.policy = policy
self.key_usage = key_usage self.key_usage = key_usage
self.key_state = "Enabled" self.key_state = "Enabled"
self.description = description self.description = description
self.enabled = True self.enabled = True
self.region = region self.region = region
self.account_id = "012345678912" self.account_id = ACCOUNT_ID
self.key_rotation_status = False self.key_rotation_status = False
self.deletion_date = None self.deletion_date = None
self.key_material = generate_master_key() self.key_material = generate_master_key()
self.origin = "AWS_KMS"
self.key_manager = "CUSTOMER"
self.customer_master_key_spec = customer_master_key_spec or "SYMMETRIC_DEFAULT"
@property @property
def physical_resource_id(self): def physical_resource_id(self):
@ -37,23 +54,55 @@ class Key(BaseModel):
self.region, self.account_id, self.id self.region, self.account_id, self.id
) )
@property
def encryption_algorithms(self):
if self.key_usage == "SIGN_VERIFY":
return None
elif self.customer_master_key_spec == "SYMMETRIC_DEFAULT":
return ["SYMMETRIC_DEFAULT"]
else:
return ["RSAES_OAEP_SHA_1", "RSAES_OAEP_SHA_256"]
@property
def signing_algorithms(self):
if self.key_usage == "ENCRYPT_DECRYPT":
return None
elif self.customer_master_key_spec in ["ECC_NIST_P256", "ECC_SECG_P256K1"]:
return ["ECDSA_SHA_256"]
elif self.customer_master_key_spec == "ECC_NIST_P384":
return ["ECDSA_SHA_384"]
elif self.customer_master_key_spec == "ECC_NIST_P521":
return ["ECDSA_SHA_512"]
else:
return [
"RSASSA_PKCS1_V1_5_SHA_256",
"RSASSA_PKCS1_V1_5_SHA_384",
"RSASSA_PKCS1_V1_5_SHA_512",
"RSASSA_PSS_SHA_256",
"RSASSA_PSS_SHA_384",
"RSASSA_PSS_SHA_512",
]
def to_dict(self): def to_dict(self):
key_dict = { key_dict = {
"KeyMetadata": { "KeyMetadata": {
"AWSAccountId": self.account_id, "AWSAccountId": self.account_id,
"Arn": self.arn, "Arn": self.arn,
"CreationDate": iso_8601_datetime_without_milliseconds(datetime.now()), "CreationDate": self.creation_date,
"CustomerMasterKeySpec": self.customer_master_key_spec,
"Description": self.description, "Description": self.description,
"Enabled": self.enabled, "Enabled": self.enabled,
"EncryptionAlgorithms": self.encryption_algorithms,
"KeyId": self.id, "KeyId": self.id,
"KeyManager": self.key_manager,
"KeyUsage": self.key_usage, "KeyUsage": self.key_usage,
"KeyState": self.key_state, "KeyState": self.key_state,
"Origin": self.origin,
"SigningAlgorithms": self.signing_algorithms,
} }
} }
if self.key_state == "PendingDeletion": if self.key_state == "PendingDeletion":
key_dict["KeyMetadata"][ key_dict["KeyMetadata"]["DeletionDate"] = unix_time(self.deletion_date)
"DeletionDate"
] = iso_8601_datetime_without_milliseconds(self.deletion_date)
return key_dict return key_dict
def delete(self, region_name): def delete(self, region_name):
@ -69,6 +118,7 @@ class Key(BaseModel):
key = kms_backend.create_key( key = kms_backend.create_key(
policy=properties["KeyPolicy"], policy=properties["KeyPolicy"],
key_usage="ENCRYPT_DECRYPT", key_usage="ENCRYPT_DECRYPT",
customer_master_key_spec="SYMMETRIC_DEFAULT",
description=properties["Description"], description=properties["Description"],
region=region_name, region=region_name,
) )
@ -92,8 +142,17 @@ class KmsBackend(BaseBackend):
self.key_to_aliases = defaultdict(set) self.key_to_aliases = defaultdict(set)
self.tagger = TaggingService(keyName='TagKey', valueName='TagValue') self.tagger = TaggingService(keyName='TagKey', valueName='TagValue')
<<<<<<< HEAD
def create_key(self, policy, key_usage, description, tags, region): def create_key(self, policy, key_usage, description, tags, region):
key = Key(policy, key_usage, description, region) key = Key(policy, key_usage, description, region)
=======
def create_key(
self, policy, key_usage, customer_master_key_spec, description, tags, region
):
key = Key(
policy, key_usage, customer_master_key_spec, description, tags, region
)
>>>>>>> 100dbd529f174f18d579a1dcc066d55409f2e38f
self.keys[key.id] = key self.keys[key.id] = key
if tags != None and len(tags) > 0: if tags != None and len(tags) > 0:
self.tag_resource(key.id, tags) self.tag_resource(key.id, tags)
@ -211,9 +270,7 @@ class KmsBackend(BaseBackend):
self.keys[key_id].deletion_date = datetime.now() + timedelta( self.keys[key_id].deletion_date = datetime.now() + timedelta(
days=pending_window_in_days days=pending_window_in_days
) )
return iso_8601_datetime_without_milliseconds( return unix_time(self.keys[key_id].deletion_date)
self.keys[key_id].deletion_date
)
def encrypt(self, key_id, plaintext, encryption_context): def encrypt(self, key_id, plaintext, encryption_context):
key_id = self.any_id_to_key_id(key_id) key_id = self.any_id_to_key_id(key_id)

View File

@ -118,11 +118,12 @@ class KmsResponse(BaseResponse):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_CreateKey.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_CreateKey.html"""
policy = self.parameters.get("Policy") policy = self.parameters.get("Policy")
key_usage = self.parameters.get("KeyUsage") key_usage = self.parameters.get("KeyUsage")
customer_master_key_spec = self.parameters.get("CustomerMasterKeySpec")
description = self.parameters.get("Description") description = self.parameters.get("Description")
tags = self.parameters.get("Tags") tags = self.parameters.get("Tags")
key = self.kms_backend.create_key( key = self.kms_backend.create_key(
policy, key_usage, description, tags, self.region policy, key_usage, customer_master_key_spec, description, tags, self.region
) )
return json.dumps(key.to_dict()) return json.dumps(key.to_dict())

View File

@ -103,7 +103,7 @@ class LogsResponse(BaseResponse):
( (
events, events,
next_backward_token, next_backward_token,
next_foward_token, next_forward_token,
) = self.logs_backend.get_log_events( ) = self.logs_backend.get_log_events(
log_group_name, log_group_name,
log_stream_name, log_stream_name,
@ -117,7 +117,7 @@ class LogsResponse(BaseResponse):
{ {
"events": events, "events": events,
"nextBackwardToken": next_backward_token, "nextBackwardToken": next_backward_token,
"nextForwardToken": next_foward_token, "nextForwardToken": next_forward_token,
} }
) )

View File

@ -10,3 +10,13 @@ class InvalidInputException(JsonRESTError):
"InvalidInputException", "InvalidInputException",
"You provided a value that does not match the required pattern.", "You provided a value that does not match the required pattern.",
) )
class DuplicateOrganizationalUnitException(JsonRESTError):
code = 400
def __init__(self):
super(DuplicateOrganizationalUnitException, self).__init__(
"DuplicateOrganizationalUnitException",
"An OU with the same name already exists.",
)

View File

@ -8,7 +8,10 @@ from moto.core import BaseBackend, BaseModel
from moto.core.exceptions import RESTError from moto.core.exceptions import RESTError
from moto.core.utils import unix_time from moto.core.utils import unix_time
from moto.organizations import utils from moto.organizations import utils
from moto.organizations.exceptions import InvalidInputException from moto.organizations.exceptions import (
InvalidInputException,
DuplicateOrganizationalUnitException,
)
class FakeOrganization(BaseModel): class FakeOrganization(BaseModel):
@ -222,6 +225,14 @@ class OrganizationsBackend(BaseBackend):
self.attach_policy(PolicyId=utils.DEFAULT_POLICY_ID, TargetId=new_ou.id) self.attach_policy(PolicyId=utils.DEFAULT_POLICY_ID, TargetId=new_ou.id)
return new_ou.describe() return new_ou.describe()
def update_organizational_unit(self, **kwargs):
for ou in self.ou:
if ou.name == kwargs["Name"]:
raise DuplicateOrganizationalUnitException
ou = self.get_organizational_unit_by_id(kwargs["OrganizationalUnitId"])
ou.name = kwargs["Name"]
return ou.describe()
def get_organizational_unit_by_id(self, ou_id): def get_organizational_unit_by_id(self, ou_id):
ou = next((ou for ou in self.ou if ou.id == ou_id), None) ou = next((ou for ou in self.ou if ou.id == ou_id), None)
if ou is None: if ou is None:

View File

@ -36,6 +36,11 @@ class OrganizationsResponse(BaseResponse):
self.organizations_backend.create_organizational_unit(**self.request_params) self.organizations_backend.create_organizational_unit(**self.request_params)
) )
def update_organizational_unit(self):
return json.dumps(
self.organizations_backend.update_organizational_unit(**self.request_params)
)
def describe_organizational_unit(self): def describe_organizational_unit(self):
return json.dumps( return json.dumps(
self.organizations_backend.describe_organizational_unit( self.organizations_backend.describe_organizational_unit(

View File

@ -130,7 +130,9 @@ class Database(BaseModel):
if not self.option_group_name and self.engine in self.default_option_groups: if not self.option_group_name and self.engine in self.default_option_groups:
self.option_group_name = self.default_option_groups[self.engine] self.option_group_name = self.default_option_groups[self.engine]
self.character_set_name = kwargs.get("character_set_name", None) self.character_set_name = kwargs.get("character_set_name", None)
self.iam_database_authentication_enabled = False self.enable_iam_database_authentication = kwargs.get(
"enable_iam_database_authentication", False
)
self.dbi_resource_id = "db-M5ENSHXFPU6XHZ4G4ZEI5QIO2U" self.dbi_resource_id = "db-M5ENSHXFPU6XHZ4G4ZEI5QIO2U"
self.tags = kwargs.get("tags", []) self.tags = kwargs.get("tags", [])
@ -214,7 +216,7 @@ class Database(BaseModel):
<ReadReplicaSourceDBInstanceIdentifier>{{ database.source_db_identifier }}</ReadReplicaSourceDBInstanceIdentifier> <ReadReplicaSourceDBInstanceIdentifier>{{ database.source_db_identifier }}</ReadReplicaSourceDBInstanceIdentifier>
{% endif %} {% endif %}
<Engine>{{ database.engine }}</Engine> <Engine>{{ database.engine }}</Engine>
<IAMDatabaseAuthenticationEnabled>{{database.iam_database_authentication_enabled }}</IAMDatabaseAuthenticationEnabled> <IAMDatabaseAuthenticationEnabled>{{database.enable_iam_database_authentication|lower }}</IAMDatabaseAuthenticationEnabled>
<LicenseModel>{{ database.license_model }}</LicenseModel> <LicenseModel>{{ database.license_model }}</LicenseModel>
<EngineVersion>{{ database.engine_version }}</EngineVersion> <EngineVersion>{{ database.engine_version }}</EngineVersion>
<OptionGroupMemberships> <OptionGroupMemberships>
@ -542,7 +544,7 @@ class Snapshot(BaseModel):
<KmsKeyId>{{ database.kms_key_id }}</KmsKeyId> <KmsKeyId>{{ database.kms_key_id }}</KmsKeyId>
<DBSnapshotArn>{{ snapshot.snapshot_arn }}</DBSnapshotArn> <DBSnapshotArn>{{ snapshot.snapshot_arn }}</DBSnapshotArn>
<Timezone></Timezone> <Timezone></Timezone>
<IAMDatabaseAuthenticationEnabled>false</IAMDatabaseAuthenticationEnabled> <IAMDatabaseAuthenticationEnabled>{{ database.enable_iam_database_authentication|lower }}</IAMDatabaseAuthenticationEnabled>
</DBSnapshot>""" </DBSnapshot>"""
) )
return template.render(snapshot=self, database=self.database) return template.render(snapshot=self, database=self.database)
@ -986,7 +988,7 @@ class RDS2Backend(BaseBackend):
) )
if option_group_kwargs["engine_name"] not in valid_option_group_engines.keys(): if option_group_kwargs["engine_name"] not in valid_option_group_engines.keys():
raise RDSClientError( raise RDSClientError(
"InvalidParameterValue", "Invalid DB engine: non-existant" "InvalidParameterValue", "Invalid DB engine: non-existent"
) )
if ( if (
option_group_kwargs["major_engine_version"] option_group_kwargs["major_engine_version"]

View File

@ -27,6 +27,9 @@ class RDS2Response(BaseResponse):
"db_subnet_group_name": self._get_param("DBSubnetGroupName"), "db_subnet_group_name": self._get_param("DBSubnetGroupName"),
"engine": self._get_param("Engine"), "engine": self._get_param("Engine"),
"engine_version": self._get_param("EngineVersion"), "engine_version": self._get_param("EngineVersion"),
"enable_iam_database_authentication": self._get_bool_param(
"EnableIAMDatabaseAuthentication"
),
"license_model": self._get_param("LicenseModel"), "license_model": self._get_param("LicenseModel"),
"iops": self._get_int_param("Iops"), "iops": self._get_int_param("Iops"),
"kms_key_id": self._get_param("KmsKeyId"), "kms_key_id": self._get_param("KmsKeyId"),
@ -367,14 +370,14 @@ class RDS2Response(BaseResponse):
def modify_db_parameter_group(self): def modify_db_parameter_group(self):
db_parameter_group_name = self._get_param("DBParameterGroupName") db_parameter_group_name = self._get_param("DBParameterGroupName")
db_parameter_group_parameters = self._get_db_parameter_group_paramters() db_parameter_group_parameters = self._get_db_parameter_group_parameters()
db_parameter_group = self.backend.modify_db_parameter_group( db_parameter_group = self.backend.modify_db_parameter_group(
db_parameter_group_name, db_parameter_group_parameters db_parameter_group_name, db_parameter_group_parameters
) )
template = self.response_template(MODIFY_DB_PARAMETER_GROUP_TEMPLATE) template = self.response_template(MODIFY_DB_PARAMETER_GROUP_TEMPLATE)
return template.render(db_parameter_group=db_parameter_group) return template.render(db_parameter_group=db_parameter_group)
def _get_db_parameter_group_paramters(self): def _get_db_parameter_group_parameters(self):
parameter_group_parameters = defaultdict(dict) parameter_group_parameters = defaultdict(dict)
for param_name, value in self.querystring.items(): for param_name, value in self.querystring.items():
if not param_name.startswith("Parameters.Parameter"): if not param_name.startswith("Parameters.Parameter"):

View File

@ -271,6 +271,7 @@ LIST_RRSET_RESPONSE = """<ListResourceRecordSetsResponse xmlns="https://route53.
{{ record_set.to_xml() }} {{ record_set.to_xml() }}
{% endfor %} {% endfor %}
</ResourceRecordSets> </ResourceRecordSets>
<IsTruncated>false</IsTruncated>
</ListResourceRecordSetsResponse>""" </ListResourceRecordSetsResponse>"""
CHANGE_RRSET_RESPONSE = """<ChangeResourceRecordSetsResponse xmlns="https://route53.amazonaws.com/doc/2012-12-12/"> CHANGE_RRSET_RESPONSE = """<ChangeResourceRecordSetsResponse xmlns="https://route53.amazonaws.com/doc/2012-12-12/">

View File

@ -1,8 +1,13 @@
import datetime
import json import json
import time
from boto3 import Session
from moto.core.exceptions import InvalidNextTokenException from moto.core.exceptions import InvalidNextTokenException
from moto.core.models import ConfigQueryModel from moto.core.models import ConfigQueryModel
from moto.s3 import s3_backends from moto.s3 import s3_backends
from moto.s3.models import get_moto_s3_account_id
class S3ConfigQuery(ConfigQueryModel): class S3ConfigQuery(ConfigQueryModel):
@ -118,4 +123,146 @@ class S3ConfigQuery(ConfigQueryModel):
return config_data return config_data
class S3AccountPublicAccessBlockConfigQuery(ConfigQueryModel):
def list_config_service_resources(
self,
resource_ids,
resource_name,
limit,
next_token,
backend_region=None,
resource_region=None,
):
# For the Account Public Access Block, they are the same for all regions. The resource ID is the AWS account ID
# There is no resource name -- it should be a blank string "" if provided.
# The resource name can only ever be None or an empty string:
if resource_name is not None and resource_name != "":
return [], None
pab = None
account_id = get_moto_s3_account_id()
regions = [region for region in Session().get_available_regions("config")]
# If a resource ID was passed in, then filter accordingly:
if resource_ids:
for id in resource_ids:
if account_id == id:
pab = self.backends["global"].account_public_access_block
break
# Otherwise, just grab the one from the backend:
if not resource_ids:
pab = self.backends["global"].account_public_access_block
# If it's not present, then return nothing
if not pab:
return [], None
# Filter on regions (and paginate on them as well):
if backend_region:
pab_list = [backend_region]
elif resource_region:
# Invalid region?
if resource_region not in regions:
return [], None
pab_list = [resource_region]
# Aggregated query where no regions were supplied so return them all:
else:
pab_list = regions
# Pagination logic:
sorted_regions = sorted(pab_list)
new_token = None
# Get the start:
if not next_token:
start = 0
else:
# Tokens for this moto feature is just the region-name:
# For OTHER non-global resource types, it's the region concatenated with the resource ID.
if next_token not in sorted_regions:
raise InvalidNextTokenException()
start = sorted_regions.index(next_token)
# Get the list of items to collect:
pab_list = sorted_regions[start : (start + limit)]
if len(sorted_regions) > (start + limit):
new_token = sorted_regions[start + limit]
return (
[
{
"type": "AWS::S3::AccountPublicAccessBlock",
"id": account_id,
"region": region,
}
for region in pab_list
],
new_token,
)
def get_config_resource(
self, resource_id, resource_name=None, backend_region=None, resource_region=None
):
# Do we even have this defined?
if not self.backends["global"].account_public_access_block:
return None
# Resource name can only ever be "" if it's supplied:
if resource_name is not None and resource_name != "":
return None
# Are we filtering based on region?
account_id = get_moto_s3_account_id()
regions = [region for region in Session().get_available_regions("config")]
# Is the resource ID correct?:
if account_id == resource_id:
if backend_region:
pab_region = backend_region
# Invalid region?
elif resource_region not in regions:
return None
else:
pab_region = resource_region
else:
return None
# Format the PAB to the AWS Config format:
creation_time = datetime.datetime.utcnow()
config_data = {
"version": "1.3",
"accountId": account_id,
"configurationItemCaptureTime": str(creation_time),
"configurationItemStatus": "OK",
"configurationStateId": str(
int(time.mktime(creation_time.timetuple()))
), # PY2 and 3 compatible
"resourceType": "AWS::S3::AccountPublicAccessBlock",
"resourceId": account_id,
"awsRegion": pab_region,
"availabilityZone": "Not Applicable",
"configuration": self.backends[
"global"
].account_public_access_block.to_config_dict(),
"supplementaryConfiguration": {},
}
# The 'configuration' field is also a JSON string:
config_data["configuration"] = json.dumps(config_data["configuration"])
return config_data
s3_config_query = S3ConfigQuery(s3_backends) s3_config_query = S3ConfigQuery(s3_backends)
s3_account_public_access_block_query = S3AccountPublicAccessBlockConfigQuery(
s3_backends
)

View File

@ -127,6 +127,18 @@ class InvalidRequest(S3ClientError):
) )
class IllegalLocationConstraintException(S3ClientError):
code = 400
def __init__(self, *args, **kwargs):
super(IllegalLocationConstraintException, self).__init__(
"IllegalLocationConstraintException",
"The unspecified location constraint is incompatible for the region specific endpoint this request was sent to.",
*args,
**kwargs
)
class MalformedXML(S3ClientError): class MalformedXML(S3ClientError):
code = 400 code = 400
@ -347,3 +359,12 @@ class InvalidPublicAccessBlockConfiguration(S3ClientError):
*args, *args,
**kwargs **kwargs
) )
class WrongPublicAccessBlockAccountIdError(S3ClientError):
code = 403
def __init__(self):
super(WrongPublicAccessBlockAccountIdError, self).__init__(
"AccessDenied", "Access Denied"
)

View File

@ -19,7 +19,7 @@ import uuid
import six import six
from bisect import insort from bisect import insort
from moto.core import BaseBackend, BaseModel from moto.core import ACCOUNT_ID, BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_with_milliseconds, rfc_1123_datetime from moto.core.utils import iso_8601_datetime_with_milliseconds, rfc_1123_datetime
from .exceptions import ( from .exceptions import (
BucketAlreadyExists, BucketAlreadyExists,
@ -37,6 +37,7 @@ from .exceptions import (
CrossLocationLoggingProhibitted, CrossLocationLoggingProhibitted,
NoSuchPublicAccessBlockConfiguration, NoSuchPublicAccessBlockConfiguration,
InvalidPublicAccessBlockConfiguration, InvalidPublicAccessBlockConfiguration,
WrongPublicAccessBlockAccountIdError,
) )
from .utils import clean_key_name, _VersionedKeyStore from .utils import clean_key_name, _VersionedKeyStore
@ -58,6 +59,13 @@ DEFAULT_TEXT_ENCODING = sys.getdefaultencoding()
OWNER = "75aa57f09aa0c8caeab4f8c24e99d10f8e7faeebf76c078efc7c6caea54ba06a" OWNER = "75aa57f09aa0c8caeab4f8c24e99d10f8e7faeebf76c078efc7c6caea54ba06a"
def get_moto_s3_account_id():
"""This makes it easy for mocking AWS Account IDs when using AWS Config
-- Simply mock.patch the ACCOUNT_ID here, and Config gets it for free.
"""
return ACCOUNT_ID
class FakeDeleteMarker(BaseModel): class FakeDeleteMarker(BaseModel):
def __init__(self, key): def __init__(self, key):
self.key = key self.key = key
@ -1163,6 +1171,7 @@ class FakeBucket(BaseModel):
class S3Backend(BaseBackend): class S3Backend(BaseBackend):
def __init__(self): def __init__(self):
self.buckets = {} self.buckets = {}
self.account_public_access_block = None
def create_bucket(self, bucket_name, region_name): def create_bucket(self, bucket_name, region_name):
if bucket_name in self.buckets: if bucket_name in self.buckets:
@ -1264,6 +1273,16 @@ class S3Backend(BaseBackend):
return bucket.public_access_block return bucket.public_access_block
def get_account_public_access_block(self, account_id):
# The account ID should equal the account id that is set for Moto:
if account_id != ACCOUNT_ID:
raise WrongPublicAccessBlockAccountIdError()
if not self.account_public_access_block:
raise NoSuchPublicAccessBlockConfiguration()
return self.account_public_access_block
def set_key( def set_key(
self, bucket_name, key_name, value, storage=None, etag=None, multipart=None self, bucket_name, key_name, value, storage=None, etag=None, multipart=None
): ):
@ -1356,6 +1375,13 @@ class S3Backend(BaseBackend):
bucket = self.get_bucket(bucket_name) bucket = self.get_bucket(bucket_name)
bucket.public_access_block = None bucket.public_access_block = None
def delete_account_public_access_block(self, account_id):
# The account ID should equal the account id that is set for Moto:
if account_id != ACCOUNT_ID:
raise WrongPublicAccessBlockAccountIdError()
self.account_public_access_block = None
def put_bucket_notification_configuration(self, bucket_name, notification_config): def put_bucket_notification_configuration(self, bucket_name, notification_config):
bucket = self.get_bucket(bucket_name) bucket = self.get_bucket(bucket_name)
bucket.set_notification_configuration(notification_config) bucket.set_notification_configuration(notification_config)
@ -1384,6 +1410,21 @@ class S3Backend(BaseBackend):
pub_block_config.get("RestrictPublicBuckets"), pub_block_config.get("RestrictPublicBuckets"),
) )
def put_account_public_access_block(self, account_id, pub_block_config):
# The account ID should equal the account id that is set for Moto:
if account_id != ACCOUNT_ID:
raise WrongPublicAccessBlockAccountIdError()
if not pub_block_config:
raise InvalidPublicAccessBlockConfiguration()
self.account_public_access_block = PublicAccessBlock(
pub_block_config.get("BlockPublicAcls"),
pub_block_config.get("IgnorePublicAcls"),
pub_block_config.get("BlockPublicPolicy"),
pub_block_config.get("RestrictPublicBuckets"),
)
def initiate_multipart(self, bucket_name, key_name, metadata): def initiate_multipart(self, bucket_name, key_name, metadata):
bucket = self.get_bucket(bucket_name) bucket = self.get_bucket(bucket_name)
new_multipart = FakeMultipart(key_name, metadata) new_multipart = FakeMultipart(key_name, metadata)

View File

@ -4,6 +4,7 @@ import re
import sys import sys
import six import six
from botocore.awsrequest import AWSPreparedRequest
from moto.core.utils import str_to_rfc_1123_datetime, py2_strip_unicode_keys from moto.core.utils import str_to_rfc_1123_datetime, py2_strip_unicode_keys
from six.moves.urllib.parse import parse_qs, urlparse, unquote from six.moves.urllib.parse import parse_qs, urlparse, unquote
@ -29,6 +30,7 @@ from .exceptions import (
InvalidPartOrder, InvalidPartOrder,
MalformedXML, MalformedXML,
MalformedACLError, MalformedACLError,
IllegalLocationConstraintException,
InvalidNotificationARN, InvalidNotificationARN,
InvalidNotificationEvent, InvalidNotificationEvent,
ObjectNotInActiveTierError, ObjectNotInActiveTierError,
@ -122,6 +124,11 @@ ACTION_MAP = {
"uploadId": "PutObject", "uploadId": "PutObject",
}, },
}, },
"CONTROL": {
"GET": {"publicAccessBlock": "GetPublicAccessBlock"},
"PUT": {"publicAccessBlock": "PutPublicAccessBlock"},
"DELETE": {"publicAccessBlock": "DeletePublicAccessBlock"},
},
} }
@ -167,7 +174,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
or host.startswith("localhost") or host.startswith("localhost")
or host.startswith("localstack") or host.startswith("localstack")
or re.match(r"^[^.]+$", host) or re.match(r"^[^.]+$", host)
or re.match(r"^.*\.svc\.cluster\.local$", host) or re.match(r"^.*\.svc\.cluster\.local:?\d*$", host)
): ):
# Default to path-based buckets for (1) localhost, (2) localstack hosts (e.g. localstack.dev), # Default to path-based buckets for (1) localhost, (2) localstack hosts (e.g. localstack.dev),
# (3) local host names that do not contain a "." (e.g., Docker container host names), or # (3) local host names that do not contain a "." (e.g., Docker container host names), or
@ -219,7 +226,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
# Depending on which calling format the client is using, we don't know # Depending on which calling format the client is using, we don't know
# if this is a bucket or key request so we have to check # if this is a bucket or key request so we have to check
if self.subdomain_based_buckets(request): if self.subdomain_based_buckets(request):
return self.key_response(request, full_url, headers) return self.key_or_control_response(request, full_url, headers)
else: else:
# Using path-based buckets # Using path-based buckets
return self.bucket_response(request, full_url, headers) return self.bucket_response(request, full_url, headers)
@ -286,7 +293,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return self._bucket_response_post(request, body, bucket_name) return self._bucket_response_post(request, body, bucket_name)
else: else:
raise NotImplementedError( raise NotImplementedError(
"Method {0} has not been impelemented in the S3 backend yet".format( "Method {0} has not been implemented in the S3 backend yet".format(
method method
) )
) )
@ -585,6 +592,29 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
next_continuation_token = None next_continuation_token = None
return result_keys, is_truncated, next_continuation_token return result_keys, is_truncated, next_continuation_token
def _body_contains_location_constraint(self, body):
if body:
try:
xmltodict.parse(body)["CreateBucketConfiguration"]["LocationConstraint"]
return True
except KeyError:
pass
return False
def _parse_pab_config(self, body):
parsed_xml = xmltodict.parse(body)
parsed_xml["PublicAccessBlockConfiguration"].pop("@xmlns", None)
# If Python 2, fix the unicode strings:
if sys.version_info[0] < 3:
parsed_xml = {
"PublicAccessBlockConfiguration": py2_strip_unicode_keys(
dict(parsed_xml["PublicAccessBlockConfiguration"])
)
}
return parsed_xml
def _bucket_response_put( def _bucket_response_put(
self, request, body, region_name, bucket_name, querystring self, request, body, region_name, bucket_name, querystring
): ):
@ -663,27 +693,23 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
raise e raise e
elif "publicAccessBlock" in querystring: elif "publicAccessBlock" in querystring:
parsed_xml = xmltodict.parse(body) pab_config = self._parse_pab_config(body)
parsed_xml["PublicAccessBlockConfiguration"].pop("@xmlns", None)
# If Python 2, fix the unicode strings:
if sys.version_info[0] < 3:
parsed_xml = {
"PublicAccessBlockConfiguration": py2_strip_unicode_keys(
dict(parsed_xml["PublicAccessBlockConfiguration"])
)
}
self.backend.put_bucket_public_access_block( self.backend.put_bucket_public_access_block(
bucket_name, parsed_xml["PublicAccessBlockConfiguration"] bucket_name, pab_config["PublicAccessBlockConfiguration"]
) )
return "" return ""
else: else:
if body:
# us-east-1, the default AWS region behaves a bit differently # us-east-1, the default AWS region behaves a bit differently
# - you should not use it as a location constraint --> it fails # - you should not use it as a location constraint --> it fails
# - querying the location constraint returns None # - querying the location constraint returns None
# - LocationConstraint has to be specified if outside us-east-1
if (
region_name != DEFAULT_REGION_NAME
and not self._body_contains_location_constraint(body)
):
raise IllegalLocationConstraintException()
if body:
try: try:
forced_region = xmltodict.parse(body)["CreateBucketConfiguration"][ forced_region = xmltodict.parse(body)["CreateBucketConfiguration"][
"LocationConstraint" "LocationConstraint"
@ -854,14 +880,20 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
) )
return 206, response_headers, response_content[begin : end + 1] return 206, response_headers, response_content[begin : end + 1]
def key_response(self, request, full_url, headers): def key_or_control_response(self, request, full_url, headers):
# Key and Control are lumped in because splitting out the regex is too much of a pain :/
self.method = request.method self.method = request.method
self.path = self._get_path(request) self.path = self._get_path(request)
self.headers = request.headers self.headers = request.headers
if "host" not in self.headers: if "host" not in self.headers:
self.headers["host"] = urlparse(full_url).netloc self.headers["host"] = urlparse(full_url).netloc
response_headers = {} response_headers = {}
try: try:
# Is this an S3 control response?
if isinstance(request, AWSPreparedRequest) and "s3-control" in request.url:
response = self._control_response(request, full_url, headers)
else:
response = self._key_response(request, full_url, headers) response = self._key_response(request, full_url, headers)
except S3ClientError as s3error: except S3ClientError as s3error:
response = s3error.code, {}, s3error.description response = s3error.code, {}, s3error.description
@ -878,6 +910,94 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
) )
return status_code, response_headers, response_content return status_code, response_headers, response_content
def _control_response(self, request, full_url, headers):
parsed_url = urlparse(full_url)
query = parse_qs(parsed_url.query, keep_blank_values=True)
method = request.method
if hasattr(request, "body"):
# Boto
body = request.body
if hasattr(body, "read"):
body = body.read()
else:
# Flask server
body = request.data
if body is None:
body = b""
if method == "GET":
return self._control_response_get(request, query, headers)
elif method == "PUT":
return self._control_response_put(request, body, query, headers)
elif method == "DELETE":
return self._control_response_delete(request, query, headers)
else:
raise NotImplementedError(
"Method {0} has not been implemented in the S3 backend yet".format(
method
)
)
def _control_response_get(self, request, query, headers):
action = self.path.split("?")[0].split("/")[
-1
] # Gets the action out of the URL sans query params.
self._set_action("CONTROL", "GET", action)
self._authenticate_and_authorize_s3_action()
response_headers = {}
if "publicAccessBlock" in action:
public_block_config = self.backend.get_account_public_access_block(
headers["x-amz-account-id"]
)
template = self.response_template(S3_PUBLIC_ACCESS_BLOCK_CONFIGURATION)
return (
200,
response_headers,
template.render(public_block_config=public_block_config),
)
raise NotImplementedError(
"Method {0} has not been implemented in the S3 backend yet".format(action)
)
def _control_response_put(self, request, body, query, headers):
action = self.path.split("?")[0].split("/")[
-1
] # Gets the action out of the URL sans query params.
self._set_action("CONTROL", "PUT", action)
self._authenticate_and_authorize_s3_action()
response_headers = {}
if "publicAccessBlock" in action:
pab_config = self._parse_pab_config(body)
self.backend.put_account_public_access_block(
headers["x-amz-account-id"],
pab_config["PublicAccessBlockConfiguration"],
)
return 200, response_headers, ""
raise NotImplementedError(
"Method {0} has not been implemented in the S3 backend yet".format(action)
)
def _control_response_delete(self, request, query, headers):
action = self.path.split("?")[0].split("/")[
-1
] # Gets the action out of the URL sans query params.
self._set_action("CONTROL", "DELETE", action)
self._authenticate_and_authorize_s3_action()
response_headers = {}
if "publicAccessBlock" in action:
self.backend.delete_account_public_access_block(headers["x-amz-account-id"])
return 200, response_headers, ""
raise NotImplementedError(
"Method {0} has not been implemented in the S3 backend yet".format(action)
)
def _key_response(self, request, full_url, headers): def _key_response(self, request, full_url, headers):
parsed_url = urlparse(full_url) parsed_url = urlparse(full_url)
query = parse_qs(parsed_url.query, keep_blank_values=True) query = parse_qs(parsed_url.query, keep_blank_values=True)
@ -1082,6 +1202,10 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
if mdirective is not None and mdirective == "REPLACE": if mdirective is not None and mdirective == "REPLACE":
metadata = metadata_from_headers(request.headers) metadata = metadata_from_headers(request.headers)
new_key.set_metadata(metadata, replace=True) new_key.set_metadata(metadata, replace=True)
tdirective = request.headers.get("x-amz-tagging-directive")
if tdirective == "REPLACE":
tagging = self._tagging_from_headers(request.headers)
new_key.set_tagging(tagging)
template = self.response_template(S3_OBJECT_COPY_RESPONSE) template = self.response_template(S3_OBJECT_COPY_RESPONSE)
response_headers.update(new_key.response_dict) response_headers.update(new_key.response_dict)
return 200, response_headers, template.render(key=new_key) return 200, response_headers, template.render(key=new_key)
@ -1482,7 +1606,7 @@ S3_ALL_BUCKETS = """<ListAllMyBucketsResult xmlns="http://s3.amazonaws.com/doc/2
{% for bucket in buckets %} {% for bucket in buckets %}
<Bucket> <Bucket>
<Name>{{ bucket.name }}</Name> <Name>{{ bucket.name }}</Name>
<CreationDate>{{ bucket.creation_date }}</CreationDate> <CreationDate>{{ bucket.creation_date.isoformat() }}</CreationDate>
</Bucket> </Bucket>
{% endfor %} {% endfor %}
</Buckets> </Buckets>
@ -1869,7 +1993,6 @@ S3_MULTIPART_LIST_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<ID>75aa57f09aa0c8caeab4f8c24e99d10f8e7faeebf76c078efc7c6caea54ba06a</ID> <ID>75aa57f09aa0c8caeab4f8c24e99d10f8e7faeebf76c078efc7c6caea54ba06a</ID>
<DisplayName>webfile</DisplayName> <DisplayName>webfile</DisplayName>
</Owner> </Owner>
<StorageClass>STANDARD</StorageClass>
<PartNumberMarker>1</PartNumberMarker> <PartNumberMarker>1</PartNumberMarker>
<NextPartNumberMarker>{{ count }}</NextPartNumberMarker> <NextPartNumberMarker>{{ count }}</NextPartNumberMarker>
<MaxParts>{{ count }}</MaxParts> <MaxParts>{{ count }}</MaxParts>

View File

@ -13,7 +13,7 @@ url_paths = {
# subdomain key of path-based bucket # subdomain key of path-based bucket
"{0}/(?P<key_or_bucket_name>[^/]+)/?$": S3ResponseInstance.ambiguous_response, "{0}/(?P<key_or_bucket_name>[^/]+)/?$": S3ResponseInstance.ambiguous_response,
# path-based bucket + key # path-based bucket + key
"{0}/(?P<bucket_name_path>[^/]+)/(?P<key_name>.+)": S3ResponseInstance.key_response, "{0}/(?P<bucket_name_path>[^/]+)/(?P<key_name>.+)": S3ResponseInstance.key_or_control_response,
# subdomain bucket + key with empty first part of path # subdomain bucket + key with empty first part of path
"{0}//(?P<key_name>.*)$": S3ResponseInstance.key_response, "{0}//(?P<key_name>.*)$": S3ResponseInstance.key_or_control_response,
} }

View File

@ -37,7 +37,7 @@ def bucket_name_from_url(url):
REGION_URL_REGEX = re.compile( REGION_URL_REGEX = re.compile(
r"^https?://(s3[-\.](?P<region1>.+)\.amazonaws\.com/(.+)|" r"^https?://(s3[-\.](?P<region1>.+)\.amazonaws\.com/(.+)|"
r"(.+)\.s3-(?P<region2>.+)\.amazonaws\.com)/?" r"(.+)\.s3[-\.](?P<region2>.+)\.amazonaws\.com)/?"
) )

View File

@ -148,11 +148,15 @@ class SESBackend(BaseBackend):
def __type_of_message__(self, destinations): def __type_of_message__(self, destinations):
"""Checks the destination for any special address that could indicate delivery, """Checks the destination for any special address that could indicate delivery,
complaint or bounce like in SES simulator""" complaint or bounce like in SES simulator"""
if isinstance(destinations, list):
alladdress = destinations
else:
alladdress = ( alladdress = (
destinations.get("ToAddresses", []) destinations.get("ToAddresses", [])
+ destinations.get("CcAddresses", []) + destinations.get("CcAddresses", [])
+ destinations.get("BccAddresses", []) + destinations.get("BccAddresses", [])
) )
for addr in alladdress: for addr in alladdress:
if SESFeedback.SUCCESS_ADDR in addr: if SESFeedback.SUCCESS_ADDR in addr:
return SESFeedback.DELIVERY return SESFeedback.DELIVERY

View File

@ -99,3 +99,28 @@ class InvalidAttributeName(RESTError):
super(InvalidAttributeName, self).__init__( super(InvalidAttributeName, self).__init__(
"InvalidAttributeName", "Unknown Attribute {}.".format(attribute_name) "InvalidAttributeName", "Unknown Attribute {}.".format(attribute_name)
) )
class InvalidParameterValue(RESTError):
code = 400
def __init__(self, message):
super(InvalidParameterValue, self).__init__("InvalidParameterValue", message)
class MissingParameter(RESTError):
code = 400
def __init__(self):
super(MissingParameter, self).__init__(
"MissingParameter", "The request must contain the parameter Actions."
)
class OverLimit(RESTError):
code = 403
def __init__(self, count):
super(OverLimit, self).__init__(
"OverLimit", "{} Actions were found, maximum allowed is 7.".format(count)
)

View File

@ -30,6 +30,9 @@ from .exceptions import (
BatchEntryIdsNotDistinct, BatchEntryIdsNotDistinct,
TooManyEntriesInBatchRequest, TooManyEntriesInBatchRequest,
InvalidAttributeName, InvalidAttributeName,
InvalidParameterValue,
MissingParameter,
OverLimit,
) )
from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID
@ -183,6 +186,8 @@ class Queue(BaseModel):
"MaximumMessageSize", "MaximumMessageSize",
"MessageRetentionPeriod", "MessageRetentionPeriod",
"QueueArn", "QueueArn",
"Policy",
"RedrivePolicy",
"ReceiveMessageWaitTimeSeconds", "ReceiveMessageWaitTimeSeconds",
"VisibilityTimeout", "VisibilityTimeout",
] ]
@ -194,6 +199,8 @@ class Queue(BaseModel):
"DeleteMessage", "DeleteMessage",
"GetQueueAttributes", "GetQueueAttributes",
"GetQueueUrl", "GetQueueUrl",
"ListDeadLetterSourceQueues",
"PurgeQueue",
"ReceiveMessage", "ReceiveMessage",
"SendMessage", "SendMessage",
) )
@ -272,7 +279,7 @@ class Queue(BaseModel):
if key in bool_fields: if key in bool_fields:
value = value == "true" value = value == "true"
if key == "RedrivePolicy" and value is not None: if key in ["Policy", "RedrivePolicy"] and value is not None:
continue continue
setattr(self, camelcase_to_underscores(key), value) setattr(self, camelcase_to_underscores(key), value)
@ -280,6 +287,9 @@ class Queue(BaseModel):
if attributes.get("RedrivePolicy", None): if attributes.get("RedrivePolicy", None):
self._setup_dlq(attributes["RedrivePolicy"]) self._setup_dlq(attributes["RedrivePolicy"])
if attributes.get("Policy"):
self.policy = attributes["Policy"]
self.last_modified_timestamp = now self.last_modified_timestamp = now
def _setup_dlq(self, policy): def _setup_dlq(self, policy):
@ -471,6 +481,24 @@ class Queue(BaseModel):
return self.name return self.name
raise UnformattedGetAttTemplateException() raise UnformattedGetAttTemplateException()
@property
def policy(self):
if self._policy_json.get("Statement"):
return json.dumps(self._policy_json)
else:
return None
@policy.setter
def policy(self, policy):
if policy:
self._policy_json = json.loads(policy)
else:
self._policy_json = {
"Version": "2012-10-17",
"Id": "{}/SQSDefaultPolicy".format(self.queue_arn),
"Statement": [],
}
class SQSBackend(BaseBackend): class SQSBackend(BaseBackend):
def __init__(self, region_name): def __init__(self, region_name):
@ -539,7 +567,7 @@ class SQSBackend(BaseBackend):
for name, q in self.queues.items(): for name, q in self.queues.items():
if prefix_re.search(name): if prefix_re.search(name):
qs.append(q) qs.append(q)
return qs return qs[:1000]
def get_queue(self, queue_name): def get_queue(self, queue_name):
queue = self.queues.get(queue_name) queue = self.queues.get(queue_name)
@ -801,25 +829,75 @@ class SQSBackend(BaseBackend):
def add_permission(self, queue_name, actions, account_ids, label): def add_permission(self, queue_name, actions, account_ids, label):
queue = self.get_queue(queue_name) queue = self.get_queue(queue_name)
if actions is None or len(actions) == 0: if not actions:
raise RESTError("InvalidParameterValue", "Need at least one Action") raise MissingParameter()
if account_ids is None or len(account_ids) == 0:
raise RESTError("InvalidParameterValue", "Need at least one Account ID")
if not all([item in Queue.ALLOWED_PERMISSIONS for item in actions]): if not account_ids:
raise RESTError("InvalidParameterValue", "Invalid permissions") raise InvalidParameterValue(
"Value [] for parameter PrincipalId is invalid. Reason: Unable to verify."
)
queue.permissions[label] = (account_ids, actions) count = len(actions)
if count > 7:
raise OverLimit(count)
invalid_action = next(
(action for action in actions if action not in Queue.ALLOWED_PERMISSIONS),
None,
)
if invalid_action:
raise InvalidParameterValue(
"Value SQS:{} for parameter ActionName is invalid. "
"Reason: Only the queue owner is allowed to invoke this action.".format(
invalid_action
)
)
policy = queue._policy_json
statement = next(
(
statement
for statement in policy["Statement"]
if statement["Sid"] == label
),
None,
)
if statement:
raise InvalidParameterValue(
"Value {} for parameter Label is invalid. "
"Reason: Already exists.".format(label)
)
principals = [
"arn:aws:iam::{}:root".format(account_id) for account_id in account_ids
]
actions = ["SQS:{}".format(action) for action in actions]
statement = {
"Sid": label,
"Effect": "Allow",
"Principal": {"AWS": principals[0] if len(principals) == 1 else principals},
"Action": actions[0] if len(actions) == 1 else actions,
"Resource": queue.queue_arn,
}
queue._policy_json["Statement"].append(statement)
def remove_permission(self, queue_name, label): def remove_permission(self, queue_name, label):
queue = self.get_queue(queue_name) queue = self.get_queue(queue_name)
if label not in queue.permissions: statements = queue._policy_json["Statement"]
raise RESTError( statements_new = [
"InvalidParameterValue", "Permission doesnt exist for the given label" statement for statement in statements if statement["Sid"] != label
]
if len(statements) == len(statements_new):
raise InvalidParameterValue(
"Value {} for parameter Label is invalid. "
"Reason: can't find label on existing policy.".format(label)
) )
del queue.permissions[label] queue._policy_json["Statement"] = statements_new
def tag_queue(self, queue_name, tags): def tag_queue(self, queue_name, tags):
queue = self.get_queue(queue_name) queue = self.get_queue(queue_name)

View File

@ -127,6 +127,10 @@ class WorkflowExecution(BaseModel):
"executionInfo": self.to_medium_dict(), "executionInfo": self.to_medium_dict(),
"executionConfiguration": {"taskList": {"name": self.task_list}}, "executionConfiguration": {"taskList": {"name": self.task_list}},
} }
# info
if self.execution_status == "CLOSED":
hsh["executionInfo"]["closeStatus"] = self.close_status
hsh["executionInfo"]["closeTimestamp"] = self.close_timestamp
# configuration # configuration
for key in self._configuration_keys: for key in self._configuration_keys:
attr = camelcase_to_underscores(key) attr = camelcase_to_underscores(key)

View File

@ -1,5 +1,5 @@
class TaggingService: class TaggingService:
def __init__(self, tagName='Tags', keyName='Key', valueName='Value'): def __init__(self, tagName="Tags", keyName="Key", valueName="Value"):
self.tagName = tagName self.tagName = tagName
self.keyName = keyName self.keyName = keyName
self.valueName = valueName self.valueName = valueName
@ -12,6 +12,12 @@ class TaggingService:
result.append({self.keyName: k, self.valueName: v}) result.append({self.keyName: k, self.valueName: v})
return {self.tagName: result} return {self.tagName: result}
def delete_all_tags_for_resource(self, arn):
del self.tags[arn]
def has_tags(self, arn):
return arn in self.tags
def tag_resource(self, arn, tags): def tag_resource(self, arn, tags):
if arn not in self.tags: if arn not in self.tags:
self.tags[arn] = {} self.tags[arn] = {}

View File

@ -20,8 +20,8 @@ import jinja2
from prompt_toolkit import ( from prompt_toolkit import (
prompt prompt
) )
from prompt_toolkit.contrib.completers import WordCompleter from prompt_toolkit.completion import WordCompleter
from prompt_toolkit.shortcuts import print_tokens from prompt_toolkit.shortcuts import print_formatted_text
from botocore import xform_name from botocore import xform_name
from botocore.session import Session from botocore.session import Session
@ -149,12 +149,12 @@ def append_mock_dict_to_backends_py(service):
with open(path) as f: with open(path) as f:
lines = [_.replace('\n', '') for _ in f.readlines()] lines = [_.replace('\n', '') for _ in f.readlines()]
if any(_ for _ in lines if re.match(".*'{}': {}_backends.*".format(service, service), _)): if any(_ for _ in lines if re.match(".*\"{}\": {}_backends.*".format(service, service), _)):
return return
filtered_lines = [_ for _ in lines if re.match(".*'.*':.*_backends.*", _)] filtered_lines = [_ for _ in lines if re.match(".*\".*\":.*_backends.*", _)]
last_elem_line_index = lines.index(filtered_lines[-1]) last_elem_line_index = lines.index(filtered_lines[-1])
new_line = " '{}': {}_backends,".format(service, get_escaped_service(service)) new_line = " \"{}\": {}_backends,".format(service, get_escaped_service(service))
prev_line = lines[last_elem_line_index] prev_line = lines[last_elem_line_index]
if not prev_line.endswith('{') and not prev_line.endswith(','): if not prev_line.endswith('{') and not prev_line.endswith(','):
lines[last_elem_line_index] += ',' lines[last_elem_line_index] += ','

View File

@ -39,11 +39,11 @@ install_requires = [
"werkzeug", "werkzeug",
"PyYAML>=5.1", "PyYAML>=5.1",
"pytz", "pytz",
"python-dateutil<2.8.1,>=2.1", "python-dateutil<3.0.0,>=2.1",
"python-jose<4.0.0", "python-jose<4.0.0",
"mock", "mock",
"docker>=2.5.1", "docker>=2.5.1",
"jsondiff==1.1.2", "jsondiff>=1.1.2",
"aws-xray-sdk!=0.96,>=0.93", "aws-xray-sdk!=0.96,>=0.93",
"responses>=0.9.0", "responses>=0.9.0",
"idna<2.9,>=2.5", "idna<2.9,>=2.5",

View File

@ -26,7 +26,14 @@ def test_create_and_get_rest_api():
response.pop("ResponseMetadata") response.pop("ResponseMetadata")
response.pop("createdDate") response.pop("createdDate")
response.should.equal( response.should.equal(
{"id": api_id, "name": "my_api", "description": "this is my api"} {
"id": api_id,
"name": "my_api",
"description": "this is my api",
"apiKeySource": "HEADER",
"endpointConfiguration": {"types": ["EDGE"]},
"tags": {},
}
) )
@ -47,6 +54,114 @@ def test_list_and_delete_apis():
len(response["items"]).should.equal(1) len(response["items"]).should.equal(1)
@mock_apigateway
def test_create_rest_api_with_tags():
client = boto3.client("apigateway", region_name="us-west-2")
response = client.create_rest_api(
name="my_api", description="this is my api", tags={"MY_TAG1": "MY_VALUE1"}
)
api_id = response["id"]
response = client.get_rest_api(restApiId=api_id)
assert "tags" in response
response["tags"].should.equal({"MY_TAG1": "MY_VALUE1"})
@mock_apigateway
def test_create_rest_api_invalid_apikeysource():
client = boto3.client("apigateway", region_name="us-west-2")
with assert_raises(ClientError) as ex:
client.create_rest_api(
name="my_api",
description="this is my api",
apiKeySource="not a valid api key source",
)
ex.exception.response["Error"]["Code"].should.equal("ValidationException")
@mock_apigateway
def test_create_rest_api_valid_apikeysources():
client = boto3.client("apigateway", region_name="us-west-2")
# 1. test creating rest api with HEADER apiKeySource
response = client.create_rest_api(
name="my_api", description="this is my api", apiKeySource="HEADER",
)
api_id = response["id"]
response = client.get_rest_api(restApiId=api_id)
response["apiKeySource"].should.equal("HEADER")
# 2. test creating rest api with AUTHORIZER apiKeySource
response = client.create_rest_api(
name="my_api2", description="this is my api", apiKeySource="AUTHORIZER",
)
api_id = response["id"]
response = client.get_rest_api(restApiId=api_id)
response["apiKeySource"].should.equal("AUTHORIZER")
@mock_apigateway
def test_create_rest_api_invalid_endpointconfiguration():
client = boto3.client("apigateway", region_name="us-west-2")
with assert_raises(ClientError) as ex:
client.create_rest_api(
name="my_api",
description="this is my api",
endpointConfiguration={"types": ["INVALID"]},
)
ex.exception.response["Error"]["Code"].should.equal("ValidationException")
@mock_apigateway
def test_create_rest_api_valid_endpointconfigurations():
client = boto3.client("apigateway", region_name="us-west-2")
# 1. test creating rest api with PRIVATE endpointConfiguration
response = client.create_rest_api(
name="my_api",
description="this is my api",
endpointConfiguration={"types": ["PRIVATE"]},
)
api_id = response["id"]
response = client.get_rest_api(restApiId=api_id)
response["endpointConfiguration"].should.equal(
{"types": ["PRIVATE"],}
)
# 2. test creating rest api with REGIONAL endpointConfiguration
response = client.create_rest_api(
name="my_api2",
description="this is my api",
endpointConfiguration={"types": ["REGIONAL"]},
)
api_id = response["id"]
response = client.get_rest_api(restApiId=api_id)
response["endpointConfiguration"].should.equal(
{"types": ["REGIONAL"],}
)
# 3. test creating rest api with EDGE endpointConfiguration
response = client.create_rest_api(
name="my_api3",
description="this is my api",
endpointConfiguration={"types": ["EDGE"]},
)
api_id = response["id"]
response = client.get_rest_api(restApiId=api_id)
response["endpointConfiguration"].should.equal(
{"types": ["EDGE"],}
)
@mock_apigateway @mock_apigateway
def test_create_resource__validate_name(): def test_create_resource__validate_name():
client = boto3.client("apigateway", region_name="us-west-2") client = boto3.client("apigateway", region_name="us-west-2")
@ -58,15 +173,15 @@ def test_create_resource__validate_name():
0 0
]["id"] ]["id"]
invalid_names = ["/users", "users/", "users/{user_id}", "us{er"] invalid_names = ["/users", "users/", "users/{user_id}", "us{er", "us+er"]
valid_names = ["users", "{user_id}", "user_09", "good-dog"] valid_names = ["users", "{user_id}", "{proxy+}", "user_09", "good-dog"]
# All invalid names should throw an exception # All invalid names should throw an exception
for name in invalid_names: for name in invalid_names:
with assert_raises(ClientError) as ex: with assert_raises(ClientError) as ex:
client.create_resource(restApiId=api_id, parentId=root_id, pathPart=name) client.create_resource(restApiId=api_id, parentId=root_id, pathPart=name)
ex.exception.response["Error"]["Code"].should.equal("BadRequestException") ex.exception.response["Error"]["Code"].should.equal("BadRequestException")
ex.exception.response["Error"]["Message"].should.equal( ex.exception.response["Error"]["Message"].should.equal(
"Resource's path part only allow a-zA-Z0-9._- and curly braces at the beginning and the end." "Resource's path part only allow a-zA-Z0-9._- and curly braces at the beginning and the end and an optional plus sign before the closing brace."
) )
# All valid names should go through # All valid names should go through
for name in valid_names: for name in valid_names:
@ -89,12 +204,7 @@ def test_create_resource():
root_resource["ResponseMetadata"].pop("HTTPHeaders", None) root_resource["ResponseMetadata"].pop("HTTPHeaders", None)
root_resource["ResponseMetadata"].pop("RetryAttempts", None) root_resource["ResponseMetadata"].pop("RetryAttempts", None)
root_resource.should.equal( root_resource.should.equal(
{ {"path": "/", "id": root_id, "ResponseMetadata": {"HTTPStatusCode": 200},}
"path": "/",
"id": root_id,
"ResponseMetadata": {"HTTPStatusCode": 200},
"resourceMethods": {"GET": {}},
}
) )
client.create_resource(restApiId=api_id, parentId=root_id, pathPart="users") client.create_resource(restApiId=api_id, parentId=root_id, pathPart="users")
@ -142,7 +252,6 @@ def test_child_resource():
"parentId": users_id, "parentId": users_id,
"id": tags_id, "id": tags_id,
"ResponseMetadata": {"HTTPStatusCode": 200}, "ResponseMetadata": {"HTTPStatusCode": 200},
"resourceMethods": {"GET": {}},
} }
) )
@ -171,6 +280,41 @@ def test_create_method():
{ {
"httpMethod": "GET", "httpMethod": "GET",
"authorizationType": "none", "authorizationType": "none",
"apiKeyRequired": False,
"ResponseMetadata": {"HTTPStatusCode": 200},
}
)
@mock_apigateway
def test_create_method_apikeyrequired():
client = boto3.client("apigateway", region_name="us-west-2")
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
resources = client.get_resources(restApiId=api_id)
root_id = [resource for resource in resources["items"] if resource["path"] == "/"][
0
]["id"]
client.put_method(
restApiId=api_id,
resourceId=root_id,
httpMethod="GET",
authorizationType="none",
apiKeyRequired=True,
)
response = client.get_method(restApiId=api_id, resourceId=root_id, httpMethod="GET")
# this is hard to match against, so remove it
response["ResponseMetadata"].pop("HTTPHeaders", None)
response["ResponseMetadata"].pop("RetryAttempts", None)
response.should.equal(
{
"httpMethod": "GET",
"authorizationType": "none",
"apiKeyRequired": True,
"ResponseMetadata": {"HTTPStatusCode": 200}, "ResponseMetadata": {"HTTPStatusCode": 200},
} }
) )

View File

@ -706,14 +706,14 @@ def test_create_autoscaling_group_boto3():
"ResourceId": "test_asg", "ResourceId": "test_asg",
"ResourceType": "auto-scaling-group", "ResourceType": "auto-scaling-group",
"Key": "propogated-tag-key", "Key": "propogated-tag-key",
"Value": "propogate-tag-value", "Value": "propagate-tag-value",
"PropagateAtLaunch": True, "PropagateAtLaunch": True,
}, },
{ {
"ResourceId": "test_asg", "ResourceId": "test_asg",
"ResourceType": "auto-scaling-group", "ResourceType": "auto-scaling-group",
"Key": "not-propogated-tag-key", "Key": "not-propogated-tag-key",
"Value": "not-propogate-tag-value", "Value": "not-propagate-tag-value",
"PropagateAtLaunch": False, "PropagateAtLaunch": False,
}, },
], ],
@ -744,14 +744,14 @@ def test_create_autoscaling_group_from_instance():
"ResourceId": "test_asg", "ResourceId": "test_asg",
"ResourceType": "auto-scaling-group", "ResourceType": "auto-scaling-group",
"Key": "propogated-tag-key", "Key": "propogated-tag-key",
"Value": "propogate-tag-value", "Value": "propagate-tag-value",
"PropagateAtLaunch": True, "PropagateAtLaunch": True,
}, },
{ {
"ResourceId": "test_asg", "ResourceId": "test_asg",
"ResourceType": "auto-scaling-group", "ResourceType": "auto-scaling-group",
"Key": "not-propogated-tag-key", "Key": "not-propogated-tag-key",
"Value": "not-propogate-tag-value", "Value": "not-propagate-tag-value",
"PropagateAtLaunch": False, "PropagateAtLaunch": False,
}, },
], ],
@ -1062,7 +1062,7 @@ def test_detach_one_instance_decrement():
"ResourceId": "test_asg", "ResourceId": "test_asg",
"ResourceType": "auto-scaling-group", "ResourceType": "auto-scaling-group",
"Key": "propogated-tag-key", "Key": "propogated-tag-key",
"Value": "propogate-tag-value", "Value": "propagate-tag-value",
"PropagateAtLaunch": True, "PropagateAtLaunch": True,
} }
], ],
@ -1116,7 +1116,7 @@ def test_detach_one_instance():
"ResourceId": "test_asg", "ResourceId": "test_asg",
"ResourceType": "auto-scaling-group", "ResourceType": "auto-scaling-group",
"Key": "propogated-tag-key", "Key": "propogated-tag-key",
"Value": "propogate-tag-value", "Value": "propagate-tag-value",
"PropagateAtLaunch": True, "PropagateAtLaunch": True,
} }
], ],
@ -1169,7 +1169,7 @@ def test_attach_one_instance():
"ResourceId": "test_asg", "ResourceId": "test_asg",
"ResourceType": "auto-scaling-group", "ResourceType": "auto-scaling-group",
"Key": "propogated-tag-key", "Key": "propogated-tag-key",
"Value": "propogate-tag-value", "Value": "propagate-tag-value",
"PropagateAtLaunch": True, "PropagateAtLaunch": True,
} }
], ],

View File

@ -58,8 +58,7 @@ def lambda_handler(event, context):
volume_id = event.get('volume_id') volume_id = event.get('volume_id')
vol = ec2.Volume(volume_id) vol = ec2.Volume(volume_id)
print('get volume details for %s\\nVolume - %s state=%s, size=%s' % (volume_id, volume_id, vol.state, vol.size)) return {{'id': vol.id, 'state': vol.state, 'size': vol.size}}
return event
""".format( """.format(
base_url="motoserver:5000" base_url="motoserver:5000"
if settings.TEST_SERVER_MODE if settings.TEST_SERVER_MODE
@ -87,14 +86,14 @@ def lambda_handler(event, context):
@mock_lambda @mock_lambda
def test_list_functions(): def test_list_functions():
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
result = conn.list_functions() result = conn.list_functions()
result["Functions"].should.have.length_of(0) result["Functions"].should.have.length_of(0)
@mock_lambda @mock_lambda
def test_invoke_requestresponse_function(): def test_invoke_requestresponse_function():
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
conn.create_function( conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
Runtime="python2.7", Runtime="python2.7",
@ -114,7 +113,44 @@ def test_invoke_requestresponse_function():
Payload=json.dumps(in_data), Payload=json.dumps(in_data),
) )
success_result["StatusCode"].should.equal(202) success_result["StatusCode"].should.equal(200)
result_obj = json.loads(
base64.b64decode(success_result["LogResult"]).decode("utf-8")
)
result_obj.should.equal(in_data)
payload = success_result["Payload"].read().decode("utf-8")
json.loads(payload).should.equal(in_data)
@mock_lambda
def test_invoke_requestresponse_function_with_arn():
from moto.awslambda.models import ACCOUNT_ID
conn = boto3.client("lambda", "us-west-2")
conn.create_function(
FunctionName="testFunction",
Runtime="python2.7",
Role=get_role_name(),
Handler="lambda_function.lambda_handler",
Code={"ZipFile": get_test_zip_file1()},
Description="test lambda function",
Timeout=3,
MemorySize=128,
Publish=True,
)
in_data = {"msg": "So long and thanks for all the fish"}
success_result = conn.invoke(
FunctionName="arn:aws:lambda:us-west-2:{}:function:testFunction".format(
ACCOUNT_ID
),
InvocationType="RequestResponse",
Payload=json.dumps(in_data),
)
success_result["StatusCode"].should.equal(200)
result_obj = json.loads( result_obj = json.loads(
base64.b64decode(success_result["LogResult"]).decode("utf-8") base64.b64decode(success_result["LogResult"]).decode("utf-8")
) )
@ -127,7 +163,7 @@ def test_invoke_requestresponse_function():
@mock_lambda @mock_lambda
def test_invoke_event_function(): def test_invoke_event_function():
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
conn.create_function( conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
Runtime="python2.7", Runtime="python2.7",
@ -149,7 +185,35 @@ def test_invoke_event_function():
FunctionName="testFunction", InvocationType="Event", Payload=json.dumps(in_data) FunctionName="testFunction", InvocationType="Event", Payload=json.dumps(in_data)
) )
success_result["StatusCode"].should.equal(202) success_result["StatusCode"].should.equal(202)
json.loads(success_result["Payload"].read().decode("utf-8")).should.equal({}) json.loads(success_result["Payload"].read().decode("utf-8")).should.equal(in_data)
@mock_lambda
def test_invoke_dryrun_function():
conn = boto3.client("lambda", _lambda_region)
conn.create_function(
FunctionName="testFunction",
Runtime="python2.7",
Role=get_role_name(),
Handler="lambda_function.lambda_handler",
Code={"ZipFile": get_test_zip_file1(),},
Description="test lambda function",
Timeout=3,
MemorySize=128,
Publish=True,
)
conn.invoke.when.called_with(
FunctionName="notAFunction", InvocationType="Event", Payload="{}"
).should.throw(botocore.client.ClientError)
in_data = {"msg": "So long and thanks for all the fish"}
success_result = conn.invoke(
FunctionName="testFunction",
InvocationType="DryRun",
Payload=json.dumps(in_data),
)
success_result["StatusCode"].should.equal(204)
if settings.TEST_SERVER_MODE: if settings.TEST_SERVER_MODE:
@ -157,11 +221,11 @@ if settings.TEST_SERVER_MODE:
@mock_ec2 @mock_ec2
@mock_lambda @mock_lambda
def test_invoke_function_get_ec2_volume(): def test_invoke_function_get_ec2_volume():
conn = boto3.resource("ec2", "us-west-2") conn = boto3.resource("ec2", _lambda_region)
vol = conn.create_volume(Size=99, AvailabilityZone="us-west-2") vol = conn.create_volume(Size=99, AvailabilityZone=_lambda_region)
vol = conn.Volume(vol.id) vol = conn.Volume(vol.id)
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
conn.create_function( conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
Runtime="python3.7", Runtime="python3.7",
@ -180,28 +244,10 @@ if settings.TEST_SERVER_MODE:
InvocationType="RequestResponse", InvocationType="RequestResponse",
Payload=json.dumps(in_data), Payload=json.dumps(in_data),
) )
result["StatusCode"].should.equal(202) result["StatusCode"].should.equal(200)
msg = "get volume details for %s\nVolume - %s state=%s, size=%s\n%s" % ( actual_payload = json.loads(result["Payload"].read().decode("utf-8"))
vol.id, expected_payload = {"id": vol.id, "state": vol.state, "size": vol.size}
vol.id, actual_payload.should.equal(expected_payload)
vol.state,
vol.size,
json.dumps(in_data).replace(
" ", ""
), # Makes the tests pass as the result is missing the whitespace
)
log_result = base64.b64decode(result["LogResult"]).decode("utf-8")
# The Docker lambda invocation will return an additional '\n', so need to replace it:
log_result = log_result.replace("\n\n", "\n")
log_result.should.equal(msg)
payload = result["Payload"].read().decode("utf-8")
# The Docker lambda invocation will return an additional '\n', so need to replace it:
payload = payload.replace("\n\n", "\n")
payload.should.equal(msg)
@mock_logs @mock_logs
@ -209,14 +255,14 @@ if settings.TEST_SERVER_MODE:
@mock_ec2 @mock_ec2
@mock_lambda @mock_lambda
def test_invoke_function_from_sns(): def test_invoke_function_from_sns():
logs_conn = boto3.client("logs", region_name="us-west-2") logs_conn = boto3.client("logs", region_name=_lambda_region)
sns_conn = boto3.client("sns", region_name="us-west-2") sns_conn = boto3.client("sns", region_name=_lambda_region)
sns_conn.create_topic(Name="some-topic") sns_conn.create_topic(Name="some-topic")
topics_json = sns_conn.list_topics() topics_json = sns_conn.list_topics()
topics = topics_json["Topics"] topics = topics_json["Topics"]
topic_arn = topics[0]["TopicArn"] topic_arn = topics[0]["TopicArn"]
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
result = conn.create_function( result = conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
Runtime="python2.7", Runtime="python2.7",
@ -259,7 +305,7 @@ def test_invoke_function_from_sns():
@mock_lambda @mock_lambda
def test_create_based_on_s3_with_missing_bucket(): def test_create_based_on_s3_with_missing_bucket():
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
conn.create_function.when.called_with( conn.create_function.when.called_with(
FunctionName="testFunction", FunctionName="testFunction",
@ -279,12 +325,15 @@ def test_create_based_on_s3_with_missing_bucket():
@mock_s3 @mock_s3
@freeze_time("2015-01-01 00:00:00") @freeze_time("2015-01-01 00:00:00")
def test_create_function_from_aws_bucket(): def test_create_function_from_aws_bucket():
s3_conn = boto3.client("s3", "us-west-2") s3_conn = boto3.client("s3", _lambda_region)
s3_conn.create_bucket(Bucket="test-bucket") s3_conn.create_bucket(
Bucket="test-bucket",
CreateBucketConfiguration={"LocationConstraint": _lambda_region},
)
zip_content = get_test_zip_file2() zip_content = get_test_zip_file2()
s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content)
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
result = conn.create_function( result = conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
@ -324,6 +373,7 @@ def test_create_function_from_aws_bucket():
"VpcId": "vpc-123abc", "VpcId": "vpc-123abc",
}, },
"ResponseMetadata": {"HTTPStatusCode": 201}, "ResponseMetadata": {"HTTPStatusCode": 201},
"State": "Active",
} }
) )
@ -331,7 +381,7 @@ def test_create_function_from_aws_bucket():
@mock_lambda @mock_lambda
@freeze_time("2015-01-01 00:00:00") @freeze_time("2015-01-01 00:00:00")
def test_create_function_from_zipfile(): def test_create_function_from_zipfile():
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
zip_content = get_test_zip_file1() zip_content = get_test_zip_file1()
result = conn.create_function( result = conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
@ -367,6 +417,7 @@ def test_create_function_from_zipfile():
"Version": "1", "Version": "1",
"VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []},
"ResponseMetadata": {"HTTPStatusCode": 201}, "ResponseMetadata": {"HTTPStatusCode": 201},
"State": "Active",
} }
) )
@ -375,12 +426,15 @@ def test_create_function_from_zipfile():
@mock_s3 @mock_s3
@freeze_time("2015-01-01 00:00:00") @freeze_time("2015-01-01 00:00:00")
def test_get_function(): def test_get_function():
s3_conn = boto3.client("s3", "us-west-2") s3_conn = boto3.client("s3", _lambda_region)
s3_conn.create_bucket(Bucket="test-bucket") s3_conn.create_bucket(
Bucket="test-bucket",
CreateBucketConfiguration={"LocationConstraint": _lambda_region},
)
zip_content = get_test_zip_file1() zip_content = get_test_zip_file1()
s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content)
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
conn.create_function( conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
@ -435,7 +489,7 @@ def test_get_function():
) )
# Test get function when can't find function name # Test get function when can't find function name
with assert_raises(ClientError): with assert_raises(conn.exceptions.ResourceNotFoundException):
conn.get_function(FunctionName="junk", Qualifier="$LATEST") conn.get_function(FunctionName="junk", Qualifier="$LATEST")
@ -444,7 +498,10 @@ def test_get_function():
def test_get_function_by_arn(): def test_get_function_by_arn():
bucket_name = "test-bucket" bucket_name = "test-bucket"
s3_conn = boto3.client("s3", "us-east-1") s3_conn = boto3.client("s3", "us-east-1")
s3_conn.create_bucket(Bucket=bucket_name) s3_conn.create_bucket(
Bucket=bucket_name,
CreateBucketConfiguration={"LocationConstraint": _lambda_region},
)
zip_content = get_test_zip_file2() zip_content = get_test_zip_file2()
s3_conn.put_object(Bucket=bucket_name, Key="test.zip", Body=zip_content) s3_conn.put_object(Bucket=bucket_name, Key="test.zip", Body=zip_content)
@ -469,12 +526,15 @@ def test_get_function_by_arn():
@mock_lambda @mock_lambda
@mock_s3 @mock_s3
def test_delete_function(): def test_delete_function():
s3_conn = boto3.client("s3", "us-west-2") s3_conn = boto3.client("s3", _lambda_region)
s3_conn.create_bucket(Bucket="test-bucket") s3_conn.create_bucket(
Bucket="test-bucket",
CreateBucketConfiguration={"LocationConstraint": _lambda_region},
)
zip_content = get_test_zip_file2() zip_content = get_test_zip_file2()
s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content)
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
conn.create_function( conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
@ -505,7 +565,10 @@ def test_delete_function():
def test_delete_function_by_arn(): def test_delete_function_by_arn():
bucket_name = "test-bucket" bucket_name = "test-bucket"
s3_conn = boto3.client("s3", "us-east-1") s3_conn = boto3.client("s3", "us-east-1")
s3_conn.create_bucket(Bucket=bucket_name) s3_conn.create_bucket(
Bucket=bucket_name,
CreateBucketConfiguration={"LocationConstraint": _lambda_region},
)
zip_content = get_test_zip_file2() zip_content = get_test_zip_file2()
s3_conn.put_object(Bucket=bucket_name, Key="test.zip", Body=zip_content) s3_conn.put_object(Bucket=bucket_name, Key="test.zip", Body=zip_content)
@ -530,7 +593,7 @@ def test_delete_function_by_arn():
@mock_lambda @mock_lambda
def test_delete_unknown_function(): def test_delete_unknown_function():
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
conn.delete_function.when.called_with( conn.delete_function.when.called_with(
FunctionName="testFunctionThatDoesntExist" FunctionName="testFunctionThatDoesntExist"
).should.throw(botocore.client.ClientError) ).should.throw(botocore.client.ClientError)
@ -539,12 +602,15 @@ def test_delete_unknown_function():
@mock_lambda @mock_lambda
@mock_s3 @mock_s3
def test_publish(): def test_publish():
s3_conn = boto3.client("s3", "us-west-2") s3_conn = boto3.client("s3", _lambda_region)
s3_conn.create_bucket(Bucket="test-bucket") s3_conn.create_bucket(
Bucket="test-bucket",
CreateBucketConfiguration={"LocationConstraint": _lambda_region},
)
zip_content = get_test_zip_file2() zip_content = get_test_zip_file2()
s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content)
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
conn.create_function( conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
@ -589,12 +655,15 @@ def test_list_create_list_get_delete_list():
test `list -> create -> list -> get -> delete -> list` integration test `list -> create -> list -> get -> delete -> list` integration
""" """
s3_conn = boto3.client("s3", "us-west-2") s3_conn = boto3.client("s3", _lambda_region)
s3_conn.create_bucket(Bucket="test-bucket") s3_conn.create_bucket(
Bucket="test-bucket",
CreateBucketConfiguration={"LocationConstraint": _lambda_region},
)
zip_content = get_test_zip_file2() zip_content = get_test_zip_file2()
s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content)
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
conn.list_functions()["Functions"].should.have.length_of(0) conn.list_functions()["Functions"].should.have.length_of(0)
@ -631,6 +700,7 @@ def test_list_create_list_get_delete_list():
"Timeout": 3, "Timeout": 3,
"Version": "$LATEST", "Version": "$LATEST",
"VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []},
"State": "Active",
}, },
"ResponseMetadata": {"HTTPStatusCode": 200}, "ResponseMetadata": {"HTTPStatusCode": 200},
} }
@ -690,12 +760,15 @@ def test_tags():
""" """
test list_tags -> tag_resource -> list_tags -> tag_resource -> list_tags -> untag_resource -> list_tags integration test list_tags -> tag_resource -> list_tags -> tag_resource -> list_tags -> untag_resource -> list_tags integration
""" """
s3_conn = boto3.client("s3", "us-west-2") s3_conn = boto3.client("s3", _lambda_region)
s3_conn.create_bucket(Bucket="test-bucket") s3_conn.create_bucket(
Bucket="test-bucket",
CreateBucketConfiguration={"LocationConstraint": _lambda_region},
)
zip_content = get_test_zip_file2() zip_content = get_test_zip_file2()
s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content)
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
function = conn.create_function( function = conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
@ -747,7 +820,7 @@ def test_tags_not_found():
""" """
Test list_tags and tag_resource when the lambda with the given arn does not exist Test list_tags and tag_resource when the lambda with the given arn does not exist
""" """
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
conn.list_tags.when.called_with( conn.list_tags.when.called_with(
Resource="arn:aws:lambda:{}:function:not-found".format(ACCOUNT_ID) Resource="arn:aws:lambda:{}:function:not-found".format(ACCOUNT_ID)
).should.throw(botocore.client.ClientError) ).should.throw(botocore.client.ClientError)
@ -765,7 +838,7 @@ def test_tags_not_found():
@mock_lambda @mock_lambda
def test_invoke_async_function(): def test_invoke_async_function():
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
conn.create_function( conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
Runtime="python2.7", Runtime="python2.7",
@ -788,7 +861,7 @@ def test_invoke_async_function():
@mock_lambda @mock_lambda
@freeze_time("2015-01-01 00:00:00") @freeze_time("2015-01-01 00:00:00")
def test_get_function_created_with_zipfile(): def test_get_function_created_with_zipfile():
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
zip_content = get_test_zip_file1() zip_content = get_test_zip_file1()
result = conn.create_function( result = conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
@ -827,13 +900,14 @@ def test_get_function_created_with_zipfile():
"Timeout": 3, "Timeout": 3,
"Version": "$LATEST", "Version": "$LATEST",
"VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []},
"State": "Active",
} }
) )
@mock_lambda @mock_lambda
def test_add_function_permission(): def test_add_function_permission():
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
zip_content = get_test_zip_file1() zip_content = get_test_zip_file1()
conn.create_function( conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
@ -864,7 +938,7 @@ def test_add_function_permission():
@mock_lambda @mock_lambda
def test_get_function_policy(): def test_get_function_policy():
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
zip_content = get_test_zip_file1() zip_content = get_test_zip_file1()
conn.create_function( conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
@ -899,12 +973,15 @@ def test_get_function_policy():
@mock_lambda @mock_lambda
@mock_s3 @mock_s3
def test_list_versions_by_function(): def test_list_versions_by_function():
s3_conn = boto3.client("s3", "us-west-2") s3_conn = boto3.client("s3", _lambda_region)
s3_conn.create_bucket(Bucket="test-bucket") s3_conn.create_bucket(
Bucket="test-bucket",
CreateBucketConfiguration={"LocationConstraint": _lambda_region},
)
zip_content = get_test_zip_file2() zip_content = get_test_zip_file2()
s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content)
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
conn.create_function( conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
@ -955,12 +1032,15 @@ def test_list_versions_by_function():
@mock_lambda @mock_lambda
@mock_s3 @mock_s3
def test_create_function_with_already_exists(): def test_create_function_with_already_exists():
s3_conn = boto3.client("s3", "us-west-2") s3_conn = boto3.client("s3", _lambda_region)
s3_conn.create_bucket(Bucket="test-bucket") s3_conn.create_bucket(
Bucket="test-bucket",
CreateBucketConfiguration={"LocationConstraint": _lambda_region},
)
zip_content = get_test_zip_file2() zip_content = get_test_zip_file2()
s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content)
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
conn.create_function( conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
@ -992,7 +1072,7 @@ def test_create_function_with_already_exists():
@mock_lambda @mock_lambda
@mock_s3 @mock_s3
def test_list_versions_by_function_for_nonexistent_function(): def test_list_versions_by_function_for_nonexistent_function():
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
versions = conn.list_versions_by_function(FunctionName="testFunction") versions = conn.list_versions_by_function(FunctionName="testFunction")
assert len(versions["Versions"]) == 0 assert len(versions["Versions"]) == 0
@ -1341,12 +1421,15 @@ def test_delete_event_source_mapping():
@mock_lambda @mock_lambda
@mock_s3 @mock_s3
def test_update_configuration(): def test_update_configuration():
s3_conn = boto3.client("s3", "us-west-2") s3_conn = boto3.client("s3", _lambda_region)
s3_conn.create_bucket(Bucket="test-bucket") s3_conn.create_bucket(
Bucket="test-bucket",
CreateBucketConfiguration={"LocationConstraint": _lambda_region},
)
zip_content = get_test_zip_file2() zip_content = get_test_zip_file2()
s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content)
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
fxn = conn.create_function( fxn = conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
@ -1389,7 +1472,7 @@ def test_update_configuration():
@mock_lambda @mock_lambda
def test_update_function_zip(): def test_update_function_zip():
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
zip_content_one = get_test_zip_file1() zip_content_one = get_test_zip_file1()
@ -1436,6 +1519,7 @@ def test_update_function_zip():
"Timeout": 3, "Timeout": 3,
"Version": "2", "Version": "2",
"VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []},
"State": "Active",
} }
) )
@ -1443,13 +1527,16 @@ def test_update_function_zip():
@mock_lambda @mock_lambda
@mock_s3 @mock_s3
def test_update_function_s3(): def test_update_function_s3():
s3_conn = boto3.client("s3", "us-west-2") s3_conn = boto3.client("s3", _lambda_region)
s3_conn.create_bucket(Bucket="test-bucket") s3_conn.create_bucket(
Bucket="test-bucket",
CreateBucketConfiguration={"LocationConstraint": _lambda_region},
)
zip_content = get_test_zip_file1() zip_content = get_test_zip_file1()
s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content)
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
fxn = conn.create_function( fxn = conn.create_function(
FunctionName="testFunctionS3", FunctionName="testFunctionS3",
@ -1498,6 +1585,7 @@ def test_update_function_s3():
"Timeout": 3, "Timeout": 3,
"Version": "2", "Version": "2",
"VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []},
"State": "Active",
} }
) )
@ -1529,7 +1617,7 @@ def test_create_function_with_unknown_arn():
def create_invalid_lambda(role): def create_invalid_lambda(role):
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
zip_content = get_test_zip_file1() zip_content = get_test_zip_file1()
with assert_raises(ClientError) as err: with assert_raises(ClientError) as err:
conn.create_function( conn.create_function(
@ -1548,7 +1636,7 @@ def create_invalid_lambda(role):
def get_role_name(): def get_role_name():
with mock_iam(): with mock_iam():
iam = boto3.client("iam", region_name="us-west-2") iam = boto3.client("iam", region_name=_lambda_region)
try: try:
return iam.get_role(RoleName="my-role")["Role"]["Arn"] return iam.get_role(RoleName="my-role")["Role"]["Arn"]
except ClientError: except ClientError:

View File

@ -94,7 +94,7 @@ def test_lambda_can_be_deleted_by_cloudformation():
# Verify function was deleted # Verify function was deleted
with assert_raises(ClientError) as e: with assert_raises(ClientError) as e:
lmbda.get_function(FunctionName=created_fn_name) lmbda.get_function(FunctionName=created_fn_name)
e.exception.response["Error"]["Code"].should.equal("404") e.exception.response["Error"]["Code"].should.equal("ResourceNotFoundException")
def create_stack(cf, s3): def create_stack(cf, s3):

View File

@ -0,0 +1,49 @@
from __future__ import unicode_literals
import json
import sure
from moto.awslambda.policy import Policy
class MockLambdaFunction:
def __init__(self, arn):
self.function_arn = arn
self.policy = None
def test_policy():
policy = Policy(MockLambdaFunction("arn"))
statement = {
"StatementId": "statement0",
"Action": "lambda:InvokeFunction",
"FunctionName": "function_name",
"Principal": "events.amazonaws.com",
"SourceArn": "arn:aws:events:us-east-1:111111111111:rule/rule_name",
"SourceAccount": "111111111111",
}
expected = {
"Action": "lambda:InvokeFunction",
"FunctionName": "function_name",
"Principal": {"Service": "events.amazonaws.com"},
"Effect": "Allow",
"Resource": "arn:$LATEST",
"Sid": "statement0",
"Condition": {
"ArnLike": {
"AWS:SourceArn": "arn:aws:events:us-east-1:111111111111:rule/rule_name",
},
"StringEquals": {"AWS:SourceAccount": "111111111111"},
},
}
policy.add_statement(json.dumps(statement))
expected.should.be.equal(policy.statements[0])
sid = statement.get("StatementId", None)
if sid == None:
raise "TestCase.statement does not contain StatementId"
policy.del_statement(sid)
[].should.be.equal(policy.statements)

View File

@ -143,7 +143,7 @@ def test_create_stack_with_notification_arn():
@mock_s3_deprecated @mock_s3_deprecated
def test_create_stack_from_s3_url(): def test_create_stack_from_s3_url():
s3_conn = boto.s3.connect_to_region("us-west-1") s3_conn = boto.s3.connect_to_region("us-west-1")
bucket = s3_conn.create_bucket("foobar") bucket = s3_conn.create_bucket("foobar", location="us-west-1")
key = boto.s3.key.Key(bucket) key = boto.s3.key.Key(bucket)
key.key = "template-key" key.key = "template-key"
key.set_contents_from_string(dummy_template_json) key.set_contents_from_string(dummy_template_json)

View File

@ -522,6 +522,13 @@ def test_boto3_list_stack_set_operations():
list_operation["Summaries"][-1]["Action"].should.equal("UPDATE") list_operation["Summaries"][-1]["Action"].should.equal("UPDATE")
@mock_cloudformation
def test_boto3_bad_list_stack_resources():
cf_conn = boto3.client("cloudformation", region_name="us-east-1")
with assert_raises(ClientError):
cf_conn.list_stack_resources(StackName="test_stack_set")
@mock_cloudformation @mock_cloudformation
def test_boto3_delete_stack_set(): def test_boto3_delete_stack_set():
cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn = boto3.client("cloudformation", region_name="us-east-1")

View File

@ -27,6 +27,11 @@ def test_create_user_pool():
result["UserPool"]["Id"].should_not.be.none result["UserPool"]["Id"].should_not.be.none
result["UserPool"]["Id"].should.match(r"[\w-]+_[0-9a-zA-Z]+") result["UserPool"]["Id"].should.match(r"[\w-]+_[0-9a-zA-Z]+")
result["UserPool"]["Arn"].should.equal(
"arn:aws:cognito-idp:us-west-2:{}:userpool/{}".format(
ACCOUNT_ID, result["UserPool"]["Id"]
)
)
result["UserPool"]["Name"].should.equal(name) result["UserPool"]["Name"].should.equal(name)
result["UserPool"]["LambdaConfig"]["PreSignUp"].should.equal(value) result["UserPool"]["LambdaConfig"]["PreSignUp"].should.equal(value)
@ -911,6 +916,55 @@ def test_admin_create_existing_user():
caught.should.be.true caught.should.be.true
@mock_cognitoidp
def test_admin_resend_invitation_existing_user():
conn = boto3.client("cognito-idp", "us-west-2")
username = str(uuid.uuid4())
value = str(uuid.uuid4())
user_pool_id = conn.create_user_pool(PoolName=str(uuid.uuid4()))["UserPool"]["Id"]
conn.admin_create_user(
UserPoolId=user_pool_id,
Username=username,
UserAttributes=[{"Name": "thing", "Value": value}],
)
caught = False
try:
conn.admin_create_user(
UserPoolId=user_pool_id,
Username=username,
UserAttributes=[{"Name": "thing", "Value": value}],
MessageAction="RESEND",
)
except conn.exceptions.UsernameExistsException:
caught = True
caught.should.be.false
@mock_cognitoidp
def test_admin_resend_invitation_missing_user():
conn = boto3.client("cognito-idp", "us-west-2")
username = str(uuid.uuid4())
value = str(uuid.uuid4())
user_pool_id = conn.create_user_pool(PoolName=str(uuid.uuid4()))["UserPool"]["Id"]
caught = False
try:
conn.admin_create_user(
UserPoolId=user_pool_id,
Username=username,
UserAttributes=[{"Name": "thing", "Value": value}],
MessageAction="RESEND",
)
except conn.exceptions.UserNotFoundException:
caught = True
caught.should.be.true
@mock_cognitoidp @mock_cognitoidp
def test_admin_get_user(): def test_admin_get_user():
conn = boto3.client("cognito-idp", "us-west-2") conn = boto3.client("cognito-idp", "us-west-2")
@ -958,6 +1012,18 @@ def test_list_users():
result["Users"].should.have.length_of(1) result["Users"].should.have.length_of(1)
result["Users"][0]["Username"].should.equal(username) result["Users"][0]["Username"].should.equal(username)
username_bis = str(uuid.uuid4())
conn.admin_create_user(
UserPoolId=user_pool_id,
Username=username_bis,
UserAttributes=[{"Name": "phone_number", "Value": "+33666666666"}],
)
result = conn.list_users(
UserPoolId=user_pool_id, Filter='phone_number="+33666666666'
)
result["Users"].should.have.length_of(1)
result["Users"][0]["Username"].should.equal(username_bis)
@mock_cognitoidp @mock_cognitoidp
def test_list_users_returns_limit_items(): def test_list_users_returns_limit_items():
@ -1142,11 +1208,13 @@ def test_token_legitimacy():
id_claims = json.loads(jws.verify(id_token, json_web_key, "RS256")) id_claims = json.loads(jws.verify(id_token, json_web_key, "RS256"))
id_claims["iss"].should.equal(issuer) id_claims["iss"].should.equal(issuer)
id_claims["aud"].should.equal(client_id) id_claims["aud"].should.equal(client_id)
id_claims["token_use"].should.equal("id")
for k, v in outputs["additional_fields"].items():
id_claims[k].should.equal(v)
access_claims = json.loads(jws.verify(access_token, json_web_key, "RS256")) access_claims = json.loads(jws.verify(access_token, json_web_key, "RS256"))
access_claims["iss"].should.equal(issuer) access_claims["iss"].should.equal(issuer)
access_claims["aud"].should.equal(client_id) access_claims["aud"].should.equal(client_id)
for k, v in outputs["additional_fields"].items(): access_claims["token_use"].should.equal("access")
access_claims[k].should.equal(v)
@mock_cognitoidp @mock_cognitoidp

View File

@ -46,4 +46,4 @@ def test_domain_dispatched_with_service():
dispatcher = DomainDispatcherApplication(create_backend_app, service="s3") dispatcher = DomainDispatcherApplication(create_backend_app, service="s3")
backend_app = dispatcher.get_application({"HTTP_HOST": "s3.us-east1.amazonaws.com"}) backend_app = dispatcher.get_application({"HTTP_HOST": "s3.us-east1.amazonaws.com"})
keys = set(backend_app.view_functions.keys()) keys = set(backend_app.view_functions.keys())
keys.should.contain("ResponseObject.key_response") keys.should.contain("ResponseObject.key_or_control_response")

View File

@ -1719,6 +1719,32 @@ def test_scan_filter4():
assert response["Count"] == 0 assert response["Count"] == 0
@mock_dynamodb2
def test_scan_filter_should_not_return_non_existing_attributes():
table_name = "my-table"
item = {"partitionKey": "pk-2", "my-attr": 42}
# Create table
res = boto3.resource("dynamodb", region_name="us-east-1")
res.create_table(
TableName=table_name,
KeySchema=[{"AttributeName": "partitionKey", "KeyType": "HASH"}],
AttributeDefinitions=[{"AttributeName": "partitionKey", "AttributeType": "S"}],
BillingMode="PAY_PER_REQUEST",
)
table = res.Table(table_name)
# Insert items
table.put_item(Item={"partitionKey": "pk-1"})
table.put_item(Item=item)
# Verify a few operations
# Assert we only find the item that has this attribute
table.scan(FilterExpression=Attr("my-attr").lt(43))["Items"].should.equal([item])
table.scan(FilterExpression=Attr("my-attr").lte(42))["Items"].should.equal([item])
table.scan(FilterExpression=Attr("my-attr").gte(42))["Items"].should.equal([item])
table.scan(FilterExpression=Attr("my-attr").gt(41))["Items"].should.equal([item])
# Sanity check that we can't find the item if the FE is wrong
table.scan(FilterExpression=Attr("my-attr").gt(43))["Items"].should.equal([])
@mock_dynamodb2 @mock_dynamodb2
def test_bad_scan_filter(): def test_bad_scan_filter():
client = boto3.client("dynamodb", region_name="us-east-1") client = boto3.client("dynamodb", region_name="us-east-1")
@ -2505,6 +2531,48 @@ def test_condition_expressions():
) )
@mock_dynamodb2
def test_condition_expression_numerical_attribute():
dynamodb = boto3.resource("dynamodb", region_name="us-east-1")
dynamodb.create_table(
TableName="my-table",
KeySchema=[{"AttributeName": "partitionKey", "KeyType": "HASH"}],
AttributeDefinitions=[{"AttributeName": "partitionKey", "AttributeType": "S"}],
)
table = dynamodb.Table("my-table")
table.put_item(Item={"partitionKey": "pk-pos", "myAttr": 5})
table.put_item(Item={"partitionKey": "pk-neg", "myAttr": -5})
# try to update the item we put in the table using numerical condition expression
# Specifically, verify that we can compare with a zero-value
# First verify that > and >= work on positive numbers
update_numerical_con_expr(
key="pk-pos", con_expr="myAttr > :zero", res="6", table=table
)
update_numerical_con_expr(
key="pk-pos", con_expr="myAttr >= :zero", res="7", table=table
)
# Second verify that < and <= work on negative numbers
update_numerical_con_expr(
key="pk-neg", con_expr="myAttr < :zero", res="-4", table=table
)
update_numerical_con_expr(
key="pk-neg", con_expr="myAttr <= :zero", res="-3", table=table
)
def update_numerical_con_expr(key, con_expr, res, table):
table.update_item(
Key={"partitionKey": key},
UpdateExpression="ADD myAttr :one",
ExpressionAttributeValues={":zero": 0, ":one": 1},
ConditionExpression=con_expr,
)
table.get_item(Key={"partitionKey": key})["Item"]["myAttr"].should.equal(
Decimal(res)
)
@mock_dynamodb2 @mock_dynamodb2
def test_condition_expression__attr_doesnt_exist(): def test_condition_expression__attr_doesnt_exist():
client = boto3.client("dynamodb", region_name="us-east-1") client = boto3.client("dynamodb", region_name="us-east-1")
@ -3489,6 +3557,83 @@ def test_update_supports_nested_list_append_onto_another_list():
) )
@mock_dynamodb2
def test_update_supports_list_append_maps():
client = boto3.client("dynamodb", region_name="us-west-1")
client.create_table(
AttributeDefinitions=[
{"AttributeName": "id", "AttributeType": "S"},
{"AttributeName": "rid", "AttributeType": "S"},
],
TableName="TestTable",
KeySchema=[
{"AttributeName": "id", "KeyType": "HASH"},
{"AttributeName": "rid", "KeyType": "RANGE"},
],
ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5},
)
client.put_item(
TableName="TestTable",
Item={
"id": {"S": "nested_list_append"},
"rid": {"S": "range_key"},
"a": {"L": [{"M": {"b": {"S": "bar1"}}}]},
},
)
# Update item using list_append expression
client.update_item(
TableName="TestTable",
Key={"id": {"S": "nested_list_append"}, "rid": {"S": "range_key"}},
UpdateExpression="SET a = list_append(a, :i)",
ExpressionAttributeValues={":i": {"L": [{"M": {"b": {"S": "bar2"}}}]}},
)
# Verify item is appended to the existing list
result = client.query(
TableName="TestTable",
KeyConditionExpression="id = :i AND begins_with(rid, :r)",
ExpressionAttributeValues={
":i": {"S": "nested_list_append"},
":r": {"S": "range_key"},
},
)["Items"]
result.should.equal(
[
{
"a": {"L": [{"M": {"b": {"S": "bar1"}}}, {"M": {"b": {"S": "bar2"}}}]},
"rid": {"S": "range_key"},
"id": {"S": "nested_list_append"},
}
]
)
@mock_dynamodb2
def test_update_supports_list_append_with_nested_if_not_exists_operation():
dynamo = boto3.resource("dynamodb", region_name="us-west-1")
table_name = "test"
dynamo.create_table(
TableName=table_name,
AttributeDefinitions=[{"AttributeName": "Id", "AttributeType": "S"}],
KeySchema=[{"AttributeName": "Id", "KeyType": "HASH"}],
ProvisionedThroughput={"ReadCapacityUnits": 20, "WriteCapacityUnits": 20},
)
table = dynamo.Table(table_name)
table.put_item(Item={"Id": "item-id", "nest1": {"nest2": {}}})
table.update_item(
Key={"Id": "item-id"},
UpdateExpression="SET nest1.nest2.event_history = list_append(if_not_exists(nest1.nest2.event_history, :empty_list), :new_value)",
ExpressionAttributeValues={":empty_list": [], ":new_value": ["some_value"]},
)
table.get_item(Key={"Id": "item-id"})["Item"].should.equal(
{"Id": "item-id", "nest1": {"nest2": {"event_history": ["some_value"]}}}
)
@mock_dynamodb2 @mock_dynamodb2
def test_update_catches_invalid_list_append_operation(): def test_update_catches_invalid_list_append_operation():
client = boto3.client("dynamodb", region_name="us-east-1") client = boto3.client("dynamodb", region_name="us-east-1")
@ -3601,3 +3746,24 @@ def test_allow_update_to_item_with_different_type():
table.get_item(Key={"job_id": "b"})["Item"]["job_details"][ table.get_item(Key={"job_id": "b"})["Item"]["job_details"][
"job_name" "job_name"
].should.be.equal({"nested": "yes"}) ].should.be.equal({"nested": "yes"})
@mock_dynamodb2
def test_query_catches_when_no_filters():
dynamo = boto3.resource("dynamodb", region_name="eu-central-1")
dynamo.create_table(
AttributeDefinitions=[{"AttributeName": "job_id", "AttributeType": "S"}],
TableName="origin-rbu-dev",
KeySchema=[{"AttributeName": "job_id", "KeyType": "HASH"}],
ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1},
)
table = dynamo.Table("origin-rbu-dev")
with assert_raises(ClientError) as ex:
table.query(TableName="original-rbu-dev")
ex.exception.response["Error"]["Code"].should.equal("ValidationException")
ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400)
ex.exception.response["Error"]["Message"].should.equal(
"Either KeyConditions or QueryFilter should be present"
)

View File

@ -12,6 +12,7 @@ import sure # noqa
from moto import mock_ec2_deprecated, mock_ec2 from moto import mock_ec2_deprecated, mock_ec2
from moto.ec2.models import AMIS, OWNER_ID from moto.ec2.models import AMIS, OWNER_ID
from moto.iam.models import ACCOUNT_ID
from tests.helpers import requires_boto_gte from tests.helpers import requires_boto_gte
@ -251,6 +252,19 @@ def test_ami_pulls_attributes_from_instance():
image.kernel_id.should.equal("test-kernel") image.kernel_id.should.equal("test-kernel")
@mock_ec2_deprecated
def test_ami_uses_account_id_if_valid_access_key_is_supplied():
access_key = "AKIAXXXXXXXXXXXXXXXX"
conn = boto.connect_ec2(access_key, "the_secret")
reservation = conn.run_instances("ami-1234abcd")
instance = reservation.instances[0]
instance.modify_attribute("kernel", "test-kernel")
image_id = conn.create_image(instance.id, "test-ami", "this is a test ami")
images = conn.get_all_images(owners=["self"])
[(ami.id, ami.owner_id) for ami in images].should.equal([(image_id, ACCOUNT_ID)])
@mock_ec2_deprecated @mock_ec2_deprecated
def test_ami_filters(): def test_ami_filters():
conn = boto.connect_ec2("the_key", "the_secret") conn = boto.connect_ec2("the_key", "the_secret")
@ -773,7 +787,7 @@ def test_ami_filter_wildcard():
instance.create_image(Name="not-matching-image") instance.create_image(Name="not-matching-image")
my_images = ec2_client.describe_images( my_images = ec2_client.describe_images(
Owners=["111122223333"], Filters=[{"Name": "name", "Values": ["test*"]}] Owners=[ACCOUNT_ID], Filters=[{"Name": "name", "Values": ["test*"]}]
)["Images"] )["Images"]
my_images.should.have.length_of(1) my_images.should.have.length_of(1)

View File

@ -236,8 +236,8 @@ def test_route_table_associations():
@mock_ec2_deprecated @mock_ec2_deprecated
def test_route_table_replace_route_table_association(): def test_route_table_replace_route_table_association():
""" """
Note: Boto has deprecated replace_route_table_assocation (which returns status) Note: Boto has deprecated replace_route_table_association (which returns status)
and now uses replace_route_table_assocation_with_assoc (which returns association ID). and now uses replace_route_table_association_with_assoc (which returns association ID).
""" """
conn = boto.connect_vpc("the_key", "the_secret") conn = boto.connect_vpc("the_key", "the_secret")
vpc = conn.create_vpc("10.0.0.0/16") vpc = conn.create_vpc("10.0.0.0/16")

View File

@ -77,7 +77,7 @@ def test_describe_repositories():
response = client.describe_repositories() response = client.describe_repositories()
len(response["repositories"]).should.equal(2) len(response["repositories"]).should.equal(2)
respository_arns = [ repository_arns = [
"arn:aws:ecr:us-east-1:012345678910:repository/test_repository1", "arn:aws:ecr:us-east-1:012345678910:repository/test_repository1",
"arn:aws:ecr:us-east-1:012345678910:repository/test_repository0", "arn:aws:ecr:us-east-1:012345678910:repository/test_repository0",
] ]
@ -86,9 +86,9 @@ def test_describe_repositories():
response["repositories"][0]["repositoryArn"], response["repositories"][0]["repositoryArn"],
response["repositories"][1]["repositoryArn"], response["repositories"][1]["repositoryArn"],
] ]
).should.equal(set(respository_arns)) ).should.equal(set(repository_arns))
respository_uris = [ repository_uris = [
"012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository1", "012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository1",
"012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository0", "012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository0",
] ]
@ -97,7 +97,7 @@ def test_describe_repositories():
response["repositories"][0]["repositoryUri"], response["repositories"][0]["repositoryUri"],
response["repositories"][1]["repositoryUri"], response["repositories"][1]["repositoryUri"],
] ]
).should.equal(set(respository_uris)) ).should.equal(set(repository_uris))
@mock_ecr @mock_ecr
@ -108,7 +108,7 @@ def test_describe_repositories_1():
response = client.describe_repositories(registryId="012345678910") response = client.describe_repositories(registryId="012345678910")
len(response["repositories"]).should.equal(2) len(response["repositories"]).should.equal(2)
respository_arns = [ repository_arns = [
"arn:aws:ecr:us-east-1:012345678910:repository/test_repository1", "arn:aws:ecr:us-east-1:012345678910:repository/test_repository1",
"arn:aws:ecr:us-east-1:012345678910:repository/test_repository0", "arn:aws:ecr:us-east-1:012345678910:repository/test_repository0",
] ]
@ -117,9 +117,9 @@ def test_describe_repositories_1():
response["repositories"][0]["repositoryArn"], response["repositories"][0]["repositoryArn"],
response["repositories"][1]["repositoryArn"], response["repositories"][1]["repositoryArn"],
] ]
).should.equal(set(respository_arns)) ).should.equal(set(repository_arns))
respository_uris = [ repository_uris = [
"012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository1", "012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository1",
"012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository0", "012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository0",
] ]
@ -128,7 +128,7 @@ def test_describe_repositories_1():
response["repositories"][0]["repositoryUri"], response["repositories"][0]["repositoryUri"],
response["repositories"][1]["repositoryUri"], response["repositories"][1]["repositoryUri"],
] ]
).should.equal(set(respository_uris)) ).should.equal(set(repository_uris))
@mock_ecr @mock_ecr
@ -147,11 +147,11 @@ def test_describe_repositories_3():
_ = client.create_repository(repositoryName="test_repository0") _ = client.create_repository(repositoryName="test_repository0")
response = client.describe_repositories(repositoryNames=["test_repository1"]) response = client.describe_repositories(repositoryNames=["test_repository1"])
len(response["repositories"]).should.equal(1) len(response["repositories"]).should.equal(1)
respository_arn = "arn:aws:ecr:us-east-1:012345678910:repository/test_repository1" repository_arn = "arn:aws:ecr:us-east-1:012345678910:repository/test_repository1"
response["repositories"][0]["repositoryArn"].should.equal(respository_arn) response["repositories"][0]["repositoryArn"].should.equal(repository_arn)
respository_uri = "012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository1" repository_uri = "012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository1"
response["repositories"][0]["repositoryUri"].should.equal(respository_uri) response["repositories"][0]["repositoryUri"].should.equal(repository_uri)
@mock_ecr @mock_ecr

View File

@ -94,6 +94,7 @@ def test_register_task_definition():
"logConfiguration": {"logDriver": "json-file"}, "logConfiguration": {"logDriver": "json-file"},
} }
], ],
networkMode="bridge",
tags=[ tags=[
{"key": "createdBy", "value": "moto-unittest"}, {"key": "createdBy", "value": "moto-unittest"},
{"key": "foo", "value": "bar"}, {"key": "foo", "value": "bar"},
@ -124,6 +125,7 @@ def test_register_task_definition():
response["taskDefinition"]["containerDefinitions"][0]["logConfiguration"][ response["taskDefinition"]["containerDefinitions"][0]["logConfiguration"][
"logDriver" "logDriver"
].should.equal("json-file") ].should.equal("json-file")
response["taskDefinition"]["networkMode"].should.equal("bridge")
@mock_ecs @mock_ecs
@ -724,7 +726,7 @@ def test_delete_service():
@mock_ecs @mock_ecs
def test_update_non_existant_service(): def test_update_non_existent_service():
client = boto3.client("ecs", region_name="us-east-1") client = boto3.client("ecs", region_name="us-east-1")
try: try:
client.update_service( client.update_service(

View File

@ -1391,7 +1391,7 @@ def test_set_security_groups():
len(resp["LoadBalancers"][0]["SecurityGroups"]).should.equal(2) len(resp["LoadBalancers"][0]["SecurityGroups"]).should.equal(2)
with assert_raises(ClientError): with assert_raises(ClientError):
client.set_security_groups(LoadBalancerArn=arn, SecurityGroups=["non_existant"]) client.set_security_groups(LoadBalancerArn=arn, SecurityGroups=["non_existent"])
@mock_elbv2 @mock_elbv2

View File

@ -1,14 +1,18 @@
import random from moto.events.models import EventsBackend
import boto3
import json
import sure # noqa
from moto.events import mock_events from moto.events import mock_events
import json
import random
import unittest
import boto3
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
from moto.core.exceptions import JsonRESTError from moto.core.exceptions import JsonRESTError
from nose.tools import assert_raises from nose.tools import assert_raises
from moto.core import ACCOUNT_ID from moto.core import ACCOUNT_ID
from moto.events.models import EventsBackend << << << < HEAD
== == == =
>>>>>> > 100dbd529f174f18d579a1dcc066d55409f2e38f
RULES = [ RULES = [
{"Name": "test1", "ScheduleExpression": "rate(5 minutes)"}, {"Name": "test1", "ScheduleExpression": "rate(5 minutes)"},
@ -456,6 +460,11 @@ def test_delete_event_bus_errors():
ClientError, "Cannot delete event bus default." ClientError, "Cannot delete event bus default."
) )
<< << << < HEAD
== == == =
>>>>>> > 100dbd529f174f18d579a1dcc066d55409f2e38f
@mock_events @mock_events
def test_rule_tagging_happy(): def test_rule_tagging_happy():
client = generate_environment() client = generate_environment()
@ -466,7 +475,12 @@ def test_rule_tagging_happy():
client.tag_resource(ResourceARN=rule_arn, Tags=tags) client.tag_resource(ResourceARN=rule_arn, Tags=tags)
actual = client.list_tags_for_resource(ResourceARN=rule_arn).get("Tags") actual = client.list_tags_for_resource(ResourceARN=rule_arn).get("Tags")
assert tags == actual tc = unittest.TestCase("__init__")
expected = [{"Value": "value1", "Key": "key1"}, {"Value": "value2", "Key": "key2"}]
tc.assertTrue(
(expected[0] == actual[0] and expected[1] == actual[1])
or (expected[1] == actual[0] and expected[0] == actual[1])
)
client.untag_resource(ResourceARN=rule_arn, TagKeys=["key1"]) client.untag_resource(ResourceARN=rule_arn, TagKeys=["key1"])
@ -474,24 +488,25 @@ def test_rule_tagging_happy():
expected = [{"Key": "key2", "Value": "value2"}] expected = [{"Key": "key2", "Value": "value2"}]
assert expected == actual assert expected == actual
@mock_events @mock_events
def test_rule_tagging_sad(): def test_rule_tagging_sad():
b = EventsBackend("us-west-2") back_end = EventsBackend("us-west-2")
try: try:
b.tag_resource('unknown', []) back_end.tag_resource("unknown", [])
raise 'tag_resource should fail if ResourceARN is not known' raise "tag_resource should fail if ResourceARN is not known"
except JsonRESTError: except JsonRESTError:
pass pass
try: try:
b.untag_resource('unknown', []) back_end.untag_resource("unknown", [])
raise 'untag_resource should fail if ResourceARN is not known' raise "untag_resource should fail if ResourceARN is not known"
except JsonRESTError: except JsonRESTError:
pass pass
try: try:
b.list_tags_for_resource('unknown') back_end.list_tags_for_resource("unknown")
raise 'list_tags_for_resource should fail if ResourceARN is not known' raise "list_tags_for_resource should fail if ResourceARN is not known"
except JsonRESTError: except JsonRESTError:
pass pass

View File

@ -132,7 +132,7 @@ def test_get_table_versions():
helpers.update_table(client, database_name, table_name, table_input) helpers.update_table(client, database_name, table_name, table_input)
version_inputs["2"] = table_input version_inputs["2"] = table_input
# Updateing with an indentical input should still create a new version # Updateing with an identical input should still create a new version
helpers.update_table(client, database_name, table_name, table_input) helpers.update_table(client, database_name, table_name, table_input)
version_inputs["3"] = table_input version_inputs["3"] = table_input

View File

@ -785,7 +785,7 @@ def test_delete_login_profile():
conn.delete_login_profile("my-user") conn.delete_login_profile("my-user")
@mock_iam() @mock_iam
def test_create_access_key(): def test_create_access_key():
conn = boto3.client("iam", region_name="us-east-1") conn = boto3.client("iam", region_name="us-east-1")
with assert_raises(ClientError): with assert_raises(ClientError):
@ -798,6 +798,19 @@ def test_create_access_key():
access_key["AccessKeyId"].should.have.length_of(20) access_key["AccessKeyId"].should.have.length_of(20)
access_key["SecretAccessKey"].should.have.length_of(40) access_key["SecretAccessKey"].should.have.length_of(40)
assert access_key["AccessKeyId"].startswith("AKIA") assert access_key["AccessKeyId"].startswith("AKIA")
conn = boto3.client(
"iam",
region_name="us-east-1",
aws_access_key_id=access_key["AccessKeyId"],
aws_secret_access_key=access_key["SecretAccessKey"],
)
access_key = conn.create_access_key()["AccessKey"]
(
datetime.utcnow() - access_key["CreateDate"].replace(tzinfo=None)
).seconds.should.be.within(0, 10)
access_key["AccessKeyId"].should.have.length_of(20)
access_key["SecretAccessKey"].should.have.length_of(40)
assert access_key["AccessKeyId"].startswith("AKIA")
@mock_iam_deprecated() @mock_iam_deprecated()
@ -825,8 +838,35 @@ def test_get_all_access_keys():
) )
@mock_iam
def test_list_access_keys():
conn = boto3.client("iam", region_name="us-east-1")
conn.create_user(UserName="my-user")
response = conn.list_access_keys(UserName="my-user")
assert_equals(
response["AccessKeyMetadata"], [],
)
access_key = conn.create_access_key(UserName="my-user")["AccessKey"]
response = conn.list_access_keys(UserName="my-user")
assert_equals(
sorted(response["AccessKeyMetadata"][0].keys()),
sorted(["Status", "CreateDate", "UserName", "AccessKeyId"]),
)
conn = boto3.client(
"iam",
region_name="us-east-1",
aws_access_key_id=access_key["AccessKeyId"],
aws_secret_access_key=access_key["SecretAccessKey"],
)
response = conn.list_access_keys()
assert_equals(
sorted(response["AccessKeyMetadata"][0].keys()),
sorted(["Status", "CreateDate", "UserName", "AccessKeyId"]),
)
@mock_iam_deprecated() @mock_iam_deprecated()
def test_delete_access_key(): def test_delete_access_key_deprecated():
conn = boto.connect_iam() conn = boto.connect_iam()
conn.create_user("my-user") conn.create_user("my-user")
access_key_id = conn.create_access_key("my-user")["create_access_key_response"][ access_key_id = conn.create_access_key("my-user")["create_access_key_response"][
@ -835,6 +875,16 @@ def test_delete_access_key():
conn.delete_access_key(access_key_id, "my-user") conn.delete_access_key(access_key_id, "my-user")
@mock_iam
def test_delete_access_key():
conn = boto3.client("iam", region_name="us-east-1")
conn.create_user(UserName="my-user")
key = conn.create_access_key(UserName="my-user")["AccessKey"]
conn.delete_access_key(AccessKeyId=key["AccessKeyId"], UserName="my-user")
key = conn.create_access_key(UserName="my-user")["AccessKey"]
conn.delete_access_key(AccessKeyId=key["AccessKeyId"])
@mock_iam() @mock_iam()
def test_mfa_devices(): def test_mfa_devices():
# Test enable device # Test enable device
@ -1326,6 +1376,9 @@ def test_update_access_key():
) )
resp = client.list_access_keys(UserName=username) resp = client.list_access_keys(UserName=username)
resp["AccessKeyMetadata"][0]["Status"].should.equal("Inactive") resp["AccessKeyMetadata"][0]["Status"].should.equal("Inactive")
client.update_access_key(AccessKeyId=key["AccessKeyId"], Status="Active")
resp = client.list_access_keys(UserName=username)
resp["AccessKeyMetadata"][0]["Status"].should.equal("Active")
@mock_iam @mock_iam

View File

@ -9,6 +9,173 @@ from botocore.exceptions import ClientError
from nose.tools import assert_raises from nose.tools import assert_raises
@mock_iot
def test_attach_policy():
client = boto3.client("iot", region_name="ap-northeast-1")
policy_name = "my-policy"
doc = "{}"
cert = client.create_keys_and_certificate(setAsActive=True)
cert_arn = cert["certificateArn"]
client.create_policy(policyName=policy_name, policyDocument=doc)
client.attach_policy(policyName=policy_name, target=cert_arn)
res = client.list_attached_policies(target=cert_arn)
res.should.have.key("policies").which.should.have.length_of(1)
res["policies"][0]["policyName"].should.equal("my-policy")
@mock_iot
def test_detach_policy():
client = boto3.client("iot", region_name="ap-northeast-1")
policy_name = "my-policy"
doc = "{}"
cert = client.create_keys_and_certificate(setAsActive=True)
cert_arn = cert["certificateArn"]
client.create_policy(policyName=policy_name, policyDocument=doc)
client.attach_policy(policyName=policy_name, target=cert_arn)
res = client.list_attached_policies(target=cert_arn)
res.should.have.key("policies").which.should.have.length_of(1)
res["policies"][0]["policyName"].should.equal("my-policy")
client.detach_policy(policyName=policy_name, target=cert_arn)
res = client.list_attached_policies(target=cert_arn)
res.should.have.key("policies").which.should.be.empty
@mock_iot
def test_list_attached_policies():
client = boto3.client("iot", region_name="ap-northeast-1")
cert = client.create_keys_and_certificate(setAsActive=True)
policies = client.list_attached_policies(target=cert["certificateArn"])
policies["policies"].should.be.empty
@mock_iot
def test_policy_versions():
client = boto3.client("iot", region_name="ap-northeast-1")
policy_name = "my-policy"
doc = "{}"
policy = client.create_policy(policyName=policy_name, policyDocument=doc)
policy.should.have.key("policyName").which.should.equal(policy_name)
policy.should.have.key("policyArn").which.should_not.be.none
policy.should.have.key("policyDocument").which.should.equal(json.dumps({}))
policy.should.have.key("policyVersionId").which.should.equal("1")
policy = client.get_policy(policyName=policy_name)
policy.should.have.key("policyName").which.should.equal(policy_name)
policy.should.have.key("policyArn").which.should_not.be.none
policy.should.have.key("policyDocument").which.should.equal(json.dumps({}))
policy.should.have.key("defaultVersionId").which.should.equal(
policy["defaultVersionId"]
)
policy1 = client.create_policy_version(
policyName=policy_name,
policyDocument=json.dumps({"version": "version_1"}),
setAsDefault=True,
)
policy1.should.have.key("policyArn").which.should_not.be.none
policy1.should.have.key("policyDocument").which.should.equal(
json.dumps({"version": "version_1"})
)
policy1.should.have.key("policyVersionId").which.should.equal("2")
policy1.should.have.key("isDefaultVersion").which.should.equal(True)
policy2 = client.create_policy_version(
policyName=policy_name,
policyDocument=json.dumps({"version": "version_2"}),
setAsDefault=False,
)
policy2.should.have.key("policyArn").which.should_not.be.none
policy2.should.have.key("policyDocument").which.should.equal(
json.dumps({"version": "version_2"})
)
policy2.should.have.key("policyVersionId").which.should.equal("3")
policy2.should.have.key("isDefaultVersion").which.should.equal(False)
policy = client.get_policy(policyName=policy_name)
policy.should.have.key("policyName").which.should.equal(policy_name)
policy.should.have.key("policyArn").which.should_not.be.none
policy.should.have.key("policyDocument").which.should.equal(
json.dumps({"version": "version_1"})
)
policy.should.have.key("defaultVersionId").which.should.equal(
policy1["policyVersionId"]
)
policy_versions = client.list_policy_versions(policyName=policy_name)
policy_versions.should.have.key("policyVersions").which.should.have.length_of(3)
list(
map(lambda item: item["isDefaultVersion"], policy_versions["policyVersions"])
).count(True).should.equal(1)
default_policy = list(
filter(lambda item: item["isDefaultVersion"], policy_versions["policyVersions"])
)
default_policy[0].should.have.key("versionId").should.equal(
policy1["policyVersionId"]
)
policy = client.get_policy(policyName=policy_name)
policy.should.have.key("policyName").which.should.equal(policy_name)
policy.should.have.key("policyArn").which.should_not.be.none
policy.should.have.key("policyDocument").which.should.equal(
json.dumps({"version": "version_1"})
)
policy.should.have.key("defaultVersionId").which.should.equal(
policy1["policyVersionId"]
)
client.set_default_policy_version(
policyName=policy_name, policyVersionId=policy2["policyVersionId"]
)
policy_versions = client.list_policy_versions(policyName=policy_name)
policy_versions.should.have.key("policyVersions").which.should.have.length_of(3)
list(
map(lambda item: item["isDefaultVersion"], policy_versions["policyVersions"])
).count(True).should.equal(1)
default_policy = list(
filter(lambda item: item["isDefaultVersion"], policy_versions["policyVersions"])
)
default_policy[0].should.have.key("versionId").should.equal(
policy2["policyVersionId"]
)
policy = client.get_policy(policyName=policy_name)
policy.should.have.key("policyName").which.should.equal(policy_name)
policy.should.have.key("policyArn").which.should_not.be.none
policy.should.have.key("policyDocument").which.should.equal(
json.dumps({"version": "version_2"})
)
policy.should.have.key("defaultVersionId").which.should.equal(
policy2["policyVersionId"]
)
client.delete_policy_version(policyName=policy_name, policyVersionId="1")
policy_versions = client.list_policy_versions(policyName=policy_name)
policy_versions.should.have.key("policyVersions").which.should.have.length_of(2)
client.delete_policy_version(
policyName=policy_name, policyVersionId=policy1["policyVersionId"]
)
policy_versions = client.list_policy_versions(policyName=policy_name)
policy_versions.should.have.key("policyVersions").which.should.have.length_of(1)
# should fail as it"s the default policy. Should use delete_policy instead
try:
client.delete_policy_version(
policyName=policy_name, policyVersionId=policy2["policyVersionId"]
)
assert False, "Should have failed in previous call"
except Exception as exception:
exception.response["Error"]["Message"].should.equal(
"Cannot delete the default version of a policy"
)
@mock_iot @mock_iot
def test_things(): def test_things():
client = boto3.client("iot", region_name="ap-northeast-1") client = boto3.client("iot", region_name="ap-northeast-1")
@ -994,7 +1161,10 @@ def test_create_job():
client = boto3.client("iot", region_name="eu-west-1") client = boto3.client("iot", region_name="eu-west-1")
name = "my-thing" name = "my-thing"
job_id = "TestJob" job_id = "TestJob"
# thing # thing# job document
# job_document = {
# "field": "value"
# }
thing = client.create_thing(thingName=name) thing = client.create_thing(thingName=name)
thing.should.have.key("thingName").which.should.equal(name) thing.should.have.key("thingName").which.should.equal(name)
thing.should.have.key("thingArn") thing.should.have.key("thingArn")
@ -1020,6 +1190,63 @@ def test_create_job():
job.should.have.key("description") job.should.have.key("description")
@mock_iot
def test_list_jobs():
client = boto3.client("iot", region_name="eu-west-1")
name = "my-thing"
job_id = "TestJob"
# thing# job document
# job_document = {
# "field": "value"
# }
thing = client.create_thing(thingName=name)
thing.should.have.key("thingName").which.should.equal(name)
thing.should.have.key("thingArn")
# job document
job_document = {"field": "value"}
job1 = client.create_job(
jobId=job_id,
targets=[thing["thingArn"]],
document=json.dumps(job_document),
description="Description",
presignedUrlConfig={
"roleArn": "arn:aws:iam::1:role/service-role/iot_job_role",
"expiresInSec": 123,
},
targetSelection="CONTINUOUS",
jobExecutionsRolloutConfig={"maximumPerMinute": 10},
)
job1.should.have.key("jobId").which.should.equal(job_id)
job1.should.have.key("jobArn")
job1.should.have.key("description")
job2 = client.create_job(
jobId=job_id + "1",
targets=[thing["thingArn"]],
document=json.dumps(job_document),
description="Description",
presignedUrlConfig={
"roleArn": "arn:aws:iam::1:role/service-role/iot_job_role",
"expiresInSec": 123,
},
targetSelection="CONTINUOUS",
jobExecutionsRolloutConfig={"maximumPerMinute": 10},
)
job2.should.have.key("jobId").which.should.equal(job_id + "1")
job2.should.have.key("jobArn")
job2.should.have.key("description")
jobs = client.list_jobs()
jobs.should.have.key("jobs")
jobs.should_not.have.key("nextToken")
jobs["jobs"][0].should.have.key("jobId").which.should.equal(job_id)
jobs["jobs"][1].should.have.key("jobId").which.should.equal(job_id + "1")
@mock_iot @mock_iot
def test_describe_job(): def test_describe_job():
client = boto3.client("iot", region_name="eu-west-1") client = boto3.client("iot", region_name="eu-west-1")
@ -1124,3 +1351,387 @@ def test_describe_job_1():
job.should.have.key("job").which.should.have.key( job.should.have.key("job").which.should.have.key(
"jobExecutionsRolloutConfig" "jobExecutionsRolloutConfig"
).which.should.have.key("maximumPerMinute").which.should.equal(10) ).which.should.have.key("maximumPerMinute").which.should.equal(10)
@mock_iot
def test_delete_job():
client = boto3.client("iot", region_name="eu-west-1")
name = "my-thing"
job_id = "TestJob"
# thing
thing = client.create_thing(thingName=name)
thing.should.have.key("thingName").which.should.equal(name)
thing.should.have.key("thingArn")
job = client.create_job(
jobId=job_id,
targets=[thing["thingArn"]],
documentSource="https://s3-eu-west-1.amazonaws.com/bucket-name/job_document.json",
presignedUrlConfig={
"roleArn": "arn:aws:iam::1:role/service-role/iot_job_role",
"expiresInSec": 123,
},
targetSelection="CONTINUOUS",
jobExecutionsRolloutConfig={"maximumPerMinute": 10},
)
job.should.have.key("jobId").which.should.equal(job_id)
job.should.have.key("jobArn")
job = client.describe_job(jobId=job_id)
job.should.have.key("job")
job.should.have.key("job").which.should.have.key("jobId").which.should.equal(job_id)
client.delete_job(jobId=job_id)
client.list_jobs()["jobs"].should.have.length_of(0)
@mock_iot
def test_cancel_job():
client = boto3.client("iot", region_name="eu-west-1")
name = "my-thing"
job_id = "TestJob"
# thing
thing = client.create_thing(thingName=name)
thing.should.have.key("thingName").which.should.equal(name)
thing.should.have.key("thingArn")
job = client.create_job(
jobId=job_id,
targets=[thing["thingArn"]],
documentSource="https://s3-eu-west-1.amazonaws.com/bucket-name/job_document.json",
presignedUrlConfig={
"roleArn": "arn:aws:iam::1:role/service-role/iot_job_role",
"expiresInSec": 123,
},
targetSelection="CONTINUOUS",
jobExecutionsRolloutConfig={"maximumPerMinute": 10},
)
job.should.have.key("jobId").which.should.equal(job_id)
job.should.have.key("jobArn")
job = client.describe_job(jobId=job_id)
job.should.have.key("job")
job.should.have.key("job").which.should.have.key("jobId").which.should.equal(job_id)
job = client.cancel_job(jobId=job_id, reasonCode="Because", comment="You are")
job.should.have.key("jobId").which.should.equal(job_id)
job.should.have.key("jobArn")
job = client.describe_job(jobId=job_id)
job.should.have.key("job")
job.should.have.key("job").which.should.have.key("jobId").which.should.equal(job_id)
job.should.have.key("job").which.should.have.key("status").which.should.equal(
"CANCELED"
)
job.should.have.key("job").which.should.have.key(
"forceCanceled"
).which.should.equal(False)
job.should.have.key("job").which.should.have.key("reasonCode").which.should.equal(
"Because"
)
job.should.have.key("job").which.should.have.key("comment").which.should.equal(
"You are"
)
@mock_iot
def test_get_job_document_with_document_source():
client = boto3.client("iot", region_name="eu-west-1")
name = "my-thing"
job_id = "TestJob"
# thing
thing = client.create_thing(thingName=name)
thing.should.have.key("thingName").which.should.equal(name)
thing.should.have.key("thingArn")
job = client.create_job(
jobId=job_id,
targets=[thing["thingArn"]],
documentSource="https://s3-eu-west-1.amazonaws.com/bucket-name/job_document.json",
presignedUrlConfig={
"roleArn": "arn:aws:iam::1:role/service-role/iot_job_role",
"expiresInSec": 123,
},
targetSelection="CONTINUOUS",
jobExecutionsRolloutConfig={"maximumPerMinute": 10},
)
job.should.have.key("jobId").which.should.equal(job_id)
job.should.have.key("jobArn")
job_document = client.get_job_document(jobId=job_id)
job_document.should.have.key("document").which.should.equal("")
@mock_iot
def test_get_job_document_with_document():
client = boto3.client("iot", region_name="eu-west-1")
name = "my-thing"
job_id = "TestJob"
# thing
thing = client.create_thing(thingName=name)
thing.should.have.key("thingName").which.should.equal(name)
thing.should.have.key("thingArn")
# job document
job_document = {"field": "value"}
job = client.create_job(
jobId=job_id,
targets=[thing["thingArn"]],
document=json.dumps(job_document),
presignedUrlConfig={
"roleArn": "arn:aws:iam::1:role/service-role/iot_job_role",
"expiresInSec": 123,
},
targetSelection="CONTINUOUS",
jobExecutionsRolloutConfig={"maximumPerMinute": 10},
)
job.should.have.key("jobId").which.should.equal(job_id)
job.should.have.key("jobArn")
job_document = client.get_job_document(jobId=job_id)
job_document.should.have.key("document").which.should.equal('{"field": "value"}')
@mock_iot
def test_describe_job_execution():
client = boto3.client("iot", region_name="eu-west-1")
name = "my-thing"
job_id = "TestJob"
# thing
thing = client.create_thing(thingName=name)
thing.should.have.key("thingName").which.should.equal(name)
thing.should.have.key("thingArn")
# job document
job_document = {"field": "value"}
job = client.create_job(
jobId=job_id,
targets=[thing["thingArn"]],
document=json.dumps(job_document),
description="Description",
presignedUrlConfig={
"roleArn": "arn:aws:iam::1:role/service-role/iot_job_role",
"expiresInSec": 123,
},
targetSelection="CONTINUOUS",
jobExecutionsRolloutConfig={"maximumPerMinute": 10},
)
job.should.have.key("jobId").which.should.equal(job_id)
job.should.have.key("jobArn")
job.should.have.key("description")
job_execution = client.describe_job_execution(jobId=job_id, thingName=name)
job_execution.should.have.key("execution")
job_execution["execution"].should.have.key("jobId").which.should.equal(job_id)
job_execution["execution"].should.have.key("status").which.should.equal("QUEUED")
job_execution["execution"].should.have.key("forceCanceled").which.should.equal(
False
)
job_execution["execution"].should.have.key("statusDetails").which.should.equal(
{"detailsMap": {}}
)
job_execution["execution"].should.have.key("thingArn").which.should.equal(
thing["thingArn"]
)
job_execution["execution"].should.have.key("queuedAt")
job_execution["execution"].should.have.key("startedAt")
job_execution["execution"].should.have.key("lastUpdatedAt")
job_execution["execution"].should.have.key("executionNumber").which.should.equal(
123
)
job_execution["execution"].should.have.key("versionNumber").which.should.equal(123)
job_execution["execution"].should.have.key(
"approximateSecondsBeforeTimedOut"
).which.should.equal(123)
job_execution = client.describe_job_execution(
jobId=job_id, thingName=name, executionNumber=123
)
job_execution.should.have.key("execution")
job_execution["execution"].should.have.key("jobId").which.should.equal(job_id)
job_execution["execution"].should.have.key("status").which.should.equal("QUEUED")
job_execution["execution"].should.have.key("forceCanceled").which.should.equal(
False
)
job_execution["execution"].should.have.key("statusDetails").which.should.equal(
{"detailsMap": {}}
)
job_execution["execution"].should.have.key("thingArn").which.should.equal(
thing["thingArn"]
)
job_execution["execution"].should.have.key("queuedAt")
job_execution["execution"].should.have.key("startedAt")
job_execution["execution"].should.have.key("lastUpdatedAt")
job_execution["execution"].should.have.key("executionNumber").which.should.equal(
123
)
job_execution["execution"].should.have.key("versionNumber").which.should.equal(123)
job_execution["execution"].should.have.key(
"approximateSecondsBeforeTimedOut"
).which.should.equal(123)
try:
client.describe_job_execution(jobId=job_id, thingName=name, executionNumber=456)
except ClientError as exc:
error_code = exc.response["Error"]["Code"]
error_code.should.equal("ResourceNotFoundException")
else:
raise Exception("Should have raised error")
@mock_iot
def test_cancel_job_execution():
client = boto3.client("iot", region_name="eu-west-1")
name = "my-thing"
job_id = "TestJob"
# thing
thing = client.create_thing(thingName=name)
thing.should.have.key("thingName").which.should.equal(name)
thing.should.have.key("thingArn")
# job document
job_document = {"field": "value"}
job = client.create_job(
jobId=job_id,
targets=[thing["thingArn"]],
document=json.dumps(job_document),
description="Description",
presignedUrlConfig={
"roleArn": "arn:aws:iam::1:role/service-role/iot_job_role",
"expiresInSec": 123,
},
targetSelection="CONTINUOUS",
jobExecutionsRolloutConfig={"maximumPerMinute": 10},
)
job.should.have.key("jobId").which.should.equal(job_id)
job.should.have.key("jobArn")
job.should.have.key("description")
client.cancel_job_execution(jobId=job_id, thingName=name)
job_execution = client.describe_job_execution(jobId=job_id, thingName=name)
job_execution.should.have.key("execution")
job_execution["execution"].should.have.key("status").which.should.equal("CANCELED")
@mock_iot
def test_delete_job_execution():
client = boto3.client("iot", region_name="eu-west-1")
name = "my-thing"
job_id = "TestJob"
# thing
thing = client.create_thing(thingName=name)
thing.should.have.key("thingName").which.should.equal(name)
thing.should.have.key("thingArn")
# job document
job_document = {"field": "value"}
job = client.create_job(
jobId=job_id,
targets=[thing["thingArn"]],
document=json.dumps(job_document),
description="Description",
presignedUrlConfig={
"roleArn": "arn:aws:iam::1:role/service-role/iot_job_role",
"expiresInSec": 123,
},
targetSelection="CONTINUOUS",
jobExecutionsRolloutConfig={"maximumPerMinute": 10},
)
job.should.have.key("jobId").which.should.equal(job_id)
job.should.have.key("jobArn")
job.should.have.key("description")
client.delete_job_execution(jobId=job_id, thingName=name, executionNumber=123)
try:
client.describe_job_execution(jobId=job_id, thingName=name, executionNumber=123)
except ClientError as exc:
error_code = exc.response["Error"]["Code"]
error_code.should.equal("ResourceNotFoundException")
else:
raise Exception("Should have raised error")
@mock_iot
def test_list_job_executions_for_job():
client = boto3.client("iot", region_name="eu-west-1")
name = "my-thing"
job_id = "TestJob"
# thing
thing = client.create_thing(thingName=name)
thing.should.have.key("thingName").which.should.equal(name)
thing.should.have.key("thingArn")
# job document
job_document = {"field": "value"}
job = client.create_job(
jobId=job_id,
targets=[thing["thingArn"]],
document=json.dumps(job_document),
description="Description",
presignedUrlConfig={
"roleArn": "arn:aws:iam::1:role/service-role/iot_job_role",
"expiresInSec": 123,
},
targetSelection="CONTINUOUS",
jobExecutionsRolloutConfig={"maximumPerMinute": 10},
)
job.should.have.key("jobId").which.should.equal(job_id)
job.should.have.key("jobArn")
job.should.have.key("description")
job_execution = client.list_job_executions_for_job(jobId=job_id)
job_execution.should.have.key("executionSummaries")
job_execution["executionSummaries"][0].should.have.key(
"thingArn"
).which.should.equal(thing["thingArn"])
@mock_iot
def test_list_job_executions_for_thing():
client = boto3.client("iot", region_name="eu-west-1")
name = "my-thing"
job_id = "TestJob"
# thing
thing = client.create_thing(thingName=name)
thing.should.have.key("thingName").which.should.equal(name)
thing.should.have.key("thingArn")
# job document
job_document = {"field": "value"}
job = client.create_job(
jobId=job_id,
targets=[thing["thingArn"]],
document=json.dumps(job_document),
description="Description",
presignedUrlConfig={
"roleArn": "arn:aws:iam::1:role/service-role/iot_job_role",
"expiresInSec": 123,
},
targetSelection="CONTINUOUS",
jobExecutionsRolloutConfig={"maximumPerMinute": 10},
)
job.should.have.key("jobId").which.should.equal(job_id)
job.should.have.key("jobArn")
job.should.have.key("description")
job_execution = client.list_job_executions_for_thing(thingName=name)
job_execution.should.have.key("executionSummaries")
job_execution["executionSummaries"][0].should.have.key("jobId").which.should.equal(
job_id
)

View File

@ -223,7 +223,7 @@ def test_create_stream_without_redshift():
@mock_kinesis @mock_kinesis
def test_deescribe_non_existant_stream(): def test_deescribe_non_existent_stream():
client = boto3.client("firehose", region_name="us-east-1") client = boto3.client("firehose", region_name="us-east-1")
client.describe_delivery_stream.when.called_with( client.describe_delivery_stream.when.called_with(

View File

@ -32,7 +32,7 @@ def test_create_cluster():
@mock_kinesis_deprecated @mock_kinesis_deprecated
def test_describe_non_existant_stream(): def test_describe_non_existent_stream():
conn = boto.kinesis.connect_to_region("us-east-1") conn = boto.kinesis.connect_to_region("us-east-1")
conn.describe_stream.when.called_with("not-a-stream").should.throw( conn.describe_stream.when.called_with("not-a-stream").should.throw(
ResourceNotFoundException ResourceNotFoundException

View File

@ -1,26 +1,19 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import unicode_literals from __future__ import unicode_literals
from datetime import date
from datetime import datetime
from dateutil.tz import tzutc
import base64 import base64
import os
import re import re
import boto3
import boto.kms import boto.kms
import botocore.exceptions
import six import six
import sure # noqa import sure # noqa
from boto.exception import JSONResponseError from boto.exception import JSONResponseError
from boto.kms.exceptions import AlreadyExistsException, NotFoundException from boto.kms.exceptions import AlreadyExistsException, NotFoundException
from freezegun import freeze_time
from nose.tools import assert_raises from nose.tools import assert_raises
from parameterized import parameterized from parameterized import parameterized
from moto.core.exceptions import JsonRESTError from moto.core.exceptions import JsonRESTError
from moto.kms.models import KmsBackend from moto.kms.models import KmsBackend
from moto.kms.exceptions import NotFoundException as MotoNotFoundException from moto.kms.exceptions import NotFoundException as MotoNotFoundException
from moto import mock_kms, mock_kms_deprecated from moto import mock_kms_deprecated
PLAINTEXT_VECTORS = ( PLAINTEXT_VECTORS = (
(b"some encodeable plaintext",), (b"some encodeable plaintext",),
@ -36,23 +29,6 @@ def _get_encoded_value(plaintext):
return plaintext.encode("utf-8") return plaintext.encode("utf-8")
@mock_kms
def test_create_key():
conn = boto3.client("kms", region_name="us-east-1")
with freeze_time("2015-01-01 00:00:00"):
key = conn.create_key(
Policy="my policy",
Description="my key",
KeyUsage="ENCRYPT_DECRYPT",
Tags=[{"TagKey": "project", "TagValue": "moto"}],
)
key["KeyMetadata"]["Description"].should.equal("my key")
key["KeyMetadata"]["KeyUsage"].should.equal("ENCRYPT_DECRYPT")
key["KeyMetadata"]["Enabled"].should.equal(True)
key["KeyMetadata"]["CreationDate"].should.be.a(date)
@mock_kms_deprecated @mock_kms_deprecated
def test_describe_key(): def test_describe_key():
conn = boto.kms.connect_to_region("us-west-2") conn = boto.kms.connect_to_region("us-west-2")
@ -97,22 +73,6 @@ def test_describe_key_via_alias_not_found():
) )
@parameterized(
(
("alias/does-not-exist",),
("arn:aws:kms:us-east-1:012345678912:alias/does-not-exist",),
("invalid",),
)
)
@mock_kms
def test_describe_key_via_alias_invalid_alias(key_id):
client = boto3.client("kms", region_name="us-east-1")
client.create_key(Description="key")
with assert_raises(client.exceptions.NotFoundException):
client.describe_key(KeyId=key_id)
@mock_kms_deprecated @mock_kms_deprecated
def test_describe_key_via_arn(): def test_describe_key_via_arn():
conn = boto.kms.connect_to_region("us-west-2") conn = boto.kms.connect_to_region("us-west-2")
@ -240,71 +200,6 @@ def test_generate_data_key():
response["KeyId"].should.equal(key_arn) response["KeyId"].should.equal(key_arn)
@mock_kms
def test_boto3_generate_data_key():
kms = boto3.client("kms", region_name="us-west-2")
key = kms.create_key()
key_id = key["KeyMetadata"]["KeyId"]
key_arn = key["KeyMetadata"]["Arn"]
response = kms.generate_data_key(KeyId=key_id, NumberOfBytes=32)
# CiphertextBlob must NOT be base64-encoded
with assert_raises(Exception):
base64.b64decode(response["CiphertextBlob"], validate=True)
# Plaintext must NOT be base64-encoded
with assert_raises(Exception):
base64.b64decode(response["Plaintext"], validate=True)
response["KeyId"].should.equal(key_arn)
@parameterized(PLAINTEXT_VECTORS)
@mock_kms
def test_encrypt(plaintext):
client = boto3.client("kms", region_name="us-west-2")
key = client.create_key(Description="key")
key_id = key["KeyMetadata"]["KeyId"]
key_arn = key["KeyMetadata"]["Arn"]
response = client.encrypt(KeyId=key_id, Plaintext=plaintext)
response["CiphertextBlob"].should_not.equal(plaintext)
# CiphertextBlob must NOT be base64-encoded
with assert_raises(Exception):
base64.b64decode(response["CiphertextBlob"], validate=True)
response["KeyId"].should.equal(key_arn)
@parameterized(PLAINTEXT_VECTORS)
@mock_kms
def test_decrypt(plaintext):
client = boto3.client("kms", region_name="us-west-2")
key = client.create_key(Description="key")
key_id = key["KeyMetadata"]["KeyId"]
key_arn = key["KeyMetadata"]["Arn"]
encrypt_response = client.encrypt(KeyId=key_id, Plaintext=plaintext)
client.create_key(Description="key")
# CiphertextBlob must NOT be base64-encoded
with assert_raises(Exception):
base64.b64decode(encrypt_response["CiphertextBlob"], validate=True)
decrypt_response = client.decrypt(CiphertextBlob=encrypt_response["CiphertextBlob"])
# Plaintext must NOT be base64-encoded
with assert_raises(Exception):
base64.b64decode(decrypt_response["Plaintext"], validate=True)
decrypt_response["Plaintext"].should.equal(_get_encoded_value(plaintext))
decrypt_response["KeyId"].should.equal(key_arn)
@mock_kms_deprecated @mock_kms_deprecated
def test_disable_key_rotation_with_missing_key(): def test_disable_key_rotation_with_missing_key():
conn = boto.kms.connect_to_region("us-west-2") conn = boto.kms.connect_to_region("us-west-2")
@ -775,25 +670,6 @@ def test__list_aliases():
len(aliases).should.equal(7) len(aliases).should.equal(7)
@parameterized(
(
("not-a-uuid",),
("alias/DoesNotExist",),
("arn:aws:kms:us-east-1:012345678912:alias/DoesNotExist",),
("d25652e4-d2d2-49f7-929a-671ccda580c6",),
(
"arn:aws:kms:us-east-1:012345678912:key/d25652e4-d2d2-49f7-929a-671ccda580c6",
),
)
)
@mock_kms
def test_invalid_key_ids(key_id):
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.generate_data_key(KeyId=key_id, NumberOfBytes=5)
@mock_kms_deprecated @mock_kms_deprecated
def test__assert_default_policy(): def test__assert_default_policy():
from moto.kms.responses import _assert_default_policy from moto.kms.responses import _assert_default_policy
@ -804,431 +680,3 @@ def test__assert_default_policy():
_assert_default_policy.when.called_with("default").should_not.throw( _assert_default_policy.when.called_with("default").should_not.throw(
MotoNotFoundException MotoNotFoundException
) )
@parameterized(PLAINTEXT_VECTORS)
@mock_kms
def test_kms_encrypt_boto3(plaintext):
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="key")
response = client.encrypt(KeyId=key["KeyMetadata"]["KeyId"], Plaintext=plaintext)
response = client.decrypt(CiphertextBlob=response["CiphertextBlob"])
response["Plaintext"].should.equal(_get_encoded_value(plaintext))
@mock_kms
def test_disable_key():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="disable-key")
client.disable_key(KeyId=key["KeyMetadata"]["KeyId"])
result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"])
assert result["KeyMetadata"]["Enabled"] == False
assert result["KeyMetadata"]["KeyState"] == "Disabled"
@mock_kms
def test_enable_key():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="enable-key")
client.disable_key(KeyId=key["KeyMetadata"]["KeyId"])
client.enable_key(KeyId=key["KeyMetadata"]["KeyId"])
result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"])
assert result["KeyMetadata"]["Enabled"] == True
assert result["KeyMetadata"]["KeyState"] == "Enabled"
@mock_kms
def test_schedule_key_deletion():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="schedule-key-deletion")
if os.environ.get("TEST_SERVER_MODE", "false").lower() == "false":
with freeze_time("2015-01-01 12:00:00"):
response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"])
assert response["KeyId"] == key["KeyMetadata"]["KeyId"]
assert response["DeletionDate"] == datetime(
2015, 1, 31, 12, 0, tzinfo=tzutc()
)
else:
# Can't manipulate time in server mode
response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"])
assert response["KeyId"] == key["KeyMetadata"]["KeyId"]
result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"])
assert result["KeyMetadata"]["Enabled"] == False
assert result["KeyMetadata"]["KeyState"] == "PendingDeletion"
assert "DeletionDate" in result["KeyMetadata"]
@mock_kms
def test_schedule_key_deletion_custom():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="schedule-key-deletion")
if os.environ.get("TEST_SERVER_MODE", "false").lower() == "false":
with freeze_time("2015-01-01 12:00:00"):
response = client.schedule_key_deletion(
KeyId=key["KeyMetadata"]["KeyId"], PendingWindowInDays=7
)
assert response["KeyId"] == key["KeyMetadata"]["KeyId"]
assert response["DeletionDate"] == datetime(
2015, 1, 8, 12, 0, tzinfo=tzutc()
)
else:
# Can't manipulate time in server mode
response = client.schedule_key_deletion(
KeyId=key["KeyMetadata"]["KeyId"], PendingWindowInDays=7
)
assert response["KeyId"] == key["KeyMetadata"]["KeyId"]
result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"])
assert result["KeyMetadata"]["Enabled"] == False
assert result["KeyMetadata"]["KeyState"] == "PendingDeletion"
assert "DeletionDate" in result["KeyMetadata"]
@mock_kms
def test_cancel_key_deletion():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="cancel-key-deletion")
client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"])
response = client.cancel_key_deletion(KeyId=key["KeyMetadata"]["KeyId"])
assert response["KeyId"] == key["KeyMetadata"]["KeyId"]
result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"])
assert result["KeyMetadata"]["Enabled"] == False
assert result["KeyMetadata"]["KeyState"] == "Disabled"
assert "DeletionDate" not in result["KeyMetadata"]
@mock_kms
def test_update_key_description():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="old_description")
key_id = key["KeyMetadata"]["KeyId"]
result = client.update_key_description(KeyId=key_id, Description="new_description")
assert "ResponseMetadata" in result
@mock_kms
def test_key_tagging_happy():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="test-key-tagging")
key_id = key["KeyMetadata"]["KeyId"]
tags = [{"TagKey": "key1", "TagValue": "value1"}, {"TagKey": "key2", "TagValue": "value2"}]
client.tag_resource(KeyId=key_id, Tags=tags)
result = client.list_resource_tags(KeyId=key_id)
actual = result.get("Tags", [])
assert tags == actual
client.untag_resource(KeyId=key_id, TagKeys=["key1"])
actual = client.list_resource_tags(KeyId=key_id).get("Tags", [])
expected = [{"TagKey": "key2", "TagValue": "value2"}]
assert expected == actual
@mock_kms
def test_key_tagging_sad():
b = KmsBackend()
try:
b.tag_resource('unknown', [])
raise 'tag_resource should fail if KeyId is not known'
except JsonRESTError:
pass
try:
b.untag_resource('unknown', [])
raise 'untag_resource should fail if KeyId is not known'
except JsonRESTError:
pass
try:
b.list_resource_tags('unknown')
raise 'list_resource_tags should fail if KeyId is not known'
except JsonRESTError:
pass
@parameterized(
(
(dict(KeySpec="AES_256"), 32),
(dict(KeySpec="AES_128"), 16),
(dict(NumberOfBytes=64), 64),
(dict(NumberOfBytes=1), 1),
(dict(NumberOfBytes=1024), 1024),
)
)
@mock_kms
def test_generate_data_key_sizes(kwargs, expected_key_length):
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="generate-data-key-size")
response = client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"], **kwargs)
assert len(response["Plaintext"]) == expected_key_length
@mock_kms
def test_generate_data_key_decrypt():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="generate-data-key-decrypt")
resp1 = client.generate_data_key(
KeyId=key["KeyMetadata"]["KeyId"], KeySpec="AES_256"
)
resp2 = client.decrypt(CiphertextBlob=resp1["CiphertextBlob"])
assert resp1["Plaintext"] == resp2["Plaintext"]
@parameterized(
(
(dict(KeySpec="AES_257"),),
(dict(KeySpec="AES_128", NumberOfBytes=16),),
(dict(NumberOfBytes=2048),),
(dict(NumberOfBytes=0),),
(dict(),),
)
)
@mock_kms
def test_generate_data_key_invalid_size_params(kwargs):
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="generate-data-key-size")
with assert_raises(
(botocore.exceptions.ClientError, botocore.exceptions.ParamValidationError)
) as err:
client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"], **kwargs)
@parameterized(
(
("alias/DoesNotExist",),
("arn:aws:kms:us-east-1:012345678912:alias/DoesNotExist",),
("d25652e4-d2d2-49f7-929a-671ccda580c6",),
(
"arn:aws:kms:us-east-1:012345678912:key/d25652e4-d2d2-49f7-929a-671ccda580c6",
),
)
)
@mock_kms
def test_generate_data_key_invalid_key(key_id):
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.generate_data_key(KeyId=key_id, KeySpec="AES_256")
@parameterized(
(
("alias/DoesExist", False),
("arn:aws:kms:us-east-1:012345678912:alias/DoesExist", False),
("", True),
("arn:aws:kms:us-east-1:012345678912:key/", True),
)
)
@mock_kms
def test_generate_data_key_all_valid_key_ids(prefix, append_key_id):
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key()
key_id = key["KeyMetadata"]["KeyId"]
client.create_alias(AliasName="alias/DoesExist", TargetKeyId=key_id)
target_id = prefix
if append_key_id:
target_id += key_id
client.generate_data_key(KeyId=key_id, NumberOfBytes=32)
@mock_kms
def test_generate_data_key_without_plaintext_decrypt():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="generate-data-key-decrypt")
resp1 = client.generate_data_key_without_plaintext(
KeyId=key["KeyMetadata"]["KeyId"], KeySpec="AES_256"
)
assert "Plaintext" not in resp1
@parameterized(PLAINTEXT_VECTORS)
@mock_kms
def test_re_encrypt_decrypt(plaintext):
client = boto3.client("kms", region_name="us-west-2")
key_1 = client.create_key(Description="key 1")
key_1_id = key_1["KeyMetadata"]["KeyId"]
key_1_arn = key_1["KeyMetadata"]["Arn"]
key_2 = client.create_key(Description="key 2")
key_2_id = key_2["KeyMetadata"]["KeyId"]
key_2_arn = key_2["KeyMetadata"]["Arn"]
encrypt_response = client.encrypt(
KeyId=key_1_id, Plaintext=plaintext, EncryptionContext={"encryption": "context"}
)
re_encrypt_response = client.re_encrypt(
CiphertextBlob=encrypt_response["CiphertextBlob"],
SourceEncryptionContext={"encryption": "context"},
DestinationKeyId=key_2_id,
DestinationEncryptionContext={"another": "context"},
)
# CiphertextBlob must NOT be base64-encoded
with assert_raises(Exception):
base64.b64decode(re_encrypt_response["CiphertextBlob"], validate=True)
re_encrypt_response["SourceKeyId"].should.equal(key_1_arn)
re_encrypt_response["KeyId"].should.equal(key_2_arn)
decrypt_response_1 = client.decrypt(
CiphertextBlob=encrypt_response["CiphertextBlob"],
EncryptionContext={"encryption": "context"},
)
decrypt_response_1["Plaintext"].should.equal(_get_encoded_value(plaintext))
decrypt_response_1["KeyId"].should.equal(key_1_arn)
decrypt_response_2 = client.decrypt(
CiphertextBlob=re_encrypt_response["CiphertextBlob"],
EncryptionContext={"another": "context"},
)
decrypt_response_2["Plaintext"].should.equal(_get_encoded_value(plaintext))
decrypt_response_2["KeyId"].should.equal(key_2_arn)
decrypt_response_1["Plaintext"].should.equal(decrypt_response_2["Plaintext"])
@mock_kms
def test_re_encrypt_to_invalid_destination():
client = boto3.client("kms", region_name="us-west-2")
key = client.create_key(Description="key 1")
key_id = key["KeyMetadata"]["KeyId"]
encrypt_response = client.encrypt(KeyId=key_id, Plaintext=b"some plaintext")
with assert_raises(client.exceptions.NotFoundException):
client.re_encrypt(
CiphertextBlob=encrypt_response["CiphertextBlob"],
DestinationKeyId="alias/DoesNotExist",
)
@parameterized(((12,), (44,), (91,), (1,), (1024,)))
@mock_kms
def test_generate_random(number_of_bytes):
client = boto3.client("kms", region_name="us-west-2")
response = client.generate_random(NumberOfBytes=number_of_bytes)
response["Plaintext"].should.be.a(bytes)
len(response["Plaintext"]).should.equal(number_of_bytes)
@parameterized(
(
(2048, botocore.exceptions.ClientError),
(1025, botocore.exceptions.ClientError),
(0, botocore.exceptions.ParamValidationError),
(-1, botocore.exceptions.ParamValidationError),
(-1024, botocore.exceptions.ParamValidationError),
)
)
@mock_kms
def test_generate_random_invalid_number_of_bytes(number_of_bytes, error_type):
client = boto3.client("kms", region_name="us-west-2")
with assert_raises(error_type):
client.generate_random(NumberOfBytes=number_of_bytes)
@mock_kms
def test_enable_key_rotation_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.enable_key_rotation(KeyId="12366f9b-1230-123d-123e-123e6ae60c02")
@mock_kms
def test_disable_key_rotation_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.disable_key_rotation(KeyId="12366f9b-1230-123d-123e-123e6ae60c02")
@mock_kms
def test_enable_key_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.enable_key(KeyId="12366f9b-1230-123d-123e-123e6ae60c02")
@mock_kms
def test_disable_key_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.disable_key(KeyId="12366f9b-1230-123d-123e-123e6ae60c02")
@mock_kms
def test_cancel_key_deletion_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.cancel_key_deletion(KeyId="12366f9b-1230-123d-123e-123e6ae60c02")
@mock_kms
def test_schedule_key_deletion_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.schedule_key_deletion(KeyId="12366f9b-1230-123d-123e-123e6ae60c02")
@mock_kms
def test_get_key_rotation_status_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.get_key_rotation_status(KeyId="12366f9b-1230-123d-123e-123e6ae60c02")
@mock_kms
def test_get_key_policy_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.get_key_policy(
KeyId="12366f9b-1230-123d-123e-123e6ae60c02", PolicyName="default"
)
@mock_kms
def test_list_key_policies_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.list_key_policies(KeyId="12366f9b-1230-123d-123e-123e6ae60c02")
@mock_kms
def test_put_key_policy_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.put_key_policy(
KeyId="00000000-0000-0000-0000-000000000000",
PolicyName="default",
Policy="new policy",
)

View File

@ -0,0 +1,638 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from datetime import datetime
from dateutil.tz import tzutc
import base64
import os
import boto3
import botocore.exceptions
import six
import sure # noqa
from freezegun import freeze_time
from nose.tools import assert_raises
from parameterized import parameterized
from moto import mock_kms
PLAINTEXT_VECTORS = (
(b"some encodeable plaintext",),
(b"some unencodeable plaintext \xec\x8a\xcf\xb6r\xe9\xb5\xeb\xff\xa23\x16",),
("some unicode characters ø˚∆øˆˆ∆ßçøˆˆçßøˆ¨¥",),
)
def _get_encoded_value(plaintext):
if isinstance(plaintext, six.binary_type):
return plaintext
return plaintext.encode("utf-8")
@mock_kms
def test_create_key():
conn = boto3.client("kms", region_name="us-east-1")
key = conn.create_key(
Policy="my policy",
Description="my key",
KeyUsage="ENCRYPT_DECRYPT",
Tags=[{"TagKey": "project", "TagValue": "moto"}],
)
key["KeyMetadata"]["Arn"].should.equal(
"arn:aws:kms:us-east-1:123456789012:key/{}".format(key["KeyMetadata"]["KeyId"])
)
key["KeyMetadata"]["AWSAccountId"].should.equal("123456789012")
key["KeyMetadata"]["CreationDate"].should.be.a(datetime)
key["KeyMetadata"]["CustomerMasterKeySpec"].should.equal("SYMMETRIC_DEFAULT")
key["KeyMetadata"]["Description"].should.equal("my key")
key["KeyMetadata"]["Enabled"].should.be.ok
key["KeyMetadata"]["EncryptionAlgorithms"].should.equal(["SYMMETRIC_DEFAULT"])
key["KeyMetadata"]["KeyId"].should_not.be.empty
key["KeyMetadata"]["KeyManager"].should.equal("CUSTOMER")
key["KeyMetadata"]["KeyState"].should.equal("Enabled")
key["KeyMetadata"]["KeyUsage"].should.equal("ENCRYPT_DECRYPT")
key["KeyMetadata"]["Origin"].should.equal("AWS_KMS")
key["KeyMetadata"].should_not.have.key("SigningAlgorithms")
key = conn.create_key(KeyUsage="ENCRYPT_DECRYPT", CustomerMasterKeySpec="RSA_2048",)
sorted(key["KeyMetadata"]["EncryptionAlgorithms"]).should.equal(
["RSAES_OAEP_SHA_1", "RSAES_OAEP_SHA_256"]
)
key["KeyMetadata"].should_not.have.key("SigningAlgorithms")
key = conn.create_key(KeyUsage="SIGN_VERIFY", CustomerMasterKeySpec="RSA_2048",)
key["KeyMetadata"].should_not.have.key("EncryptionAlgorithms")
sorted(key["KeyMetadata"]["SigningAlgorithms"]).should.equal(
[
"RSASSA_PKCS1_V1_5_SHA_256",
"RSASSA_PKCS1_V1_5_SHA_384",
"RSASSA_PKCS1_V1_5_SHA_512",
"RSASSA_PSS_SHA_256",
"RSASSA_PSS_SHA_384",
"RSASSA_PSS_SHA_512",
]
)
key = conn.create_key(
KeyUsage="SIGN_VERIFY", CustomerMasterKeySpec="ECC_SECG_P256K1",
)
key["KeyMetadata"].should_not.have.key("EncryptionAlgorithms")
key["KeyMetadata"]["SigningAlgorithms"].should.equal(["ECDSA_SHA_256"])
key = conn.create_key(
KeyUsage="SIGN_VERIFY", CustomerMasterKeySpec="ECC_NIST_P384",
)
key["KeyMetadata"].should_not.have.key("EncryptionAlgorithms")
key["KeyMetadata"]["SigningAlgorithms"].should.equal(["ECDSA_SHA_384"])
key = conn.create_key(
KeyUsage="SIGN_VERIFY", CustomerMasterKeySpec="ECC_NIST_P521",
)
key["KeyMetadata"].should_not.have.key("EncryptionAlgorithms")
key["KeyMetadata"]["SigningAlgorithms"].should.equal(["ECDSA_SHA_512"])
@mock_kms
def test_describe_key():
client = boto3.client("kms", region_name="us-east-1")
response = client.create_key(Description="my key", KeyUsage="ENCRYPT_DECRYPT",)
key_id = response["KeyMetadata"]["KeyId"]
response = client.describe_key(KeyId=key_id)
response["KeyMetadata"]["AWSAccountId"].should.equal("123456789012")
response["KeyMetadata"]["CreationDate"].should.be.a(datetime)
response["KeyMetadata"]["CustomerMasterKeySpec"].should.equal("SYMMETRIC_DEFAULT")
response["KeyMetadata"]["Description"].should.equal("my key")
response["KeyMetadata"]["Enabled"].should.be.ok
response["KeyMetadata"]["EncryptionAlgorithms"].should.equal(["SYMMETRIC_DEFAULT"])
response["KeyMetadata"]["KeyId"].should_not.be.empty
response["KeyMetadata"]["KeyManager"].should.equal("CUSTOMER")
response["KeyMetadata"]["KeyState"].should.equal("Enabled")
response["KeyMetadata"]["KeyUsage"].should.equal("ENCRYPT_DECRYPT")
response["KeyMetadata"]["Origin"].should.equal("AWS_KMS")
response["KeyMetadata"].should_not.have.key("SigningAlgorithms")
@parameterized(
(
("alias/does-not-exist",),
("arn:aws:kms:us-east-1:012345678912:alias/does-not-exist",),
("invalid",),
)
)
@mock_kms
def test_describe_key_via_alias_invalid_alias(key_id):
client = boto3.client("kms", region_name="us-east-1")
client.create_key(Description="key")
with assert_raises(client.exceptions.NotFoundException):
client.describe_key(KeyId=key_id)
@mock_kms
def test_generate_data_key():
kms = boto3.client("kms", region_name="us-west-2")
key = kms.create_key()
key_id = key["KeyMetadata"]["KeyId"]
key_arn = key["KeyMetadata"]["Arn"]
response = kms.generate_data_key(KeyId=key_id, NumberOfBytes=32)
# CiphertextBlob must NOT be base64-encoded
with assert_raises(Exception):
base64.b64decode(response["CiphertextBlob"], validate=True)
# Plaintext must NOT be base64-encoded
with assert_raises(Exception):
base64.b64decode(response["Plaintext"], validate=True)
response["KeyId"].should.equal(key_arn)
@parameterized(PLAINTEXT_VECTORS)
@mock_kms
def test_encrypt(plaintext):
client = boto3.client("kms", region_name="us-west-2")
key = client.create_key(Description="key")
key_id = key["KeyMetadata"]["KeyId"]
key_arn = key["KeyMetadata"]["Arn"]
response = client.encrypt(KeyId=key_id, Plaintext=plaintext)
response["CiphertextBlob"].should_not.equal(plaintext)
# CiphertextBlob must NOT be base64-encoded
with assert_raises(Exception):
base64.b64decode(response["CiphertextBlob"], validate=True)
response["KeyId"].should.equal(key_arn)
@parameterized(PLAINTEXT_VECTORS)
@mock_kms
def test_decrypt(plaintext):
client = boto3.client("kms", region_name="us-west-2")
key = client.create_key(Description="key")
key_id = key["KeyMetadata"]["KeyId"]
key_arn = key["KeyMetadata"]["Arn"]
encrypt_response = client.encrypt(KeyId=key_id, Plaintext=plaintext)
client.create_key(Description="key")
# CiphertextBlob must NOT be base64-encoded
with assert_raises(Exception):
base64.b64decode(encrypt_response["CiphertextBlob"], validate=True)
decrypt_response = client.decrypt(CiphertextBlob=encrypt_response["CiphertextBlob"])
# Plaintext must NOT be base64-encoded
with assert_raises(Exception):
base64.b64decode(decrypt_response["Plaintext"], validate=True)
decrypt_response["Plaintext"].should.equal(_get_encoded_value(plaintext))
decrypt_response["KeyId"].should.equal(key_arn)
@parameterized(
(
("not-a-uuid",),
("alias/DoesNotExist",),
("arn:aws:kms:us-east-1:012345678912:alias/DoesNotExist",),
("d25652e4-d2d2-49f7-929a-671ccda580c6",),
(
"arn:aws:kms:us-east-1:012345678912:key/d25652e4-d2d2-49f7-929a-671ccda580c6",
),
)
)
@mock_kms
def test_invalid_key_ids(key_id):
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.generate_data_key(KeyId=key_id, NumberOfBytes=5)
@parameterized(PLAINTEXT_VECTORS)
@mock_kms
def test_kms_encrypt(plaintext):
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="key")
response = client.encrypt(KeyId=key["KeyMetadata"]["KeyId"], Plaintext=plaintext)
response = client.decrypt(CiphertextBlob=response["CiphertextBlob"])
response["Plaintext"].should.equal(_get_encoded_value(plaintext))
@mock_kms
def test_disable_key():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="disable-key")
client.disable_key(KeyId=key["KeyMetadata"]["KeyId"])
result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"])
assert result["KeyMetadata"]["Enabled"] == False
assert result["KeyMetadata"]["KeyState"] == "Disabled"
@mock_kms
def test_enable_key():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="enable-key")
client.disable_key(KeyId=key["KeyMetadata"]["KeyId"])
client.enable_key(KeyId=key["KeyMetadata"]["KeyId"])
result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"])
assert result["KeyMetadata"]["Enabled"] == True
assert result["KeyMetadata"]["KeyState"] == "Enabled"
@mock_kms
def test_schedule_key_deletion():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="schedule-key-deletion")
if os.environ.get("TEST_SERVER_MODE", "false").lower() == "false":
with freeze_time("2015-01-01 12:00:00"):
response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"])
assert response["KeyId"] == key["KeyMetadata"]["KeyId"]
assert response["DeletionDate"] == datetime(
2015, 1, 31, 12, 0, tzinfo=tzutc()
)
else:
# Can't manipulate time in server mode
response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"])
assert response["KeyId"] == key["KeyMetadata"]["KeyId"]
result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"])
assert result["KeyMetadata"]["Enabled"] == False
assert result["KeyMetadata"]["KeyState"] == "PendingDeletion"
assert "DeletionDate" in result["KeyMetadata"]
@mock_kms
def test_schedule_key_deletion_custom():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="schedule-key-deletion")
if os.environ.get("TEST_SERVER_MODE", "false").lower() == "false":
with freeze_time("2015-01-01 12:00:00"):
response = client.schedule_key_deletion(
KeyId=key["KeyMetadata"]["KeyId"], PendingWindowInDays=7
)
assert response["KeyId"] == key["KeyMetadata"]["KeyId"]
assert response["DeletionDate"] == datetime(
2015, 1, 8, 12, 0, tzinfo=tzutc()
)
else:
# Can't manipulate time in server mode
response = client.schedule_key_deletion(
KeyId=key["KeyMetadata"]["KeyId"], PendingWindowInDays=7
)
assert response["KeyId"] == key["KeyMetadata"]["KeyId"]
result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"])
assert result["KeyMetadata"]["Enabled"] == False
assert result["KeyMetadata"]["KeyState"] == "PendingDeletion"
assert "DeletionDate" in result["KeyMetadata"]
@mock_kms
def test_cancel_key_deletion():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="cancel-key-deletion")
client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"])
response = client.cancel_key_deletion(KeyId=key["KeyMetadata"]["KeyId"])
assert response["KeyId"] == key["KeyMetadata"]["KeyId"]
result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"])
assert result["KeyMetadata"]["Enabled"] == False
assert result["KeyMetadata"]["KeyState"] == "Disabled"
assert "DeletionDate" not in result["KeyMetadata"]
@mock_kms
def test_update_key_description():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="old_description")
key_id = key["KeyMetadata"]["KeyId"]
result = client.update_key_description(KeyId=key_id, Description="new_description")
assert "ResponseMetadata" in result
@mock_kms
def test_tag_resource():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="cancel-key-deletion")
response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"])
keyid = response["KeyId"]
response = client.tag_resource(
KeyId=keyid, Tags=[{"TagKey": "string", "TagValue": "string"}]
)
# Shouldn't have any data, just header
assert len(response.keys()) == 1
@mock_kms
def test_list_resource_tags():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="cancel-key-deletion")
response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"])
keyid = response["KeyId"]
response = client.tag_resource(
KeyId=keyid, Tags=[{"TagKey": "string", "TagValue": "string"}]
)
response = client.list_resource_tags(KeyId=keyid)
assert response["Tags"][0]["TagKey"] == "string"
assert response["Tags"][0]["TagValue"] == "string"
@parameterized(
(
(dict(KeySpec="AES_256"), 32),
(dict(KeySpec="AES_128"), 16),
(dict(NumberOfBytes=64), 64),
(dict(NumberOfBytes=1), 1),
(dict(NumberOfBytes=1024), 1024),
)
)
@mock_kms
def test_generate_data_key_sizes(kwargs, expected_key_length):
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="generate-data-key-size")
response = client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"], **kwargs)
assert len(response["Plaintext"]) == expected_key_length
@mock_kms
def test_generate_data_key_decrypt():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="generate-data-key-decrypt")
resp1 = client.generate_data_key(
KeyId=key["KeyMetadata"]["KeyId"], KeySpec="AES_256"
)
resp2 = client.decrypt(CiphertextBlob=resp1["CiphertextBlob"])
assert resp1["Plaintext"] == resp2["Plaintext"]
@parameterized(
(
(dict(KeySpec="AES_257"),),
(dict(KeySpec="AES_128", NumberOfBytes=16),),
(dict(NumberOfBytes=2048),),
(dict(NumberOfBytes=0),),
(dict(),),
)
)
@mock_kms
def test_generate_data_key_invalid_size_params(kwargs):
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="generate-data-key-size")
with assert_raises(
(botocore.exceptions.ClientError, botocore.exceptions.ParamValidationError)
) as err:
client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"], **kwargs)
@parameterized(
(
("alias/DoesNotExist",),
("arn:aws:kms:us-east-1:012345678912:alias/DoesNotExist",),
("d25652e4-d2d2-49f7-929a-671ccda580c6",),
(
"arn:aws:kms:us-east-1:012345678912:key/d25652e4-d2d2-49f7-929a-671ccda580c6",
),
)
)
@mock_kms
def test_generate_data_key_invalid_key(key_id):
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.generate_data_key(KeyId=key_id, KeySpec="AES_256")
@parameterized(
(
("alias/DoesExist", False),
("arn:aws:kms:us-east-1:012345678912:alias/DoesExist", False),
("", True),
("arn:aws:kms:us-east-1:012345678912:key/", True),
)
)
@mock_kms
def test_generate_data_key_all_valid_key_ids(prefix, append_key_id):
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key()
key_id = key["KeyMetadata"]["KeyId"]
client.create_alias(AliasName="alias/DoesExist", TargetKeyId=key_id)
target_id = prefix
if append_key_id:
target_id += key_id
client.generate_data_key(KeyId=key_id, NumberOfBytes=32)
@mock_kms
def test_generate_data_key_without_plaintext_decrypt():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="generate-data-key-decrypt")
resp1 = client.generate_data_key_without_plaintext(
KeyId=key["KeyMetadata"]["KeyId"], KeySpec="AES_256"
)
assert "Plaintext" not in resp1
@parameterized(PLAINTEXT_VECTORS)
@mock_kms
def test_re_encrypt_decrypt(plaintext):
client = boto3.client("kms", region_name="us-west-2")
key_1 = client.create_key(Description="key 1")
key_1_id = key_1["KeyMetadata"]["KeyId"]
key_1_arn = key_1["KeyMetadata"]["Arn"]
key_2 = client.create_key(Description="key 2")
key_2_id = key_2["KeyMetadata"]["KeyId"]
key_2_arn = key_2["KeyMetadata"]["Arn"]
encrypt_response = client.encrypt(
KeyId=key_1_id, Plaintext=plaintext, EncryptionContext={"encryption": "context"}
)
re_encrypt_response = client.re_encrypt(
CiphertextBlob=encrypt_response["CiphertextBlob"],
SourceEncryptionContext={"encryption": "context"},
DestinationKeyId=key_2_id,
DestinationEncryptionContext={"another": "context"},
)
# CiphertextBlob must NOT be base64-encoded
with assert_raises(Exception):
base64.b64decode(re_encrypt_response["CiphertextBlob"], validate=True)
re_encrypt_response["SourceKeyId"].should.equal(key_1_arn)
re_encrypt_response["KeyId"].should.equal(key_2_arn)
decrypt_response_1 = client.decrypt(
CiphertextBlob=encrypt_response["CiphertextBlob"],
EncryptionContext={"encryption": "context"},
)
decrypt_response_1["Plaintext"].should.equal(_get_encoded_value(plaintext))
decrypt_response_1["KeyId"].should.equal(key_1_arn)
decrypt_response_2 = client.decrypt(
CiphertextBlob=re_encrypt_response["CiphertextBlob"],
EncryptionContext={"another": "context"},
)
decrypt_response_2["Plaintext"].should.equal(_get_encoded_value(plaintext))
decrypt_response_2["KeyId"].should.equal(key_2_arn)
decrypt_response_1["Plaintext"].should.equal(decrypt_response_2["Plaintext"])
@mock_kms
def test_re_encrypt_to_invalid_destination():
client = boto3.client("kms", region_name="us-west-2")
key = client.create_key(Description="key 1")
key_id = key["KeyMetadata"]["KeyId"]
encrypt_response = client.encrypt(KeyId=key_id, Plaintext=b"some plaintext")
with assert_raises(client.exceptions.NotFoundException):
client.re_encrypt(
CiphertextBlob=encrypt_response["CiphertextBlob"],
DestinationKeyId="alias/DoesNotExist",
)
@parameterized(((12,), (44,), (91,), (1,), (1024,)))
@mock_kms
def test_generate_random(number_of_bytes):
client = boto3.client("kms", region_name="us-west-2")
response = client.generate_random(NumberOfBytes=number_of_bytes)
response["Plaintext"].should.be.a(bytes)
len(response["Plaintext"]).should.equal(number_of_bytes)
@parameterized(
(
(2048, botocore.exceptions.ClientError),
(1025, botocore.exceptions.ClientError),
(0, botocore.exceptions.ParamValidationError),
(-1, botocore.exceptions.ParamValidationError),
(-1024, botocore.exceptions.ParamValidationError),
)
)
@mock_kms
def test_generate_random_invalid_number_of_bytes(number_of_bytes, error_type):
client = boto3.client("kms", region_name="us-west-2")
with assert_raises(error_type):
client.generate_random(NumberOfBytes=number_of_bytes)
@mock_kms
def test_enable_key_rotation_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.enable_key_rotation(KeyId="12366f9b-1230-123d-123e-123e6ae60c02")
@mock_kms
def test_disable_key_rotation_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.disable_key_rotation(KeyId="12366f9b-1230-123d-123e-123e6ae60c02")
@mock_kms
def test_enable_key_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.enable_key(KeyId="12366f9b-1230-123d-123e-123e6ae60c02")
@mock_kms
def test_disable_key_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.disable_key(KeyId="12366f9b-1230-123d-123e-123e6ae60c02")
@mock_kms
def test_cancel_key_deletion_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.cancel_key_deletion(KeyId="12366f9b-1230-123d-123e-123e6ae60c02")
@mock_kms
def test_schedule_key_deletion_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.schedule_key_deletion(KeyId="12366f9b-1230-123d-123e-123e6ae60c02")
@mock_kms
def test_get_key_rotation_status_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.get_key_rotation_status(KeyId="12366f9b-1230-123d-123e-123e6ae60c02")
@mock_kms
def test_get_key_policy_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.get_key_policy(
KeyId="12366f9b-1230-123d-123e-123e6ae60c02", PolicyName="default"
)
@mock_kms
def test_list_key_policies_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.list_key_policies(KeyId="12366f9b-1230-123d-123e-123e6ae60c02")
@mock_kms
def test_put_key_policy_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.put_key_policy(
KeyId="00000000-0000-0000-0000-000000000000",
PolicyName="default",
Policy="new policy",
)

View File

@ -102,7 +102,7 @@ def test_deserialize_ciphertext_blob(raw, serialized):
@parameterized(((ec[0],) for ec in ENCRYPTION_CONTEXT_VECTORS)) @parameterized(((ec[0],) for ec in ENCRYPTION_CONTEXT_VECTORS))
def test_encrypt_decrypt_cycle(encryption_context): def test_encrypt_decrypt_cycle(encryption_context):
plaintext = b"some secret plaintext" plaintext = b"some secret plaintext"
master_key = Key("nop", "nop", "nop", "nop") master_key = Key("nop", "nop", "nop", "nop", [], "nop")
master_key_map = {master_key.id: master_key} master_key_map = {master_key.id: master_key}
ciphertext_blob = encrypt( ciphertext_blob = encrypt(
@ -133,7 +133,7 @@ def test_encrypt_unknown_key_id():
def test_decrypt_invalid_ciphertext_format(): def test_decrypt_invalid_ciphertext_format():
master_key = Key("nop", "nop", "nop", "nop") master_key = Key("nop", "nop", "nop", "nop", [], "nop")
master_key_map = {master_key.id: master_key} master_key_map = {master_key.id: master_key}
with assert_raises(InvalidCiphertextException): with assert_raises(InvalidCiphertextException):
@ -153,7 +153,7 @@ def test_decrypt_unknwown_key_id():
def test_decrypt_invalid_ciphertext(): def test_decrypt_invalid_ciphertext():
master_key = Key("nop", "nop", "nop", "nop") master_key = Key("nop", "nop", "nop", "nop", [], "nop")
master_key_map = {master_key.id: master_key} master_key_map = {master_key.id: master_key}
ciphertext_blob = ( ciphertext_blob = (
master_key.id.encode("utf-8") + b"123456789012" master_key.id.encode("utf-8") + b"123456789012"
@ -171,7 +171,7 @@ def test_decrypt_invalid_ciphertext():
def test_decrypt_invalid_encryption_context(): def test_decrypt_invalid_encryption_context():
plaintext = b"some secret plaintext" plaintext = b"some secret plaintext"
master_key = Key("nop", "nop", "nop", "nop") master_key = Key("nop", "nop", "nop", "nop", [], "nop")
master_key_map = {master_key.id: master_key} master_key_map = {master_key.id: master_key}
ciphertext_blob = encrypt( ciphertext_blob = encrypt(

View File

@ -713,3 +713,41 @@ def test_untag_resource_errors():
ex.response["Error"]["Message"].should.equal( ex.response["Error"]["Message"].should.equal(
"You provided a value that does not match the required pattern." "You provided a value that does not match the required pattern."
) )
@mock_organizations
def test_update_organizational_unit():
client = boto3.client("organizations", region_name="us-east-1")
org = client.create_organization(FeatureSet="ALL")["Organization"]
root_id = client.list_roots()["Roots"][0]["Id"]
ou_name = "ou01"
response = client.create_organizational_unit(ParentId=root_id, Name=ou_name)
validate_organizational_unit(org, response)
response["OrganizationalUnit"]["Name"].should.equal(ou_name)
new_ou_name = "ou02"
response = client.update_organizational_unit(
OrganizationalUnitId=response["OrganizationalUnit"]["Id"], Name=new_ou_name
)
validate_organizational_unit(org, response)
response["OrganizationalUnit"]["Name"].should.equal(new_ou_name)
@mock_organizations
def test_update_organizational_unit_duplicate_error():
client = boto3.client("organizations", region_name="us-east-1")
org = client.create_organization(FeatureSet="ALL")["Organization"]
root_id = client.list_roots()["Roots"][0]["Id"]
ou_name = "ou01"
response = client.create_organizational_unit(ParentId=root_id, Name=ou_name)
validate_organizational_unit(org, response)
response["OrganizationalUnit"]["Name"].should.equal(ou_name)
with assert_raises(ClientError) as e:
client.update_organizational_unit(
OrganizationalUnitId=response["OrganizationalUnit"]["Id"], Name=ou_name
)
exc = e.exception
exc.operation_name.should.equal("UpdateOrganizationalUnit")
exc.response["Error"]["Code"].should.contain("DuplicateOrganizationalUnitException")
exc.response["Error"]["Message"].should.equal(
"An OU with the same name already exists."
)

View File

@ -68,7 +68,7 @@ def test_get_databases_paginated():
@mock_rds_deprecated @mock_rds_deprecated
def test_describe_non_existant_database(): def test_describe_non_existent_database():
conn = boto.rds.connect_to_region("us-west-2") conn = boto.rds.connect_to_region("us-west-2")
conn.get_all_dbinstances.when.called_with("not-a-db").should.throw(BotoServerError) conn.get_all_dbinstances.when.called_with("not-a-db").should.throw(BotoServerError)
@ -86,7 +86,7 @@ def test_delete_database():
@mock_rds_deprecated @mock_rds_deprecated
def test_delete_non_existant_database(): def test_delete_non_existent_database():
conn = boto.rds.connect_to_region("us-west-2") conn = boto.rds.connect_to_region("us-west-2")
conn.delete_dbinstance.when.called_with("not-a-db").should.throw(BotoServerError) conn.delete_dbinstance.when.called_with("not-a-db").should.throw(BotoServerError)
@ -119,7 +119,7 @@ def test_get_security_groups():
@mock_rds_deprecated @mock_rds_deprecated
def test_get_non_existant_security_group(): def test_get_non_existent_security_group():
conn = boto.rds.connect_to_region("us-west-2") conn = boto.rds.connect_to_region("us-west-2")
conn.get_all_dbsecurity_groups.when.called_with("not-a-sg").should.throw( conn.get_all_dbsecurity_groups.when.called_with("not-a-sg").should.throw(
BotoServerError BotoServerError
@ -138,7 +138,7 @@ def test_delete_database_security_group():
@mock_rds_deprecated @mock_rds_deprecated
def test_delete_non_existant_security_group(): def test_delete_non_existent_security_group():
conn = boto.rds.connect_to_region("us-west-2") conn = boto.rds.connect_to_region("us-west-2")
conn.delete_dbsecurity_group.when.called_with("not-a-db").should.throw( conn.delete_dbsecurity_group.when.called_with("not-a-db").should.throw(
BotoServerError BotoServerError

View File

@ -312,7 +312,7 @@ def test_get_databases_paginated():
@mock_rds2 @mock_rds2
def test_describe_non_existant_database(): def test_describe_non_existent_database():
conn = boto3.client("rds", region_name="us-west-2") conn = boto3.client("rds", region_name="us-west-2")
conn.describe_db_instances.when.called_with( conn.describe_db_instances.when.called_with(
DBInstanceIdentifier="not-a-db" DBInstanceIdentifier="not-a-db"
@ -378,7 +378,7 @@ def test_rename_db_instance():
@mock_rds2 @mock_rds2
def test_modify_non_existant_database(): def test_modify_non_existent_database():
conn = boto3.client("rds", region_name="us-west-2") conn = boto3.client("rds", region_name="us-west-2")
conn.modify_db_instance.when.called_with( conn.modify_db_instance.when.called_with(
DBInstanceIdentifier="not-a-db", AllocatedStorage=20, ApplyImmediately=True DBInstanceIdentifier="not-a-db", AllocatedStorage=20, ApplyImmediately=True
@ -403,7 +403,7 @@ def test_reboot_db_instance():
@mock_rds2 @mock_rds2
def test_reboot_non_existant_database(): def test_reboot_non_existent_database():
conn = boto3.client("rds", region_name="us-west-2") conn = boto3.client("rds", region_name="us-west-2")
conn.reboot_db_instance.when.called_with( conn.reboot_db_instance.when.called_with(
DBInstanceIdentifier="not-a-db" DBInstanceIdentifier="not-a-db"
@ -444,7 +444,7 @@ def test_delete_database():
@mock_rds2 @mock_rds2
def test_delete_non_existant_database(): def test_delete_non_existent_database():
conn = boto3.client("rds2", region_name="us-west-2") conn = boto3.client("rds2", region_name="us-west-2")
conn.delete_db_instance.when.called_with( conn.delete_db_instance.when.called_with(
DBInstanceIdentifier="not-a-db" DBInstanceIdentifier="not-a-db"
@ -663,7 +663,7 @@ def test_describe_option_group():
@mock_rds2 @mock_rds2
def test_describe_non_existant_option_group(): def test_describe_non_existent_option_group():
conn = boto3.client("rds", region_name="us-west-2") conn = boto3.client("rds", region_name="us-west-2")
conn.describe_option_groups.when.called_with( conn.describe_option_groups.when.called_with(
OptionGroupName="not-a-option-group" OptionGroupName="not-a-option-group"
@ -688,10 +688,10 @@ def test_delete_option_group():
@mock_rds2 @mock_rds2
def test_delete_non_existant_option_group(): def test_delete_non_existent_option_group():
conn = boto3.client("rds", region_name="us-west-2") conn = boto3.client("rds", region_name="us-west-2")
conn.delete_option_group.when.called_with( conn.delete_option_group.when.called_with(
OptionGroupName="non-existant" OptionGroupName="non-existent"
).should.throw(ClientError) ).should.throw(ClientError)
@ -754,10 +754,10 @@ def test_modify_option_group_no_options():
@mock_rds2 @mock_rds2
def test_modify_non_existant_option_group(): def test_modify_non_existent_option_group():
conn = boto3.client("rds", region_name="us-west-2") conn = boto3.client("rds", region_name="us-west-2")
conn.modify_option_group.when.called_with( conn.modify_option_group.when.called_with(
OptionGroupName="non-existant", OptionGroupName="non-existent",
OptionsToInclude=[ OptionsToInclude=[
( (
"OptionName", "OptionName",
@ -771,7 +771,7 @@ def test_modify_non_existant_option_group():
@mock_rds2 @mock_rds2
def test_delete_non_existant_database(): def test_delete_non_existent_database():
conn = boto3.client("rds", region_name="us-west-2") conn = boto3.client("rds", region_name="us-west-2")
conn.delete_db_instance.when.called_with( conn.delete_db_instance.when.called_with(
DBInstanceIdentifier="not-a-db" DBInstanceIdentifier="not-a-db"
@ -1053,7 +1053,7 @@ def test_get_security_groups():
@mock_rds2 @mock_rds2
def test_get_non_existant_security_group(): def test_get_non_existent_security_group():
conn = boto3.client("rds", region_name="us-west-2") conn = boto3.client("rds", region_name="us-west-2")
conn.describe_db_security_groups.when.called_with( conn.describe_db_security_groups.when.called_with(
DBSecurityGroupName="not-a-sg" DBSecurityGroupName="not-a-sg"
@ -1076,7 +1076,7 @@ def test_delete_database_security_group():
@mock_rds2 @mock_rds2
def test_delete_non_existant_security_group(): def test_delete_non_existent_security_group():
conn = boto3.client("rds", region_name="us-west-2") conn = boto3.client("rds", region_name="us-west-2")
conn.delete_db_security_group.when.called_with( conn.delete_db_security_group.when.called_with(
DBSecurityGroupName="not-a-db" DBSecurityGroupName="not-a-db"
@ -1615,7 +1615,7 @@ def test_describe_db_parameter_group():
@mock_rds2 @mock_rds2
def test_describe_non_existant_db_parameter_group(): def test_describe_non_existent_db_parameter_group():
conn = boto3.client("rds", region_name="us-west-2") conn = boto3.client("rds", region_name="us-west-2")
db_parameter_groups = conn.describe_db_parameter_groups(DBParameterGroupName="test") db_parameter_groups = conn.describe_db_parameter_groups(DBParameterGroupName="test")
len(db_parameter_groups["DBParameterGroups"]).should.equal(0) len(db_parameter_groups["DBParameterGroups"]).should.equal(0)
@ -1669,10 +1669,10 @@ def test_modify_db_parameter_group():
@mock_rds2 @mock_rds2
def test_delete_non_existant_db_parameter_group(): def test_delete_non_existent_db_parameter_group():
conn = boto3.client("rds", region_name="us-west-2") conn = boto3.client("rds", region_name="us-west-2")
conn.delete_db_parameter_group.when.called_with( conn.delete_db_parameter_group.when.called_with(
DBParameterGroupName="non-existant" DBParameterGroupName="non-existent"
).should.throw(ClientError) ).should.throw(ClientError)
@ -1689,3 +1689,36 @@ def test_create_parameter_group_with_tags():
ResourceName="arn:aws:rds:us-west-2:1234567890:pg:test" ResourceName="arn:aws:rds:us-west-2:1234567890:pg:test"
) )
result["TagList"].should.equal([{"Value": "bar", "Key": "foo"}]) result["TagList"].should.equal([{"Value": "bar", "Key": "foo"}])
@mock_rds2
def test_create_db_with_iam_authentication():
conn = boto3.client("rds", region_name="us-west-2")
database = conn.create_db_instance(
DBInstanceIdentifier="rds",
DBInstanceClass="db.t1.micro",
Engine="postgres",
EnableIAMDatabaseAuthentication=True,
)
db_instance = database["DBInstance"]
db_instance["IAMDatabaseAuthenticationEnabled"].should.equal(True)
@mock_rds2
def test_create_db_snapshot_with_iam_authentication():
conn = boto3.client("rds", region_name="us-west-2")
conn.create_db_instance(
DBInstanceIdentifier="rds",
DBInstanceClass="db.t1.micro",
Engine="postgres",
EnableIAMDatabaseAuthentication=True,
)
snapshot = conn.create_db_snapshot(
DBInstanceIdentifier="rds", DBSnapshotIdentifier="snapshot"
).get("DBSnapshot")
snapshot.get("IAMDatabaseAuthenticationEnabled").should.equal(True)

View File

@ -21,7 +21,10 @@ def test_get_resources_s3():
# Create 4 buckets # Create 4 buckets
for i in range(1, 5): for i in range(1, 5):
i_str = str(i) i_str = str(i)
s3_client.create_bucket(Bucket="test_bucket" + i_str) s3_client.create_bucket(
Bucket="test_bucket" + i_str,
CreateBucketConfiguration={"LocationConstraint": "eu-central-1"},
)
s3_client.put_bucket_tagging( s3_client.put_bucket_tagging(
Bucket="test_bucket" + i_str, Bucket="test_bucket" + i_str,
Tagging={"TagSet": [{"Key": "key" + i_str, "Value": "value" + i_str}]}, Tagging={"TagSet": [{"Key": "key" + i_str, "Value": "value" + i_str}]},

View File

@ -862,6 +862,8 @@ def test_list_resource_record_sets_name_type_filters():
StartRecordName=all_records[start_with][1], StartRecordName=all_records[start_with][1],
) )
response["IsTruncated"].should.equal(False)
returned_records = [ returned_records = [
(record["Type"], record["Name"]) for record in response["ResourceRecordSets"] (record["Type"], record["Name"]) for record in response["ResourceRecordSets"]
] ]

File diff suppressed because it is too large Load Diff

View File

@ -16,7 +16,7 @@ from moto import mock_s3_deprecated, mock_s3
@mock_s3_deprecated @mock_s3_deprecated
def test_lifecycle_create(): def test_lifecycle_create():
conn = boto.s3.connect_to_region("us-west-1") conn = boto.s3.connect_to_region("us-west-1")
bucket = conn.create_bucket("foobar") bucket = conn.create_bucket("foobar", location="us-west-1")
lifecycle = Lifecycle() lifecycle = Lifecycle()
lifecycle.add_rule("myid", "", "Enabled", 30) lifecycle.add_rule("myid", "", "Enabled", 30)
@ -33,7 +33,9 @@ def test_lifecycle_create():
@mock_s3 @mock_s3
def test_lifecycle_with_filters(): def test_lifecycle_with_filters():
client = boto3.client("s3") client = boto3.client("s3")
client.create_bucket(Bucket="bucket") client.create_bucket(
Bucket="bucket", CreateBucketConfiguration={"LocationConstraint": "us-west-1"}
)
# Create a lifecycle rule with a Filter (no tags): # Create a lifecycle rule with a Filter (no tags):
lfc = { lfc = {
@ -245,7 +247,9 @@ def test_lifecycle_with_filters():
@mock_s3 @mock_s3
def test_lifecycle_with_eodm(): def test_lifecycle_with_eodm():
client = boto3.client("s3") client = boto3.client("s3")
client.create_bucket(Bucket="bucket") client.create_bucket(
Bucket="bucket", CreateBucketConfiguration={"LocationConstraint": "us-west-1"}
)
lfc = { lfc = {
"Rules": [ "Rules": [
@ -293,7 +297,9 @@ def test_lifecycle_with_eodm():
@mock_s3 @mock_s3
def test_lifecycle_with_nve(): def test_lifecycle_with_nve():
client = boto3.client("s3") client = boto3.client("s3")
client.create_bucket(Bucket="bucket") client.create_bucket(
Bucket="bucket", CreateBucketConfiguration={"LocationConstraint": "us-west-1"}
)
lfc = { lfc = {
"Rules": [ "Rules": [
@ -327,7 +333,9 @@ def test_lifecycle_with_nve():
@mock_s3 @mock_s3
def test_lifecycle_with_nvt(): def test_lifecycle_with_nvt():
client = boto3.client("s3") client = boto3.client("s3")
client.create_bucket(Bucket="bucket") client.create_bucket(
Bucket="bucket", CreateBucketConfiguration={"LocationConstraint": "us-west-1"}
)
lfc = { lfc = {
"Rules": [ "Rules": [
@ -393,7 +401,9 @@ def test_lifecycle_with_nvt():
@mock_s3 @mock_s3
def test_lifecycle_with_aimu(): def test_lifecycle_with_aimu():
client = boto3.client("s3") client = boto3.client("s3")
client.create_bucket(Bucket="bucket") client.create_bucket(
Bucket="bucket", CreateBucketConfiguration={"LocationConstraint": "us-west-1"}
)
lfc = { lfc = {
"Rules": [ "Rules": [
@ -432,7 +442,7 @@ def test_lifecycle_with_aimu():
@mock_s3_deprecated @mock_s3_deprecated
def test_lifecycle_with_glacier_transition(): def test_lifecycle_with_glacier_transition():
conn = boto.s3.connect_to_region("us-west-1") conn = boto.s3.connect_to_region("us-west-1")
bucket = conn.create_bucket("foobar") bucket = conn.create_bucket("foobar", location="us-west-1")
lifecycle = Lifecycle() lifecycle = Lifecycle()
transition = Transition(days=30, storage_class="GLACIER") transition = Transition(days=30, storage_class="GLACIER")
@ -451,7 +461,7 @@ def test_lifecycle_with_glacier_transition():
@mock_s3_deprecated @mock_s3_deprecated
def test_lifecycle_multi(): def test_lifecycle_multi():
conn = boto.s3.connect_to_region("us-west-1") conn = boto.s3.connect_to_region("us-west-1")
bucket = conn.create_bucket("foobar") bucket = conn.create_bucket("foobar", location="us-west-1")
date = "2022-10-12T00:00:00.000Z" date = "2022-10-12T00:00:00.000Z"
sc = "GLACIER" sc = "GLACIER"
@ -493,7 +503,7 @@ def test_lifecycle_multi():
@mock_s3_deprecated @mock_s3_deprecated
def test_lifecycle_delete(): def test_lifecycle_delete():
conn = boto.s3.connect_to_region("us-west-1") conn = boto.s3.connect_to_region("us-west-1")
bucket = conn.create_bucket("foobar") bucket = conn.create_bucket("foobar", location="us-west-1")
lifecycle = Lifecycle() lifecycle = Lifecycle()
lifecycle.add_rule(expiration=30) lifecycle.add_rule(expiration=30)

View File

@ -11,7 +11,7 @@ from moto import mock_s3
@mock_s3 @mock_s3
def test_s3_storage_class_standard(): def test_s3_storage_class_standard():
s3 = boto3.client("s3") s3 = boto3.client("s3", region_name="us-east-1")
s3.create_bucket(Bucket="Bucket") s3.create_bucket(Bucket="Bucket")
# add an object to the bucket with standard storage # add an object to the bucket with standard storage
@ -26,7 +26,9 @@ def test_s3_storage_class_standard():
@mock_s3 @mock_s3
def test_s3_storage_class_infrequent_access(): def test_s3_storage_class_infrequent_access():
s3 = boto3.client("s3") s3 = boto3.client("s3")
s3.create_bucket(Bucket="Bucket") s3.create_bucket(
Bucket="Bucket", CreateBucketConfiguration={"LocationConstraint": "us-west-2"}
)
# add an object to the bucket with standard storage # add an object to the bucket with standard storage
@ -46,7 +48,9 @@ def test_s3_storage_class_infrequent_access():
def test_s3_storage_class_intelligent_tiering(): def test_s3_storage_class_intelligent_tiering():
s3 = boto3.client("s3") s3 = boto3.client("s3")
s3.create_bucket(Bucket="Bucket") s3.create_bucket(
Bucket="Bucket", CreateBucketConfiguration={"LocationConstraint": "us-east-2"}
)
s3.put_object( s3.put_object(
Bucket="Bucket", Bucket="Bucket",
Key="my_key_infrequent", Key="my_key_infrequent",
@ -61,7 +65,7 @@ def test_s3_storage_class_intelligent_tiering():
@mock_s3 @mock_s3
def test_s3_storage_class_copy(): def test_s3_storage_class_copy():
s3 = boto3.client("s3") s3 = boto3.client("s3", region_name="us-east-1")
s3.create_bucket(Bucket="Bucket") s3.create_bucket(Bucket="Bucket")
s3.put_object( s3.put_object(
Bucket="Bucket", Key="First_Object", Body="Body", StorageClass="STANDARD" Bucket="Bucket", Key="First_Object", Body="Body", StorageClass="STANDARD"
@ -86,7 +90,7 @@ def test_s3_storage_class_copy():
@mock_s3 @mock_s3
def test_s3_invalid_copied_storage_class(): def test_s3_invalid_copied_storage_class():
s3 = boto3.client("s3") s3 = boto3.client("s3", region_name="us-east-1")
s3.create_bucket(Bucket="Bucket") s3.create_bucket(Bucket="Bucket")
s3.put_object( s3.put_object(
Bucket="Bucket", Key="First_Object", Body="Body", StorageClass="STANDARD" Bucket="Bucket", Key="First_Object", Body="Body", StorageClass="STANDARD"
@ -119,7 +123,9 @@ def test_s3_invalid_copied_storage_class():
@mock_s3 @mock_s3
def test_s3_invalid_storage_class(): def test_s3_invalid_storage_class():
s3 = boto3.client("s3") s3 = boto3.client("s3")
s3.create_bucket(Bucket="Bucket") s3.create_bucket(
Bucket="Bucket", CreateBucketConfiguration={"LocationConstraint": "us-west-1"}
)
# Try to add an object with an invalid storage class # Try to add an object with an invalid storage class
with assert_raises(ClientError) as err: with assert_raises(ClientError) as err:
@ -137,7 +143,9 @@ def test_s3_invalid_storage_class():
@mock_s3 @mock_s3
def test_s3_default_storage_class(): def test_s3_default_storage_class():
s3 = boto3.client("s3") s3 = boto3.client("s3")
s3.create_bucket(Bucket="Bucket") s3.create_bucket(
Bucket="Bucket", CreateBucketConfiguration={"LocationConstraint": "us-west-1"}
)
s3.put_object(Bucket="Bucket", Key="First_Object", Body="Body") s3.put_object(Bucket="Bucket", Key="First_Object", Body="Body")
@ -150,7 +158,9 @@ def test_s3_default_storage_class():
@mock_s3 @mock_s3
def test_s3_copy_object_error_for_glacier_storage_class(): def test_s3_copy_object_error_for_glacier_storage_class():
s3 = boto3.client("s3") s3 = boto3.client("s3")
s3.create_bucket(Bucket="Bucket") s3.create_bucket(
Bucket="Bucket", CreateBucketConfiguration={"LocationConstraint": "us-west-1"}
)
s3.put_object( s3.put_object(
Bucket="Bucket", Key="First_Object", Body="Body", StorageClass="GLACIER" Bucket="Bucket", Key="First_Object", Body="Body", StorageClass="GLACIER"
@ -169,7 +179,9 @@ def test_s3_copy_object_error_for_glacier_storage_class():
@mock_s3 @mock_s3
def test_s3_copy_object_error_for_deep_archive_storage_class(): def test_s3_copy_object_error_for_deep_archive_storage_class():
s3 = boto3.client("s3") s3 = boto3.client("s3")
s3.create_bucket(Bucket="Bucket") s3.create_bucket(
Bucket="Bucket", CreateBucketConfiguration={"LocationConstraint": "us-west-1"}
)
s3.put_object( s3.put_object(
Bucket="Bucket", Key="First_Object", Body="Body", StorageClass="DEEP_ARCHIVE" Bucket="Bucket", Key="First_Object", Body="Body", StorageClass="DEEP_ARCHIVE"

View File

@ -174,7 +174,7 @@ def test_bucket_deletion():
# Get non-existing bucket # Get non-existing bucket
conn.get_bucket.when.called_with("foobar").should.throw(S3ResponseError) conn.get_bucket.when.called_with("foobar").should.throw(S3ResponseError)
# Delete non-existant bucket # Delete non-existent bucket
conn.delete_bucket.when.called_with("foobar").should.throw(S3ResponseError) conn.delete_bucket.when.called_with("foobar").should.throw(S3ResponseError)

View File

@ -40,6 +40,8 @@ def __setup_feedback_env__(
) )
# Verify SES domain # Verify SES domain
ses_conn.verify_domain_identity(Domain=domain) ses_conn.verify_domain_identity(Domain=domain)
# Specify email address to allow for raw e-mails to be processed
ses_conn.verify_email_identity(EmailAddress="test@example.com")
# Setup SES notification topic # Setup SES notification topic
if expected_msg is not None: if expected_msg is not None:
ses_conn.set_identity_notification_topic( ses_conn.set_identity_notification_topic(
@ -47,7 +49,7 @@ def __setup_feedback_env__(
) )
def __test_sns_feedback__(addr, expected_msg): def __test_sns_feedback__(addr, expected_msg, raw_email=False):
region_name = "us-east-1" region_name = "us-east-1"
ses_conn = boto3.client("ses", region_name=region_name) ses_conn = boto3.client("ses", region_name=region_name)
sns_conn = boto3.client("sns", region_name=region_name) sns_conn = boto3.client("sns", region_name=region_name)
@ -73,6 +75,17 @@ def __test_sns_feedback__(addr, expected_msg):
"Body": {"Text": {"Data": "test body"}}, "Body": {"Text": {"Data": "test body"}},
}, },
) )
if raw_email:
kwargs.pop("Message")
kwargs.pop("Destination")
kwargs.update(
{
"Destinations": [addr + "@" + domain],
"RawMessage": {"Data": bytearray("raw_email", "utf-8")},
}
)
ses_conn.send_raw_email(**kwargs)
else:
ses_conn.send_email(**kwargs) ses_conn.send_email(**kwargs)
# Wait for messages in the queues # Wait for messages in the queues
@ -112,3 +125,12 @@ def test_sns_feedback_complaint():
@mock_ses @mock_ses
def test_sns_feedback_delivery(): def test_sns_feedback_delivery():
__test_sns_feedback__(SESFeedback.SUCCESS_ADDR, SESFeedback.DELIVERY) __test_sns_feedback__(SESFeedback.SUCCESS_ADDR, SESFeedback.DELIVERY)
@mock_sqs
@mock_sns
@mock_ses
def test_sns_feedback_delivery_raw_email():
__test_sns_feedback__(
SESFeedback.SUCCESS_ADDR, SESFeedback.DELIVERY, raw_email=True
)

View File

@ -88,8 +88,8 @@ def test_list_platform_applications():
conn.create_platform_application(name="application1", platform="APNS") conn.create_platform_application(name="application1", platform="APNS")
conn.create_platform_application(name="application2", platform="APNS") conn.create_platform_application(name="application2", platform="APNS")
applications_repsonse = conn.list_platform_applications() applications_response = conn.list_platform_applications()
applications = applications_repsonse["ListPlatformApplicationsResponse"][ applications = applications_response["ListPlatformApplicationsResponse"][
"ListPlatformApplicationsResult" "ListPlatformApplicationsResult"
]["PlatformApplications"] ]["PlatformApplications"]
applications.should.have.length_of(2) applications.should.have.length_of(2)
@ -101,8 +101,8 @@ def test_delete_platform_application():
conn.create_platform_application(name="application1", platform="APNS") conn.create_platform_application(name="application1", platform="APNS")
conn.create_platform_application(name="application2", platform="APNS") conn.create_platform_application(name="application2", platform="APNS")
applications_repsonse = conn.list_platform_applications() applications_response = conn.list_platform_applications()
applications = applications_repsonse["ListPlatformApplicationsResponse"][ applications = applications_response["ListPlatformApplicationsResponse"][
"ListPlatformApplicationsResult" "ListPlatformApplicationsResult"
]["PlatformApplications"] ]["PlatformApplications"]
applications.should.have.length_of(2) applications.should.have.length_of(2)
@ -110,8 +110,8 @@ def test_delete_platform_application():
application_arn = applications[0]["PlatformApplicationArn"] application_arn = applications[0]["PlatformApplicationArn"]
conn.delete_platform_application(application_arn) conn.delete_platform_application(application_arn)
applications_repsonse = conn.list_platform_applications() applications_response = conn.list_platform_applications()
applications = applications_repsonse["ListPlatformApplicationsResponse"][ applications = applications_response["ListPlatformApplicationsResponse"][
"ListPlatformApplicationsResult" "ListPlatformApplicationsResult"
]["PlatformApplications"] ]["PlatformApplications"]
applications.should.have.length_of(1) applications.should.have.length_of(1)

View File

@ -88,8 +88,8 @@ def test_list_platform_applications():
Name="application2", Platform="APNS", Attributes={} Name="application2", Platform="APNS", Attributes={}
) )
applications_repsonse = conn.list_platform_applications() applications_response = conn.list_platform_applications()
applications = applications_repsonse["PlatformApplications"] applications = applications_response["PlatformApplications"]
applications.should.have.length_of(2) applications.should.have.length_of(2)
@ -103,15 +103,15 @@ def test_delete_platform_application():
Name="application2", Platform="APNS", Attributes={} Name="application2", Platform="APNS", Attributes={}
) )
applications_repsonse = conn.list_platform_applications() applications_response = conn.list_platform_applications()
applications = applications_repsonse["PlatformApplications"] applications = applications_response["PlatformApplications"]
applications.should.have.length_of(2) applications.should.have.length_of(2)
application_arn = applications[0]["PlatformApplicationArn"] application_arn = applications[0]["PlatformApplicationArn"]
conn.delete_platform_application(PlatformApplicationArn=application_arn) conn.delete_platform_application(PlatformApplicationArn=application_arn)
applications_repsonse = conn.list_platform_applications() applications_response = conn.list_platform_applications()
applications = applications_repsonse["PlatformApplications"] applications = applications_response["PlatformApplications"]
applications.should.have.length_of(1) applications.should.have.length_of(1)

View File

@ -806,7 +806,7 @@ def test_filtering_string_array_with_string_no_array_no_match():
topic.publish( topic.publish(
Message="no_match", Message="no_match",
MessageAttributes={ MessageAttributes={
"price": {"DataType": "String.Array", "StringValue": "one hundread"} "price": {"DataType": "String.Array", "StringValue": "one hundred"}
}, },
) )

View File

@ -132,6 +132,35 @@ def test_create_queue_with_tags():
) )
@mock_sqs
def test_create_queue_with_policy():
client = boto3.client("sqs", region_name="us-east-1")
response = client.create_queue(
QueueName="test-queue",
Attributes={
"Policy": json.dumps(
{
"Version": "2012-10-17",
"Id": "test",
"Statement": [{"Effect": "Allow", "Principal": "*", "Action": "*"}],
}
)
},
)
queue_url = response["QueueUrl"]
response = client.get_queue_attributes(
QueueUrl=queue_url, AttributeNames=["Policy"]
)
json.loads(response["Attributes"]["Policy"]).should.equal(
{
"Version": "2012-10-17",
"Id": "test",
"Statement": [{"Effect": "Allow", "Principal": "*", "Action": "*"}],
}
)
@mock_sqs @mock_sqs
def test_get_queue_url(): def test_get_queue_url():
client = boto3.client("sqs", region_name="us-east-1") client = boto3.client("sqs", region_name="us-east-1")
@ -331,7 +360,20 @@ def test_delete_queue():
@mock_sqs @mock_sqs
def test_get_queue_attributes(): def test_get_queue_attributes():
client = boto3.client("sqs", region_name="us-east-1") client = boto3.client("sqs", region_name="us-east-1")
response = client.create_queue(QueueName="test-queue")
dlq_resp = client.create_queue(QueueName="test-dlr-queue")
dlq_arn1 = client.get_queue_attributes(QueueUrl=dlq_resp["QueueUrl"])["Attributes"][
"QueueArn"
]
response = client.create_queue(
QueueName="test-queue",
Attributes={
"RedrivePolicy": json.dumps(
{"deadLetterTargetArn": dlq_arn1, "maxReceiveCount": 2}
),
},
)
queue_url = response["QueueUrl"] queue_url = response["QueueUrl"]
response = client.get_queue_attributes(QueueUrl=queue_url) response = client.get_queue_attributes(QueueUrl=queue_url)
@ -356,6 +398,7 @@ def test_get_queue_attributes():
"ApproximateNumberOfMessages", "ApproximateNumberOfMessages",
"MaximumMessageSize", "MaximumMessageSize",
"QueueArn", "QueueArn",
"RedrivePolicy",
"VisibilityTimeout", "VisibilityTimeout",
], ],
) )
@ -366,6 +409,9 @@ def test_get_queue_attributes():
"MaximumMessageSize": "65536", "MaximumMessageSize": "65536",
"QueueArn": "arn:aws:sqs:us-east-1:{}:test-queue".format(ACCOUNT_ID), "QueueArn": "arn:aws:sqs:us-east-1:{}:test-queue".format(ACCOUNT_ID),
"VisibilityTimeout": "30", "VisibilityTimeout": "30",
"RedrivePolicy": json.dumps(
{"deadLetterTargetArn": dlq_arn1, "maxReceiveCount": 2}
),
} }
) )
@ -1169,18 +1215,169 @@ def test_permissions():
Actions=["SendMessage"], Actions=["SendMessage"],
) )
with assert_raises(ClientError): response = client.get_queue_attributes(
client.add_permission( QueueUrl=queue_url, AttributeNames=["Policy"]
QueueUrl=queue_url, )
Label="account2", policy = json.loads(response["Attributes"]["Policy"])
AWSAccountIds=["222211111111"], policy["Version"].should.equal("2012-10-17")
Actions=["SomeRubbish"], policy["Id"].should.equal(
"arn:aws:sqs:us-east-1:123456789012:test-dlr-queue.fifo/SQSDefaultPolicy"
)
sorted(policy["Statement"], key=lambda x: x["Sid"]).should.equal(
[
{
"Sid": "account1",
"Effect": "Allow",
"Principal": {"AWS": "arn:aws:iam::111111111111:root"},
"Action": "SQS:*",
"Resource": "arn:aws:sqs:us-east-1:123456789012:test-dlr-queue.fifo",
},
{
"Sid": "account2",
"Effect": "Allow",
"Principal": {"AWS": "arn:aws:iam::222211111111:root"},
"Action": "SQS:SendMessage",
"Resource": "arn:aws:sqs:us-east-1:123456789012:test-dlr-queue.fifo",
},
]
) )
client.remove_permission(QueueUrl=queue_url, Label="account2") client.remove_permission(QueueUrl=queue_url, Label="account2")
with assert_raises(ClientError): response = client.get_queue_attributes(
client.remove_permission(QueueUrl=queue_url, Label="non_existant") QueueUrl=queue_url, AttributeNames=["Policy"]
)
json.loads(response["Attributes"]["Policy"]).should.equal(
{
"Version": "2012-10-17",
"Id": "arn:aws:sqs:us-east-1:123456789012:test-dlr-queue.fifo/SQSDefaultPolicy",
"Statement": [
{
"Sid": "account1",
"Effect": "Allow",
"Principal": {"AWS": "arn:aws:iam::111111111111:root"},
"Action": "SQS:*",
"Resource": "arn:aws:sqs:us-east-1:123456789012:test-dlr-queue.fifo",
},
],
}
)
@mock_sqs
def test_add_permission_errors():
client = boto3.client("sqs", region_name="us-east-1")
response = client.create_queue(QueueName="test-queue")
queue_url = response["QueueUrl"]
client.add_permission(
QueueUrl=queue_url,
Label="test",
AWSAccountIds=["111111111111"],
Actions=["ReceiveMessage"],
)
with assert_raises(ClientError) as e:
client.add_permission(
QueueUrl=queue_url,
Label="test",
AWSAccountIds=["111111111111"],
Actions=["ReceiveMessage", "SendMessage"],
)
ex = e.exception
ex.operation_name.should.equal("AddPermission")
ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400)
ex.response["Error"]["Code"].should.contain("InvalidParameterValue")
ex.response["Error"]["Message"].should.equal(
"Value test for parameter Label is invalid. " "Reason: Already exists."
)
with assert_raises(ClientError) as e:
client.add_permission(
QueueUrl=queue_url,
Label="test-2",
AWSAccountIds=["111111111111"],
Actions=["RemovePermission"],
)
ex = e.exception
ex.operation_name.should.equal("AddPermission")
ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400)
ex.response["Error"]["Code"].should.contain("InvalidParameterValue")
ex.response["Error"]["Message"].should.equal(
"Value SQS:RemovePermission for parameter ActionName is invalid. "
"Reason: Only the queue owner is allowed to invoke this action."
)
with assert_raises(ClientError) as e:
client.add_permission(
QueueUrl=queue_url,
Label="test-2",
AWSAccountIds=["111111111111"],
Actions=[],
)
ex = e.exception
ex.operation_name.should.equal("AddPermission")
ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400)
ex.response["Error"]["Code"].should.contain("MissingParameter")
ex.response["Error"]["Message"].should.equal(
"The request must contain the parameter Actions."
)
with assert_raises(ClientError) as e:
client.add_permission(
QueueUrl=queue_url,
Label="test-2",
AWSAccountIds=[],
Actions=["ReceiveMessage"],
)
ex = e.exception
ex.operation_name.should.equal("AddPermission")
ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400)
ex.response["Error"]["Code"].should.contain("InvalidParameterValue")
ex.response["Error"]["Message"].should.equal(
"Value [] for parameter PrincipalId is invalid. Reason: Unable to verify."
)
with assert_raises(ClientError) as e:
client.add_permission(
QueueUrl=queue_url,
Label="test-2",
AWSAccountIds=["111111111111"],
Actions=[
"ChangeMessageVisibility",
"DeleteMessage",
"GetQueueAttributes",
"GetQueueUrl",
"ListDeadLetterSourceQueues",
"PurgeQueue",
"ReceiveMessage",
"SendMessage",
],
)
ex = e.exception
ex.operation_name.should.equal("AddPermission")
ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(403)
ex.response["Error"]["Code"].should.contain("OverLimit")
ex.response["Error"]["Message"].should.equal(
"8 Actions were found, maximum allowed is 7."
)
@mock_sqs
def test_remove_permission_errors():
client = boto3.client("sqs", region_name="us-east-1")
response = client.create_queue(QueueName="test-queue")
queue_url = response["QueueUrl"]
with assert_raises(ClientError) as e:
client.remove_permission(QueueUrl=queue_url, Label="test")
ex = e.exception
ex.operation_name.should.equal("RemovePermission")
ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400)
ex.response["Error"]["Code"].should.contain("InvalidParameterValue")
ex.response["Error"]["Message"].should.equal(
"Value test for parameter Label is invalid. "
"Reason: can't find label on existing policy."
)
@mock_sqs @mock_sqs
@ -1562,3 +1759,23 @@ def test_receive_message_for_queue_with_receive_message_wait_time_seconds_set():
) )
queue.receive_messages() queue.receive_messages()
@mock_sqs
def test_list_queues_limits_to_1000_queues():
client = boto3.client("sqs", region_name="us-east-1")
for i in range(1001):
client.create_queue(QueueName="test-queue-{0}".format(i))
client.list_queues()["QueueUrls"].should.have.length_of(1000)
client.list_queues(QueueNamePrefix="test-queue")["QueueUrls"].should.have.length_of(
1000
)
resource = boto3.resource("sqs", region_name="us-east-1")
list(resource.queues.all()).should.have.length_of(1000)
list(resource.queues.filter(QueueNamePrefix="test-queue")).should.have.length_of(
1000
)

View File

@ -148,6 +148,39 @@ def test_workflow_execution_full_dict_representation():
) )
def test_closed_workflow_execution_full_dict_representation():
domain = get_basic_domain()
wf_type = WorkflowType(
"test-workflow",
"v1.0",
task_list="queue",
default_child_policy="ABANDON",
default_execution_start_to_close_timeout="300",
default_task_start_to_close_timeout="300",
)
wfe = WorkflowExecution(domain, wf_type, "ab1234")
wfe.execution_status = "CLOSED"
wfe.close_status = "CANCELED"
wfe.close_timestamp = 1420066801.123
fd = wfe.to_full_dict()
medium_dict = wfe.to_medium_dict()
medium_dict["closeStatus"] = "CANCELED"
medium_dict["closeTimestamp"] = 1420066801.123
fd["executionInfo"].should.equal(medium_dict)
fd["openCounts"]["openTimers"].should.equal(0)
fd["openCounts"]["openDecisionTasks"].should.equal(0)
fd["openCounts"]["openActivityTasks"].should.equal(0)
fd["executionConfiguration"].should.equal(
{
"childPolicy": "ABANDON",
"executionStartToCloseTimeout": "300",
"taskList": {"name": "queue"},
"taskStartToCloseTimeout": "300",
}
)
def test_workflow_execution_list_dict_representation(): def test_workflow_execution_list_dict_representation():
domain = get_basic_domain() domain = get_basic_domain()
wf_type = WorkflowType( wf_type = WorkflowType(

View File

@ -5,7 +5,7 @@ from moto.swf.models import ActivityType, Domain, WorkflowType, WorkflowExecutio
# Some useful constants # Some useful constants
# Here are some activity timeouts we use in moto/swf tests ; they're extracted # Here are some activity timeouts we use in moto/swf tests ; they're extracted
# from semi-real world example, the goal is mostly to have predictible and # from semi-real world example, the goal is mostly to have predictable and
# intuitive behaviour in moto/swf own tests... # intuitive behaviour in moto/swf own tests...
ACTIVITY_TASK_TIMEOUTS = { ACTIVITY_TASK_TIMEOUTS = {
"heartbeatTimeout": "300", # 5 mins "heartbeatTimeout": "300", # 5 mins

View File

@ -1,53 +1,79 @@
import unittest import sure
from moto.utilities.tagging_service import TaggingService from moto.utilities.tagging_service import TaggingService
class TestTaggingService(unittest.TestCase): def test_list_empty():
def test_list_empty(self):
svc = TaggingService() svc = TaggingService()
result = svc.list_tags_for_resource('test') result = svc.list_tags_for_resource("test")
self.assertEqual(result, {'Tags': []})
def test_create_tag(self): {"Tags": []}.should.be.equal(result)
svc = TaggingService('TheTags', 'TagKey', 'TagValue')
tags = [{'TagKey': 'key_key', 'TagValue': 'value_value'}]
svc.tag_resource('arn', tags)
actual = svc.list_tags_for_resource('arn')
expected = {'TheTags': [{'TagKey': 'key_key', 'TagValue': 'value_value'}]}
self.assertDictEqual(expected, actual)
def test_create_tag_without_value(self):
def test_create_tag():
svc = TaggingService("TheTags", "TagKey", "TagValue")
tags = [{"TagKey": "key_key", "TagValue": "value_value"}]
svc.tag_resource("arn", tags)
actual = svc.list_tags_for_resource("arn")
expected = {"TheTags": [{"TagKey": "key_key", "TagValue": "value_value"}]}
expected.should.be.equal(actual)
def test_create_tag_without_value():
svc = TaggingService() svc = TaggingService()
tags = [{'Key': 'key_key'}] tags = [{"Key": "key_key"}]
svc.tag_resource('arn', tags) svc.tag_resource("arn", tags)
actual = svc.list_tags_for_resource('arn') actual = svc.list_tags_for_resource("arn")
expected = {'Tags': [{'Key': 'key_key', 'Value': ''}]} expected = {"Tags": [{"Key": "key_key", "Value": None}]}
self.assertDictEqual(expected, actual)
def test_delete_tag(self): expected.should.be.equal(actual)
svc = TaggingService()
tags = [{'Key': 'key_key', 'Value': 'value_value'}]
svc.tag_resource('arn', tags)
svc.untag_resource('arn', ['key_key'])
result = svc.list_tags_for_resource('arn')
self.assertEqual(
result, {'Tags': []})
def test_list_empty_delete(self):
svc = TaggingService()
svc.untag_resource('arn', ['key_key'])
result = svc.list_tags_for_resource('arn')
self.assertEqual(
result, {'Tags': []})
def test_extract_tag_names(self): def test_delete_tag_using_names():
svc = TaggingService() svc = TaggingService()
tags = [{'Key': 'key1', 'Value': 'value1'}, {'Key': 'key2', 'Value': 'value2'}] tags = [{"Key": "key_key", "Value": "value_value"}]
svc.tag_resource("arn", tags)
svc.untag_resource_using_names("arn", ["key_key"])
result = svc.list_tags_for_resource("arn")
{"Tags": []}.should.be.equal(result)
def test_delete_all_tags_for_resource():
svc = TaggingService()
tags = [{"Key": "key_key", "Value": "value_value"}]
tags2 = [{"Key": "key_key2", "Value": "value_value2"}]
svc.tag_resource("arn", tags)
svc.tag_resource("arn", tags2)
svc.delete_all_tags_for_resource("arn")
result = svc.list_tags_for_resource("arn")
{"Tags": []}.should.be.equal(result)
def test_list_empty_delete():
svc = TaggingService()
svc.untag_resource_using_names("arn", ["key_key"])
result = svc.list_tags_for_resource("arn")
{"Tags": []}.should.be.equal(result)
def test_delete_tag_using_tags():
svc = TaggingService()
tags = [{"Key": "key_key", "Value": "value_value"}]
svc.tag_resource("arn", tags)
svc.untag_resource_using_tags("arn", tags)
result = svc.list_tags_for_resource("arn")
{"Tags": []}.should.be.equal(result)
def test_extract_tag_names():
svc = TaggingService()
tags = [{"Key": "key1", "Value": "value1"}, {"Key": "key2", "Value": "value2"}]
actual = svc.extract_tag_names(tags) actual = svc.extract_tag_names(tags)
expected = ['key1', 'key2'] expected = ["key1", "key2"]
self.assertEqual(expected, actual)
expected.should.be.equal(actual)
if __name__ == '__main__':
unittest.main()