merging from master

This commit is contained in:
Bryan Alexander 2020-02-18 10:47:05 -06:00
commit 445f474534
125 changed files with 7407 additions and 3848 deletions

2
.gitignore vendored
View File

@ -20,3 +20,5 @@ env/
.vscode/ .vscode/
tests/file.tmp tests/file.tmp
.eggs/ .eggs/
.mypy_cache/
*.tmp

View File

@ -26,11 +26,12 @@ install:
fi fi
docker run --rm -t --name motoserver -e TEST_SERVER_MODE=true -e AWS_SECRET_ACCESS_KEY=server_secret -e AWS_ACCESS_KEY_ID=server_key -v `pwd`:/moto -p 5000:5000 -v /var/run/docker.sock:/var/run/docker.sock python:${PYTHON_DOCKER_TAG} /moto/travis_moto_server.sh & docker run --rm -t --name motoserver -e TEST_SERVER_MODE=true -e AWS_SECRET_ACCESS_KEY=server_secret -e AWS_ACCESS_KEY_ID=server_key -v `pwd`:/moto -p 5000:5000 -v /var/run/docker.sock:/var/run/docker.sock python:${PYTHON_DOCKER_TAG} /moto/travis_moto_server.sh &
fi fi
travis_retry pip install -r requirements-dev.txt
travis_retry pip install boto==2.45.0 travis_retry pip install boto==2.45.0
travis_retry pip install boto3 travis_retry pip install boto3
travis_retry pip install dist/moto*.gz travis_retry pip install dist/moto*.gz
travis_retry pip install coveralls==1.1 travis_retry pip install coveralls==1.1
travis_retry pip install -r requirements-dev.txt travis_retry pip install coverage==4.5.4
if [ "$TEST_SERVER_MODE" = "true" ]; then if [ "$TEST_SERVER_MODE" = "true" ]; then
python wait_for.py python wait_for.py

View File

@ -283,14 +283,14 @@ def test_describe_instances_allowed():
] ]
} }
access_key = ... access_key = ...
# create access key for an IAM user/assumed role that has the policy above. # create access key for an IAM user/assumed role that has the policy above.
# this part should call __exactly__ 4 AWS actions, so that authentication and authorization starts exactly after this # this part should call __exactly__ 4 AWS actions, so that authentication and authorization starts exactly after this
client = boto3.client('ec2', region_name='us-east-1', client = boto3.client('ec2', region_name='us-east-1',
aws_access_key_id=access_key['AccessKeyId'], aws_access_key_id=access_key['AccessKeyId'],
aws_secret_access_key=access_key['SecretAccessKey']) aws_secret_access_key=access_key['SecretAccessKey'])
# if the IAM principal whose access key is used, does not have the permission to describe instances, this will fail # if the IAM principal whose access key is used, does not have the permission to describe instances, this will fail
instances = client.describe_instances()['Reservations'][0]['Instances'] instances = client.describe_instances()['Reservations'][0]['Instances']
assert len(instances) == 0 assert len(instances) == 0
``` ```
@ -310,16 +310,16 @@ You need to ensure that the mocks are actually in place. Changes made to recent
have altered some of the mock behavior. In short, you need to ensure that you _always_ do the following: have altered some of the mock behavior. In short, you need to ensure that you _always_ do the following:
1. Ensure that your tests have dummy environment variables set up: 1. Ensure that your tests have dummy environment variables set up:
export AWS_ACCESS_KEY_ID='testing' export AWS_ACCESS_KEY_ID='testing'
export AWS_SECRET_ACCESS_KEY='testing' export AWS_SECRET_ACCESS_KEY='testing'
export AWS_SECURITY_TOKEN='testing' export AWS_SECURITY_TOKEN='testing'
export AWS_SESSION_TOKEN='testing' export AWS_SESSION_TOKEN='testing'
1. __VERY IMPORTANT__: ensure that you have your mocks set up __BEFORE__ your `boto3` client is established. 1. __VERY IMPORTANT__: ensure that you have your mocks set up __BEFORE__ your `boto3` client is established.
This can typically happen if you import a module that has a `boto3` client instantiated outside of a function. This can typically happen if you import a module that has a `boto3` client instantiated outside of a function.
See the pesky imports section below on how to work around this. See the pesky imports section below on how to work around this.
### Example on usage? ### Example on usage?
If you are a user of [pytest](https://pytest.org/en/latest/), you can leverage [pytest fixtures](https://pytest.org/en/latest/fixture.html#fixture) If you are a user of [pytest](https://pytest.org/en/latest/), you can leverage [pytest fixtures](https://pytest.org/en/latest/fixture.html#fixture)
to help set up your mocks and other AWS resources that you would need. to help set up your mocks and other AWS resources that you would need.
@ -354,7 +354,7 @@ def cloudwatch(aws_credentials):
... etc. ... etc.
``` ```
In the code sample above, all of the AWS/mocked fixtures take in a parameter of `aws_credentials`, In the code sample above, all of the AWS/mocked fixtures take in a parameter of `aws_credentials`,
which sets the proper fake environment variables. The fake environment variables are used so that `botocore` doesn't try to locate real which sets the proper fake environment variables. The fake environment variables are used so that `botocore` doesn't try to locate real
credentials on your system. credentials on your system.
@ -364,7 +364,7 @@ def test_create_bucket(s3):
# s3 is a fixture defined above that yields a boto3 s3 client. # s3 is a fixture defined above that yields a boto3 s3 client.
# Feel free to instantiate another boto3 S3 client -- Keep note of the region though. # Feel free to instantiate another boto3 S3 client -- Keep note of the region though.
s3.create_bucket(Bucket="somebucket") s3.create_bucket(Bucket="somebucket")
result = s3.list_buckets() result = s3.list_buckets()
assert len(result['Buckets']) == 1 assert len(result['Buckets']) == 1
assert result['Buckets'][0]['Name'] == 'somebucket' assert result['Buckets'][0]['Name'] == 'somebucket'
@ -373,7 +373,7 @@ def test_create_bucket(s3):
### What about those pesky imports? ### What about those pesky imports?
Recall earlier, it was mentioned that mocks should be established __BEFORE__ the clients are set up. One way Recall earlier, it was mentioned that mocks should be established __BEFORE__ the clients are set up. One way
to avoid import issues is to make use of local Python imports -- i.e. import the module inside of the unit to avoid import issues is to make use of local Python imports -- i.e. import the module inside of the unit
test you want to run vs. importing at the top of the file. test you want to run vs. importing at the top of the file.
Example: Example:
```python ```python
@ -381,12 +381,12 @@ def test_something(s3):
from some.package.that.does.something.with.s3 import some_func # <-- Local import for unit test from some.package.that.does.something.with.s3 import some_func # <-- Local import for unit test
# ^^ Importing here ensures that the mock has been established. # ^^ Importing here ensures that the mock has been established.
sume_func() # The mock has been established from the "s3" pytest fixture, so this function that uses some_func() # The mock has been established from the "s3" pytest fixture, so this function that uses
# a package-level S3 client will properly use the mock and not reach out to AWS. # a package-level S3 client will properly use the mock and not reach out to AWS.
``` ```
### Other caveats ### Other caveats
For Tox, Travis CI, and other build systems, you might need to also perform a `touch ~/.aws/credentials` For Tox, Travis CI, and other build systems, you might need to also perform a `touch ~/.aws/credentials`
command before running the tests. As long as that file is present (empty preferably) and the environment command before running the tests. As long as that file is present (empty preferably) and the environment
variables above are set, you should be good to go. variables above are set, you should be good to go.
@ -450,6 +450,16 @@ boto3.resource(
) )
``` ```
### Caveats
The standalone server has some caveats with some services. The following services
require that you update your hosts file for your code to work properly:
1. `s3-control`
For the above services, this is required because the hostname is in the form of `AWS_ACCOUNT_ID.localhost`.
As a result, you need to add that entry to your host file for your tests to function properly.
## Install ## Install

View File

@ -56,9 +56,10 @@ author = 'Steve Pulec'
# built documents. # built documents.
# #
# The short X.Y version. # The short X.Y version.
version = '0.4.10' import moto
version = moto.__version__
# The full version, including alpha/beta/rc tags. # The full version, including alpha/beta/rc tags.
release = '0.4.10' release = moto.__version__
# The language for content autogenerated by Sphinx. Refer to documentation # The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages. # for a list of supported languages.

View File

@ -24,8 +24,7 @@ For example, we have the following code we want to test:
.. sourcecode:: python .. sourcecode:: python
import boto import boto3
from boto.s3.key import Key
class MyModel(object): class MyModel(object):
def __init__(self, name, value): def __init__(self, name, value):
@ -33,11 +32,8 @@ For example, we have the following code we want to test:
self.value = value self.value = value
def save(self): def save(self):
conn = boto.connect_s3() s3 = boto3.client('s3', region_name='us-east-1')
bucket = conn.get_bucket('mybucket') s3.put_object(Bucket='mybucket', Key=self.name, Body=self.value)
k = Key(bucket)
k.key = self.name
k.set_contents_from_string(self.value)
There are several ways to do this, but you should keep in mind that Moto creates a full, blank environment. There are several ways to do this, but you should keep in mind that Moto creates a full, blank environment.
@ -48,20 +44,23 @@ With a decorator wrapping, all the calls to S3 are automatically mocked out.
.. sourcecode:: python .. sourcecode:: python
import boto import boto3
from moto import mock_s3 from moto import mock_s3
from mymodule import MyModel from mymodule import MyModel
@mock_s3 @mock_s3
def test_my_model_save(): def test_my_model_save():
conn = boto.connect_s3() conn = boto3.resource('s3', region_name='us-east-1')
# We need to create the bucket since this is all in Moto's 'virtual' AWS account # We need to create the bucket since this is all in Moto's 'virtual' AWS account
conn.create_bucket('mybucket') conn.create_bucket(Bucket='mybucket')
model_instance = MyModel('steve', 'is awesome') model_instance = MyModel('steve', 'is awesome')
model_instance.save() model_instance.save()
assert conn.get_bucket('mybucket').get_key('steve').get_contents_as_string() == 'is awesome' body = conn.Object('mybucket', 'steve').get()[
'Body'].read().decode("utf-8")
assert body == 'is awesome'
Context manager Context manager
~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~
@ -72,13 +71,16 @@ Same as the Decorator, every call inside the ``with`` statement is mocked out.
def test_my_model_save(): def test_my_model_save():
with mock_s3(): with mock_s3():
conn = boto.connect_s3() conn = boto3.resource('s3', region_name='us-east-1')
conn.create_bucket('mybucket') conn.create_bucket(Bucket='mybucket')
model_instance = MyModel('steve', 'is awesome') model_instance = MyModel('steve', 'is awesome')
model_instance.save() model_instance.save()
assert conn.get_bucket('mybucket').get_key('steve').get_contents_as_string() == 'is awesome' body = conn.Object('mybucket', 'steve').get()[
'Body'].read().decode("utf-8")
assert body == 'is awesome'
Raw Raw
~~~ ~~~
@ -91,13 +93,16 @@ You can also start and stop the mocking manually.
mock = mock_s3() mock = mock_s3()
mock.start() mock.start()
conn = boto.connect_s3() conn = boto3.resource('s3', region_name='us-east-1')
conn.create_bucket('mybucket') conn.create_bucket(Bucket='mybucket')
model_instance = MyModel('steve', 'is awesome') model_instance = MyModel('steve', 'is awesome')
model_instance.save() model_instance.save()
assert conn.get_bucket('mybucket').get_key('steve').get_contents_as_string() == 'is awesome' body = conn.Object('mybucket', 'steve').get()[
'Body'].read().decode("utf-8")
assert body == 'is awesome'
mock.stop() mock.stop()

View File

@ -76,7 +76,7 @@ Currently implemented Services:
+---------------------------+-----------------------+------------------------------------+ +---------------------------+-----------------------+------------------------------------+
| Logs | @mock_logs | basic endpoints done | | Logs | @mock_logs | basic endpoints done |
+---------------------------+-----------------------+------------------------------------+ +---------------------------+-----------------------+------------------------------------+
| Organizations | @mock_organizations | some core edpoints done | | Organizations | @mock_organizations | some core endpoints done |
+---------------------------+-----------------------+------------------------------------+ +---------------------------+-----------------------+------------------------------------+
| Polly | @mock_polly | all endpoints done | | Polly | @mock_polly | all endpoints done |
+---------------------------+-----------------------+------------------------------------+ +---------------------------+-----------------------+------------------------------------+

View File

@ -39,7 +39,7 @@ class InvalidResourcePathException(BadRequestException):
def __init__(self): def __init__(self):
super(InvalidResourcePathException, self).__init__( super(InvalidResourcePathException, self).__init__(
"BadRequestException", "BadRequestException",
"Resource's path part only allow a-zA-Z0-9._- and curly braces at the beginning and the end.", "Resource's path part only allow a-zA-Z0-9._- and curly braces at the beginning and the end and an optional plus sign before the closing brace.",
) )

View File

@ -83,14 +83,14 @@ class MethodResponse(BaseModel, dict):
class Method(BaseModel, dict): class Method(BaseModel, dict):
def __init__(self, method_type, authorization_type): def __init__(self, method_type, authorization_type, **kwargs):
super(Method, self).__init__() super(Method, self).__init__()
self.update( self.update(
dict( dict(
httpMethod=method_type, httpMethod=method_type,
authorizationType=authorization_type, authorizationType=authorization_type,
authorizerId=None, authorizerId=None,
apiKeyRequired=None, apiKeyRequired=kwargs.get("api_key_required") or False,
requestParameters=None, requestParameters=None,
requestModels=None, requestModels=None,
methodIntegration=None, methodIntegration=None,
@ -117,14 +117,15 @@ class Resource(BaseModel):
self.api_id = api_id self.api_id = api_id
self.path_part = path_part self.path_part = path_part
self.parent_id = parent_id self.parent_id = parent_id
self.resource_methods = {"GET": {}} self.resource_methods = {}
def to_dict(self): def to_dict(self):
response = { response = {
"path": self.get_path(), "path": self.get_path(),
"id": self.id, "id": self.id,
"resourceMethods": self.resource_methods,
} }
if self.resource_methods:
response["resourceMethods"] = self.resource_methods
if self.parent_id: if self.parent_id:
response["parentId"] = self.parent_id response["parentId"] = self.parent_id
response["pathPart"] = self.path_part response["pathPart"] = self.path_part
@ -158,8 +159,12 @@ class Resource(BaseModel):
) )
return response.status_code, response.text return response.status_code, response.text
def add_method(self, method_type, authorization_type): def add_method(self, method_type, authorization_type, api_key_required):
method = Method(method_type=method_type, authorization_type=authorization_type) method = Method(
method_type=method_type,
authorization_type=authorization_type,
api_key_required=api_key_required,
)
self.resource_methods[method_type] = method self.resource_methods[method_type] = method
return method return method
@ -394,12 +399,17 @@ class UsagePlanKey(BaseModel, dict):
class RestAPI(BaseModel): class RestAPI(BaseModel):
def __init__(self, id, region_name, name, description): def __init__(self, id, region_name, name, description, **kwargs):
self.id = id self.id = id
self.region_name = region_name self.region_name = region_name
self.name = name self.name = name
self.description = description self.description = description
self.create_date = int(time.time()) self.create_date = int(time.time())
self.api_key_source = kwargs.get("api_key_source") or "HEADER"
self.endpoint_configuration = kwargs.get("endpoint_configuration") or {
"types": ["EDGE"]
}
self.tags = kwargs.get("tags") or {}
self.deployments = {} self.deployments = {}
self.stages = {} self.stages = {}
@ -416,6 +426,9 @@ class RestAPI(BaseModel):
"name": self.name, "name": self.name,
"description": self.description, "description": self.description,
"createdDate": int(time.time()), "createdDate": int(time.time()),
"apiKeySource": self.api_key_source,
"endpointConfiguration": self.endpoint_configuration,
"tags": self.tags,
} }
def add_child(self, path, parent_id=None): def add_child(self, path, parent_id=None):
@ -529,9 +542,24 @@ class APIGatewayBackend(BaseBackend):
self.__dict__ = {} self.__dict__ = {}
self.__init__(region_name) self.__init__(region_name)
def create_rest_api(self, name, description): def create_rest_api(
self,
name,
description,
api_key_source=None,
endpoint_configuration=None,
tags=None,
):
api_id = create_id() api_id = create_id()
rest_api = RestAPI(api_id, self.region_name, name, description) rest_api = RestAPI(
api_id,
self.region_name,
name,
description,
api_key_source=api_key_source,
endpoint_configuration=endpoint_configuration,
tags=tags,
)
self.apis[api_id] = rest_api self.apis[api_id] = rest_api
return rest_api return rest_api
@ -556,7 +584,7 @@ class APIGatewayBackend(BaseBackend):
return resource return resource
def create_resource(self, function_id, parent_resource_id, path_part): def create_resource(self, function_id, parent_resource_id, path_part):
if not re.match("^\\{?[a-zA-Z0-9._-]+\\}?$", path_part): if not re.match("^\\{?[a-zA-Z0-9._-]+\\+?\\}?$", path_part):
raise InvalidResourcePathException() raise InvalidResourcePathException()
api = self.get_rest_api(function_id) api = self.get_rest_api(function_id)
child = api.add_child(path=path_part, parent_id=parent_resource_id) child = api.add_child(path=path_part, parent_id=parent_resource_id)
@ -571,9 +599,18 @@ class APIGatewayBackend(BaseBackend):
resource = self.get_resource(function_id, resource_id) resource = self.get_resource(function_id, resource_id)
return resource.get_method(method_type) return resource.get_method(method_type)
def create_method(self, function_id, resource_id, method_type, authorization_type): def create_method(
self,
function_id,
resource_id,
method_type,
authorization_type,
api_key_required=None,
):
resource = self.get_resource(function_id, resource_id) resource = self.get_resource(function_id, resource_id)
method = resource.add_method(method_type, authorization_type) method = resource.add_method(
method_type, authorization_type, api_key_required=api_key_required
)
return method return method
def get_stage(self, function_id, stage_name): def get_stage(self, function_id, stage_name):

View File

@ -12,6 +12,9 @@ from .exceptions import (
ApiKeyAlreadyExists, ApiKeyAlreadyExists,
) )
API_KEY_SOURCES = ["AUTHORIZER", "HEADER"]
ENDPOINT_CONFIGURATION_TYPES = ["PRIVATE", "EDGE", "REGIONAL"]
class APIGatewayResponse(BaseResponse): class APIGatewayResponse(BaseResponse):
def error(self, type_, message, status=400): def error(self, type_, message, status=400):
@ -45,7 +48,45 @@ class APIGatewayResponse(BaseResponse):
elif self.method == "POST": elif self.method == "POST":
name = self._get_param("name") name = self._get_param("name")
description = self._get_param("description") description = self._get_param("description")
rest_api = self.backend.create_rest_api(name, description) api_key_source = self._get_param("apiKeySource")
endpoint_configuration = self._get_param("endpointConfiguration")
tags = self._get_param("tags")
# Param validation
if api_key_source and api_key_source not in API_KEY_SOURCES:
return self.error(
"ValidationException",
(
"1 validation error detected: "
"Value '{api_key_source}' at 'createRestApiInput.apiKeySource' failed "
"to satisfy constraint: Member must satisfy enum value set: "
"[AUTHORIZER, HEADER]"
).format(api_key_source=api_key_source),
)
if endpoint_configuration and "types" in endpoint_configuration:
invalid_types = list(
set(endpoint_configuration["types"])
- set(ENDPOINT_CONFIGURATION_TYPES)
)
if invalid_types:
return self.error(
"ValidationException",
(
"1 validation error detected: Value '{endpoint_type}' "
"at 'createRestApiInput.endpointConfiguration.types' failed "
"to satisfy constraint: Member must satisfy enum value set: "
"[PRIVATE, EDGE, REGIONAL]"
).format(endpoint_type=invalid_types[0]),
)
rest_api = self.backend.create_rest_api(
name,
description,
api_key_source=api_key_source,
endpoint_configuration=endpoint_configuration,
tags=tags,
)
return 200, {}, json.dumps(rest_api.to_dict()) return 200, {}, json.dumps(rest_api.to_dict())
def restapis_individual(self, request, full_url, headers): def restapis_individual(self, request, full_url, headers):
@ -104,8 +145,13 @@ class APIGatewayResponse(BaseResponse):
return 200, {}, json.dumps(method) return 200, {}, json.dumps(method)
elif self.method == "PUT": elif self.method == "PUT":
authorization_type = self._get_param("authorizationType") authorization_type = self._get_param("authorizationType")
api_key_required = self._get_param("apiKeyRequired")
method = self.backend.create_method( method = self.backend.create_method(
function_id, resource_id, method_type, authorization_type function_id,
resource_id,
method_type,
authorization_type,
api_key_required,
) )
return 200, {}, json.dumps(method) return 200, {}, json.dumps(method)

View File

@ -1,4 +1,5 @@
from botocore.client import ClientError from botocore.client import ClientError
from moto.core.exceptions import JsonRESTError
class LambdaClientError(ClientError): class LambdaClientError(ClientError):
@ -29,3 +30,12 @@ class InvalidRoleFormat(LambdaClientError):
role, InvalidRoleFormat.pattern role, InvalidRoleFormat.pattern
) )
super(InvalidRoleFormat, self).__init__("ValidationException", message) super(InvalidRoleFormat, self).__init__("ValidationException", message)
class PreconditionFailedException(JsonRESTError):
code = 412
def __init__(self, message):
super(PreconditionFailedException, self).__init__(
"PreconditionFailedException", message
)

View File

@ -25,6 +25,7 @@ import requests.adapters
from boto3 import Session from boto3 import Session
from moto.awslambda.policy import Policy
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.exceptions import RESTError from moto.core.exceptions import RESTError
from moto.iam.models import iam_backend from moto.iam.models import iam_backend
@ -47,15 +48,11 @@ from moto.core import ACCOUNT_ID
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try: try:
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
except ImportError: except ImportError:
from backports.tempfile import TemporaryDirectory from backports.tempfile import TemporaryDirectory
# The lambci container is returning a special escape character for the "RequestID" fields. Unicode 033:
# _stderr_regex = re.compile(r"START|END|REPORT RequestId: .*")
_stderr_regex = re.compile(r"\033\[\d+.*")
_orig_adapter_send = requests.adapters.HTTPAdapter.send _orig_adapter_send = requests.adapters.HTTPAdapter.send
docker_3 = docker.__version__[0] >= "3" docker_3 = docker.__version__[0] >= "3"
@ -164,7 +161,8 @@ class LambdaFunction(BaseModel):
self.logs_backend = logs_backends[self.region] self.logs_backend = logs_backends[self.region]
self.environment_vars = spec.get("Environment", {}).get("Variables", {}) self.environment_vars = spec.get("Environment", {}).get("Variables", {})
self.docker_client = docker.from_env() self.docker_client = docker.from_env()
self.policy = "" self.policy = None
self.state = "Active"
# Unfortunately mocking replaces this method w/o fallback enabled, so we # Unfortunately mocking replaces this method w/o fallback enabled, so we
# need to replace it if we detect it's been mocked # need to replace it if we detect it's been mocked
@ -274,11 +272,11 @@ class LambdaFunction(BaseModel):
"MemorySize": self.memory_size, "MemorySize": self.memory_size,
"Role": self.role, "Role": self.role,
"Runtime": self.run_time, "Runtime": self.run_time,
"State": self.state,
"Timeout": self.timeout, "Timeout": self.timeout,
"Version": str(self.version), "Version": str(self.version),
"VpcConfig": self.vpc_config, "VpcConfig": self.vpc_config,
} }
if self.environment_vars: if self.environment_vars:
config["Environment"] = {"Variables": self.environment_vars} config["Environment"] = {"Variables": self.environment_vars}
@ -385,7 +383,7 @@ class LambdaFunction(BaseModel):
try: try:
# TODO: I believe we can keep the container running and feed events as needed # TODO: I believe we can keep the container running and feed events as needed
# also need to hook it up to the other services so it can make kws/s3 etc calls # also need to hook it up to the other services so it can make kws/s3 etc calls
# Should get invoke_id /RequestId from invovation # Should get invoke_id /RequestId from invocation
env_vars = { env_vars = {
"AWS_LAMBDA_FUNCTION_TIMEOUT": self.timeout, "AWS_LAMBDA_FUNCTION_TIMEOUT": self.timeout,
"AWS_LAMBDA_FUNCTION_NAME": self.function_name, "AWS_LAMBDA_FUNCTION_NAME": self.function_name,
@ -397,6 +395,7 @@ class LambdaFunction(BaseModel):
env_vars.update(self.environment_vars) env_vars.update(self.environment_vars)
container = output = exit_code = None container = output = exit_code = None
log_config = docker.types.LogConfig(type=docker.types.LogConfig.types.JSON)
with _DockerDataVolumeContext(self) as data_vol: with _DockerDataVolumeContext(self) as data_vol:
try: try:
run_kwargs = ( run_kwargs = (
@ -412,6 +411,7 @@ class LambdaFunction(BaseModel):
volumes=["{}:/var/task".format(data_vol.name)], volumes=["{}:/var/task".format(data_vol.name)],
environment=env_vars, environment=env_vars,
detach=True, detach=True,
log_config=log_config,
**run_kwargs **run_kwargs
) )
finally: finally:
@ -453,14 +453,9 @@ class LambdaFunction(BaseModel):
if exit_code != 0: if exit_code != 0:
raise Exception("lambda invoke failed output: {}".format(output)) raise Exception("lambda invoke failed output: {}".format(output))
# strip out RequestId lines (TODO: This will return an additional '\n' in the response) # We only care about the response from the lambda
output = os.linesep.join( # Which is the last line of the output, according to https://github.com/lambci/docker-lambda/issues/25
[ output = output.splitlines()[-1]
line
for line in self.convert(output).splitlines()
if not _stderr_regex.match(line)
]
)
return output, False return output, False
except BaseException as e: except BaseException as e:
traceback.print_exc() traceback.print_exc()
@ -480,7 +475,7 @@ class LambdaFunction(BaseModel):
payload["result"] = response_headers["x-amz-log-result"] payload["result"] = response_headers["x-amz-log-result"]
result = res.encode("utf-8") result = res.encode("utf-8")
else: else:
result = json.dumps(payload) result = res
if errored: if errored:
response_headers["x-amz-function-error"] = "Handled" response_headers["x-amz-function-error"] = "Handled"
@ -709,7 +704,8 @@ class LambdaStorage(object):
"versions": [], "versions": [],
"alias": weakref.WeakValueDictionary(), "alias": weakref.WeakValueDictionary(),
} }
# instantiate a new policy for this version of the lambda
fn.policy = Policy(fn)
self._arns[fn.function_arn] = fn self._arns[fn.function_arn] = fn
def publish_function(self, name): def publish_function(self, name):
@ -1010,8 +1006,21 @@ class LambdaBackend(BaseBackend):
return True return True
return False return False
def add_policy(self, function_name, policy): def add_policy_statement(self, function_name, raw):
self.get_function(function_name).policy = policy fn = self.get_function(function_name)
fn.policy.add_statement(raw)
def del_policy_statement(self, function_name, sid, revision=""):
fn = self.get_function(function_name)
fn.policy.del_statement(sid, revision)
def get_policy(self, function_name):
fn = self.get_function(function_name)
return fn.policy.get_policy()
def get_policy_wire_format(self, function_name):
fn = self.get_function(function_name)
return fn.policy.wire_format()
def update_function_code(self, function_name, qualifier, body): def update_function_code(self, function_name, qualifier, body):
fn = self.get_function(function_name, qualifier) fn = self.get_function(function_name, qualifier)

134
moto/awslambda/policy.py Normal file
View File

@ -0,0 +1,134 @@
from __future__ import unicode_literals
import json
import uuid
from six import string_types
from moto.awslambda.exceptions import PreconditionFailedException
class Policy:
def __init__(self, parent):
self.revision = str(uuid.uuid4())
self.statements = []
self.parent = parent
def wire_format(self):
p = self.get_policy()
p["Policy"] = json.dumps(p["Policy"])
return json.dumps(p)
def get_policy(self):
return {
"Policy": {
"Version": "2012-10-17",
"Id": "default",
"Statement": self.statements,
},
"RevisionId": self.revision,
}
# adds the raw JSON statement to the policy
def add_statement(self, raw):
policy = json.loads(raw, object_hook=self.decode_policy)
if len(policy.revision) > 0 and self.revision != policy.revision:
raise PreconditionFailedException(
"The RevisionId provided does not match the latest RevisionId"
" for the Lambda function or alias. Call the GetFunction or the GetAlias API to retrieve"
" the latest RevisionId for your resource."
)
self.statements.append(policy.statements[0])
self.revision = str(uuid.uuid4())
# removes the statement that matches 'sid' from the policy
def del_statement(self, sid, revision=""):
if len(revision) > 0 and self.revision != revision:
raise PreconditionFailedException(
"The RevisionId provided does not match the latest RevisionId"
" for the Lambda function or alias. Call the GetFunction or the GetAlias API to retrieve"
" the latest RevisionId for your resource."
)
for statement in self.statements:
if "Sid" in statement and statement["Sid"] == sid:
self.statements.remove(statement)
# converts AddPermission request to PolicyStatement
# https://docs.aws.amazon.com/lambda/latest/dg/API_AddPermission.html
def decode_policy(self, obj):
# import pydevd
# pydevd.settrace("localhost", port=5678)
policy = Policy(self.parent)
policy.revision = obj.get("RevisionId", "")
# set some default values if these keys are not set
self.ensure_set(obj, "Effect", "Allow")
self.ensure_set(obj, "Resource", self.parent.function_arn + ":$LATEST")
self.ensure_set(obj, "StatementId", str(uuid.uuid4()))
# transform field names and values
self.transform_property(obj, "StatementId", "Sid", self.nop_formatter)
self.transform_property(obj, "Principal", "Principal", self.principal_formatter)
self.transform_property(
obj, "SourceArn", "SourceArn", self.source_arn_formatter
)
self.transform_property(
obj, "SourceAccount", "SourceAccount", self.source_account_formatter
)
# remove RevisionId and EventSourceToken if they are set
self.remove_if_set(obj, ["RevisionId", "EventSourceToken"])
# merge conditional statements into a single map under the Condition key
self.condition_merge(obj)
# append resulting statement to policy.statements
policy.statements.append(obj)
return policy
def nop_formatter(self, obj):
return obj
def ensure_set(self, obj, key, value):
if key not in obj:
obj[key] = value
def principal_formatter(self, obj):
if isinstance(obj, string_types):
if obj.endswith(".amazonaws.com"):
return {"Service": obj}
if obj.endswith(":root"):
return {"AWS": obj}
return obj
def source_account_formatter(self, obj):
return {"StringEquals": {"AWS:SourceAccount": obj}}
def source_arn_formatter(self, obj):
return {"ArnLike": {"AWS:SourceArn": obj}}
def transform_property(self, obj, old_name, new_name, formatter):
if old_name in obj:
obj[new_name] = formatter(obj[old_name])
if new_name != old_name:
del obj[old_name]
def remove_if_set(self, obj, keys):
for key in keys:
if key in obj:
del obj[key]
def condition_merge(self, obj):
if "SourceArn" in obj:
if "Condition" not in obj:
obj["Condition"] = {}
obj["Condition"].update(obj["SourceArn"])
del obj["SourceArn"]
if "SourceAccount" in obj:
if "Condition" not in obj:
obj["Condition"] = {}
obj["Condition"].update(obj["SourceAccount"])
del obj["SourceAccount"]

View File

@ -120,8 +120,12 @@ class LambdaResponse(BaseResponse):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == "GET": if request.method == "GET":
return self._get_policy(request, full_url, headers) return self._get_policy(request, full_url, headers)
if request.method == "POST": elif request.method == "POST":
return self._add_policy(request, full_url, headers) return self._add_policy(request, full_url, headers)
elif request.method == "DELETE":
return self._del_policy(request, full_url, headers, self.querystring)
else:
raise ValueError("Cannot handle {0} request".format(request.method))
def configuration(self, request, full_url, headers): def configuration(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -141,9 +145,9 @@ class LambdaResponse(BaseResponse):
path = request.path if hasattr(request, "path") else path_url(request.url) path = request.path if hasattr(request, "path") else path_url(request.url)
function_name = path.split("/")[-2] function_name = path.split("/")[-2]
if self.lambda_backend.get_function(function_name): if self.lambda_backend.get_function(function_name):
policy = self.body statement = self.body
self.lambda_backend.add_policy(function_name, policy) self.lambda_backend.add_policy_statement(function_name, statement)
return 200, {}, json.dumps(dict(Statement=policy)) return 200, {}, json.dumps({"Statement": statement})
else: else:
return 404, {}, "{}" return 404, {}, "{}"
@ -151,28 +155,42 @@ class LambdaResponse(BaseResponse):
path = request.path if hasattr(request, "path") else path_url(request.url) path = request.path if hasattr(request, "path") else path_url(request.url)
function_name = path.split("/")[-2] function_name = path.split("/")[-2]
if self.lambda_backend.get_function(function_name): if self.lambda_backend.get_function(function_name):
lambda_function = self.lambda_backend.get_function(function_name) out = self.lambda_backend.get_policy_wire_format(function_name)
return ( return 200, {}, out
200, else:
{}, return 404, {}, "{}"
json.dumps(
dict(Policy='{"Statement":[' + lambda_function.policy + "]}") def _del_policy(self, request, full_url, headers, querystring):
), path = request.path if hasattr(request, "path") else path_url(request.url)
function_name = path.split("/")[-3]
statement_id = path.split("/")[-1].split("?")[0]
revision = querystring.get("RevisionId", "")
if self.lambda_backend.get_function(function_name):
self.lambda_backend.del_policy_statement(
function_name, statement_id, revision
) )
return 204, {}, "{}"
else: else:
return 404, {}, "{}" return 404, {}, "{}"
def _invoke(self, request, full_url): def _invoke(self, request, full_url):
response_headers = {} response_headers = {}
function_name = self.path.rsplit("/", 2)[-2] # URL Decode in case it's a ARN:
function_name = unquote(self.path.rsplit("/", 2)[-2])
qualifier = self._get_param("qualifier") qualifier = self._get_param("qualifier")
response_header, payload = self.lambda_backend.invoke( response_header, payload = self.lambda_backend.invoke(
function_name, qualifier, self.body, self.headers, response_headers function_name, qualifier, self.body, self.headers, response_headers
) )
if payload: if payload:
return 202, response_headers, payload if request.headers["X-Amz-Invocation-Type"] == "Event":
status_code = 202
elif request.headers["X-Amz-Invocation-Type"] == "DryRun":
status_code = 204
else:
status_code = 200
return status_code, response_headers, payload
else: else:
return 404, response_headers, "{}" return 404, response_headers, "{}"
@ -283,7 +301,7 @@ class LambdaResponse(BaseResponse):
code["Configuration"]["FunctionArn"] += ":$LATEST" code["Configuration"]["FunctionArn"] += ":$LATEST"
return 200, {}, json.dumps(code) return 200, {}, json.dumps(code)
else: else:
return 404, {}, "{}" return 404, {"x-amzn-ErrorType": "ResourceNotFoundException"}, "{}"
def _get_aws_region(self, full_url): def _get_aws_region(self, full_url):
region = self.region_regex.search(full_url) region = self.region_regex.search(full_url)

View File

@ -6,14 +6,16 @@ url_bases = ["https?://lambda.(.+).amazonaws.com"]
response = LambdaResponse() response = LambdaResponse()
url_paths = { url_paths = {
"{0}/(?P<api_version>[^/]+)/functions/?$": response.root, r"{0}/(?P<api_version>[^/]+)/functions/?$": response.root,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_:%-]+)/?$": response.function, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_:%-]+)/?$": response.function,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/versions/?$": response.versions, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/versions/?$": response.versions,
r"{0}/(?P<api_version>[^/]+)/event-source-mappings/?$": response.event_source_mappings, r"{0}/(?P<api_version>[^/]+)/event-source-mappings/?$": response.event_source_mappings,
r"{0}/(?P<api_version>[^/]+)/event-source-mappings/(?P<UUID>[\w_-]+)/?$": response.event_source_mapping, r"{0}/(?P<api_version>[^/]+)/event-source-mappings/(?P<UUID>[\w_-]+)/?$": response.event_source_mapping,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invocations/?$": response.invoke, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invocations/?$": response.invoke,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<resource_arn>.+)/invocations/?$": response.invoke,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invoke-async/?$": response.invoke_async, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invoke-async/?$": response.invoke_async,
r"{0}/(?P<api_version>[^/]+)/tags/(?P<resource_arn>.+)": response.tag, r"{0}/(?P<api_version>[^/]+)/tags/(?P<resource_arn>.+)": response.tag,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/policy/(?P<statement_id>[\w_-]+)$": response.policy,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/policy/?$": response.policy, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/policy/?$": response.policy,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/configuration/?$": response.configuration, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/configuration/?$": response.configuration,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/code/?$": response.code, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/code/?$": response.code,

View File

@ -677,6 +677,8 @@ class CloudFormationBackend(BaseBackend):
def list_stack_resources(self, stack_name_or_id): def list_stack_resources(self, stack_name_or_id):
stack = self.get_stack(stack_name_or_id) stack = self.get_stack(stack_name_or_id)
if stack is None:
return None
return stack.stack_resources return stack.stack_resources
def delete_stack(self, name_or_stack_id): def delete_stack(self, name_or_stack_id):

View File

@ -229,6 +229,9 @@ class CloudFormationResponse(BaseResponse):
stack_name_or_id = self._get_param("StackName") stack_name_or_id = self._get_param("StackName")
resources = self.cloudformation_backend.list_stack_resources(stack_name_or_id) resources = self.cloudformation_backend.list_stack_resources(stack_name_or_id)
if resources is None:
raise ValidationError(stack_name_or_id)
template = self.response_template(LIST_STACKS_RESOURCES_RESPONSE) template = self.response_template(LIST_STACKS_RESOURCES_RESPONSE)
return template.render(resources=resources) return template.render(resources=resources)

View File

@ -14,6 +14,7 @@ from jose import jws
from moto.compat import OrderedDict from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID
from .exceptions import ( from .exceptions import (
GroupExistsException, GroupExistsException,
NotAuthorizedError, NotAuthorizedError,
@ -69,6 +70,9 @@ class CognitoIdpUserPool(BaseModel):
def __init__(self, region, name, extended_config): def __init__(self, region, name, extended_config):
self.region = region self.region = region
self.id = "{}_{}".format(self.region, str(uuid.uuid4().hex)) self.id = "{}_{}".format(self.region, str(uuid.uuid4().hex))
self.arn = "arn:aws:cognito-idp:{}:{}:userpool/{}".format(
self.region, DEFAULT_ACCOUNT_ID, self.id
)
self.name = name self.name = name
self.status = None self.status = None
self.extended_config = extended_config or {} self.extended_config = extended_config or {}
@ -91,6 +95,7 @@ class CognitoIdpUserPool(BaseModel):
def _base_json(self): def _base_json(self):
return { return {
"Id": self.id, "Id": self.id,
"Arn": self.arn,
"Name": self.name, "Name": self.name,
"Status": self.status, "Status": self.status,
"CreationDate": time.mktime(self.creation_date.timetuple()), "CreationDate": time.mktime(self.creation_date.timetuple()),
@ -108,7 +113,9 @@ class CognitoIdpUserPool(BaseModel):
return user_pool_json return user_pool_json
def create_jwt(self, client_id, username, expires_in=60 * 60, extra_data={}): def create_jwt(
self, client_id, username, token_use, expires_in=60 * 60, extra_data={}
):
now = int(time.time()) now = int(time.time())
payload = { payload = {
"iss": "https://cognito-idp.{}.amazonaws.com/{}".format( "iss": "https://cognito-idp.{}.amazonaws.com/{}".format(
@ -116,7 +123,7 @@ class CognitoIdpUserPool(BaseModel):
), ),
"sub": self.users[username].id, "sub": self.users[username].id,
"aud": client_id, "aud": client_id,
"token_use": "id", "token_use": token_use,
"auth_time": now, "auth_time": now,
"exp": now + expires_in, "exp": now + expires_in,
} }
@ -125,7 +132,10 @@ class CognitoIdpUserPool(BaseModel):
return jws.sign(payload, self.json_web_key, algorithm="RS256"), expires_in return jws.sign(payload, self.json_web_key, algorithm="RS256"), expires_in
def create_id_token(self, client_id, username): def create_id_token(self, client_id, username):
id_token, expires_in = self.create_jwt(client_id, username) extra_data = self.get_user_extra_data_by_client_id(client_id, username)
id_token, expires_in = self.create_jwt(
client_id, username, "id", extra_data=extra_data
)
self.id_tokens[id_token] = (client_id, username) self.id_tokens[id_token] = (client_id, username)
return id_token, expires_in return id_token, expires_in
@ -135,10 +145,7 @@ class CognitoIdpUserPool(BaseModel):
return refresh_token return refresh_token
def create_access_token(self, client_id, username): def create_access_token(self, client_id, username):
extra_data = self.get_user_extra_data_by_client_id(client_id, username) access_token, expires_in = self.create_jwt(client_id, username, "access")
access_token, expires_in = self.create_jwt(
client_id, username, extra_data=extra_data
)
self.access_tokens[access_token] = (client_id, username) self.access_tokens[access_token] = (client_id, username)
return access_token, expires_in return access_token, expires_in
@ -562,12 +569,17 @@ class CognitoIdpBackend(BaseBackend):
user.groups.discard(group) user.groups.discard(group)
# User # User
def admin_create_user(self, user_pool_id, username, temporary_password, attributes): def admin_create_user(
self, user_pool_id, username, message_action, temporary_password, attributes
):
user_pool = self.user_pools.get(user_pool_id) user_pool = self.user_pools.get(user_pool_id)
if not user_pool: if not user_pool:
raise ResourceNotFoundError(user_pool_id) raise ResourceNotFoundError(user_pool_id)
if username in user_pool.users: if message_action and message_action == "RESEND":
if username not in user_pool.users:
raise UserNotFoundError(username)
elif username in user_pool.users:
raise UsernameExistsException(username) raise UsernameExistsException(username)
user = CognitoIdpUser( user = CognitoIdpUser(

View File

@ -259,10 +259,12 @@ class CognitoIdpResponse(BaseResponse):
def admin_create_user(self): def admin_create_user(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
username = self._get_param("Username") username = self._get_param("Username")
message_action = self._get_param("MessageAction")
temporary_password = self._get_param("TemporaryPassword") temporary_password = self._get_param("TemporaryPassword")
user = cognitoidp_backends[self.region].admin_create_user( user = cognitoidp_backends[self.region].admin_create_user(
user_pool_id, user_pool_id,
username, username,
message_action,
temporary_password, temporary_password,
self._get_param("UserAttributes", []), self._get_param("UserAttributes", []),
) )
@ -279,9 +281,18 @@ class CognitoIdpResponse(BaseResponse):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
limit = self._get_param("Limit") limit = self._get_param("Limit")
token = self._get_param("PaginationToken") token = self._get_param("PaginationToken")
filt = self._get_param("Filter")
users, token = cognitoidp_backends[self.region].list_users( users, token = cognitoidp_backends[self.region].list_users(
user_pool_id, limit=limit, pagination_token=token user_pool_id, limit=limit, pagination_token=token
) )
if filt:
name, value = filt.replace('"', "").split("=")
users = [
user
for user in users
for attribute in user.attributes
if attribute["Name"] == name and attribute["Value"] == value
]
response = {"Users": [user.to_json(extended=True) for user in users]} response = {"Users": [user.to_json(extended=True) for user in users]}
if token: if token:
response["PaginationToken"] = str(token) response["PaginationToken"] = str(token)

View File

@ -43,7 +43,7 @@ from moto.config.exceptions import (
) )
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.s3.config import s3_config_query from moto.s3.config import s3_account_public_access_block_query, s3_config_query
from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID
@ -58,7 +58,10 @@ POP_STRINGS = [
DEFAULT_PAGE_SIZE = 100 DEFAULT_PAGE_SIZE = 100
# Map the Config resource type to a backend: # Map the Config resource type to a backend:
RESOURCE_MAP = {"AWS::S3::Bucket": s3_config_query} RESOURCE_MAP = {
"AWS::S3::Bucket": s3_config_query,
"AWS::S3::AccountPublicAccessBlock": s3_account_public_access_block_query,
}
def datetime2int(date): def datetime2int(date):
@ -867,16 +870,17 @@ class ConfigBackend(BaseBackend):
backend_region=backend_query_region, backend_region=backend_query_region,
) )
result = { resource_identifiers = []
"resourceIdentifiers": [ for identifier in identifiers:
{ item = {"resourceType": identifier["type"], "resourceId": identifier["id"]}
"resourceType": identifier["type"],
"resourceId": identifier["id"], # Some resource types lack names:
"resourceName": identifier["name"], if identifier.get("name"):
} item["resourceName"] = identifier["name"]
for identifier in identifiers
] resource_identifiers.append(item)
}
result = {"resourceIdentifiers": resource_identifiers}
if new_token: if new_token:
result["nextToken"] = new_token result["nextToken"] = new_token
@ -927,18 +931,21 @@ class ConfigBackend(BaseBackend):
resource_region=resource_region, resource_region=resource_region,
) )
result = { resource_identifiers = []
"ResourceIdentifiers": [ for identifier in identifiers:
{ item = {
"SourceAccountId": DEFAULT_ACCOUNT_ID, "SourceAccountId": DEFAULT_ACCOUNT_ID,
"SourceRegion": identifier["region"], "SourceRegion": identifier["region"],
"ResourceType": identifier["type"], "ResourceType": identifier["type"],
"ResourceId": identifier["id"], "ResourceId": identifier["id"],
"ResourceName": identifier["name"], }
}
for identifier in identifiers if identifier.get("name"):
] item["ResourceName"] = identifier["name"]
}
resource_identifiers.append(item)
result = {"ResourceIdentifiers": resource_identifiers}
if new_token: if new_token:
result["NextToken"] = new_token result["NextToken"] = new_token

View File

@ -606,12 +606,13 @@ class ConfigQueryModel(object):
As such, the proper way to implement is to first obtain a full list of results from all the region backends, and then filter As such, the proper way to implement is to first obtain a full list of results from all the region backends, and then filter
from there. It may be valuable to make this a concatenation of the region and resource name. from there. It may be valuable to make this a concatenation of the region and resource name.
:param resource_region: :param resource_ids: A list of resource IDs
:param resource_ids: :param resource_name: The individual name of a resource
:param resource_name: :param limit: How many per page
:param limit: :param next_token: The item that will page on
:param next_token:
:param backend_region: The region for the backend to pull results from. Set to `None` if this is an aggregated query. :param backend_region: The region for the backend to pull results from. Set to `None` if this is an aggregated query.
:param resource_region: The region for where the resources reside to pull results from. Set to `None` if this is a
non-aggregated query.
:return: This should return a list of Dicts that have the following fields: :return: This should return a list of Dicts that have the following fields:
[ [
{ {

View File

@ -977,10 +977,8 @@ class OpLessThan(Op):
lhs = self.lhs.expr(item) lhs = self.lhs.expr(item)
rhs = self.rhs.expr(item) rhs = self.rhs.expr(item)
# In python3 None is not a valid comparator when using < or > so must be handled specially # In python3 None is not a valid comparator when using < or > so must be handled specially
if lhs and rhs: if lhs is not None and rhs is not None:
return lhs < rhs return lhs < rhs
elif lhs is None and rhs:
return True
else: else:
return False return False
@ -992,10 +990,8 @@ class OpGreaterThan(Op):
lhs = self.lhs.expr(item) lhs = self.lhs.expr(item)
rhs = self.rhs.expr(item) rhs = self.rhs.expr(item)
# In python3 None is not a valid comparator when using < or > so must be handled specially # In python3 None is not a valid comparator when using < or > so must be handled specially
if lhs and rhs: if lhs is not None and rhs is not None:
return lhs > rhs return lhs > rhs
elif lhs and rhs is None:
return True
else: else:
return False return False
@ -1025,10 +1021,8 @@ class OpLessThanOrEqual(Op):
lhs = self.lhs.expr(item) lhs = self.lhs.expr(item)
rhs = self.rhs.expr(item) rhs = self.rhs.expr(item)
# In python3 None is not a valid comparator when using < or > so must be handled specially # In python3 None is not a valid comparator when using < or > so must be handled specially
if lhs and rhs: if lhs is not None and rhs is not None:
return lhs <= rhs return lhs <= rhs
elif lhs is None and rhs or lhs is None and rhs is None:
return True
else: else:
return False return False
@ -1040,10 +1034,8 @@ class OpGreaterThanOrEqual(Op):
lhs = self.lhs.expr(item) lhs = self.lhs.expr(item)
rhs = self.rhs.expr(item) rhs = self.rhs.expr(item)
# In python3 None is not a valid comparator when using < or > so must be handled specially # In python3 None is not a valid comparator when using < or > so must be handled specially
if lhs and rhs: if lhs is not None and rhs is not None:
return lhs >= rhs return lhs >= rhs
elif lhs and rhs is None or lhs is None and rhs is None:
return True
else: else:
return False return False

View File

@ -448,16 +448,21 @@ class Item(BaseModel):
if list_append_re: if list_append_re:
new_value = expression_attribute_values[list_append_re.group(2).strip()] new_value = expression_attribute_values[list_append_re.group(2).strip()]
old_list_key = list_append_re.group(1) old_list_key = list_append_re.group(1)
# Get the existing value # old_key could be a function itself (if_not_exists)
old_list = self.attrs[old_list_key.split(".")[0]] if old_list_key.startswith("if_not_exists"):
if "." in old_list_key: old_list = DynamoType(
# Value is nested inside a map - find the appropriate child attr expression_attribute_values[self._get_default(old_list_key)]
old_list = old_list.child_attr(
".".join(old_list_key.split(".")[1:])
) )
else:
old_list = self.attrs[old_list_key.split(".")[0]]
if "." in old_list_key:
# Value is nested inside a map - find the appropriate child attr
old_list = old_list.child_attr(
".".join(old_list_key.split(".")[1:])
)
if not old_list.is_list(): if not old_list.is_list():
raise ParamValidationError raise ParamValidationError
old_list.value.extend(new_value["L"]) old_list.value.extend([DynamoType(v) for v in new_value["L"]])
value = old_list value = old_list
return value return value

View File

@ -508,6 +508,13 @@ class DynamoHandler(BaseResponse):
# 'KeyConditions': {u'forum_name': {u'ComparisonOperator': u'EQ', u'AttributeValueList': [{u'S': u'the-key'}]}} # 'KeyConditions': {u'forum_name': {u'ComparisonOperator': u'EQ', u'AttributeValueList': [{u'S': u'the-key'}]}}
key_conditions = self.body.get("KeyConditions") key_conditions = self.body.get("KeyConditions")
query_filters = self.body.get("QueryFilter") query_filters = self.body.get("QueryFilter")
if not (key_conditions or query_filters):
return self.error(
"com.amazonaws.dynamodb.v20111205#ValidationException",
"Either KeyConditions or QueryFilter should be present",
)
if key_conditions: if key_conditions:
( (
hash_key_name, hash_key_name,

View File

@ -27,6 +27,7 @@ from moto.core.utils import (
iso_8601_datetime_with_milliseconds, iso_8601_datetime_with_milliseconds,
camelcase_to_underscores, camelcase_to_underscores,
) )
from moto.iam.models import ACCOUNT_ID
from .exceptions import ( from .exceptions import (
CidrLimitExceeded, CidrLimitExceeded,
DependencyViolationError, DependencyViolationError,
@ -139,18 +140,23 @@ from .utils import (
rsa_public_key_fingerprint, rsa_public_key_fingerprint,
) )
INSTANCE_TYPES = json.load(
open(resource_filename(__name__, "resources/instance_types.json"), "r") def _load_resource(filename):
) with open(filename, "r") as f:
AMIS = json.load( return json.load(f)
open(
os.environ.get("MOTO_AMIS_PATH")
or resource_filename(__name__, "resources/amis.json"), INSTANCE_TYPES = _load_resource(
"r", resource_filename(__name__, "resources/instance_types.json")
)
) )
OWNER_ID = "111122223333" AMIS = _load_resource(
os.environ.get("MOTO_AMIS_PATH")
or resource_filename(__name__, "resources/amis.json"),
)
OWNER_ID = ACCOUNT_ID
def utc_date_and_time(): def utc_date_and_time():
@ -1336,7 +1342,7 @@ class AmiBackend(object):
source_ami=None, source_ami=None,
name=name, name=name,
description=description, description=description,
owner_id=context.get_current_user() if context else OWNER_ID, owner_id=OWNER_ID,
) )
self.amis[ami_id] = ami self.amis[ami_id] = ami
return ami return ami
@ -1387,14 +1393,7 @@ class AmiBackend(object):
# Limit by owner ids # Limit by owner ids
if owners: if owners:
# support filtering by Owners=['self'] # support filtering by Owners=['self']
owners = list( owners = list(map(lambda o: OWNER_ID if o == "self" else o, owners,))
map(
lambda o: context.get_current_user()
if context and o == "self"
else o,
owners,
)
)
images = [ami for ami in images if ami.owner_id in owners] images = [ami for ami in images if ami.owner_id in owners]
# Generic filters # Generic filters

View File

@ -104,7 +104,7 @@ class SecurityGroups(BaseResponse):
if self.is_not_dryrun("GrantSecurityGroupIngress"): if self.is_not_dryrun("GrantSecurityGroupIngress"):
for args in self._process_rules_from_querystring(): for args in self._process_rules_from_querystring():
self.ec2_backend.authorize_security_group_ingress(*args) self.ec2_backend.authorize_security_group_ingress(*args)
return AUTHORIZE_SECURITY_GROUP_INGRESS_REPONSE return AUTHORIZE_SECURITY_GROUP_INGRESS_RESPONSE
def create_security_group(self): def create_security_group(self):
name = self._get_param("GroupName") name = self._get_param("GroupName")
@ -158,7 +158,7 @@ class SecurityGroups(BaseResponse):
if self.is_not_dryrun("RevokeSecurityGroupIngress"): if self.is_not_dryrun("RevokeSecurityGroupIngress"):
for args in self._process_rules_from_querystring(): for args in self._process_rules_from_querystring():
self.ec2_backend.revoke_security_group_ingress(*args) self.ec2_backend.revoke_security_group_ingress(*args)
return REVOKE_SECURITY_GROUP_INGRESS_REPONSE return REVOKE_SECURITY_GROUP_INGRESS_RESPONSE
CREATE_SECURITY_GROUP_RESPONSE = """<CreateSecurityGroupResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/"> CREATE_SECURITY_GROUP_RESPONSE = """<CreateSecurityGroupResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
@ -265,12 +265,12 @@ DESCRIBE_SECURITY_GROUPS_RESPONSE = (
</DescribeSecurityGroupsResponse>""" </DescribeSecurityGroupsResponse>"""
) )
AUTHORIZE_SECURITY_GROUP_INGRESS_REPONSE = """<AuthorizeSecurityGroupIngressResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/"> AUTHORIZE_SECURITY_GROUP_INGRESS_RESPONSE = """<AuthorizeSecurityGroupIngressResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId> <requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId>
<return>true</return> <return>true</return>
</AuthorizeSecurityGroupIngressResponse>""" </AuthorizeSecurityGroupIngressResponse>"""
REVOKE_SECURITY_GROUP_INGRESS_REPONSE = """<RevokeSecurityGroupIngressResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/"> REVOKE_SECURITY_GROUP_INGRESS_RESPONSE = """<RevokeSecurityGroupIngressResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId> <requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId>
<return>true</return> <return>true</return>
</RevokeSecurityGroupIngressResponse>""" </RevokeSecurityGroupIngressResponse>"""

View File

@ -118,6 +118,7 @@ class TaskDefinition(BaseObject):
revision, revision,
container_definitions, container_definitions,
region_name, region_name,
network_mode=None,
volumes=None, volumes=None,
tags=None, tags=None,
): ):
@ -132,6 +133,10 @@ class TaskDefinition(BaseObject):
self.volumes = [] self.volumes = []
else: else:
self.volumes = volumes self.volumes = volumes
if network_mode is None:
self.network_mode = "bridge"
else:
self.network_mode = network_mode
@property @property
def response_object(self): def response_object(self):
@ -553,7 +558,7 @@ class EC2ContainerServiceBackend(BaseBackend):
raise Exception("{0} is not a cluster".format(cluster_name)) raise Exception("{0} is not a cluster".format(cluster_name))
def register_task_definition( def register_task_definition(
self, family, container_definitions, volumes, tags=None self, family, container_definitions, volumes=None, network_mode=None, tags=None
): ):
if family in self.task_definitions: if family in self.task_definitions:
last_id = self._get_last_task_definition_revision_id(family) last_id = self._get_last_task_definition_revision_id(family)
@ -562,7 +567,13 @@ class EC2ContainerServiceBackend(BaseBackend):
self.task_definitions[family] = {} self.task_definitions[family] = {}
revision = 1 revision = 1
task_definition = TaskDefinition( task_definition = TaskDefinition(
family, revision, container_definitions, self.region_name, volumes, tags family,
revision,
container_definitions,
self.region_name,
volumes=volumes,
network_mode=network_mode,
tags=tags,
) )
self.task_definitions[family][revision] = task_definition self.task_definitions[family][revision] = task_definition

View File

@ -62,8 +62,13 @@ class EC2ContainerServiceResponse(BaseResponse):
container_definitions = self._get_param("containerDefinitions") container_definitions = self._get_param("containerDefinitions")
volumes = self._get_param("volumes") volumes = self._get_param("volumes")
tags = self._get_param("tags") tags = self._get_param("tags")
network_mode = self._get_param("networkMode")
task_definition = self.ecs_backend.register_task_definition( task_definition = self.ecs_backend.register_task_definition(
family, container_definitions, volumes, tags family,
container_definitions,
volumes=volumes,
network_mode=network_mode,
tags=tags,
) )
return json.dumps({"taskDefinition": task_definition.response_object}) return json.dumps({"taskDefinition": task_definition.response_object})

View File

@ -143,6 +143,9 @@ class EventsBackend(BaseBackend):
def delete_rule(self, name): def delete_rule(self, name):
self.rules_order.pop(self.rules_order.index(name)) self.rules_order.pop(self.rules_order.index(name))
arn = self.rules.get(name).arn
if self.tagger.has_tags(arn):
self.tagger.delete_all_tags_for_resource(arn)
return self.rules.pop(name) is not None return self.rules.pop(name) is not None
def describe_rule(self, name): def describe_rule(self, name):
@ -362,32 +365,41 @@ class EventsBackend(BaseBackend):
) )
self.event_buses.pop(name, None) self.event_buses.pop(name, None)
def list_tags_for_resource(self, arn): def list_tags_for_resource(self, arn):
name = arn.split('/')[-1] name = arn.split('/')[-1]
if name in self.rules: if name in self.rules:
return self.tagger.list_tags_for_resource(self.rules[name].arn) return self.tagger.list_tags_for_resource(self.rules[name].arn)
raise JsonRESTError( raise JsonRESTError(
"ResourceNotFoundException", "An entity that you specified does not exist." "ResourceNotFoundException", "An entity that you specified does not exist."
) )
def list_tags_for_resource(self, arn):
name = arn.split("/")[-1]
if name in self.rules:
return self.tagger.list_tags_for_resource(self.rules[name].arn)
raise JsonRESTError(
"ResourceNotFoundException", "An entity that you specified does not exist."
)
def tag_resource(self, arn, tags): def tag_resource(self, arn, tags):
name = arn.split('/')[-1] name = arn.split("/")[-1]
if name in self.rules: if name in self.rules:
self.tagger.tag_resource(self.rules[name].arn, tags) self.tagger.tag_resource(self.rules[name].arn, tags)
return {} return {}
raise JsonRESTError( raise JsonRESTError(
"ResourceNotFoundException", "An entity that you specified does not exist." "ResourceNotFoundException", "An entity that you specified does not exist."
) )
def untag_resource(self, arn, tag_names): def untag_resource(self, arn, tag_names):
name = arn.split('/')[-1] name = arn.split("/")[-1]
if name in self.rules: if name in self.rules:
self.tagger.untag_resource_using_names(self.rules[name].arn, tag_names) self.tagger.untag_resource_using_names(self.rules[name].arn, tag_names)
return {} return {}
raise JsonRESTError( raise JsonRESTError(
"ResourceNotFoundException", "An entity that you specified does not exist." "ResourceNotFoundException", "An entity that you specified does not exist."
) )
events_backends = {} events_backends = {}
for region in Session().get_available_regions("events"): for region in Session().get_available_regions("events"):

View File

@ -563,6 +563,10 @@ class IamResponse(BaseResponse):
def create_access_key(self): def create_access_key(self):
user_name = self._get_param("UserName") user_name = self._get_param("UserName")
if not user_name:
access_key_id = self.get_current_user()
access_key = iam_backend.get_access_key_last_used(access_key_id)
user_name = access_key["user_name"]
key = iam_backend.create_access_key(user_name) key = iam_backend.create_access_key(user_name)
template = self.response_template(CREATE_ACCESS_KEY_TEMPLATE) template = self.response_template(CREATE_ACCESS_KEY_TEMPLATE)
@ -572,6 +576,10 @@ class IamResponse(BaseResponse):
user_name = self._get_param("UserName") user_name = self._get_param("UserName")
access_key_id = self._get_param("AccessKeyId") access_key_id = self._get_param("AccessKeyId")
status = self._get_param("Status") status = self._get_param("Status")
if not user_name:
access_key = iam_backend.get_access_key_last_used(access_key_id)
user_name = access_key["user_name"]
iam_backend.update_access_key(user_name, access_key_id, status) iam_backend.update_access_key(user_name, access_key_id, status)
template = self.response_template(GENERIC_EMPTY_TEMPLATE) template = self.response_template(GENERIC_EMPTY_TEMPLATE)
return template.render(name="UpdateAccessKey") return template.render(name="UpdateAccessKey")
@ -587,6 +595,11 @@ class IamResponse(BaseResponse):
def list_access_keys(self): def list_access_keys(self):
user_name = self._get_param("UserName") user_name = self._get_param("UserName")
if not user_name:
access_key_id = self.get_current_user()
access_key = iam_backend.get_access_key_last_used(access_key_id)
user_name = access_key["user_name"]
keys = iam_backend.get_all_access_keys(user_name) keys = iam_backend.get_all_access_keys(user_name)
template = self.response_template(LIST_ACCESS_KEYS_TEMPLATE) template = self.response_template(LIST_ACCESS_KEYS_TEMPLATE)
return template.render(user_name=user_name, keys=keys) return template.render(user_name=user_name, keys=keys)
@ -594,6 +607,9 @@ class IamResponse(BaseResponse):
def delete_access_key(self): def delete_access_key(self):
user_name = self._get_param("UserName") user_name = self._get_param("UserName")
access_key_id = self._get_param("AccessKeyId") access_key_id = self._get_param("AccessKeyId")
if not user_name:
access_key = iam_backend.get_access_key_last_used(access_key_id)
user_name = access_key["user_name"]
iam_backend.delete_access_key(access_key_id, user_name) iam_backend.delete_access_key(access_key_id, user_name)
template = self.response_template(GENERIC_EMPTY_TEMPLATE) template = self.response_template(GENERIC_EMPTY_TEMPLATE)

View File

@ -22,6 +22,15 @@ class InvalidRequestException(IoTClientError):
) )
class InvalidStateTransitionException(IoTClientError):
def __init__(self, msg=None):
self.code = 409
super(InvalidStateTransitionException, self).__init__(
"InvalidStateTransitionException",
msg or "An attempt was made to change to an invalid state.",
)
class VersionConflictException(IoTClientError): class VersionConflictException(IoTClientError):
def __init__(self, name): def __init__(self, name):
self.code = 409 self.code = 409

View File

@ -17,6 +17,7 @@ from .exceptions import (
DeleteConflictException, DeleteConflictException,
ResourceNotFoundException, ResourceNotFoundException,
InvalidRequestException, InvalidRequestException,
InvalidStateTransitionException,
VersionConflictException, VersionConflictException,
) )
@ -29,7 +30,7 @@ class FakeThing(BaseModel):
self.attributes = attributes self.attributes = attributes
self.arn = "arn:aws:iot:%s:1:thing/%s" % (self.region_name, thing_name) self.arn = "arn:aws:iot:%s:1:thing/%s" % (self.region_name, thing_name)
self.version = 1 self.version = 1
# TODO: we need to handle 'version'? # TODO: we need to handle "version"?
# for iot-data # for iot-data
self.thing_shadow = None self.thing_shadow = None
@ -174,18 +175,19 @@ class FakeCertificate(BaseModel):
class FakePolicy(BaseModel): class FakePolicy(BaseModel):
def __init__(self, name, document, region_name): def __init__(self, name, document, region_name, default_version_id="1"):
self.name = name self.name = name
self.document = document self.document = document
self.arn = "arn:aws:iot:%s:1:policy/%s" % (region_name, name) self.arn = "arn:aws:iot:%s:1:policy/%s" % (region_name, name)
self.version = "1" # TODO: handle version self.default_version_id = default_version_id
self.versions = [FakePolicyVersion(self.name, document, True, region_name)]
def to_get_dict(self): def to_get_dict(self):
return { return {
"policyName": self.name, "policyName": self.name,
"policyArn": self.arn, "policyArn": self.arn,
"policyDocument": self.document, "policyDocument": self.document,
"defaultVersionId": self.version, "defaultVersionId": self.default_version_id,
} }
def to_dict_at_creation(self): def to_dict_at_creation(self):
@ -193,13 +195,52 @@ class FakePolicy(BaseModel):
"policyName": self.name, "policyName": self.name,
"policyArn": self.arn, "policyArn": self.arn,
"policyDocument": self.document, "policyDocument": self.document,
"policyVersionId": self.version, "policyVersionId": self.default_version_id,
} }
def to_dict(self): def to_dict(self):
return {"policyName": self.name, "policyArn": self.arn} return {"policyName": self.name, "policyArn": self.arn}
class FakePolicyVersion(object):
def __init__(self, policy_name, document, is_default, region_name):
self.name = policy_name
self.arn = "arn:aws:iot:%s:1:policy/%s" % (region_name, policy_name)
self.document = document or {}
self.is_default = is_default
self.version_id = "1"
self.create_datetime = time.mktime(datetime(2015, 1, 1).timetuple())
self.last_modified_datetime = time.mktime(datetime(2015, 1, 2).timetuple())
def to_get_dict(self):
return {
"policyName": self.name,
"policyArn": self.arn,
"policyDocument": self.document,
"policyVersionId": self.version_id,
"isDefaultVersion": self.is_default,
"creationDate": self.create_datetime,
"lastModifiedDate": self.last_modified_datetime,
"generationId": self.version_id,
}
def to_dict_at_creation(self):
return {
"policyArn": self.arn,
"policyDocument": self.document,
"policyVersionId": self.version_id,
"isDefaultVersion": self.is_default,
}
def to_dict(self):
return {
"versionId": self.version_id,
"isDefaultVersion": self.is_default,
"createDate": self.create_datetime,
}
class FakeJob(BaseModel): class FakeJob(BaseModel):
JOB_ID_REGEX_PATTERN = "[a-zA-Z0-9_-]" JOB_ID_REGEX_PATTERN = "[a-zA-Z0-9_-]"
JOB_ID_REGEX = re.compile(JOB_ID_REGEX_PATTERN) JOB_ID_REGEX = re.compile(JOB_ID_REGEX_PATTERN)
@ -226,12 +267,14 @@ class FakeJob(BaseModel):
self.targets = targets self.targets = targets
self.document_source = document_source self.document_source = document_source
self.document = document self.document = document
self.force = False
self.description = description self.description = description
self.presigned_url_config = presigned_url_config self.presigned_url_config = presigned_url_config
self.target_selection = target_selection self.target_selection = target_selection
self.job_executions_rollout_config = job_executions_rollout_config self.job_executions_rollout_config = job_executions_rollout_config
self.status = None # IN_PROGRESS | CANCELED | COMPLETED self.status = "QUEUED" # IN_PROGRESS | CANCELED | COMPLETED
self.comment = None self.comment = None
self.reason_code = None
self.created_at = time.mktime(datetime(2015, 1, 1).timetuple()) self.created_at = time.mktime(datetime(2015, 1, 1).timetuple())
self.last_updated_at = time.mktime(datetime(2015, 1, 1).timetuple()) self.last_updated_at = time.mktime(datetime(2015, 1, 1).timetuple())
self.completed_at = None self.completed_at = None
@ -258,9 +301,11 @@ class FakeJob(BaseModel):
"jobExecutionsRolloutConfig": self.job_executions_rollout_config, "jobExecutionsRolloutConfig": self.job_executions_rollout_config,
"status": self.status, "status": self.status,
"comment": self.comment, "comment": self.comment,
"forceCanceled": self.force,
"reasonCode": self.reason_code,
"createdAt": self.created_at, "createdAt": self.created_at,
"lastUpdatedAt": self.last_updated_at, "lastUpdatedAt": self.last_updated_at,
"completedAt": self.completedAt, "completedAt": self.completed_at,
"jobProcessDetails": self.job_process_details, "jobProcessDetails": self.job_process_details,
"documentParameters": self.document_parameters, "documentParameters": self.document_parameters,
"document": self.document, "document": self.document,
@ -275,12 +320,67 @@ class FakeJob(BaseModel):
return regex_match and length_match return regex_match and length_match
class FakeJobExecution(BaseModel):
def __init__(
self,
job_id,
thing_arn,
status="QUEUED",
force_canceled=False,
status_details_map={},
):
self.job_id = job_id
self.status = status # IN_PROGRESS | CANCELED | COMPLETED
self.force_canceled = force_canceled
self.status_details_map = status_details_map
self.thing_arn = thing_arn
self.queued_at = time.mktime(datetime(2015, 1, 1).timetuple())
self.started_at = time.mktime(datetime(2015, 1, 1).timetuple())
self.last_updated_at = time.mktime(datetime(2015, 1, 1).timetuple())
self.execution_number = 123
self.version_number = 123
self.approximate_seconds_before_time_out = 123
def to_get_dict(self):
obj = {
"jobId": self.job_id,
"status": self.status,
"forceCanceled": self.force_canceled,
"statusDetails": {"detailsMap": self.status_details_map},
"thingArn": self.thing_arn,
"queuedAt": self.queued_at,
"startedAt": self.started_at,
"lastUpdatedAt": self.last_updated_at,
"executionNumber": self.execution_number,
"versionNumber": self.version_number,
"approximateSecondsBeforeTimedOut": self.approximate_seconds_before_time_out,
}
return obj
def to_dict(self):
obj = {
"jobId": self.job_id,
"thingArn": self.thing_arn,
"jobExecutionSummary": {
"status": self.status,
"queuedAt": self.queued_at,
"startedAt": self.started_at,
"lastUpdatedAt": self.last_updated_at,
"executionNumber": self.execution_number,
},
}
return obj
class IoTBackend(BaseBackend): class IoTBackend(BaseBackend):
def __init__(self, region_name=None): def __init__(self, region_name=None):
super(IoTBackend, self).__init__() super(IoTBackend, self).__init__()
self.region_name = region_name self.region_name = region_name
self.things = OrderedDict() self.things = OrderedDict()
self.jobs = OrderedDict() self.jobs = OrderedDict()
self.job_executions = OrderedDict()
self.thing_types = OrderedDict() self.thing_types = OrderedDict()
self.thing_groups = OrderedDict() self.thing_groups = OrderedDict()
self.certificates = OrderedDict() self.certificates = OrderedDict()
@ -535,6 +635,28 @@ class IoTBackend(BaseBackend):
self.policies[policy.name] = policy self.policies[policy.name] = policy
return policy return policy
def attach_policy(self, policy_name, target):
principal = self._get_principal(target)
policy = self.get_policy(policy_name)
k = (target, policy_name)
if k in self.principal_policies:
return
self.principal_policies[k] = (principal, policy)
def detach_policy(self, policy_name, target):
# this may raises ResourceNotFoundException
self._get_principal(target)
self.get_policy(policy_name)
k = (target, policy_name)
if k not in self.principal_policies:
raise ResourceNotFoundException()
del self.principal_policies[k]
def list_attached_policies(self, target):
policies = [v[1] for k, v in self.principal_policies.items() if k[0] == target]
return policies
def list_policies(self): def list_policies(self):
policies = self.policies.values() policies = self.policies.values()
return policies return policies
@ -559,6 +681,60 @@ class IoTBackend(BaseBackend):
policy = self.get_policy(policy_name) policy = self.get_policy(policy_name)
del self.policies[policy.name] del self.policies[policy.name]
def create_policy_version(self, policy_name, policy_document, set_as_default):
policy = self.get_policy(policy_name)
if not policy:
raise ResourceNotFoundException()
version = FakePolicyVersion(
policy_name, policy_document, set_as_default, self.region_name
)
policy.versions.append(version)
version.version_id = "{0}".format(len(policy.versions))
if set_as_default:
self.set_default_policy_version(policy_name, version.version_id)
return version
def set_default_policy_version(self, policy_name, version_id):
policy = self.get_policy(policy_name)
if not policy:
raise ResourceNotFoundException()
for version in policy.versions:
if version.version_id == version_id:
version.is_default = True
policy.default_version_id = version.version_id
policy.document = version.document
else:
version.is_default = False
def get_policy_version(self, policy_name, version_id):
policy = self.get_policy(policy_name)
if not policy:
raise ResourceNotFoundException()
for version in policy.versions:
if version.version_id == version_id:
return version
raise ResourceNotFoundException()
def list_policy_versions(self, policy_name):
policy = self.get_policy(policy_name)
if not policy:
raise ResourceNotFoundException()
return policy.versions
def delete_policy_version(self, policy_name, version_id):
policy = self.get_policy(policy_name)
if not policy:
raise ResourceNotFoundException()
if version_id == policy.default_version_id:
raise InvalidRequestException(
"Cannot delete the default version of a policy"
)
for i, v in enumerate(policy.versions):
if v.version_id == version_id:
del policy.versions[i]
return
raise ResourceNotFoundException()
def _get_principal(self, principal_arn): def _get_principal(self, principal_arn):
""" """
raise ResourceNotFoundException raise ResourceNotFoundException
@ -574,14 +750,6 @@ class IoTBackend(BaseBackend):
pass pass
raise ResourceNotFoundException() raise ResourceNotFoundException()
def attach_policy(self, policy_name, target):
principal = self._get_principal(target)
policy = self.get_policy(policy_name)
k = (target, policy_name)
if k in self.principal_policies:
return
self.principal_policies[k] = (principal, policy)
def attach_principal_policy(self, policy_name, principal_arn): def attach_principal_policy(self, policy_name, principal_arn):
principal = self._get_principal(principal_arn) principal = self._get_principal(principal_arn)
policy = self.get_policy(policy_name) policy = self.get_policy(policy_name)
@ -590,15 +758,6 @@ class IoTBackend(BaseBackend):
return return
self.principal_policies[k] = (principal, policy) self.principal_policies[k] = (principal, policy)
def detach_policy(self, policy_name, target):
# this may raises ResourceNotFoundException
self._get_principal(target)
self.get_policy(policy_name)
k = (target, policy_name)
if k not in self.principal_policies:
raise ResourceNotFoundException()
del self.principal_policies[k]
def detach_principal_policy(self, policy_name, principal_arn): def detach_principal_policy(self, policy_name, principal_arn):
# this may raises ResourceNotFoundException # this may raises ResourceNotFoundException
self._get_principal(principal_arn) self._get_principal(principal_arn)
@ -819,11 +978,187 @@ class IoTBackend(BaseBackend):
self.region_name, self.region_name,
) )
self.jobs[job_id] = job self.jobs[job_id] = job
for thing_arn in targets:
thing_name = thing_arn.split(":")[-1].split("/")[-1]
job_execution = FakeJobExecution(job_id, thing_arn)
self.job_executions[(job_id, thing_name)] = job_execution
return job.job_arn, job_id, description return job.job_arn, job_id, description
def describe_job(self, job_id): def describe_job(self, job_id):
jobs = [_ for _ in self.jobs.values() if _.job_id == job_id]
if len(jobs) == 0:
raise ResourceNotFoundException()
return jobs[0]
def delete_job(self, job_id, force):
job = self.jobs[job_id]
if job.status == "IN_PROGRESS" and force:
del self.jobs[job_id]
elif job.status != "IN_PROGRESS":
del self.jobs[job_id]
else:
raise InvalidStateTransitionException()
def cancel_job(self, job_id, reason_code, comment, force):
job = self.jobs[job_id]
job.reason_code = reason_code if reason_code is not None else job.reason_code
job.comment = comment if comment is not None else job.comment
job.force = force if force is not None and force != job.force else job.force
job.status = "CANCELED"
if job.status == "IN_PROGRESS" and force:
self.jobs[job_id] = job
elif job.status != "IN_PROGRESS":
self.jobs[job_id] = job
else:
raise InvalidStateTransitionException()
return job
def get_job_document(self, job_id):
return self.jobs[job_id] return self.jobs[job_id]
def list_jobs(
self,
status,
target_selection,
max_results,
token,
thing_group_name,
thing_group_id,
):
# TODO: implement filters
all_jobs = [_.to_dict() for _ in self.jobs.values()]
filtered_jobs = all_jobs
if token is None:
jobs = filtered_jobs[0:max_results]
next_token = str(max_results) if len(filtered_jobs) > max_results else None
else:
token = int(token)
jobs = filtered_jobs[token : token + max_results]
next_token = (
str(token + max_results)
if len(filtered_jobs) > token + max_results
else None
)
return jobs, next_token
def describe_job_execution(self, job_id, thing_name, execution_number):
try:
job_execution = self.job_executions[(job_id, thing_name)]
except KeyError:
raise ResourceNotFoundException()
if job_execution is None or (
execution_number is not None
and job_execution.execution_number != execution_number
):
raise ResourceNotFoundException()
return job_execution
def cancel_job_execution(
self, job_id, thing_name, force, expected_version, status_details
):
job_execution = self.job_executions[(job_id, thing_name)]
if job_execution is None:
raise ResourceNotFoundException()
job_execution.force_canceled = (
force if force is not None else job_execution.force_canceled
)
# TODO: implement expected_version and status_details (at most 10 can be specified)
if job_execution.status == "IN_PROGRESS" and force:
job_execution.status = "CANCELED"
self.job_executions[(job_id, thing_name)] = job_execution
elif job_execution.status != "IN_PROGRESS":
job_execution.status = "CANCELED"
self.job_executions[(job_id, thing_name)] = job_execution
else:
raise InvalidStateTransitionException()
def delete_job_execution(self, job_id, thing_name, execution_number, force):
job_execution = self.job_executions[(job_id, thing_name)]
if job_execution.execution_number != execution_number:
raise ResourceNotFoundException()
if job_execution.status == "IN_PROGRESS" and force:
del self.job_executions[(job_id, thing_name)]
elif job_execution.status != "IN_PROGRESS":
del self.job_executions[(job_id, thing_name)]
else:
raise InvalidStateTransitionException()
def list_job_executions_for_job(self, job_id, status, max_results, next_token):
job_executions = [
self.job_executions[je].to_dict()
for je in self.job_executions
if je[0] == job_id
]
if status is not None:
job_executions = list(
filter(
lambda elem: status in elem["status"] and elem["status"] == status,
job_executions,
)
)
token = next_token
if token is None:
job_executions = job_executions[0:max_results]
next_token = str(max_results) if len(job_executions) > max_results else None
else:
token = int(token)
job_executions = job_executions[token : token + max_results]
next_token = (
str(token + max_results)
if len(job_executions) > token + max_results
else None
)
return job_executions, next_token
def list_job_executions_for_thing(
self, thing_name, status, max_results, next_token
):
job_executions = [
self.job_executions[je].to_dict()
for je in self.job_executions
if je[1] == thing_name
]
if status is not None:
job_executions = list(
filter(
lambda elem: status in elem["status"] and elem["status"] == status,
job_executions,
)
)
token = next_token
if token is None:
job_executions = job_executions[0:max_results]
next_token = str(max_results) if len(job_executions) > max_results else None
else:
token = int(token)
job_executions = job_executions[token : token + max_results]
next_token = (
str(token + max_results)
if len(job_executions) > token + max_results
else None
)
return job_executions, next_token
iot_backends = {} iot_backends = {}
for region in Session().get_available_regions("iot"): for region in Session().get_available_regions("iot"):

View File

@ -1,6 +1,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import json import json
from six.moves.urllib.parse import unquote
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import iot_backends from .models import iot_backends
@ -141,6 +142,8 @@ class IoTResponse(BaseResponse):
createdAt=job.created_at, createdAt=job.created_at,
description=job.description, description=job.description,
documentParameters=job.document_parameters, documentParameters=job.document_parameters,
forceCanceled=job.force,
reasonCode=job.reason_code,
jobArn=job.job_arn, jobArn=job.job_arn,
jobExecutionsRolloutConfig=job.job_executions_rollout_config, jobExecutionsRolloutConfig=job.job_executions_rollout_config,
jobId=job.job_id, jobId=job.job_id,
@ -154,6 +157,127 @@ class IoTResponse(BaseResponse):
) )
) )
def delete_job(self):
job_id = self._get_param("jobId")
force = self._get_bool_param("force")
self.iot_backend.delete_job(job_id=job_id, force=force)
return json.dumps(dict())
def cancel_job(self):
job_id = self._get_param("jobId")
reason_code = self._get_param("reasonCode")
comment = self._get_param("comment")
force = self._get_bool_param("force")
job = self.iot_backend.cancel_job(
job_id=job_id, reason_code=reason_code, comment=comment, force=force
)
return json.dumps(job.to_dict())
def get_job_document(self):
job = self.iot_backend.get_job_document(job_id=self._get_param("jobId"))
if job.document is not None:
return json.dumps({"document": job.document})
else:
# job.document_source is not None:
# TODO: needs to be implemented to get document_source's content from S3
return json.dumps({"document": ""})
def list_jobs(self):
status = (self._get_param("status"),)
target_selection = (self._get_param("targetSelection"),)
max_results = self._get_int_param(
"maxResults", 50
) # not the default, but makes testing easier
previous_next_token = self._get_param("nextToken")
thing_group_name = (self._get_param("thingGroupName"),)
thing_group_id = self._get_param("thingGroupId")
jobs, next_token = self.iot_backend.list_jobs(
status=status,
target_selection=target_selection,
max_results=max_results,
token=previous_next_token,
thing_group_name=thing_group_name,
thing_group_id=thing_group_id,
)
return json.dumps(dict(jobs=jobs, nextToken=next_token))
def describe_job_execution(self):
job_id = self._get_param("jobId")
thing_name = self._get_param("thingName")
execution_number = self._get_int_param("executionNumber")
job_execution = self.iot_backend.describe_job_execution(
job_id=job_id, thing_name=thing_name, execution_number=execution_number
)
return json.dumps(dict(execution=job_execution.to_get_dict()))
def cancel_job_execution(self):
job_id = self._get_param("jobId")
thing_name = self._get_param("thingName")
force = self._get_bool_param("force")
expected_version = self._get_int_param("expectedVersion")
status_details = self._get_param("statusDetails")
self.iot_backend.cancel_job_execution(
job_id=job_id,
thing_name=thing_name,
force=force,
expected_version=expected_version,
status_details=status_details,
)
return json.dumps(dict())
def delete_job_execution(self):
job_id = self._get_param("jobId")
thing_name = self._get_param("thingName")
execution_number = self._get_int_param("executionNumber")
force = self._get_bool_param("force")
self.iot_backend.delete_job_execution(
job_id=job_id,
thing_name=thing_name,
execution_number=execution_number,
force=force,
)
return json.dumps(dict())
def list_job_executions_for_job(self):
job_id = self._get_param("jobId")
status = self._get_param("status")
max_results = self._get_int_param(
"maxResults", 50
) # not the default, but makes testing easier
next_token = self._get_param("nextToken")
job_executions, next_token = self.iot_backend.list_job_executions_for_job(
job_id=job_id, status=status, max_results=max_results, next_token=next_token
)
return json.dumps(dict(executionSummaries=job_executions, nextToken=next_token))
def list_job_executions_for_thing(self):
thing_name = self._get_param("thingName")
status = self._get_param("status")
max_results = self._get_int_param(
"maxResults", 50
) # not the default, but makes testing easier
next_token = self._get_param("nextToken")
job_executions, next_token = self.iot_backend.list_job_executions_for_thing(
thing_name=thing_name,
status=status,
max_results=max_results,
next_token=next_token,
)
return json.dumps(dict(executionSummaries=job_executions, nextToken=next_token))
def create_keys_and_certificate(self): def create_keys_and_certificate(self):
set_as_active = self._get_bool_param("setAsActive") set_as_active = self._get_bool_param("setAsActive")
cert, key_pair = self.iot_backend.create_keys_and_certificate( cert, key_pair = self.iot_backend.create_keys_and_certificate(
@ -241,12 +365,61 @@ class IoTResponse(BaseResponse):
self.iot_backend.delete_policy(policy_name=policy_name) self.iot_backend.delete_policy(policy_name=policy_name)
return json.dumps(dict()) return json.dumps(dict())
def create_policy_version(self):
policy_name = self._get_param("policyName")
policy_document = self._get_param("policyDocument")
set_as_default = self._get_bool_param("setAsDefault")
policy_version = self.iot_backend.create_policy_version(
policy_name, policy_document, set_as_default
)
return json.dumps(dict(policy_version.to_dict_at_creation()))
def set_default_policy_version(self):
policy_name = self._get_param("policyName")
version_id = self._get_param("policyVersionId")
self.iot_backend.set_default_policy_version(policy_name, version_id)
return json.dumps(dict())
def get_policy_version(self):
policy_name = self._get_param("policyName")
version_id = self._get_param("policyVersionId")
policy_version = self.iot_backend.get_policy_version(policy_name, version_id)
return json.dumps(dict(policy_version.to_get_dict()))
def list_policy_versions(self):
policy_name = self._get_param("policyName")
policiy_versions = self.iot_backend.list_policy_versions(
policy_name=policy_name
)
return json.dumps(dict(policyVersions=[_.to_dict() for _ in policiy_versions]))
def delete_policy_version(self):
policy_name = self._get_param("policyName")
version_id = self._get_param("policyVersionId")
self.iot_backend.delete_policy_version(policy_name, version_id)
return json.dumps(dict())
def attach_policy(self): def attach_policy(self):
policy_name = self._get_param("policyName") policy_name = self._get_param("policyName")
target = self._get_param("target") target = self._get_param("target")
self.iot_backend.attach_policy(policy_name=policy_name, target=target) self.iot_backend.attach_policy(policy_name=policy_name, target=target)
return json.dumps(dict()) return json.dumps(dict())
def list_attached_policies(self):
principal = unquote(self._get_param("target"))
# marker = self._get_param("marker")
# page_size = self._get_int_param("pageSize")
policies = self.iot_backend.list_attached_policies(target=principal)
# TODO: implement pagination in the future
next_marker = None
return json.dumps(
dict(policies=[_.to_dict() for _ in policies], nextMarker=next_marker)
)
def attach_principal_policy(self): def attach_principal_policy(self):
policy_name = self._get_param("policyName") policy_name = self._get_param("policyName")
principal = self.headers.get("x-amzn-iot-principal") principal = self.headers.get("x-amzn-iot-principal")

View File

@ -7,25 +7,42 @@ from datetime import datetime, timedelta
from boto3 import Session from boto3 import Session
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
<<<<<<< HEAD
from moto.core.exceptions import JsonRESTError from moto.core.exceptions import JsonRESTError
from moto.core.utils import iso_8601_datetime_without_milliseconds from moto.core.utils import iso_8601_datetime_without_milliseconds
from moto.utilities.tagging_service import TaggingService from moto.utilities.tagging_service import TaggingService
=======
from moto.core.utils import unix_time
from moto.iam.models import ACCOUNT_ID
>>>>>>> 100dbd529f174f18d579a1dcc066d55409f2e38f
from .utils import decrypt, encrypt, generate_key_id, generate_master_key from .utils import decrypt, encrypt, generate_key_id, generate_master_key
class Key(BaseModel): class Key(BaseModel):
<<<<<<< HEAD
def __init__(self, policy, key_usage, description, region): def __init__(self, policy, key_usage, description, region):
=======
def __init__(
self, policy, key_usage, customer_master_key_spec, description, tags, region
):
>>>>>>> 100dbd529f174f18d579a1dcc066d55409f2e38f
self.id = generate_key_id() self.id = generate_key_id()
self.creation_date = unix_time()
self.policy = policy self.policy = policy
self.key_usage = key_usage self.key_usage = key_usage
self.key_state = "Enabled" self.key_state = "Enabled"
self.description = description self.description = description
self.enabled = True self.enabled = True
self.region = region self.region = region
self.account_id = "012345678912" self.account_id = ACCOUNT_ID
self.key_rotation_status = False self.key_rotation_status = False
self.deletion_date = None self.deletion_date = None
self.key_material = generate_master_key() self.key_material = generate_master_key()
self.origin = "AWS_KMS"
self.key_manager = "CUSTOMER"
self.customer_master_key_spec = customer_master_key_spec or "SYMMETRIC_DEFAULT"
@property @property
def physical_resource_id(self): def physical_resource_id(self):
@ -37,23 +54,55 @@ class Key(BaseModel):
self.region, self.account_id, self.id self.region, self.account_id, self.id
) )
@property
def encryption_algorithms(self):
if self.key_usage == "SIGN_VERIFY":
return None
elif self.customer_master_key_spec == "SYMMETRIC_DEFAULT":
return ["SYMMETRIC_DEFAULT"]
else:
return ["RSAES_OAEP_SHA_1", "RSAES_OAEP_SHA_256"]
@property
def signing_algorithms(self):
if self.key_usage == "ENCRYPT_DECRYPT":
return None
elif self.customer_master_key_spec in ["ECC_NIST_P256", "ECC_SECG_P256K1"]:
return ["ECDSA_SHA_256"]
elif self.customer_master_key_spec == "ECC_NIST_P384":
return ["ECDSA_SHA_384"]
elif self.customer_master_key_spec == "ECC_NIST_P521":
return ["ECDSA_SHA_512"]
else:
return [
"RSASSA_PKCS1_V1_5_SHA_256",
"RSASSA_PKCS1_V1_5_SHA_384",
"RSASSA_PKCS1_V1_5_SHA_512",
"RSASSA_PSS_SHA_256",
"RSASSA_PSS_SHA_384",
"RSASSA_PSS_SHA_512",
]
def to_dict(self): def to_dict(self):
key_dict = { key_dict = {
"KeyMetadata": { "KeyMetadata": {
"AWSAccountId": self.account_id, "AWSAccountId": self.account_id,
"Arn": self.arn, "Arn": self.arn,
"CreationDate": iso_8601_datetime_without_milliseconds(datetime.now()), "CreationDate": self.creation_date,
"CustomerMasterKeySpec": self.customer_master_key_spec,
"Description": self.description, "Description": self.description,
"Enabled": self.enabled, "Enabled": self.enabled,
"EncryptionAlgorithms": self.encryption_algorithms,
"KeyId": self.id, "KeyId": self.id,
"KeyManager": self.key_manager,
"KeyUsage": self.key_usage, "KeyUsage": self.key_usage,
"KeyState": self.key_state, "KeyState": self.key_state,
"Origin": self.origin,
"SigningAlgorithms": self.signing_algorithms,
} }
} }
if self.key_state == "PendingDeletion": if self.key_state == "PendingDeletion":
key_dict["KeyMetadata"][ key_dict["KeyMetadata"]["DeletionDate"] = unix_time(self.deletion_date)
"DeletionDate"
] = iso_8601_datetime_without_milliseconds(self.deletion_date)
return key_dict return key_dict
def delete(self, region_name): def delete(self, region_name):
@ -69,6 +118,7 @@ class Key(BaseModel):
key = kms_backend.create_key( key = kms_backend.create_key(
policy=properties["KeyPolicy"], policy=properties["KeyPolicy"],
key_usage="ENCRYPT_DECRYPT", key_usage="ENCRYPT_DECRYPT",
customer_master_key_spec="SYMMETRIC_DEFAULT",
description=properties["Description"], description=properties["Description"],
region=region_name, region=region_name,
) )
@ -92,8 +142,17 @@ class KmsBackend(BaseBackend):
self.key_to_aliases = defaultdict(set) self.key_to_aliases = defaultdict(set)
self.tagger = TaggingService(keyName='TagKey', valueName='TagValue') self.tagger = TaggingService(keyName='TagKey', valueName='TagValue')
<<<<<<< HEAD
def create_key(self, policy, key_usage, description, tags, region): def create_key(self, policy, key_usage, description, tags, region):
key = Key(policy, key_usage, description, region) key = Key(policy, key_usage, description, region)
=======
def create_key(
self, policy, key_usage, customer_master_key_spec, description, tags, region
):
key = Key(
policy, key_usage, customer_master_key_spec, description, tags, region
)
>>>>>>> 100dbd529f174f18d579a1dcc066d55409f2e38f
self.keys[key.id] = key self.keys[key.id] = key
if tags != None and len(tags) > 0: if tags != None and len(tags) > 0:
self.tag_resource(key.id, tags) self.tag_resource(key.id, tags)
@ -211,9 +270,7 @@ class KmsBackend(BaseBackend):
self.keys[key_id].deletion_date = datetime.now() + timedelta( self.keys[key_id].deletion_date = datetime.now() + timedelta(
days=pending_window_in_days days=pending_window_in_days
) )
return iso_8601_datetime_without_milliseconds( return unix_time(self.keys[key_id].deletion_date)
self.keys[key_id].deletion_date
)
def encrypt(self, key_id, plaintext, encryption_context): def encrypt(self, key_id, plaintext, encryption_context):
key_id = self.any_id_to_key_id(key_id) key_id = self.any_id_to_key_id(key_id)

View File

@ -118,11 +118,12 @@ class KmsResponse(BaseResponse):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_CreateKey.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_CreateKey.html"""
policy = self.parameters.get("Policy") policy = self.parameters.get("Policy")
key_usage = self.parameters.get("KeyUsage") key_usage = self.parameters.get("KeyUsage")
customer_master_key_spec = self.parameters.get("CustomerMasterKeySpec")
description = self.parameters.get("Description") description = self.parameters.get("Description")
tags = self.parameters.get("Tags") tags = self.parameters.get("Tags")
key = self.kms_backend.create_key( key = self.kms_backend.create_key(
policy, key_usage, description, tags, self.region policy, key_usage, customer_master_key_spec, description, tags, self.region
) )
return json.dumps(key.to_dict()) return json.dumps(key.to_dict())

View File

@ -103,7 +103,7 @@ class LogsResponse(BaseResponse):
( (
events, events,
next_backward_token, next_backward_token,
next_foward_token, next_forward_token,
) = self.logs_backend.get_log_events( ) = self.logs_backend.get_log_events(
log_group_name, log_group_name,
log_stream_name, log_stream_name,
@ -117,7 +117,7 @@ class LogsResponse(BaseResponse):
{ {
"events": events, "events": events,
"nextBackwardToken": next_backward_token, "nextBackwardToken": next_backward_token,
"nextForwardToken": next_foward_token, "nextForwardToken": next_forward_token,
} }
) )

View File

@ -10,3 +10,13 @@ class InvalidInputException(JsonRESTError):
"InvalidInputException", "InvalidInputException",
"You provided a value that does not match the required pattern.", "You provided a value that does not match the required pattern.",
) )
class DuplicateOrganizationalUnitException(JsonRESTError):
code = 400
def __init__(self):
super(DuplicateOrganizationalUnitException, self).__init__(
"DuplicateOrganizationalUnitException",
"An OU with the same name already exists.",
)

View File

@ -8,7 +8,10 @@ from moto.core import BaseBackend, BaseModel
from moto.core.exceptions import RESTError from moto.core.exceptions import RESTError
from moto.core.utils import unix_time from moto.core.utils import unix_time
from moto.organizations import utils from moto.organizations import utils
from moto.organizations.exceptions import InvalidInputException from moto.organizations.exceptions import (
InvalidInputException,
DuplicateOrganizationalUnitException,
)
class FakeOrganization(BaseModel): class FakeOrganization(BaseModel):
@ -222,6 +225,14 @@ class OrganizationsBackend(BaseBackend):
self.attach_policy(PolicyId=utils.DEFAULT_POLICY_ID, TargetId=new_ou.id) self.attach_policy(PolicyId=utils.DEFAULT_POLICY_ID, TargetId=new_ou.id)
return new_ou.describe() return new_ou.describe()
def update_organizational_unit(self, **kwargs):
for ou in self.ou:
if ou.name == kwargs["Name"]:
raise DuplicateOrganizationalUnitException
ou = self.get_organizational_unit_by_id(kwargs["OrganizationalUnitId"])
ou.name = kwargs["Name"]
return ou.describe()
def get_organizational_unit_by_id(self, ou_id): def get_organizational_unit_by_id(self, ou_id):
ou = next((ou for ou in self.ou if ou.id == ou_id), None) ou = next((ou for ou in self.ou if ou.id == ou_id), None)
if ou is None: if ou is None:

View File

@ -36,6 +36,11 @@ class OrganizationsResponse(BaseResponse):
self.organizations_backend.create_organizational_unit(**self.request_params) self.organizations_backend.create_organizational_unit(**self.request_params)
) )
def update_organizational_unit(self):
return json.dumps(
self.organizations_backend.update_organizational_unit(**self.request_params)
)
def describe_organizational_unit(self): def describe_organizational_unit(self):
return json.dumps( return json.dumps(
self.organizations_backend.describe_organizational_unit( self.organizations_backend.describe_organizational_unit(

View File

@ -130,7 +130,9 @@ class Database(BaseModel):
if not self.option_group_name and self.engine in self.default_option_groups: if not self.option_group_name and self.engine in self.default_option_groups:
self.option_group_name = self.default_option_groups[self.engine] self.option_group_name = self.default_option_groups[self.engine]
self.character_set_name = kwargs.get("character_set_name", None) self.character_set_name = kwargs.get("character_set_name", None)
self.iam_database_authentication_enabled = False self.enable_iam_database_authentication = kwargs.get(
"enable_iam_database_authentication", False
)
self.dbi_resource_id = "db-M5ENSHXFPU6XHZ4G4ZEI5QIO2U" self.dbi_resource_id = "db-M5ENSHXFPU6XHZ4G4ZEI5QIO2U"
self.tags = kwargs.get("tags", []) self.tags = kwargs.get("tags", [])
@ -214,7 +216,7 @@ class Database(BaseModel):
<ReadReplicaSourceDBInstanceIdentifier>{{ database.source_db_identifier }}</ReadReplicaSourceDBInstanceIdentifier> <ReadReplicaSourceDBInstanceIdentifier>{{ database.source_db_identifier }}</ReadReplicaSourceDBInstanceIdentifier>
{% endif %} {% endif %}
<Engine>{{ database.engine }}</Engine> <Engine>{{ database.engine }}</Engine>
<IAMDatabaseAuthenticationEnabled>{{database.iam_database_authentication_enabled }}</IAMDatabaseAuthenticationEnabled> <IAMDatabaseAuthenticationEnabled>{{database.enable_iam_database_authentication|lower }}</IAMDatabaseAuthenticationEnabled>
<LicenseModel>{{ database.license_model }}</LicenseModel> <LicenseModel>{{ database.license_model }}</LicenseModel>
<EngineVersion>{{ database.engine_version }}</EngineVersion> <EngineVersion>{{ database.engine_version }}</EngineVersion>
<OptionGroupMemberships> <OptionGroupMemberships>
@ -542,7 +544,7 @@ class Snapshot(BaseModel):
<KmsKeyId>{{ database.kms_key_id }}</KmsKeyId> <KmsKeyId>{{ database.kms_key_id }}</KmsKeyId>
<DBSnapshotArn>{{ snapshot.snapshot_arn }}</DBSnapshotArn> <DBSnapshotArn>{{ snapshot.snapshot_arn }}</DBSnapshotArn>
<Timezone></Timezone> <Timezone></Timezone>
<IAMDatabaseAuthenticationEnabled>false</IAMDatabaseAuthenticationEnabled> <IAMDatabaseAuthenticationEnabled>{{ database.enable_iam_database_authentication|lower }}</IAMDatabaseAuthenticationEnabled>
</DBSnapshot>""" </DBSnapshot>"""
) )
return template.render(snapshot=self, database=self.database) return template.render(snapshot=self, database=self.database)
@ -986,7 +988,7 @@ class RDS2Backend(BaseBackend):
) )
if option_group_kwargs["engine_name"] not in valid_option_group_engines.keys(): if option_group_kwargs["engine_name"] not in valid_option_group_engines.keys():
raise RDSClientError( raise RDSClientError(
"InvalidParameterValue", "Invalid DB engine: non-existant" "InvalidParameterValue", "Invalid DB engine: non-existent"
) )
if ( if (
option_group_kwargs["major_engine_version"] option_group_kwargs["major_engine_version"]

View File

@ -27,6 +27,9 @@ class RDS2Response(BaseResponse):
"db_subnet_group_name": self._get_param("DBSubnetGroupName"), "db_subnet_group_name": self._get_param("DBSubnetGroupName"),
"engine": self._get_param("Engine"), "engine": self._get_param("Engine"),
"engine_version": self._get_param("EngineVersion"), "engine_version": self._get_param("EngineVersion"),
"enable_iam_database_authentication": self._get_bool_param(
"EnableIAMDatabaseAuthentication"
),
"license_model": self._get_param("LicenseModel"), "license_model": self._get_param("LicenseModel"),
"iops": self._get_int_param("Iops"), "iops": self._get_int_param("Iops"),
"kms_key_id": self._get_param("KmsKeyId"), "kms_key_id": self._get_param("KmsKeyId"),
@ -367,14 +370,14 @@ class RDS2Response(BaseResponse):
def modify_db_parameter_group(self): def modify_db_parameter_group(self):
db_parameter_group_name = self._get_param("DBParameterGroupName") db_parameter_group_name = self._get_param("DBParameterGroupName")
db_parameter_group_parameters = self._get_db_parameter_group_paramters() db_parameter_group_parameters = self._get_db_parameter_group_parameters()
db_parameter_group = self.backend.modify_db_parameter_group( db_parameter_group = self.backend.modify_db_parameter_group(
db_parameter_group_name, db_parameter_group_parameters db_parameter_group_name, db_parameter_group_parameters
) )
template = self.response_template(MODIFY_DB_PARAMETER_GROUP_TEMPLATE) template = self.response_template(MODIFY_DB_PARAMETER_GROUP_TEMPLATE)
return template.render(db_parameter_group=db_parameter_group) return template.render(db_parameter_group=db_parameter_group)
def _get_db_parameter_group_paramters(self): def _get_db_parameter_group_parameters(self):
parameter_group_parameters = defaultdict(dict) parameter_group_parameters = defaultdict(dict)
for param_name, value in self.querystring.items(): for param_name, value in self.querystring.items():
if not param_name.startswith("Parameters.Parameter"): if not param_name.startswith("Parameters.Parameter"):

View File

@ -271,6 +271,7 @@ LIST_RRSET_RESPONSE = """<ListResourceRecordSetsResponse xmlns="https://route53.
{{ record_set.to_xml() }} {{ record_set.to_xml() }}
{% endfor %} {% endfor %}
</ResourceRecordSets> </ResourceRecordSets>
<IsTruncated>false</IsTruncated>
</ListResourceRecordSetsResponse>""" </ListResourceRecordSetsResponse>"""
CHANGE_RRSET_RESPONSE = """<ChangeResourceRecordSetsResponse xmlns="https://route53.amazonaws.com/doc/2012-12-12/"> CHANGE_RRSET_RESPONSE = """<ChangeResourceRecordSetsResponse xmlns="https://route53.amazonaws.com/doc/2012-12-12/">

View File

@ -1,8 +1,13 @@
import datetime
import json import json
import time
from boto3 import Session
from moto.core.exceptions import InvalidNextTokenException from moto.core.exceptions import InvalidNextTokenException
from moto.core.models import ConfigQueryModel from moto.core.models import ConfigQueryModel
from moto.s3 import s3_backends from moto.s3 import s3_backends
from moto.s3.models import get_moto_s3_account_id
class S3ConfigQuery(ConfigQueryModel): class S3ConfigQuery(ConfigQueryModel):
@ -118,4 +123,146 @@ class S3ConfigQuery(ConfigQueryModel):
return config_data return config_data
class S3AccountPublicAccessBlockConfigQuery(ConfigQueryModel):
def list_config_service_resources(
self,
resource_ids,
resource_name,
limit,
next_token,
backend_region=None,
resource_region=None,
):
# For the Account Public Access Block, they are the same for all regions. The resource ID is the AWS account ID
# There is no resource name -- it should be a blank string "" if provided.
# The resource name can only ever be None or an empty string:
if resource_name is not None and resource_name != "":
return [], None
pab = None
account_id = get_moto_s3_account_id()
regions = [region for region in Session().get_available_regions("config")]
# If a resource ID was passed in, then filter accordingly:
if resource_ids:
for id in resource_ids:
if account_id == id:
pab = self.backends["global"].account_public_access_block
break
# Otherwise, just grab the one from the backend:
if not resource_ids:
pab = self.backends["global"].account_public_access_block
# If it's not present, then return nothing
if not pab:
return [], None
# Filter on regions (and paginate on them as well):
if backend_region:
pab_list = [backend_region]
elif resource_region:
# Invalid region?
if resource_region not in regions:
return [], None
pab_list = [resource_region]
# Aggregated query where no regions were supplied so return them all:
else:
pab_list = regions
# Pagination logic:
sorted_regions = sorted(pab_list)
new_token = None
# Get the start:
if not next_token:
start = 0
else:
# Tokens for this moto feature is just the region-name:
# For OTHER non-global resource types, it's the region concatenated with the resource ID.
if next_token not in sorted_regions:
raise InvalidNextTokenException()
start = sorted_regions.index(next_token)
# Get the list of items to collect:
pab_list = sorted_regions[start : (start + limit)]
if len(sorted_regions) > (start + limit):
new_token = sorted_regions[start + limit]
return (
[
{
"type": "AWS::S3::AccountPublicAccessBlock",
"id": account_id,
"region": region,
}
for region in pab_list
],
new_token,
)
def get_config_resource(
self, resource_id, resource_name=None, backend_region=None, resource_region=None
):
# Do we even have this defined?
if not self.backends["global"].account_public_access_block:
return None
# Resource name can only ever be "" if it's supplied:
if resource_name is not None and resource_name != "":
return None
# Are we filtering based on region?
account_id = get_moto_s3_account_id()
regions = [region for region in Session().get_available_regions("config")]
# Is the resource ID correct?:
if account_id == resource_id:
if backend_region:
pab_region = backend_region
# Invalid region?
elif resource_region not in regions:
return None
else:
pab_region = resource_region
else:
return None
# Format the PAB to the AWS Config format:
creation_time = datetime.datetime.utcnow()
config_data = {
"version": "1.3",
"accountId": account_id,
"configurationItemCaptureTime": str(creation_time),
"configurationItemStatus": "OK",
"configurationStateId": str(
int(time.mktime(creation_time.timetuple()))
), # PY2 and 3 compatible
"resourceType": "AWS::S3::AccountPublicAccessBlock",
"resourceId": account_id,
"awsRegion": pab_region,
"availabilityZone": "Not Applicable",
"configuration": self.backends[
"global"
].account_public_access_block.to_config_dict(),
"supplementaryConfiguration": {},
}
# The 'configuration' field is also a JSON string:
config_data["configuration"] = json.dumps(config_data["configuration"])
return config_data
s3_config_query = S3ConfigQuery(s3_backends) s3_config_query = S3ConfigQuery(s3_backends)
s3_account_public_access_block_query = S3AccountPublicAccessBlockConfigQuery(
s3_backends
)

View File

@ -127,6 +127,18 @@ class InvalidRequest(S3ClientError):
) )
class IllegalLocationConstraintException(S3ClientError):
code = 400
def __init__(self, *args, **kwargs):
super(IllegalLocationConstraintException, self).__init__(
"IllegalLocationConstraintException",
"The unspecified location constraint is incompatible for the region specific endpoint this request was sent to.",
*args,
**kwargs
)
class MalformedXML(S3ClientError): class MalformedXML(S3ClientError):
code = 400 code = 400
@ -347,3 +359,12 @@ class InvalidPublicAccessBlockConfiguration(S3ClientError):
*args, *args,
**kwargs **kwargs
) )
class WrongPublicAccessBlockAccountIdError(S3ClientError):
code = 403
def __init__(self):
super(WrongPublicAccessBlockAccountIdError, self).__init__(
"AccessDenied", "Access Denied"
)

View File

@ -19,7 +19,7 @@ import uuid
import six import six
from bisect import insort from bisect import insort
from moto.core import BaseBackend, BaseModel from moto.core import ACCOUNT_ID, BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_with_milliseconds, rfc_1123_datetime from moto.core.utils import iso_8601_datetime_with_milliseconds, rfc_1123_datetime
from .exceptions import ( from .exceptions import (
BucketAlreadyExists, BucketAlreadyExists,
@ -37,6 +37,7 @@ from .exceptions import (
CrossLocationLoggingProhibitted, CrossLocationLoggingProhibitted,
NoSuchPublicAccessBlockConfiguration, NoSuchPublicAccessBlockConfiguration,
InvalidPublicAccessBlockConfiguration, InvalidPublicAccessBlockConfiguration,
WrongPublicAccessBlockAccountIdError,
) )
from .utils import clean_key_name, _VersionedKeyStore from .utils import clean_key_name, _VersionedKeyStore
@ -58,6 +59,13 @@ DEFAULT_TEXT_ENCODING = sys.getdefaultencoding()
OWNER = "75aa57f09aa0c8caeab4f8c24e99d10f8e7faeebf76c078efc7c6caea54ba06a" OWNER = "75aa57f09aa0c8caeab4f8c24e99d10f8e7faeebf76c078efc7c6caea54ba06a"
def get_moto_s3_account_id():
"""This makes it easy for mocking AWS Account IDs when using AWS Config
-- Simply mock.patch the ACCOUNT_ID here, and Config gets it for free.
"""
return ACCOUNT_ID
class FakeDeleteMarker(BaseModel): class FakeDeleteMarker(BaseModel):
def __init__(self, key): def __init__(self, key):
self.key = key self.key = key
@ -1163,6 +1171,7 @@ class FakeBucket(BaseModel):
class S3Backend(BaseBackend): class S3Backend(BaseBackend):
def __init__(self): def __init__(self):
self.buckets = {} self.buckets = {}
self.account_public_access_block = None
def create_bucket(self, bucket_name, region_name): def create_bucket(self, bucket_name, region_name):
if bucket_name in self.buckets: if bucket_name in self.buckets:
@ -1264,6 +1273,16 @@ class S3Backend(BaseBackend):
return bucket.public_access_block return bucket.public_access_block
def get_account_public_access_block(self, account_id):
# The account ID should equal the account id that is set for Moto:
if account_id != ACCOUNT_ID:
raise WrongPublicAccessBlockAccountIdError()
if not self.account_public_access_block:
raise NoSuchPublicAccessBlockConfiguration()
return self.account_public_access_block
def set_key( def set_key(
self, bucket_name, key_name, value, storage=None, etag=None, multipart=None self, bucket_name, key_name, value, storage=None, etag=None, multipart=None
): ):
@ -1356,6 +1375,13 @@ class S3Backend(BaseBackend):
bucket = self.get_bucket(bucket_name) bucket = self.get_bucket(bucket_name)
bucket.public_access_block = None bucket.public_access_block = None
def delete_account_public_access_block(self, account_id):
# The account ID should equal the account id that is set for Moto:
if account_id != ACCOUNT_ID:
raise WrongPublicAccessBlockAccountIdError()
self.account_public_access_block = None
def put_bucket_notification_configuration(self, bucket_name, notification_config): def put_bucket_notification_configuration(self, bucket_name, notification_config):
bucket = self.get_bucket(bucket_name) bucket = self.get_bucket(bucket_name)
bucket.set_notification_configuration(notification_config) bucket.set_notification_configuration(notification_config)
@ -1384,6 +1410,21 @@ class S3Backend(BaseBackend):
pub_block_config.get("RestrictPublicBuckets"), pub_block_config.get("RestrictPublicBuckets"),
) )
def put_account_public_access_block(self, account_id, pub_block_config):
# The account ID should equal the account id that is set for Moto:
if account_id != ACCOUNT_ID:
raise WrongPublicAccessBlockAccountIdError()
if not pub_block_config:
raise InvalidPublicAccessBlockConfiguration()
self.account_public_access_block = PublicAccessBlock(
pub_block_config.get("BlockPublicAcls"),
pub_block_config.get("IgnorePublicAcls"),
pub_block_config.get("BlockPublicPolicy"),
pub_block_config.get("RestrictPublicBuckets"),
)
def initiate_multipart(self, bucket_name, key_name, metadata): def initiate_multipart(self, bucket_name, key_name, metadata):
bucket = self.get_bucket(bucket_name) bucket = self.get_bucket(bucket_name)
new_multipart = FakeMultipart(key_name, metadata) new_multipart = FakeMultipart(key_name, metadata)

View File

@ -4,6 +4,7 @@ import re
import sys import sys
import six import six
from botocore.awsrequest import AWSPreparedRequest
from moto.core.utils import str_to_rfc_1123_datetime, py2_strip_unicode_keys from moto.core.utils import str_to_rfc_1123_datetime, py2_strip_unicode_keys
from six.moves.urllib.parse import parse_qs, urlparse, unquote from six.moves.urllib.parse import parse_qs, urlparse, unquote
@ -29,6 +30,7 @@ from .exceptions import (
InvalidPartOrder, InvalidPartOrder,
MalformedXML, MalformedXML,
MalformedACLError, MalformedACLError,
IllegalLocationConstraintException,
InvalidNotificationARN, InvalidNotificationARN,
InvalidNotificationEvent, InvalidNotificationEvent,
ObjectNotInActiveTierError, ObjectNotInActiveTierError,
@ -122,6 +124,11 @@ ACTION_MAP = {
"uploadId": "PutObject", "uploadId": "PutObject",
}, },
}, },
"CONTROL": {
"GET": {"publicAccessBlock": "GetPublicAccessBlock"},
"PUT": {"publicAccessBlock": "PutPublicAccessBlock"},
"DELETE": {"publicAccessBlock": "DeletePublicAccessBlock"},
},
} }
@ -167,7 +174,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
or host.startswith("localhost") or host.startswith("localhost")
or host.startswith("localstack") or host.startswith("localstack")
or re.match(r"^[^.]+$", host) or re.match(r"^[^.]+$", host)
or re.match(r"^.*\.svc\.cluster\.local$", host) or re.match(r"^.*\.svc\.cluster\.local:?\d*$", host)
): ):
# Default to path-based buckets for (1) localhost, (2) localstack hosts (e.g. localstack.dev), # Default to path-based buckets for (1) localhost, (2) localstack hosts (e.g. localstack.dev),
# (3) local host names that do not contain a "." (e.g., Docker container host names), or # (3) local host names that do not contain a "." (e.g., Docker container host names), or
@ -219,7 +226,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
# Depending on which calling format the client is using, we don't know # Depending on which calling format the client is using, we don't know
# if this is a bucket or key request so we have to check # if this is a bucket or key request so we have to check
if self.subdomain_based_buckets(request): if self.subdomain_based_buckets(request):
return self.key_response(request, full_url, headers) return self.key_or_control_response(request, full_url, headers)
else: else:
# Using path-based buckets # Using path-based buckets
return self.bucket_response(request, full_url, headers) return self.bucket_response(request, full_url, headers)
@ -286,7 +293,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return self._bucket_response_post(request, body, bucket_name) return self._bucket_response_post(request, body, bucket_name)
else: else:
raise NotImplementedError( raise NotImplementedError(
"Method {0} has not been impelemented in the S3 backend yet".format( "Method {0} has not been implemented in the S3 backend yet".format(
method method
) )
) )
@ -585,6 +592,29 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
next_continuation_token = None next_continuation_token = None
return result_keys, is_truncated, next_continuation_token return result_keys, is_truncated, next_continuation_token
def _body_contains_location_constraint(self, body):
if body:
try:
xmltodict.parse(body)["CreateBucketConfiguration"]["LocationConstraint"]
return True
except KeyError:
pass
return False
def _parse_pab_config(self, body):
parsed_xml = xmltodict.parse(body)
parsed_xml["PublicAccessBlockConfiguration"].pop("@xmlns", None)
# If Python 2, fix the unicode strings:
if sys.version_info[0] < 3:
parsed_xml = {
"PublicAccessBlockConfiguration": py2_strip_unicode_keys(
dict(parsed_xml["PublicAccessBlockConfiguration"])
)
}
return parsed_xml
def _bucket_response_put( def _bucket_response_put(
self, request, body, region_name, bucket_name, querystring self, request, body, region_name, bucket_name, querystring
): ):
@ -663,27 +693,23 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
raise e raise e
elif "publicAccessBlock" in querystring: elif "publicAccessBlock" in querystring:
parsed_xml = xmltodict.parse(body) pab_config = self._parse_pab_config(body)
parsed_xml["PublicAccessBlockConfiguration"].pop("@xmlns", None)
# If Python 2, fix the unicode strings:
if sys.version_info[0] < 3:
parsed_xml = {
"PublicAccessBlockConfiguration": py2_strip_unicode_keys(
dict(parsed_xml["PublicAccessBlockConfiguration"])
)
}
self.backend.put_bucket_public_access_block( self.backend.put_bucket_public_access_block(
bucket_name, parsed_xml["PublicAccessBlockConfiguration"] bucket_name, pab_config["PublicAccessBlockConfiguration"]
) )
return "" return ""
else: else:
# us-east-1, the default AWS region behaves a bit differently
# - you should not use it as a location constraint --> it fails
# - querying the location constraint returns None
# - LocationConstraint has to be specified if outside us-east-1
if (
region_name != DEFAULT_REGION_NAME
and not self._body_contains_location_constraint(body)
):
raise IllegalLocationConstraintException()
if body: if body:
# us-east-1, the default AWS region behaves a bit differently
# - you should not use it as a location constraint --> it fails
# - querying the location constraint returns None
try: try:
forced_region = xmltodict.parse(body)["CreateBucketConfiguration"][ forced_region = xmltodict.parse(body)["CreateBucketConfiguration"][
"LocationConstraint" "LocationConstraint"
@ -854,15 +880,21 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
) )
return 206, response_headers, response_content[begin : end + 1] return 206, response_headers, response_content[begin : end + 1]
def key_response(self, request, full_url, headers): def key_or_control_response(self, request, full_url, headers):
# Key and Control are lumped in because splitting out the regex is too much of a pain :/
self.method = request.method self.method = request.method
self.path = self._get_path(request) self.path = self._get_path(request)
self.headers = request.headers self.headers = request.headers
if "host" not in self.headers: if "host" not in self.headers:
self.headers["host"] = urlparse(full_url).netloc self.headers["host"] = urlparse(full_url).netloc
response_headers = {} response_headers = {}
try: try:
response = self._key_response(request, full_url, headers) # Is this an S3 control response?
if isinstance(request, AWSPreparedRequest) and "s3-control" in request.url:
response = self._control_response(request, full_url, headers)
else:
response = self._key_response(request, full_url, headers)
except S3ClientError as s3error: except S3ClientError as s3error:
response = s3error.code, {}, s3error.description response = s3error.code, {}, s3error.description
@ -878,6 +910,94 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
) )
return status_code, response_headers, response_content return status_code, response_headers, response_content
def _control_response(self, request, full_url, headers):
parsed_url = urlparse(full_url)
query = parse_qs(parsed_url.query, keep_blank_values=True)
method = request.method
if hasattr(request, "body"):
# Boto
body = request.body
if hasattr(body, "read"):
body = body.read()
else:
# Flask server
body = request.data
if body is None:
body = b""
if method == "GET":
return self._control_response_get(request, query, headers)
elif method == "PUT":
return self._control_response_put(request, body, query, headers)
elif method == "DELETE":
return self._control_response_delete(request, query, headers)
else:
raise NotImplementedError(
"Method {0} has not been implemented in the S3 backend yet".format(
method
)
)
def _control_response_get(self, request, query, headers):
action = self.path.split("?")[0].split("/")[
-1
] # Gets the action out of the URL sans query params.
self._set_action("CONTROL", "GET", action)
self._authenticate_and_authorize_s3_action()
response_headers = {}
if "publicAccessBlock" in action:
public_block_config = self.backend.get_account_public_access_block(
headers["x-amz-account-id"]
)
template = self.response_template(S3_PUBLIC_ACCESS_BLOCK_CONFIGURATION)
return (
200,
response_headers,
template.render(public_block_config=public_block_config),
)
raise NotImplementedError(
"Method {0} has not been implemented in the S3 backend yet".format(action)
)
def _control_response_put(self, request, body, query, headers):
action = self.path.split("?")[0].split("/")[
-1
] # Gets the action out of the URL sans query params.
self._set_action("CONTROL", "PUT", action)
self._authenticate_and_authorize_s3_action()
response_headers = {}
if "publicAccessBlock" in action:
pab_config = self._parse_pab_config(body)
self.backend.put_account_public_access_block(
headers["x-amz-account-id"],
pab_config["PublicAccessBlockConfiguration"],
)
return 200, response_headers, ""
raise NotImplementedError(
"Method {0} has not been implemented in the S3 backend yet".format(action)
)
def _control_response_delete(self, request, query, headers):
action = self.path.split("?")[0].split("/")[
-1
] # Gets the action out of the URL sans query params.
self._set_action("CONTROL", "DELETE", action)
self._authenticate_and_authorize_s3_action()
response_headers = {}
if "publicAccessBlock" in action:
self.backend.delete_account_public_access_block(headers["x-amz-account-id"])
return 200, response_headers, ""
raise NotImplementedError(
"Method {0} has not been implemented in the S3 backend yet".format(action)
)
def _key_response(self, request, full_url, headers): def _key_response(self, request, full_url, headers):
parsed_url = urlparse(full_url) parsed_url = urlparse(full_url)
query = parse_qs(parsed_url.query, keep_blank_values=True) query = parse_qs(parsed_url.query, keep_blank_values=True)
@ -1082,6 +1202,10 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
if mdirective is not None and mdirective == "REPLACE": if mdirective is not None and mdirective == "REPLACE":
metadata = metadata_from_headers(request.headers) metadata = metadata_from_headers(request.headers)
new_key.set_metadata(metadata, replace=True) new_key.set_metadata(metadata, replace=True)
tdirective = request.headers.get("x-amz-tagging-directive")
if tdirective == "REPLACE":
tagging = self._tagging_from_headers(request.headers)
new_key.set_tagging(tagging)
template = self.response_template(S3_OBJECT_COPY_RESPONSE) template = self.response_template(S3_OBJECT_COPY_RESPONSE)
response_headers.update(new_key.response_dict) response_headers.update(new_key.response_dict)
return 200, response_headers, template.render(key=new_key) return 200, response_headers, template.render(key=new_key)
@ -1482,7 +1606,7 @@ S3_ALL_BUCKETS = """<ListAllMyBucketsResult xmlns="http://s3.amazonaws.com/doc/2
{% for bucket in buckets %} {% for bucket in buckets %}
<Bucket> <Bucket>
<Name>{{ bucket.name }}</Name> <Name>{{ bucket.name }}</Name>
<CreationDate>{{ bucket.creation_date }}</CreationDate> <CreationDate>{{ bucket.creation_date.isoformat() }}</CreationDate>
</Bucket> </Bucket>
{% endfor %} {% endfor %}
</Buckets> </Buckets>
@ -1869,7 +1993,6 @@ S3_MULTIPART_LIST_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<ID>75aa57f09aa0c8caeab4f8c24e99d10f8e7faeebf76c078efc7c6caea54ba06a</ID> <ID>75aa57f09aa0c8caeab4f8c24e99d10f8e7faeebf76c078efc7c6caea54ba06a</ID>
<DisplayName>webfile</DisplayName> <DisplayName>webfile</DisplayName>
</Owner> </Owner>
<StorageClass>STANDARD</StorageClass>
<PartNumberMarker>1</PartNumberMarker> <PartNumberMarker>1</PartNumberMarker>
<NextPartNumberMarker>{{ count }}</NextPartNumberMarker> <NextPartNumberMarker>{{ count }}</NextPartNumberMarker>
<MaxParts>{{ count }}</MaxParts> <MaxParts>{{ count }}</MaxParts>

View File

@ -13,7 +13,7 @@ url_paths = {
# subdomain key of path-based bucket # subdomain key of path-based bucket
"{0}/(?P<key_or_bucket_name>[^/]+)/?$": S3ResponseInstance.ambiguous_response, "{0}/(?P<key_or_bucket_name>[^/]+)/?$": S3ResponseInstance.ambiguous_response,
# path-based bucket + key # path-based bucket + key
"{0}/(?P<bucket_name_path>[^/]+)/(?P<key_name>.+)": S3ResponseInstance.key_response, "{0}/(?P<bucket_name_path>[^/]+)/(?P<key_name>.+)": S3ResponseInstance.key_or_control_response,
# subdomain bucket + key with empty first part of path # subdomain bucket + key with empty first part of path
"{0}//(?P<key_name>.*)$": S3ResponseInstance.key_response, "{0}//(?P<key_name>.*)$": S3ResponseInstance.key_or_control_response,
} }

View File

@ -37,7 +37,7 @@ def bucket_name_from_url(url):
REGION_URL_REGEX = re.compile( REGION_URL_REGEX = re.compile(
r"^https?://(s3[-\.](?P<region1>.+)\.amazonaws\.com/(.+)|" r"^https?://(s3[-\.](?P<region1>.+)\.amazonaws\.com/(.+)|"
r"(.+)\.s3-(?P<region2>.+)\.amazonaws\.com)/?" r"(.+)\.s3[-\.](?P<region2>.+)\.amazonaws\.com)/?"
) )

View File

@ -148,11 +148,15 @@ class SESBackend(BaseBackend):
def __type_of_message__(self, destinations): def __type_of_message__(self, destinations):
"""Checks the destination for any special address that could indicate delivery, """Checks the destination for any special address that could indicate delivery,
complaint or bounce like in SES simulator""" complaint or bounce like in SES simulator"""
alladdress = ( if isinstance(destinations, list):
destinations.get("ToAddresses", []) alladdress = destinations
+ destinations.get("CcAddresses", []) else:
+ destinations.get("BccAddresses", []) alladdress = (
) destinations.get("ToAddresses", [])
+ destinations.get("CcAddresses", [])
+ destinations.get("BccAddresses", [])
)
for addr in alladdress: for addr in alladdress:
if SESFeedback.SUCCESS_ADDR in addr: if SESFeedback.SUCCESS_ADDR in addr:
return SESFeedback.DELIVERY return SESFeedback.DELIVERY

View File

@ -99,3 +99,28 @@ class InvalidAttributeName(RESTError):
super(InvalidAttributeName, self).__init__( super(InvalidAttributeName, self).__init__(
"InvalidAttributeName", "Unknown Attribute {}.".format(attribute_name) "InvalidAttributeName", "Unknown Attribute {}.".format(attribute_name)
) )
class InvalidParameterValue(RESTError):
code = 400
def __init__(self, message):
super(InvalidParameterValue, self).__init__("InvalidParameterValue", message)
class MissingParameter(RESTError):
code = 400
def __init__(self):
super(MissingParameter, self).__init__(
"MissingParameter", "The request must contain the parameter Actions."
)
class OverLimit(RESTError):
code = 403
def __init__(self, count):
super(OverLimit, self).__init__(
"OverLimit", "{} Actions were found, maximum allowed is 7.".format(count)
)

View File

@ -30,6 +30,9 @@ from .exceptions import (
BatchEntryIdsNotDistinct, BatchEntryIdsNotDistinct,
TooManyEntriesInBatchRequest, TooManyEntriesInBatchRequest,
InvalidAttributeName, InvalidAttributeName,
InvalidParameterValue,
MissingParameter,
OverLimit,
) )
from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID
@ -183,6 +186,8 @@ class Queue(BaseModel):
"MaximumMessageSize", "MaximumMessageSize",
"MessageRetentionPeriod", "MessageRetentionPeriod",
"QueueArn", "QueueArn",
"Policy",
"RedrivePolicy",
"ReceiveMessageWaitTimeSeconds", "ReceiveMessageWaitTimeSeconds",
"VisibilityTimeout", "VisibilityTimeout",
] ]
@ -194,6 +199,8 @@ class Queue(BaseModel):
"DeleteMessage", "DeleteMessage",
"GetQueueAttributes", "GetQueueAttributes",
"GetQueueUrl", "GetQueueUrl",
"ListDeadLetterSourceQueues",
"PurgeQueue",
"ReceiveMessage", "ReceiveMessage",
"SendMessage", "SendMessage",
) )
@ -272,7 +279,7 @@ class Queue(BaseModel):
if key in bool_fields: if key in bool_fields:
value = value == "true" value = value == "true"
if key == "RedrivePolicy" and value is not None: if key in ["Policy", "RedrivePolicy"] and value is not None:
continue continue
setattr(self, camelcase_to_underscores(key), value) setattr(self, camelcase_to_underscores(key), value)
@ -280,6 +287,9 @@ class Queue(BaseModel):
if attributes.get("RedrivePolicy", None): if attributes.get("RedrivePolicy", None):
self._setup_dlq(attributes["RedrivePolicy"]) self._setup_dlq(attributes["RedrivePolicy"])
if attributes.get("Policy"):
self.policy = attributes["Policy"]
self.last_modified_timestamp = now self.last_modified_timestamp = now
def _setup_dlq(self, policy): def _setup_dlq(self, policy):
@ -471,6 +481,24 @@ class Queue(BaseModel):
return self.name return self.name
raise UnformattedGetAttTemplateException() raise UnformattedGetAttTemplateException()
@property
def policy(self):
if self._policy_json.get("Statement"):
return json.dumps(self._policy_json)
else:
return None
@policy.setter
def policy(self, policy):
if policy:
self._policy_json = json.loads(policy)
else:
self._policy_json = {
"Version": "2012-10-17",
"Id": "{}/SQSDefaultPolicy".format(self.queue_arn),
"Statement": [],
}
class SQSBackend(BaseBackend): class SQSBackend(BaseBackend):
def __init__(self, region_name): def __init__(self, region_name):
@ -539,7 +567,7 @@ class SQSBackend(BaseBackend):
for name, q in self.queues.items(): for name, q in self.queues.items():
if prefix_re.search(name): if prefix_re.search(name):
qs.append(q) qs.append(q)
return qs return qs[:1000]
def get_queue(self, queue_name): def get_queue(self, queue_name):
queue = self.queues.get(queue_name) queue = self.queues.get(queue_name)
@ -801,25 +829,75 @@ class SQSBackend(BaseBackend):
def add_permission(self, queue_name, actions, account_ids, label): def add_permission(self, queue_name, actions, account_ids, label):
queue = self.get_queue(queue_name) queue = self.get_queue(queue_name)
if actions is None or len(actions) == 0: if not actions:
raise RESTError("InvalidParameterValue", "Need at least one Action") raise MissingParameter()
if account_ids is None or len(account_ids) == 0:
raise RESTError("InvalidParameterValue", "Need at least one Account ID")
if not all([item in Queue.ALLOWED_PERMISSIONS for item in actions]): if not account_ids:
raise RESTError("InvalidParameterValue", "Invalid permissions") raise InvalidParameterValue(
"Value [] for parameter PrincipalId is invalid. Reason: Unable to verify."
)
queue.permissions[label] = (account_ids, actions) count = len(actions)
if count > 7:
raise OverLimit(count)
invalid_action = next(
(action for action in actions if action not in Queue.ALLOWED_PERMISSIONS),
None,
)
if invalid_action:
raise InvalidParameterValue(
"Value SQS:{} for parameter ActionName is invalid. "
"Reason: Only the queue owner is allowed to invoke this action.".format(
invalid_action
)
)
policy = queue._policy_json
statement = next(
(
statement
for statement in policy["Statement"]
if statement["Sid"] == label
),
None,
)
if statement:
raise InvalidParameterValue(
"Value {} for parameter Label is invalid. "
"Reason: Already exists.".format(label)
)
principals = [
"arn:aws:iam::{}:root".format(account_id) for account_id in account_ids
]
actions = ["SQS:{}".format(action) for action in actions]
statement = {
"Sid": label,
"Effect": "Allow",
"Principal": {"AWS": principals[0] if len(principals) == 1 else principals},
"Action": actions[0] if len(actions) == 1 else actions,
"Resource": queue.queue_arn,
}
queue._policy_json["Statement"].append(statement)
def remove_permission(self, queue_name, label): def remove_permission(self, queue_name, label):
queue = self.get_queue(queue_name) queue = self.get_queue(queue_name)
if label not in queue.permissions: statements = queue._policy_json["Statement"]
raise RESTError( statements_new = [
"InvalidParameterValue", "Permission doesnt exist for the given label" statement for statement in statements if statement["Sid"] != label
]
if len(statements) == len(statements_new):
raise InvalidParameterValue(
"Value {} for parameter Label is invalid. "
"Reason: can't find label on existing policy.".format(label)
) )
del queue.permissions[label] queue._policy_json["Statement"] = statements_new
def tag_queue(self, queue_name, tags): def tag_queue(self, queue_name, tags):
queue = self.get_queue(queue_name) queue = self.get_queue(queue_name)

View File

@ -127,6 +127,10 @@ class WorkflowExecution(BaseModel):
"executionInfo": self.to_medium_dict(), "executionInfo": self.to_medium_dict(),
"executionConfiguration": {"taskList": {"name": self.task_list}}, "executionConfiguration": {"taskList": {"name": self.task_list}},
} }
# info
if self.execution_status == "CLOSED":
hsh["executionInfo"]["closeStatus"] = self.close_status
hsh["executionInfo"]["closeTimestamp"] = self.close_timestamp
# configuration # configuration
for key in self._configuration_keys: for key in self._configuration_keys:
attr = camelcase_to_underscores(key) attr = camelcase_to_underscores(key)

View File

@ -1,5 +1,5 @@
class TaggingService: class TaggingService:
def __init__(self, tagName='Tags', keyName='Key', valueName='Value'): def __init__(self, tagName="Tags", keyName="Key", valueName="Value"):
self.tagName = tagName self.tagName = tagName
self.keyName = keyName self.keyName = keyName
self.valueName = valueName self.valueName = valueName
@ -12,6 +12,12 @@ class TaggingService:
result.append({self.keyName: k, self.valueName: v}) result.append({self.keyName: k, self.valueName: v})
return {self.tagName: result} return {self.tagName: result}
def delete_all_tags_for_resource(self, arn):
del self.tags[arn]
def has_tags(self, arn):
return arn in self.tags
def tag_resource(self, arn, tags): def tag_resource(self, arn, tags):
if arn not in self.tags: if arn not in self.tags:
self.tags[arn] = {} self.tags[arn] = {}

View File

@ -20,8 +20,8 @@ import jinja2
from prompt_toolkit import ( from prompt_toolkit import (
prompt prompt
) )
from prompt_toolkit.contrib.completers import WordCompleter from prompt_toolkit.completion import WordCompleter
from prompt_toolkit.shortcuts import print_tokens from prompt_toolkit.shortcuts import print_formatted_text
from botocore import xform_name from botocore import xform_name
from botocore.session import Session from botocore.session import Session
@ -149,12 +149,12 @@ def append_mock_dict_to_backends_py(service):
with open(path) as f: with open(path) as f:
lines = [_.replace('\n', '') for _ in f.readlines()] lines = [_.replace('\n', '') for _ in f.readlines()]
if any(_ for _ in lines if re.match(".*'{}': {}_backends.*".format(service, service), _)): if any(_ for _ in lines if re.match(".*\"{}\": {}_backends.*".format(service, service), _)):
return return
filtered_lines = [_ for _ in lines if re.match(".*'.*':.*_backends.*", _)] filtered_lines = [_ for _ in lines if re.match(".*\".*\":.*_backends.*", _)]
last_elem_line_index = lines.index(filtered_lines[-1]) last_elem_line_index = lines.index(filtered_lines[-1])
new_line = " '{}': {}_backends,".format(service, get_escaped_service(service)) new_line = " \"{}\": {}_backends,".format(service, get_escaped_service(service))
prev_line = lines[last_elem_line_index] prev_line = lines[last_elem_line_index]
if not prev_line.endswith('{') and not prev_line.endswith(','): if not prev_line.endswith('{') and not prev_line.endswith(','):
lines[last_elem_line_index] += ',' lines[last_elem_line_index] += ','

View File

@ -39,11 +39,11 @@ install_requires = [
"werkzeug", "werkzeug",
"PyYAML>=5.1", "PyYAML>=5.1",
"pytz", "pytz",
"python-dateutil<2.8.1,>=2.1", "python-dateutil<3.0.0,>=2.1",
"python-jose<4.0.0", "python-jose<4.0.0",
"mock", "mock",
"docker>=2.5.1", "docker>=2.5.1",
"jsondiff==1.1.2", "jsondiff>=1.1.2",
"aws-xray-sdk!=0.96,>=0.93", "aws-xray-sdk!=0.96,>=0.93",
"responses>=0.9.0", "responses>=0.9.0",
"idna<2.9,>=2.5", "idna<2.9,>=2.5",

View File

@ -26,7 +26,14 @@ def test_create_and_get_rest_api():
response.pop("ResponseMetadata") response.pop("ResponseMetadata")
response.pop("createdDate") response.pop("createdDate")
response.should.equal( response.should.equal(
{"id": api_id, "name": "my_api", "description": "this is my api"} {
"id": api_id,
"name": "my_api",
"description": "this is my api",
"apiKeySource": "HEADER",
"endpointConfiguration": {"types": ["EDGE"]},
"tags": {},
}
) )
@ -47,6 +54,114 @@ def test_list_and_delete_apis():
len(response["items"]).should.equal(1) len(response["items"]).should.equal(1)
@mock_apigateway
def test_create_rest_api_with_tags():
client = boto3.client("apigateway", region_name="us-west-2")
response = client.create_rest_api(
name="my_api", description="this is my api", tags={"MY_TAG1": "MY_VALUE1"}
)
api_id = response["id"]
response = client.get_rest_api(restApiId=api_id)
assert "tags" in response
response["tags"].should.equal({"MY_TAG1": "MY_VALUE1"})
@mock_apigateway
def test_create_rest_api_invalid_apikeysource():
client = boto3.client("apigateway", region_name="us-west-2")
with assert_raises(ClientError) as ex:
client.create_rest_api(
name="my_api",
description="this is my api",
apiKeySource="not a valid api key source",
)
ex.exception.response["Error"]["Code"].should.equal("ValidationException")
@mock_apigateway
def test_create_rest_api_valid_apikeysources():
client = boto3.client("apigateway", region_name="us-west-2")
# 1. test creating rest api with HEADER apiKeySource
response = client.create_rest_api(
name="my_api", description="this is my api", apiKeySource="HEADER",
)
api_id = response["id"]
response = client.get_rest_api(restApiId=api_id)
response["apiKeySource"].should.equal("HEADER")
# 2. test creating rest api with AUTHORIZER apiKeySource
response = client.create_rest_api(
name="my_api2", description="this is my api", apiKeySource="AUTHORIZER",
)
api_id = response["id"]
response = client.get_rest_api(restApiId=api_id)
response["apiKeySource"].should.equal("AUTHORIZER")
@mock_apigateway
def test_create_rest_api_invalid_endpointconfiguration():
client = boto3.client("apigateway", region_name="us-west-2")
with assert_raises(ClientError) as ex:
client.create_rest_api(
name="my_api",
description="this is my api",
endpointConfiguration={"types": ["INVALID"]},
)
ex.exception.response["Error"]["Code"].should.equal("ValidationException")
@mock_apigateway
def test_create_rest_api_valid_endpointconfigurations():
client = boto3.client("apigateway", region_name="us-west-2")
# 1. test creating rest api with PRIVATE endpointConfiguration
response = client.create_rest_api(
name="my_api",
description="this is my api",
endpointConfiguration={"types": ["PRIVATE"]},
)
api_id = response["id"]
response = client.get_rest_api(restApiId=api_id)
response["endpointConfiguration"].should.equal(
{"types": ["PRIVATE"],}
)
# 2. test creating rest api with REGIONAL endpointConfiguration
response = client.create_rest_api(
name="my_api2",
description="this is my api",
endpointConfiguration={"types": ["REGIONAL"]},
)
api_id = response["id"]
response = client.get_rest_api(restApiId=api_id)
response["endpointConfiguration"].should.equal(
{"types": ["REGIONAL"],}
)
# 3. test creating rest api with EDGE endpointConfiguration
response = client.create_rest_api(
name="my_api3",
description="this is my api",
endpointConfiguration={"types": ["EDGE"]},
)
api_id = response["id"]
response = client.get_rest_api(restApiId=api_id)
response["endpointConfiguration"].should.equal(
{"types": ["EDGE"],}
)
@mock_apigateway @mock_apigateway
def test_create_resource__validate_name(): def test_create_resource__validate_name():
client = boto3.client("apigateway", region_name="us-west-2") client = boto3.client("apigateway", region_name="us-west-2")
@ -58,15 +173,15 @@ def test_create_resource__validate_name():
0 0
]["id"] ]["id"]
invalid_names = ["/users", "users/", "users/{user_id}", "us{er"] invalid_names = ["/users", "users/", "users/{user_id}", "us{er", "us+er"]
valid_names = ["users", "{user_id}", "user_09", "good-dog"] valid_names = ["users", "{user_id}", "{proxy+}", "user_09", "good-dog"]
# All invalid names should throw an exception # All invalid names should throw an exception
for name in invalid_names: for name in invalid_names:
with assert_raises(ClientError) as ex: with assert_raises(ClientError) as ex:
client.create_resource(restApiId=api_id, parentId=root_id, pathPart=name) client.create_resource(restApiId=api_id, parentId=root_id, pathPart=name)
ex.exception.response["Error"]["Code"].should.equal("BadRequestException") ex.exception.response["Error"]["Code"].should.equal("BadRequestException")
ex.exception.response["Error"]["Message"].should.equal( ex.exception.response["Error"]["Message"].should.equal(
"Resource's path part only allow a-zA-Z0-9._- and curly braces at the beginning and the end." "Resource's path part only allow a-zA-Z0-9._- and curly braces at the beginning and the end and an optional plus sign before the closing brace."
) )
# All valid names should go through # All valid names should go through
for name in valid_names: for name in valid_names:
@ -89,12 +204,7 @@ def test_create_resource():
root_resource["ResponseMetadata"].pop("HTTPHeaders", None) root_resource["ResponseMetadata"].pop("HTTPHeaders", None)
root_resource["ResponseMetadata"].pop("RetryAttempts", None) root_resource["ResponseMetadata"].pop("RetryAttempts", None)
root_resource.should.equal( root_resource.should.equal(
{ {"path": "/", "id": root_id, "ResponseMetadata": {"HTTPStatusCode": 200},}
"path": "/",
"id": root_id,
"ResponseMetadata": {"HTTPStatusCode": 200},
"resourceMethods": {"GET": {}},
}
) )
client.create_resource(restApiId=api_id, parentId=root_id, pathPart="users") client.create_resource(restApiId=api_id, parentId=root_id, pathPart="users")
@ -142,7 +252,6 @@ def test_child_resource():
"parentId": users_id, "parentId": users_id,
"id": tags_id, "id": tags_id,
"ResponseMetadata": {"HTTPStatusCode": 200}, "ResponseMetadata": {"HTTPStatusCode": 200},
"resourceMethods": {"GET": {}},
} }
) )
@ -171,6 +280,41 @@ def test_create_method():
{ {
"httpMethod": "GET", "httpMethod": "GET",
"authorizationType": "none", "authorizationType": "none",
"apiKeyRequired": False,
"ResponseMetadata": {"HTTPStatusCode": 200},
}
)
@mock_apigateway
def test_create_method_apikeyrequired():
client = boto3.client("apigateway", region_name="us-west-2")
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
resources = client.get_resources(restApiId=api_id)
root_id = [resource for resource in resources["items"] if resource["path"] == "/"][
0
]["id"]
client.put_method(
restApiId=api_id,
resourceId=root_id,
httpMethod="GET",
authorizationType="none",
apiKeyRequired=True,
)
response = client.get_method(restApiId=api_id, resourceId=root_id, httpMethod="GET")
# this is hard to match against, so remove it
response["ResponseMetadata"].pop("HTTPHeaders", None)
response["ResponseMetadata"].pop("RetryAttempts", None)
response.should.equal(
{
"httpMethod": "GET",
"authorizationType": "none",
"apiKeyRequired": True,
"ResponseMetadata": {"HTTPStatusCode": 200}, "ResponseMetadata": {"HTTPStatusCode": 200},
} }
) )

View File

@ -706,14 +706,14 @@ def test_create_autoscaling_group_boto3():
"ResourceId": "test_asg", "ResourceId": "test_asg",
"ResourceType": "auto-scaling-group", "ResourceType": "auto-scaling-group",
"Key": "propogated-tag-key", "Key": "propogated-tag-key",
"Value": "propogate-tag-value", "Value": "propagate-tag-value",
"PropagateAtLaunch": True, "PropagateAtLaunch": True,
}, },
{ {
"ResourceId": "test_asg", "ResourceId": "test_asg",
"ResourceType": "auto-scaling-group", "ResourceType": "auto-scaling-group",
"Key": "not-propogated-tag-key", "Key": "not-propogated-tag-key",
"Value": "not-propogate-tag-value", "Value": "not-propagate-tag-value",
"PropagateAtLaunch": False, "PropagateAtLaunch": False,
}, },
], ],
@ -744,14 +744,14 @@ def test_create_autoscaling_group_from_instance():
"ResourceId": "test_asg", "ResourceId": "test_asg",
"ResourceType": "auto-scaling-group", "ResourceType": "auto-scaling-group",
"Key": "propogated-tag-key", "Key": "propogated-tag-key",
"Value": "propogate-tag-value", "Value": "propagate-tag-value",
"PropagateAtLaunch": True, "PropagateAtLaunch": True,
}, },
{ {
"ResourceId": "test_asg", "ResourceId": "test_asg",
"ResourceType": "auto-scaling-group", "ResourceType": "auto-scaling-group",
"Key": "not-propogated-tag-key", "Key": "not-propogated-tag-key",
"Value": "not-propogate-tag-value", "Value": "not-propagate-tag-value",
"PropagateAtLaunch": False, "PropagateAtLaunch": False,
}, },
], ],
@ -1062,7 +1062,7 @@ def test_detach_one_instance_decrement():
"ResourceId": "test_asg", "ResourceId": "test_asg",
"ResourceType": "auto-scaling-group", "ResourceType": "auto-scaling-group",
"Key": "propogated-tag-key", "Key": "propogated-tag-key",
"Value": "propogate-tag-value", "Value": "propagate-tag-value",
"PropagateAtLaunch": True, "PropagateAtLaunch": True,
} }
], ],
@ -1116,7 +1116,7 @@ def test_detach_one_instance():
"ResourceId": "test_asg", "ResourceId": "test_asg",
"ResourceType": "auto-scaling-group", "ResourceType": "auto-scaling-group",
"Key": "propogated-tag-key", "Key": "propogated-tag-key",
"Value": "propogate-tag-value", "Value": "propagate-tag-value",
"PropagateAtLaunch": True, "PropagateAtLaunch": True,
} }
], ],
@ -1169,7 +1169,7 @@ def test_attach_one_instance():
"ResourceId": "test_asg", "ResourceId": "test_asg",
"ResourceType": "auto-scaling-group", "ResourceType": "auto-scaling-group",
"Key": "propogated-tag-key", "Key": "propogated-tag-key",
"Value": "propogate-tag-value", "Value": "propagate-tag-value",
"PropagateAtLaunch": True, "PropagateAtLaunch": True,
} }
], ],

View File

@ -58,8 +58,7 @@ def lambda_handler(event, context):
volume_id = event.get('volume_id') volume_id = event.get('volume_id')
vol = ec2.Volume(volume_id) vol = ec2.Volume(volume_id)
print('get volume details for %s\\nVolume - %s state=%s, size=%s' % (volume_id, volume_id, vol.state, vol.size)) return {{'id': vol.id, 'state': vol.state, 'size': vol.size}}
return event
""".format( """.format(
base_url="motoserver:5000" base_url="motoserver:5000"
if settings.TEST_SERVER_MODE if settings.TEST_SERVER_MODE
@ -79,7 +78,7 @@ def lambda_handler(event, context):
def get_test_zip_file4(): def get_test_zip_file4():
pfunc = """ pfunc = """
def lambda_handler(event, context): def lambda_handler(event, context):
raise Exception('I failed!') raise Exception('I failed!')
""" """
return _process_lambda(pfunc) return _process_lambda(pfunc)
@ -87,14 +86,14 @@ def lambda_handler(event, context):
@mock_lambda @mock_lambda
def test_list_functions(): def test_list_functions():
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
result = conn.list_functions() result = conn.list_functions()
result["Functions"].should.have.length_of(0) result["Functions"].should.have.length_of(0)
@mock_lambda @mock_lambda
def test_invoke_requestresponse_function(): def test_invoke_requestresponse_function():
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
conn.create_function( conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
Runtime="python2.7", Runtime="python2.7",
@ -114,7 +113,44 @@ def test_invoke_requestresponse_function():
Payload=json.dumps(in_data), Payload=json.dumps(in_data),
) )
success_result["StatusCode"].should.equal(202) success_result["StatusCode"].should.equal(200)
result_obj = json.loads(
base64.b64decode(success_result["LogResult"]).decode("utf-8")
)
result_obj.should.equal(in_data)
payload = success_result["Payload"].read().decode("utf-8")
json.loads(payload).should.equal(in_data)
@mock_lambda
def test_invoke_requestresponse_function_with_arn():
from moto.awslambda.models import ACCOUNT_ID
conn = boto3.client("lambda", "us-west-2")
conn.create_function(
FunctionName="testFunction",
Runtime="python2.7",
Role=get_role_name(),
Handler="lambda_function.lambda_handler",
Code={"ZipFile": get_test_zip_file1()},
Description="test lambda function",
Timeout=3,
MemorySize=128,
Publish=True,
)
in_data = {"msg": "So long and thanks for all the fish"}
success_result = conn.invoke(
FunctionName="arn:aws:lambda:us-west-2:{}:function:testFunction".format(
ACCOUNT_ID
),
InvocationType="RequestResponse",
Payload=json.dumps(in_data),
)
success_result["StatusCode"].should.equal(200)
result_obj = json.loads( result_obj = json.loads(
base64.b64decode(success_result["LogResult"]).decode("utf-8") base64.b64decode(success_result["LogResult"]).decode("utf-8")
) )
@ -127,7 +163,7 @@ def test_invoke_requestresponse_function():
@mock_lambda @mock_lambda
def test_invoke_event_function(): def test_invoke_event_function():
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
conn.create_function( conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
Runtime="python2.7", Runtime="python2.7",
@ -149,7 +185,35 @@ def test_invoke_event_function():
FunctionName="testFunction", InvocationType="Event", Payload=json.dumps(in_data) FunctionName="testFunction", InvocationType="Event", Payload=json.dumps(in_data)
) )
success_result["StatusCode"].should.equal(202) success_result["StatusCode"].should.equal(202)
json.loads(success_result["Payload"].read().decode("utf-8")).should.equal({}) json.loads(success_result["Payload"].read().decode("utf-8")).should.equal(in_data)
@mock_lambda
def test_invoke_dryrun_function():
conn = boto3.client("lambda", _lambda_region)
conn.create_function(
FunctionName="testFunction",
Runtime="python2.7",
Role=get_role_name(),
Handler="lambda_function.lambda_handler",
Code={"ZipFile": get_test_zip_file1(),},
Description="test lambda function",
Timeout=3,
MemorySize=128,
Publish=True,
)
conn.invoke.when.called_with(
FunctionName="notAFunction", InvocationType="Event", Payload="{}"
).should.throw(botocore.client.ClientError)
in_data = {"msg": "So long and thanks for all the fish"}
success_result = conn.invoke(
FunctionName="testFunction",
InvocationType="DryRun",
Payload=json.dumps(in_data),
)
success_result["StatusCode"].should.equal(204)
if settings.TEST_SERVER_MODE: if settings.TEST_SERVER_MODE:
@ -157,11 +221,11 @@ if settings.TEST_SERVER_MODE:
@mock_ec2 @mock_ec2
@mock_lambda @mock_lambda
def test_invoke_function_get_ec2_volume(): def test_invoke_function_get_ec2_volume():
conn = boto3.resource("ec2", "us-west-2") conn = boto3.resource("ec2", _lambda_region)
vol = conn.create_volume(Size=99, AvailabilityZone="us-west-2") vol = conn.create_volume(Size=99, AvailabilityZone=_lambda_region)
vol = conn.Volume(vol.id) vol = conn.Volume(vol.id)
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
conn.create_function( conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
Runtime="python3.7", Runtime="python3.7",
@ -180,28 +244,10 @@ if settings.TEST_SERVER_MODE:
InvocationType="RequestResponse", InvocationType="RequestResponse",
Payload=json.dumps(in_data), Payload=json.dumps(in_data),
) )
result["StatusCode"].should.equal(202) result["StatusCode"].should.equal(200)
msg = "get volume details for %s\nVolume - %s state=%s, size=%s\n%s" % ( actual_payload = json.loads(result["Payload"].read().decode("utf-8"))
vol.id, expected_payload = {"id": vol.id, "state": vol.state, "size": vol.size}
vol.id, actual_payload.should.equal(expected_payload)
vol.state,
vol.size,
json.dumps(in_data).replace(
" ", ""
), # Makes the tests pass as the result is missing the whitespace
)
log_result = base64.b64decode(result["LogResult"]).decode("utf-8")
# The Docker lambda invocation will return an additional '\n', so need to replace it:
log_result = log_result.replace("\n\n", "\n")
log_result.should.equal(msg)
payload = result["Payload"].read().decode("utf-8")
# The Docker lambda invocation will return an additional '\n', so need to replace it:
payload = payload.replace("\n\n", "\n")
payload.should.equal(msg)
@mock_logs @mock_logs
@ -209,14 +255,14 @@ if settings.TEST_SERVER_MODE:
@mock_ec2 @mock_ec2
@mock_lambda @mock_lambda
def test_invoke_function_from_sns(): def test_invoke_function_from_sns():
logs_conn = boto3.client("logs", region_name="us-west-2") logs_conn = boto3.client("logs", region_name=_lambda_region)
sns_conn = boto3.client("sns", region_name="us-west-2") sns_conn = boto3.client("sns", region_name=_lambda_region)
sns_conn.create_topic(Name="some-topic") sns_conn.create_topic(Name="some-topic")
topics_json = sns_conn.list_topics() topics_json = sns_conn.list_topics()
topics = topics_json["Topics"] topics = topics_json["Topics"]
topic_arn = topics[0]["TopicArn"] topic_arn = topics[0]["TopicArn"]
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
result = conn.create_function( result = conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
Runtime="python2.7", Runtime="python2.7",
@ -259,7 +305,7 @@ def test_invoke_function_from_sns():
@mock_lambda @mock_lambda
def test_create_based_on_s3_with_missing_bucket(): def test_create_based_on_s3_with_missing_bucket():
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
conn.create_function.when.called_with( conn.create_function.when.called_with(
FunctionName="testFunction", FunctionName="testFunction",
@ -279,12 +325,15 @@ def test_create_based_on_s3_with_missing_bucket():
@mock_s3 @mock_s3
@freeze_time("2015-01-01 00:00:00") @freeze_time("2015-01-01 00:00:00")
def test_create_function_from_aws_bucket(): def test_create_function_from_aws_bucket():
s3_conn = boto3.client("s3", "us-west-2") s3_conn = boto3.client("s3", _lambda_region)
s3_conn.create_bucket(Bucket="test-bucket") s3_conn.create_bucket(
Bucket="test-bucket",
CreateBucketConfiguration={"LocationConstraint": _lambda_region},
)
zip_content = get_test_zip_file2() zip_content = get_test_zip_file2()
s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content)
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
result = conn.create_function( result = conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
@ -324,6 +373,7 @@ def test_create_function_from_aws_bucket():
"VpcId": "vpc-123abc", "VpcId": "vpc-123abc",
}, },
"ResponseMetadata": {"HTTPStatusCode": 201}, "ResponseMetadata": {"HTTPStatusCode": 201},
"State": "Active",
} }
) )
@ -331,7 +381,7 @@ def test_create_function_from_aws_bucket():
@mock_lambda @mock_lambda
@freeze_time("2015-01-01 00:00:00") @freeze_time("2015-01-01 00:00:00")
def test_create_function_from_zipfile(): def test_create_function_from_zipfile():
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
zip_content = get_test_zip_file1() zip_content = get_test_zip_file1()
result = conn.create_function( result = conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
@ -367,6 +417,7 @@ def test_create_function_from_zipfile():
"Version": "1", "Version": "1",
"VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []},
"ResponseMetadata": {"HTTPStatusCode": 201}, "ResponseMetadata": {"HTTPStatusCode": 201},
"State": "Active",
} }
) )
@ -375,12 +426,15 @@ def test_create_function_from_zipfile():
@mock_s3 @mock_s3
@freeze_time("2015-01-01 00:00:00") @freeze_time("2015-01-01 00:00:00")
def test_get_function(): def test_get_function():
s3_conn = boto3.client("s3", "us-west-2") s3_conn = boto3.client("s3", _lambda_region)
s3_conn.create_bucket(Bucket="test-bucket") s3_conn.create_bucket(
Bucket="test-bucket",
CreateBucketConfiguration={"LocationConstraint": _lambda_region},
)
zip_content = get_test_zip_file1() zip_content = get_test_zip_file1()
s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content)
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
conn.create_function( conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
@ -435,7 +489,7 @@ def test_get_function():
) )
# Test get function when can't find function name # Test get function when can't find function name
with assert_raises(ClientError): with assert_raises(conn.exceptions.ResourceNotFoundException):
conn.get_function(FunctionName="junk", Qualifier="$LATEST") conn.get_function(FunctionName="junk", Qualifier="$LATEST")
@ -444,7 +498,10 @@ def test_get_function():
def test_get_function_by_arn(): def test_get_function_by_arn():
bucket_name = "test-bucket" bucket_name = "test-bucket"
s3_conn = boto3.client("s3", "us-east-1") s3_conn = boto3.client("s3", "us-east-1")
s3_conn.create_bucket(Bucket=bucket_name) s3_conn.create_bucket(
Bucket=bucket_name,
CreateBucketConfiguration={"LocationConstraint": _lambda_region},
)
zip_content = get_test_zip_file2() zip_content = get_test_zip_file2()
s3_conn.put_object(Bucket=bucket_name, Key="test.zip", Body=zip_content) s3_conn.put_object(Bucket=bucket_name, Key="test.zip", Body=zip_content)
@ -469,12 +526,15 @@ def test_get_function_by_arn():
@mock_lambda @mock_lambda
@mock_s3 @mock_s3
def test_delete_function(): def test_delete_function():
s3_conn = boto3.client("s3", "us-west-2") s3_conn = boto3.client("s3", _lambda_region)
s3_conn.create_bucket(Bucket="test-bucket") s3_conn.create_bucket(
Bucket="test-bucket",
CreateBucketConfiguration={"LocationConstraint": _lambda_region},
)
zip_content = get_test_zip_file2() zip_content = get_test_zip_file2()
s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content)
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
conn.create_function( conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
@ -505,7 +565,10 @@ def test_delete_function():
def test_delete_function_by_arn(): def test_delete_function_by_arn():
bucket_name = "test-bucket" bucket_name = "test-bucket"
s3_conn = boto3.client("s3", "us-east-1") s3_conn = boto3.client("s3", "us-east-1")
s3_conn.create_bucket(Bucket=bucket_name) s3_conn.create_bucket(
Bucket=bucket_name,
CreateBucketConfiguration={"LocationConstraint": _lambda_region},
)
zip_content = get_test_zip_file2() zip_content = get_test_zip_file2()
s3_conn.put_object(Bucket=bucket_name, Key="test.zip", Body=zip_content) s3_conn.put_object(Bucket=bucket_name, Key="test.zip", Body=zip_content)
@ -530,7 +593,7 @@ def test_delete_function_by_arn():
@mock_lambda @mock_lambda
def test_delete_unknown_function(): def test_delete_unknown_function():
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
conn.delete_function.when.called_with( conn.delete_function.when.called_with(
FunctionName="testFunctionThatDoesntExist" FunctionName="testFunctionThatDoesntExist"
).should.throw(botocore.client.ClientError) ).should.throw(botocore.client.ClientError)
@ -539,12 +602,15 @@ def test_delete_unknown_function():
@mock_lambda @mock_lambda
@mock_s3 @mock_s3
def test_publish(): def test_publish():
s3_conn = boto3.client("s3", "us-west-2") s3_conn = boto3.client("s3", _lambda_region)
s3_conn.create_bucket(Bucket="test-bucket") s3_conn.create_bucket(
Bucket="test-bucket",
CreateBucketConfiguration={"LocationConstraint": _lambda_region},
)
zip_content = get_test_zip_file2() zip_content = get_test_zip_file2()
s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content)
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
conn.create_function( conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
@ -589,12 +655,15 @@ def test_list_create_list_get_delete_list():
test `list -> create -> list -> get -> delete -> list` integration test `list -> create -> list -> get -> delete -> list` integration
""" """
s3_conn = boto3.client("s3", "us-west-2") s3_conn = boto3.client("s3", _lambda_region)
s3_conn.create_bucket(Bucket="test-bucket") s3_conn.create_bucket(
Bucket="test-bucket",
CreateBucketConfiguration={"LocationConstraint": _lambda_region},
)
zip_content = get_test_zip_file2() zip_content = get_test_zip_file2()
s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content)
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
conn.list_functions()["Functions"].should.have.length_of(0) conn.list_functions()["Functions"].should.have.length_of(0)
@ -631,6 +700,7 @@ def test_list_create_list_get_delete_list():
"Timeout": 3, "Timeout": 3,
"Version": "$LATEST", "Version": "$LATEST",
"VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []},
"State": "Active",
}, },
"ResponseMetadata": {"HTTPStatusCode": 200}, "ResponseMetadata": {"HTTPStatusCode": 200},
} }
@ -690,12 +760,15 @@ def test_tags():
""" """
test list_tags -> tag_resource -> list_tags -> tag_resource -> list_tags -> untag_resource -> list_tags integration test list_tags -> tag_resource -> list_tags -> tag_resource -> list_tags -> untag_resource -> list_tags integration
""" """
s3_conn = boto3.client("s3", "us-west-2") s3_conn = boto3.client("s3", _lambda_region)
s3_conn.create_bucket(Bucket="test-bucket") s3_conn.create_bucket(
Bucket="test-bucket",
CreateBucketConfiguration={"LocationConstraint": _lambda_region},
)
zip_content = get_test_zip_file2() zip_content = get_test_zip_file2()
s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content)
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
function = conn.create_function( function = conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
@ -747,7 +820,7 @@ def test_tags_not_found():
""" """
Test list_tags and tag_resource when the lambda with the given arn does not exist Test list_tags and tag_resource when the lambda with the given arn does not exist
""" """
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
conn.list_tags.when.called_with( conn.list_tags.when.called_with(
Resource="arn:aws:lambda:{}:function:not-found".format(ACCOUNT_ID) Resource="arn:aws:lambda:{}:function:not-found".format(ACCOUNT_ID)
).should.throw(botocore.client.ClientError) ).should.throw(botocore.client.ClientError)
@ -765,7 +838,7 @@ def test_tags_not_found():
@mock_lambda @mock_lambda
def test_invoke_async_function(): def test_invoke_async_function():
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
conn.create_function( conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
Runtime="python2.7", Runtime="python2.7",
@ -788,7 +861,7 @@ def test_invoke_async_function():
@mock_lambda @mock_lambda
@freeze_time("2015-01-01 00:00:00") @freeze_time("2015-01-01 00:00:00")
def test_get_function_created_with_zipfile(): def test_get_function_created_with_zipfile():
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
zip_content = get_test_zip_file1() zip_content = get_test_zip_file1()
result = conn.create_function( result = conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
@ -827,13 +900,14 @@ def test_get_function_created_with_zipfile():
"Timeout": 3, "Timeout": 3,
"Version": "$LATEST", "Version": "$LATEST",
"VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []},
"State": "Active",
} }
) )
@mock_lambda @mock_lambda
def test_add_function_permission(): def test_add_function_permission():
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
zip_content = get_test_zip_file1() zip_content = get_test_zip_file1()
conn.create_function( conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
@ -864,7 +938,7 @@ def test_add_function_permission():
@mock_lambda @mock_lambda
def test_get_function_policy(): def test_get_function_policy():
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
zip_content = get_test_zip_file1() zip_content = get_test_zip_file1()
conn.create_function( conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
@ -899,12 +973,15 @@ def test_get_function_policy():
@mock_lambda @mock_lambda
@mock_s3 @mock_s3
def test_list_versions_by_function(): def test_list_versions_by_function():
s3_conn = boto3.client("s3", "us-west-2") s3_conn = boto3.client("s3", _lambda_region)
s3_conn.create_bucket(Bucket="test-bucket") s3_conn.create_bucket(
Bucket="test-bucket",
CreateBucketConfiguration={"LocationConstraint": _lambda_region},
)
zip_content = get_test_zip_file2() zip_content = get_test_zip_file2()
s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content)
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
conn.create_function( conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
@ -955,12 +1032,15 @@ def test_list_versions_by_function():
@mock_lambda @mock_lambda
@mock_s3 @mock_s3
def test_create_function_with_already_exists(): def test_create_function_with_already_exists():
s3_conn = boto3.client("s3", "us-west-2") s3_conn = boto3.client("s3", _lambda_region)
s3_conn.create_bucket(Bucket="test-bucket") s3_conn.create_bucket(
Bucket="test-bucket",
CreateBucketConfiguration={"LocationConstraint": _lambda_region},
)
zip_content = get_test_zip_file2() zip_content = get_test_zip_file2()
s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content)
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
conn.create_function( conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
@ -992,7 +1072,7 @@ def test_create_function_with_already_exists():
@mock_lambda @mock_lambda
@mock_s3 @mock_s3
def test_list_versions_by_function_for_nonexistent_function(): def test_list_versions_by_function_for_nonexistent_function():
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
versions = conn.list_versions_by_function(FunctionName="testFunction") versions = conn.list_versions_by_function(FunctionName="testFunction")
assert len(versions["Versions"]) == 0 assert len(versions["Versions"]) == 0
@ -1341,12 +1421,15 @@ def test_delete_event_source_mapping():
@mock_lambda @mock_lambda
@mock_s3 @mock_s3
def test_update_configuration(): def test_update_configuration():
s3_conn = boto3.client("s3", "us-west-2") s3_conn = boto3.client("s3", _lambda_region)
s3_conn.create_bucket(Bucket="test-bucket") s3_conn.create_bucket(
Bucket="test-bucket",
CreateBucketConfiguration={"LocationConstraint": _lambda_region},
)
zip_content = get_test_zip_file2() zip_content = get_test_zip_file2()
s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content)
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
fxn = conn.create_function( fxn = conn.create_function(
FunctionName="testFunction", FunctionName="testFunction",
@ -1389,7 +1472,7 @@ def test_update_configuration():
@mock_lambda @mock_lambda
def test_update_function_zip(): def test_update_function_zip():
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
zip_content_one = get_test_zip_file1() zip_content_one = get_test_zip_file1()
@ -1436,6 +1519,7 @@ def test_update_function_zip():
"Timeout": 3, "Timeout": 3,
"Version": "2", "Version": "2",
"VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []},
"State": "Active",
} }
) )
@ -1443,13 +1527,16 @@ def test_update_function_zip():
@mock_lambda @mock_lambda
@mock_s3 @mock_s3
def test_update_function_s3(): def test_update_function_s3():
s3_conn = boto3.client("s3", "us-west-2") s3_conn = boto3.client("s3", _lambda_region)
s3_conn.create_bucket(Bucket="test-bucket") s3_conn.create_bucket(
Bucket="test-bucket",
CreateBucketConfiguration={"LocationConstraint": _lambda_region},
)
zip_content = get_test_zip_file1() zip_content = get_test_zip_file1()
s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content)
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
fxn = conn.create_function( fxn = conn.create_function(
FunctionName="testFunctionS3", FunctionName="testFunctionS3",
@ -1498,6 +1585,7 @@ def test_update_function_s3():
"Timeout": 3, "Timeout": 3,
"Version": "2", "Version": "2",
"VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []},
"State": "Active",
} }
) )
@ -1529,7 +1617,7 @@ def test_create_function_with_unknown_arn():
def create_invalid_lambda(role): def create_invalid_lambda(role):
conn = boto3.client("lambda", "us-west-2") conn = boto3.client("lambda", _lambda_region)
zip_content = get_test_zip_file1() zip_content = get_test_zip_file1()
with assert_raises(ClientError) as err: with assert_raises(ClientError) as err:
conn.create_function( conn.create_function(
@ -1548,7 +1636,7 @@ def create_invalid_lambda(role):
def get_role_name(): def get_role_name():
with mock_iam(): with mock_iam():
iam = boto3.client("iam", region_name="us-west-2") iam = boto3.client("iam", region_name=_lambda_region)
try: try:
return iam.get_role(RoleName="my-role")["Role"]["Arn"] return iam.get_role(RoleName="my-role")["Role"]["Arn"]
except ClientError: except ClientError:

View File

@ -94,7 +94,7 @@ def test_lambda_can_be_deleted_by_cloudformation():
# Verify function was deleted # Verify function was deleted
with assert_raises(ClientError) as e: with assert_raises(ClientError) as e:
lmbda.get_function(FunctionName=created_fn_name) lmbda.get_function(FunctionName=created_fn_name)
e.exception.response["Error"]["Code"].should.equal("404") e.exception.response["Error"]["Code"].should.equal("ResourceNotFoundException")
def create_stack(cf, s3): def create_stack(cf, s3):

View File

@ -0,0 +1,49 @@
from __future__ import unicode_literals
import json
import sure
from moto.awslambda.policy import Policy
class MockLambdaFunction:
def __init__(self, arn):
self.function_arn = arn
self.policy = None
def test_policy():
policy = Policy(MockLambdaFunction("arn"))
statement = {
"StatementId": "statement0",
"Action": "lambda:InvokeFunction",
"FunctionName": "function_name",
"Principal": "events.amazonaws.com",
"SourceArn": "arn:aws:events:us-east-1:111111111111:rule/rule_name",
"SourceAccount": "111111111111",
}
expected = {
"Action": "lambda:InvokeFunction",
"FunctionName": "function_name",
"Principal": {"Service": "events.amazonaws.com"},
"Effect": "Allow",
"Resource": "arn:$LATEST",
"Sid": "statement0",
"Condition": {
"ArnLike": {
"AWS:SourceArn": "arn:aws:events:us-east-1:111111111111:rule/rule_name",
},
"StringEquals": {"AWS:SourceAccount": "111111111111"},
},
}
policy.add_statement(json.dumps(statement))
expected.should.be.equal(policy.statements[0])
sid = statement.get("StatementId", None)
if sid == None:
raise "TestCase.statement does not contain StatementId"
policy.del_statement(sid)
[].should.be.equal(policy.statements)

View File

@ -1,5 +1,5 @@
from __future__ import unicode_literals from __future__ import unicode_literals
template = { template = {
"Resources": {"VPCEIP": {"Type": "AWS::EC2::EIP", "Properties": {"Domain": "vpc"}}} "Resources": {"VPCEIP": {"Type": "AWS::EC2::EIP", "Properties": {"Domain": "vpc"}}}
} }

View File

@ -1,276 +1,276 @@
from __future__ import unicode_literals from __future__ import unicode_literals
template = { template = {
"Description": "AWS CloudFormation Sample Template vpc_single_instance_in_subnet.template: Sample template showing how to create a VPC and add an EC2 instance with an Elastic IP address and a security group. **WARNING** This template creates an Amazon EC2 instance. You will be billed for the AWS resources used if you create a stack from this template.", "Description": "AWS CloudFormation Sample Template vpc_single_instance_in_subnet.template: Sample template showing how to create a VPC and add an EC2 instance with an Elastic IP address and a security group. **WARNING** This template creates an Amazon EC2 instance. You will be billed for the AWS resources used if you create a stack from this template.",
"Parameters": { "Parameters": {
"SSHLocation": { "SSHLocation": {
"ConstraintDescription": "must be a valid IP CIDR range of the form x.x.x.x/x.", "ConstraintDescription": "must be a valid IP CIDR range of the form x.x.x.x/x.",
"Description": " The IP address range that can be used to SSH to the EC2 instances", "Description": " The IP address range that can be used to SSH to the EC2 instances",
"Default": "0.0.0.0/0", "Default": "0.0.0.0/0",
"MinLength": "9", "MinLength": "9",
"AllowedPattern": "(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})/(\\d{1,2})", "AllowedPattern": "(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})/(\\d{1,2})",
"MaxLength": "18", "MaxLength": "18",
"Type": "String", "Type": "String",
}, },
"KeyName": { "KeyName": {
"Type": "String", "Type": "String",
"Description": "Name of an existing EC2 KeyPair to enable SSH access to the instance", "Description": "Name of an existing EC2 KeyPair to enable SSH access to the instance",
"MinLength": "1", "MinLength": "1",
"AllowedPattern": "[\\x20-\\x7E]*", "AllowedPattern": "[\\x20-\\x7E]*",
"MaxLength": "255", "MaxLength": "255",
"ConstraintDescription": "can contain only ASCII characters.", "ConstraintDescription": "can contain only ASCII characters.",
}, },
"InstanceType": { "InstanceType": {
"Default": "m1.small", "Default": "m1.small",
"ConstraintDescription": "must be a valid EC2 instance type.", "ConstraintDescription": "must be a valid EC2 instance type.",
"Type": "String", "Type": "String",
"Description": "WebServer EC2 instance type", "Description": "WebServer EC2 instance type",
"AllowedValues": [ "AllowedValues": [
"t1.micro", "t1.micro",
"m1.small", "m1.small",
"m1.medium", "m1.medium",
"m1.large", "m1.large",
"m1.xlarge", "m1.xlarge",
"m2.xlarge", "m2.xlarge",
"m2.2xlarge", "m2.2xlarge",
"m2.4xlarge", "m2.4xlarge",
"m3.xlarge", "m3.xlarge",
"m3.2xlarge", "m3.2xlarge",
"c1.medium", "c1.medium",
"c1.xlarge", "c1.xlarge",
"cc1.4xlarge", "cc1.4xlarge",
"cc2.8xlarge", "cc2.8xlarge",
"cg1.4xlarge", "cg1.4xlarge",
], ],
}, },
}, },
"AWSTemplateFormatVersion": "2010-09-09", "AWSTemplateFormatVersion": "2010-09-09",
"Outputs": { "Outputs": {
"URL": { "URL": {
"Description": "Newly created application URL", "Description": "Newly created application URL",
"Value": { "Value": {
"Fn::Join": [ "Fn::Join": [
"", "",
["http://", {"Fn::GetAtt": ["WebServerInstance", "PublicIp"]}], ["http://", {"Fn::GetAtt": ["WebServerInstance", "PublicIp"]}],
] ]
}, },
} }
}, },
"Resources": { "Resources": {
"Subnet": { "Subnet": {
"Type": "AWS::EC2::Subnet", "Type": "AWS::EC2::Subnet",
"Properties": { "Properties": {
"VpcId": {"Ref": "VPC"}, "VpcId": {"Ref": "VPC"},
"CidrBlock": "10.0.0.0/24", "CidrBlock": "10.0.0.0/24",
"Tags": [{"Value": {"Ref": "AWS::StackId"}, "Key": "Application"}], "Tags": [{"Value": {"Ref": "AWS::StackId"}, "Key": "Application"}],
}, },
}, },
"WebServerWaitHandle": {"Type": "AWS::CloudFormation::WaitConditionHandle"}, "WebServerWaitHandle": {"Type": "AWS::CloudFormation::WaitConditionHandle"},
"Route": { "Route": {
"Type": "AWS::EC2::Route", "Type": "AWS::EC2::Route",
"Properties": { "Properties": {
"GatewayId": {"Ref": "InternetGateway"}, "GatewayId": {"Ref": "InternetGateway"},
"DestinationCidrBlock": "0.0.0.0/0", "DestinationCidrBlock": "0.0.0.0/0",
"RouteTableId": {"Ref": "RouteTable"}, "RouteTableId": {"Ref": "RouteTable"},
}, },
"DependsOn": "AttachGateway", "DependsOn": "AttachGateway",
}, },
"SubnetRouteTableAssociation": { "SubnetRouteTableAssociation": {
"Type": "AWS::EC2::SubnetRouteTableAssociation", "Type": "AWS::EC2::SubnetRouteTableAssociation",
"Properties": { "Properties": {
"SubnetId": {"Ref": "Subnet"}, "SubnetId": {"Ref": "Subnet"},
"RouteTableId": {"Ref": "RouteTable"}, "RouteTableId": {"Ref": "RouteTable"},
}, },
}, },
"InternetGateway": { "InternetGateway": {
"Type": "AWS::EC2::InternetGateway", "Type": "AWS::EC2::InternetGateway",
"Properties": { "Properties": {
"Tags": [{"Value": {"Ref": "AWS::StackId"}, "Key": "Application"}] "Tags": [{"Value": {"Ref": "AWS::StackId"}, "Key": "Application"}]
}, },
}, },
"RouteTable": { "RouteTable": {
"Type": "AWS::EC2::RouteTable", "Type": "AWS::EC2::RouteTable",
"Properties": { "Properties": {
"VpcId": {"Ref": "VPC"}, "VpcId": {"Ref": "VPC"},
"Tags": [{"Value": {"Ref": "AWS::StackId"}, "Key": "Application"}], "Tags": [{"Value": {"Ref": "AWS::StackId"}, "Key": "Application"}],
}, },
}, },
"WebServerWaitCondition": { "WebServerWaitCondition": {
"Type": "AWS::CloudFormation::WaitCondition", "Type": "AWS::CloudFormation::WaitCondition",
"Properties": {"Handle": {"Ref": "WebServerWaitHandle"}, "Timeout": "300"}, "Properties": {"Handle": {"Ref": "WebServerWaitHandle"}, "Timeout": "300"},
"DependsOn": "WebServerInstance", "DependsOn": "WebServerInstance",
}, },
"VPC": { "VPC": {
"Type": "AWS::EC2::VPC", "Type": "AWS::EC2::VPC",
"Properties": { "Properties": {
"CidrBlock": "10.0.0.0/16", "CidrBlock": "10.0.0.0/16",
"Tags": [{"Value": {"Ref": "AWS::StackId"}, "Key": "Application"}], "Tags": [{"Value": {"Ref": "AWS::StackId"}, "Key": "Application"}],
}, },
}, },
"InstanceSecurityGroup": { "InstanceSecurityGroup": {
"Type": "AWS::EC2::SecurityGroup", "Type": "AWS::EC2::SecurityGroup",
"Properties": { "Properties": {
"SecurityGroupIngress": [ "SecurityGroupIngress": [
{ {
"ToPort": "22", "ToPort": "22",
"IpProtocol": "tcp", "IpProtocol": "tcp",
"CidrIp": {"Ref": "SSHLocation"}, "CidrIp": {"Ref": "SSHLocation"},
"FromPort": "22", "FromPort": "22",
}, },
{ {
"ToPort": "80", "ToPort": "80",
"IpProtocol": "tcp", "IpProtocol": "tcp",
"CidrIp": "0.0.0.0/0", "CidrIp": "0.0.0.0/0",
"FromPort": "80", "FromPort": "80",
}, },
], ],
"VpcId": {"Ref": "VPC"}, "VpcId": {"Ref": "VPC"},
"GroupDescription": "Enable SSH access via port 22", "GroupDescription": "Enable SSH access via port 22",
}, },
}, },
"WebServerInstance": { "WebServerInstance": {
"Type": "AWS::EC2::Instance", "Type": "AWS::EC2::Instance",
"Properties": { "Properties": {
"UserData": { "UserData": {
"Fn::Base64": { "Fn::Base64": {
"Fn::Join": [ "Fn::Join": [
"", "",
[ [
"#!/bin/bash\n", "#!/bin/bash\n",
"yum update -y aws-cfn-bootstrap\n", "yum update -y aws-cfn-bootstrap\n",
"# Helper function\n", "# Helper function\n",
"function error_exit\n", "function error_exit\n",
"{\n", "{\n",
' /opt/aws/bin/cfn-signal -e 1 -r "$1" \'', ' /opt/aws/bin/cfn-signal -e 1 -r "$1" \'',
{"Ref": "WebServerWaitHandle"}, {"Ref": "WebServerWaitHandle"},
"'\n", "'\n",
" exit 1\n", " exit 1\n",
"}\n", "}\n",
"# Install the simple web page\n", "# Install the simple web page\n",
"/opt/aws/bin/cfn-init -s ", "/opt/aws/bin/cfn-init -s ",
{"Ref": "AWS::StackId"}, {"Ref": "AWS::StackId"},
" -r WebServerInstance ", " -r WebServerInstance ",
" --region ", " --region ",
{"Ref": "AWS::Region"}, {"Ref": "AWS::Region"},
" || error_exit 'Failed to run cfn-init'\n", " || error_exit 'Failed to run cfn-init'\n",
"# Start up the cfn-hup daemon to listen for changes to the Web Server metadata\n", "# Start up the cfn-hup daemon to listen for changes to the Web Server metadata\n",
"/opt/aws/bin/cfn-hup || error_exit 'Failed to start cfn-hup'\n", "/opt/aws/bin/cfn-hup || error_exit 'Failed to start cfn-hup'\n",
"# All done so signal success\n", "# All done so signal success\n",
'/opt/aws/bin/cfn-signal -e 0 -r "WebServer setup complete" \'', '/opt/aws/bin/cfn-signal -e 0 -r "WebServer setup complete" \'',
{"Ref": "WebServerWaitHandle"}, {"Ref": "WebServerWaitHandle"},
"'\n", "'\n",
], ],
] ]
} }
}, },
"Tags": [ "Tags": [
{"Value": {"Ref": "AWS::StackId"}, "Key": "Application"}, {"Value": {"Ref": "AWS::StackId"}, "Key": "Application"},
{"Value": "Bar", "Key": "Foo"}, {"Value": "Bar", "Key": "Foo"},
], ],
"SecurityGroupIds": [{"Ref": "InstanceSecurityGroup"}], "SecurityGroupIds": [{"Ref": "InstanceSecurityGroup"}],
"KeyName": {"Ref": "KeyName"}, "KeyName": {"Ref": "KeyName"},
"SubnetId": {"Ref": "Subnet"}, "SubnetId": {"Ref": "Subnet"},
"ImageId": { "ImageId": {
"Fn::FindInMap": ["RegionMap", {"Ref": "AWS::Region"}, "AMI"] "Fn::FindInMap": ["RegionMap", {"Ref": "AWS::Region"}, "AMI"]
}, },
"InstanceType": {"Ref": "InstanceType"}, "InstanceType": {"Ref": "InstanceType"},
}, },
"Metadata": { "Metadata": {
"Comment": "Install a simple PHP application", "Comment": "Install a simple PHP application",
"AWS::CloudFormation::Init": { "AWS::CloudFormation::Init": {
"config": { "config": {
"files": { "files": {
"/etc/cfn/cfn-hup.conf": { "/etc/cfn/cfn-hup.conf": {
"content": { "content": {
"Fn::Join": [ "Fn::Join": [
"", "",
[ [
"[main]\n", "[main]\n",
"stack=", "stack=",
{"Ref": "AWS::StackId"}, {"Ref": "AWS::StackId"},
"\n", "\n",
"region=", "region=",
{"Ref": "AWS::Region"}, {"Ref": "AWS::Region"},
"\n", "\n",
], ],
] ]
}, },
"owner": "root", "owner": "root",
"group": "root", "group": "root",
"mode": "000400", "mode": "000400",
}, },
"/etc/cfn/hooks.d/cfn-auto-reloader.conf": { "/etc/cfn/hooks.d/cfn-auto-reloader.conf": {
"content": { "content": {
"Fn::Join": [ "Fn::Join": [
"", "",
[ [
"[cfn-auto-reloader-hook]\n", "[cfn-auto-reloader-hook]\n",
"triggers=post.update\n", "triggers=post.update\n",
"path=Resources.WebServerInstance.Metadata.AWS::CloudFormation::Init\n", "path=Resources.WebServerInstance.Metadata.AWS::CloudFormation::Init\n",
"action=/opt/aws/bin/cfn-init -s ", "action=/opt/aws/bin/cfn-init -s ",
{"Ref": "AWS::StackId"}, {"Ref": "AWS::StackId"},
" -r WebServerInstance ", " -r WebServerInstance ",
" --region ", " --region ",
{"Ref": "AWS::Region"}, {"Ref": "AWS::Region"},
"\n", "\n",
"runas=root\n", "runas=root\n",
], ],
] ]
} }
}, },
"/var/www/html/index.php": { "/var/www/html/index.php": {
"content": { "content": {
"Fn::Join": [ "Fn::Join": [
"", "",
[ [
"<?php\n", "<?php\n",
"echo '<h1>AWS CloudFormation sample PHP application</h1>';\n", "echo '<h1>AWS CloudFormation sample PHP application</h1>';\n",
"?>\n", "?>\n",
], ],
] ]
}, },
"owner": "apache", "owner": "apache",
"group": "apache", "group": "apache",
"mode": "000644", "mode": "000644",
}, },
}, },
"services": { "services": {
"sysvinit": { "sysvinit": {
"httpd": {"ensureRunning": "true", "enabled": "true"}, "httpd": {"ensureRunning": "true", "enabled": "true"},
"sendmail": { "sendmail": {
"ensureRunning": "false", "ensureRunning": "false",
"enabled": "false", "enabled": "false",
}, },
} }
}, },
"packages": {"yum": {"httpd": [], "php": []}}, "packages": {"yum": {"httpd": [], "php": []}},
} }
}, },
}, },
}, },
"IPAddress": { "IPAddress": {
"Type": "AWS::EC2::EIP", "Type": "AWS::EC2::EIP",
"Properties": {"InstanceId": {"Ref": "WebServerInstance"}, "Domain": "vpc"}, "Properties": {"InstanceId": {"Ref": "WebServerInstance"}, "Domain": "vpc"},
"DependsOn": "AttachGateway", "DependsOn": "AttachGateway",
}, },
"AttachGateway": { "AttachGateway": {
"Type": "AWS::EC2::VPCGatewayAttachment", "Type": "AWS::EC2::VPCGatewayAttachment",
"Properties": { "Properties": {
"VpcId": {"Ref": "VPC"}, "VpcId": {"Ref": "VPC"},
"InternetGatewayId": {"Ref": "InternetGateway"}, "InternetGatewayId": {"Ref": "InternetGateway"},
}, },
}, },
}, },
"Mappings": { "Mappings": {
"RegionMap": { "RegionMap": {
"ap-southeast-1": {"AMI": "ami-74dda626"}, "ap-southeast-1": {"AMI": "ami-74dda626"},
"ap-southeast-2": {"AMI": "ami-b3990e89"}, "ap-southeast-2": {"AMI": "ami-b3990e89"},
"us-west-2": {"AMI": "ami-16fd7026"}, "us-west-2": {"AMI": "ami-16fd7026"},
"us-east-1": {"AMI": "ami-7f418316"}, "us-east-1": {"AMI": "ami-7f418316"},
"ap-northeast-1": {"AMI": "ami-dcfa4edd"}, "ap-northeast-1": {"AMI": "ami-dcfa4edd"},
"us-west-1": {"AMI": "ami-951945d0"}, "us-west-1": {"AMI": "ami-951945d0"},
"eu-west-1": {"AMI": "ami-24506250"}, "eu-west-1": {"AMI": "ami-24506250"},
"sa-east-1": {"AMI": "ami-3e3be423"}, "sa-east-1": {"AMI": "ami-3e3be423"},
} }
}, },
} }

View File

@ -143,7 +143,7 @@ def test_create_stack_with_notification_arn():
@mock_s3_deprecated @mock_s3_deprecated
def test_create_stack_from_s3_url(): def test_create_stack_from_s3_url():
s3_conn = boto.s3.connect_to_region("us-west-1") s3_conn = boto.s3.connect_to_region("us-west-1")
bucket = s3_conn.create_bucket("foobar") bucket = s3_conn.create_bucket("foobar", location="us-west-1")
key = boto.s3.key.Key(bucket) key = boto.s3.key.Key(bucket)
key.key = "template-key" key.key = "template-key"
key.set_contents_from_string(dummy_template_json) key.set_contents_from_string(dummy_template_json)

View File

@ -522,6 +522,13 @@ def test_boto3_list_stack_set_operations():
list_operation["Summaries"][-1]["Action"].should.equal("UPDATE") list_operation["Summaries"][-1]["Action"].should.equal("UPDATE")
@mock_cloudformation
def test_boto3_bad_list_stack_resources():
cf_conn = boto3.client("cloudformation", region_name="us-east-1")
with assert_raises(ClientError):
cf_conn.list_stack_resources(StackName="test_stack_set")
@mock_cloudformation @mock_cloudformation
def test_boto3_delete_stack_set(): def test_boto3_delete_stack_set():
cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn = boto3.client("cloudformation", region_name="us-east-1")

View File

@ -1,117 +1,117 @@
import boto import boto
from boto.ec2.cloudwatch.alarm import MetricAlarm from boto.ec2.cloudwatch.alarm import MetricAlarm
import sure # noqa import sure # noqa
from moto import mock_cloudwatch_deprecated from moto import mock_cloudwatch_deprecated
def alarm_fixture(name="tester", action=None): def alarm_fixture(name="tester", action=None):
action = action or ["arn:alarm"] action = action or ["arn:alarm"]
return MetricAlarm( return MetricAlarm(
name=name, name=name,
namespace="{0}_namespace".format(name), namespace="{0}_namespace".format(name),
metric="{0}_metric".format(name), metric="{0}_metric".format(name),
comparison=">=", comparison=">=",
threshold=2.0, threshold=2.0,
period=60, period=60,
evaluation_periods=5, evaluation_periods=5,
statistic="Average", statistic="Average",
description="A test", description="A test",
dimensions={"InstanceId": ["i-0123456,i-0123457"]}, dimensions={"InstanceId": ["i-0123456,i-0123457"]},
alarm_actions=action, alarm_actions=action,
ok_actions=["arn:ok"], ok_actions=["arn:ok"],
insufficient_data_actions=["arn:insufficient"], insufficient_data_actions=["arn:insufficient"],
unit="Seconds", unit="Seconds",
) )
@mock_cloudwatch_deprecated @mock_cloudwatch_deprecated
def test_create_alarm(): def test_create_alarm():
conn = boto.connect_cloudwatch() conn = boto.connect_cloudwatch()
alarm = alarm_fixture() alarm = alarm_fixture()
conn.create_alarm(alarm) conn.create_alarm(alarm)
alarms = conn.describe_alarms() alarms = conn.describe_alarms()
alarms.should.have.length_of(1) alarms.should.have.length_of(1)
alarm = alarms[0] alarm = alarms[0]
alarm.name.should.equal("tester") alarm.name.should.equal("tester")
alarm.namespace.should.equal("tester_namespace") alarm.namespace.should.equal("tester_namespace")
alarm.metric.should.equal("tester_metric") alarm.metric.should.equal("tester_metric")
alarm.comparison.should.equal(">=") alarm.comparison.should.equal(">=")
alarm.threshold.should.equal(2.0) alarm.threshold.should.equal(2.0)
alarm.period.should.equal(60) alarm.period.should.equal(60)
alarm.evaluation_periods.should.equal(5) alarm.evaluation_periods.should.equal(5)
alarm.statistic.should.equal("Average") alarm.statistic.should.equal("Average")
alarm.description.should.equal("A test") alarm.description.should.equal("A test")
dict(alarm.dimensions).should.equal({"InstanceId": ["i-0123456,i-0123457"]}) dict(alarm.dimensions).should.equal({"InstanceId": ["i-0123456,i-0123457"]})
list(alarm.alarm_actions).should.equal(["arn:alarm"]) list(alarm.alarm_actions).should.equal(["arn:alarm"])
list(alarm.ok_actions).should.equal(["arn:ok"]) list(alarm.ok_actions).should.equal(["arn:ok"])
list(alarm.insufficient_data_actions).should.equal(["arn:insufficient"]) list(alarm.insufficient_data_actions).should.equal(["arn:insufficient"])
alarm.unit.should.equal("Seconds") alarm.unit.should.equal("Seconds")
@mock_cloudwatch_deprecated @mock_cloudwatch_deprecated
def test_delete_alarm(): def test_delete_alarm():
conn = boto.connect_cloudwatch() conn = boto.connect_cloudwatch()
alarms = conn.describe_alarms() alarms = conn.describe_alarms()
alarms.should.have.length_of(0) alarms.should.have.length_of(0)
alarm = alarm_fixture() alarm = alarm_fixture()
conn.create_alarm(alarm) conn.create_alarm(alarm)
alarms = conn.describe_alarms() alarms = conn.describe_alarms()
alarms.should.have.length_of(1) alarms.should.have.length_of(1)
alarms[0].delete() alarms[0].delete()
alarms = conn.describe_alarms() alarms = conn.describe_alarms()
alarms.should.have.length_of(0) alarms.should.have.length_of(0)
@mock_cloudwatch_deprecated @mock_cloudwatch_deprecated
def test_put_metric_data(): def test_put_metric_data():
conn = boto.connect_cloudwatch() conn = boto.connect_cloudwatch()
conn.put_metric_data( conn.put_metric_data(
namespace="tester", namespace="tester",
name="metric", name="metric",
value=1.5, value=1.5,
dimensions={"InstanceId": ["i-0123456,i-0123457"]}, dimensions={"InstanceId": ["i-0123456,i-0123457"]},
) )
metrics = conn.list_metrics() metrics = conn.list_metrics()
metrics.should.have.length_of(1) metrics.should.have.length_of(1)
metric = metrics[0] metric = metrics[0]
metric.namespace.should.equal("tester") metric.namespace.should.equal("tester")
metric.name.should.equal("metric") metric.name.should.equal("metric")
dict(metric.dimensions).should.equal({"InstanceId": ["i-0123456,i-0123457"]}) dict(metric.dimensions).should.equal({"InstanceId": ["i-0123456,i-0123457"]})
@mock_cloudwatch_deprecated @mock_cloudwatch_deprecated
def test_describe_alarms(): def test_describe_alarms():
conn = boto.connect_cloudwatch() conn = boto.connect_cloudwatch()
alarms = conn.describe_alarms() alarms = conn.describe_alarms()
alarms.should.have.length_of(0) alarms.should.have.length_of(0)
conn.create_alarm(alarm_fixture(name="nfoobar", action="afoobar")) conn.create_alarm(alarm_fixture(name="nfoobar", action="afoobar"))
conn.create_alarm(alarm_fixture(name="nfoobaz", action="afoobaz")) conn.create_alarm(alarm_fixture(name="nfoobaz", action="afoobaz"))
conn.create_alarm(alarm_fixture(name="nbarfoo", action="abarfoo")) conn.create_alarm(alarm_fixture(name="nbarfoo", action="abarfoo"))
conn.create_alarm(alarm_fixture(name="nbazfoo", action="abazfoo")) conn.create_alarm(alarm_fixture(name="nbazfoo", action="abazfoo"))
alarms = conn.describe_alarms() alarms = conn.describe_alarms()
alarms.should.have.length_of(4) alarms.should.have.length_of(4)
alarms = conn.describe_alarms(alarm_name_prefix="nfoo") alarms = conn.describe_alarms(alarm_name_prefix="nfoo")
alarms.should.have.length_of(2) alarms.should.have.length_of(2)
alarms = conn.describe_alarms(alarm_names=["nfoobar", "nbarfoo", "nbazfoo"]) alarms = conn.describe_alarms(alarm_names=["nfoobar", "nbarfoo", "nbazfoo"])
alarms.should.have.length_of(3) alarms.should.have.length_of(3)
alarms = conn.describe_alarms(action_prefix="afoo") alarms = conn.describe_alarms(action_prefix="afoo")
alarms.should.have.length_of(2) alarms.should.have.length_of(2)
for alarm in conn.describe_alarms(): for alarm in conn.describe_alarms():
alarm.delete() alarm.delete()
alarms = conn.describe_alarms() alarms = conn.describe_alarms()
alarms.should.have.length_of(0) alarms.should.have.length_of(0)

View File

@ -27,6 +27,11 @@ def test_create_user_pool():
result["UserPool"]["Id"].should_not.be.none result["UserPool"]["Id"].should_not.be.none
result["UserPool"]["Id"].should.match(r"[\w-]+_[0-9a-zA-Z]+") result["UserPool"]["Id"].should.match(r"[\w-]+_[0-9a-zA-Z]+")
result["UserPool"]["Arn"].should.equal(
"arn:aws:cognito-idp:us-west-2:{}:userpool/{}".format(
ACCOUNT_ID, result["UserPool"]["Id"]
)
)
result["UserPool"]["Name"].should.equal(name) result["UserPool"]["Name"].should.equal(name)
result["UserPool"]["LambdaConfig"]["PreSignUp"].should.equal(value) result["UserPool"]["LambdaConfig"]["PreSignUp"].should.equal(value)
@ -911,6 +916,55 @@ def test_admin_create_existing_user():
caught.should.be.true caught.should.be.true
@mock_cognitoidp
def test_admin_resend_invitation_existing_user():
conn = boto3.client("cognito-idp", "us-west-2")
username = str(uuid.uuid4())
value = str(uuid.uuid4())
user_pool_id = conn.create_user_pool(PoolName=str(uuid.uuid4()))["UserPool"]["Id"]
conn.admin_create_user(
UserPoolId=user_pool_id,
Username=username,
UserAttributes=[{"Name": "thing", "Value": value}],
)
caught = False
try:
conn.admin_create_user(
UserPoolId=user_pool_id,
Username=username,
UserAttributes=[{"Name": "thing", "Value": value}],
MessageAction="RESEND",
)
except conn.exceptions.UsernameExistsException:
caught = True
caught.should.be.false
@mock_cognitoidp
def test_admin_resend_invitation_missing_user():
conn = boto3.client("cognito-idp", "us-west-2")
username = str(uuid.uuid4())
value = str(uuid.uuid4())
user_pool_id = conn.create_user_pool(PoolName=str(uuid.uuid4()))["UserPool"]["Id"]
caught = False
try:
conn.admin_create_user(
UserPoolId=user_pool_id,
Username=username,
UserAttributes=[{"Name": "thing", "Value": value}],
MessageAction="RESEND",
)
except conn.exceptions.UserNotFoundException:
caught = True
caught.should.be.true
@mock_cognitoidp @mock_cognitoidp
def test_admin_get_user(): def test_admin_get_user():
conn = boto3.client("cognito-idp", "us-west-2") conn = boto3.client("cognito-idp", "us-west-2")
@ -958,6 +1012,18 @@ def test_list_users():
result["Users"].should.have.length_of(1) result["Users"].should.have.length_of(1)
result["Users"][0]["Username"].should.equal(username) result["Users"][0]["Username"].should.equal(username)
username_bis = str(uuid.uuid4())
conn.admin_create_user(
UserPoolId=user_pool_id,
Username=username_bis,
UserAttributes=[{"Name": "phone_number", "Value": "+33666666666"}],
)
result = conn.list_users(
UserPoolId=user_pool_id, Filter='phone_number="+33666666666'
)
result["Users"].should.have.length_of(1)
result["Users"][0]["Username"].should.equal(username_bis)
@mock_cognitoidp @mock_cognitoidp
def test_list_users_returns_limit_items(): def test_list_users_returns_limit_items():
@ -1142,11 +1208,13 @@ def test_token_legitimacy():
id_claims = json.loads(jws.verify(id_token, json_web_key, "RS256")) id_claims = json.loads(jws.verify(id_token, json_web_key, "RS256"))
id_claims["iss"].should.equal(issuer) id_claims["iss"].should.equal(issuer)
id_claims["aud"].should.equal(client_id) id_claims["aud"].should.equal(client_id)
id_claims["token_use"].should.equal("id")
for k, v in outputs["additional_fields"].items():
id_claims[k].should.equal(v)
access_claims = json.loads(jws.verify(access_token, json_web_key, "RS256")) access_claims = json.loads(jws.verify(access_token, json_web_key, "RS256"))
access_claims["iss"].should.equal(issuer) access_claims["iss"].should.equal(issuer)
access_claims["aud"].should.equal(client_id) access_claims["aud"].should.equal(client_id)
for k, v in outputs["additional_fields"].items(): access_claims["token_use"].should.equal("access")
access_claims[k].should.equal(v)
@mock_cognitoidp @mock_cognitoidp

View File

@ -46,4 +46,4 @@ def test_domain_dispatched_with_service():
dispatcher = DomainDispatcherApplication(create_backend_app, service="s3") dispatcher = DomainDispatcherApplication(create_backend_app, service="s3")
backend_app = dispatcher.get_application({"HTTP_HOST": "s3.us-east1.amazonaws.com"}) backend_app = dispatcher.get_application({"HTTP_HOST": "s3.us-east1.amazonaws.com"})
keys = set(backend_app.view_functions.keys()) keys = set(backend_app.view_functions.keys())
keys.should.contain("ResponseObject.key_response") keys.should.contain("ResponseObject.key_or_control_response")

View File

@ -1,182 +1,182 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import boto.datapipeline import boto.datapipeline
import sure # noqa import sure # noqa
from moto import mock_datapipeline_deprecated from moto import mock_datapipeline_deprecated
from moto.datapipeline.utils import remove_capitalization_of_dict_keys from moto.datapipeline.utils import remove_capitalization_of_dict_keys
def get_value_from_fields(key, fields): def get_value_from_fields(key, fields):
for field in fields: for field in fields:
if field["key"] == key: if field["key"] == key:
return field["stringValue"] return field["stringValue"]
@mock_datapipeline_deprecated @mock_datapipeline_deprecated
def test_create_pipeline(): def test_create_pipeline():
conn = boto.datapipeline.connect_to_region("us-west-2") conn = boto.datapipeline.connect_to_region("us-west-2")
res = conn.create_pipeline("mypipeline", "some-unique-id") res = conn.create_pipeline("mypipeline", "some-unique-id")
pipeline_id = res["pipelineId"] pipeline_id = res["pipelineId"]
pipeline_descriptions = conn.describe_pipelines([pipeline_id])[ pipeline_descriptions = conn.describe_pipelines([pipeline_id])[
"pipelineDescriptionList" "pipelineDescriptionList"
] ]
pipeline_descriptions.should.have.length_of(1) pipeline_descriptions.should.have.length_of(1)
pipeline_description = pipeline_descriptions[0] pipeline_description = pipeline_descriptions[0]
pipeline_description["name"].should.equal("mypipeline") pipeline_description["name"].should.equal("mypipeline")
pipeline_description["pipelineId"].should.equal(pipeline_id) pipeline_description["pipelineId"].should.equal(pipeline_id)
fields = pipeline_description["fields"] fields = pipeline_description["fields"]
get_value_from_fields("@pipelineState", fields).should.equal("PENDING") get_value_from_fields("@pipelineState", fields).should.equal("PENDING")
get_value_from_fields("uniqueId", fields).should.equal("some-unique-id") get_value_from_fields("uniqueId", fields).should.equal("some-unique-id")
PIPELINE_OBJECTS = [ PIPELINE_OBJECTS = [
{ {
"id": "Default", "id": "Default",
"name": "Default", "name": "Default",
"fields": [{"key": "workerGroup", "stringValue": "workerGroup"}], "fields": [{"key": "workerGroup", "stringValue": "workerGroup"}],
}, },
{ {
"id": "Schedule", "id": "Schedule",
"name": "Schedule", "name": "Schedule",
"fields": [ "fields": [
{"key": "startDateTime", "stringValue": "2012-12-12T00:00:00"}, {"key": "startDateTime", "stringValue": "2012-12-12T00:00:00"},
{"key": "type", "stringValue": "Schedule"}, {"key": "type", "stringValue": "Schedule"},
{"key": "period", "stringValue": "1 hour"}, {"key": "period", "stringValue": "1 hour"},
{"key": "endDateTime", "stringValue": "2012-12-21T18:00:00"}, {"key": "endDateTime", "stringValue": "2012-12-21T18:00:00"},
], ],
}, },
{ {
"id": "SayHello", "id": "SayHello",
"name": "SayHello", "name": "SayHello",
"fields": [ "fields": [
{"key": "type", "stringValue": "ShellCommandActivity"}, {"key": "type", "stringValue": "ShellCommandActivity"},
{"key": "command", "stringValue": "echo hello"}, {"key": "command", "stringValue": "echo hello"},
{"key": "parent", "refValue": "Default"}, {"key": "parent", "refValue": "Default"},
{"key": "schedule", "refValue": "Schedule"}, {"key": "schedule", "refValue": "Schedule"},
], ],
}, },
] ]
@mock_datapipeline_deprecated @mock_datapipeline_deprecated
def test_creating_pipeline_definition(): def test_creating_pipeline_definition():
conn = boto.datapipeline.connect_to_region("us-west-2") conn = boto.datapipeline.connect_to_region("us-west-2")
res = conn.create_pipeline("mypipeline", "some-unique-id") res = conn.create_pipeline("mypipeline", "some-unique-id")
pipeline_id = res["pipelineId"] pipeline_id = res["pipelineId"]
conn.put_pipeline_definition(PIPELINE_OBJECTS, pipeline_id) conn.put_pipeline_definition(PIPELINE_OBJECTS, pipeline_id)
pipeline_definition = conn.get_pipeline_definition(pipeline_id) pipeline_definition = conn.get_pipeline_definition(pipeline_id)
pipeline_definition["pipelineObjects"].should.have.length_of(3) pipeline_definition["pipelineObjects"].should.have.length_of(3)
default_object = pipeline_definition["pipelineObjects"][0] default_object = pipeline_definition["pipelineObjects"][0]
default_object["name"].should.equal("Default") default_object["name"].should.equal("Default")
default_object["id"].should.equal("Default") default_object["id"].should.equal("Default")
default_object["fields"].should.equal( default_object["fields"].should.equal(
[{"key": "workerGroup", "stringValue": "workerGroup"}] [{"key": "workerGroup", "stringValue": "workerGroup"}]
) )
@mock_datapipeline_deprecated @mock_datapipeline_deprecated
def test_describing_pipeline_objects(): def test_describing_pipeline_objects():
conn = boto.datapipeline.connect_to_region("us-west-2") conn = boto.datapipeline.connect_to_region("us-west-2")
res = conn.create_pipeline("mypipeline", "some-unique-id") res = conn.create_pipeline("mypipeline", "some-unique-id")
pipeline_id = res["pipelineId"] pipeline_id = res["pipelineId"]
conn.put_pipeline_definition(PIPELINE_OBJECTS, pipeline_id) conn.put_pipeline_definition(PIPELINE_OBJECTS, pipeline_id)
objects = conn.describe_objects(["Schedule", "Default"], pipeline_id)[ objects = conn.describe_objects(["Schedule", "Default"], pipeline_id)[
"pipelineObjects" "pipelineObjects"
] ]
objects.should.have.length_of(2) objects.should.have.length_of(2)
default_object = [x for x in objects if x["id"] == "Default"][0] default_object = [x for x in objects if x["id"] == "Default"][0]
default_object["name"].should.equal("Default") default_object["name"].should.equal("Default")
default_object["fields"].should.equal( default_object["fields"].should.equal(
[{"key": "workerGroup", "stringValue": "workerGroup"}] [{"key": "workerGroup", "stringValue": "workerGroup"}]
) )
@mock_datapipeline_deprecated @mock_datapipeline_deprecated
def test_activate_pipeline(): def test_activate_pipeline():
conn = boto.datapipeline.connect_to_region("us-west-2") conn = boto.datapipeline.connect_to_region("us-west-2")
res = conn.create_pipeline("mypipeline", "some-unique-id") res = conn.create_pipeline("mypipeline", "some-unique-id")
pipeline_id = res["pipelineId"] pipeline_id = res["pipelineId"]
conn.activate_pipeline(pipeline_id) conn.activate_pipeline(pipeline_id)
pipeline_descriptions = conn.describe_pipelines([pipeline_id])[ pipeline_descriptions = conn.describe_pipelines([pipeline_id])[
"pipelineDescriptionList" "pipelineDescriptionList"
] ]
pipeline_descriptions.should.have.length_of(1) pipeline_descriptions.should.have.length_of(1)
pipeline_description = pipeline_descriptions[0] pipeline_description = pipeline_descriptions[0]
fields = pipeline_description["fields"] fields = pipeline_description["fields"]
get_value_from_fields("@pipelineState", fields).should.equal("SCHEDULED") get_value_from_fields("@pipelineState", fields).should.equal("SCHEDULED")
@mock_datapipeline_deprecated @mock_datapipeline_deprecated
def test_delete_pipeline(): def test_delete_pipeline():
conn = boto.datapipeline.connect_to_region("us-west-2") conn = boto.datapipeline.connect_to_region("us-west-2")
res = conn.create_pipeline("mypipeline", "some-unique-id") res = conn.create_pipeline("mypipeline", "some-unique-id")
pipeline_id = res["pipelineId"] pipeline_id = res["pipelineId"]
conn.delete_pipeline(pipeline_id) conn.delete_pipeline(pipeline_id)
response = conn.list_pipelines() response = conn.list_pipelines()
response["pipelineIdList"].should.have.length_of(0) response["pipelineIdList"].should.have.length_of(0)
@mock_datapipeline_deprecated @mock_datapipeline_deprecated
def test_listing_pipelines(): def test_listing_pipelines():
conn = boto.datapipeline.connect_to_region("us-west-2") conn = boto.datapipeline.connect_to_region("us-west-2")
res1 = conn.create_pipeline("mypipeline1", "some-unique-id1") res1 = conn.create_pipeline("mypipeline1", "some-unique-id1")
res2 = conn.create_pipeline("mypipeline2", "some-unique-id2") res2 = conn.create_pipeline("mypipeline2", "some-unique-id2")
response = conn.list_pipelines() response = conn.list_pipelines()
response["hasMoreResults"].should.be(False) response["hasMoreResults"].should.be(False)
response["marker"].should.be.none response["marker"].should.be.none
response["pipelineIdList"].should.have.length_of(2) response["pipelineIdList"].should.have.length_of(2)
response["pipelineIdList"].should.contain( response["pipelineIdList"].should.contain(
{"id": res1["pipelineId"], "name": "mypipeline1"} {"id": res1["pipelineId"], "name": "mypipeline1"}
) )
response["pipelineIdList"].should.contain( response["pipelineIdList"].should.contain(
{"id": res2["pipelineId"], "name": "mypipeline2"} {"id": res2["pipelineId"], "name": "mypipeline2"}
) )
@mock_datapipeline_deprecated @mock_datapipeline_deprecated
def test_listing_paginated_pipelines(): def test_listing_paginated_pipelines():
conn = boto.datapipeline.connect_to_region("us-west-2") conn = boto.datapipeline.connect_to_region("us-west-2")
for i in range(100): for i in range(100):
conn.create_pipeline("mypipeline%d" % i, "some-unique-id%d" % i) conn.create_pipeline("mypipeline%d" % i, "some-unique-id%d" % i)
response = conn.list_pipelines() response = conn.list_pipelines()
response["hasMoreResults"].should.be(True) response["hasMoreResults"].should.be(True)
response["marker"].should.equal(response["pipelineIdList"][-1]["id"]) response["marker"].should.equal(response["pipelineIdList"][-1]["id"])
response["pipelineIdList"].should.have.length_of(50) response["pipelineIdList"].should.have.length_of(50)
# testing a helper function # testing a helper function
def test_remove_capitalization_of_dict_keys(): def test_remove_capitalization_of_dict_keys():
result = remove_capitalization_of_dict_keys( result = remove_capitalization_of_dict_keys(
{ {
"Id": "IdValue", "Id": "IdValue",
"Fields": [{"Key": "KeyValue", "StringValue": "StringValueValue"}], "Fields": [{"Key": "KeyValue", "StringValue": "StringValueValue"}],
} }
) )
result.should.equal( result.should.equal(
{ {
"id": "IdValue", "id": "IdValue",
"fields": [{"key": "KeyValue", "stringValue": "StringValueValue"}], "fields": [{"key": "KeyValue", "stringValue": "StringValueValue"}],
} }
) )

View File

@ -1,470 +1,470 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import boto import boto
import sure # noqa import sure # noqa
from freezegun import freeze_time from freezegun import freeze_time
from moto import mock_dynamodb_deprecated from moto import mock_dynamodb_deprecated
from boto.dynamodb import condition from boto.dynamodb import condition
from boto.dynamodb.exceptions import DynamoDBKeyNotFoundError, DynamoDBValidationError from boto.dynamodb.exceptions import DynamoDBKeyNotFoundError, DynamoDBValidationError
from boto.exception import DynamoDBResponseError from boto.exception import DynamoDBResponseError
def create_table(conn): def create_table(conn):
message_table_schema = conn.create_schema( message_table_schema = conn.create_schema(
hash_key_name="forum_name", hash_key_name="forum_name",
hash_key_proto_value=str, hash_key_proto_value=str,
range_key_name="subject", range_key_name="subject",
range_key_proto_value=str, range_key_proto_value=str,
) )
table = conn.create_table( table = conn.create_table(
name="messages", schema=message_table_schema, read_units=10, write_units=10 name="messages", schema=message_table_schema, read_units=10, write_units=10
) )
return table return table
@freeze_time("2012-01-14") @freeze_time("2012-01-14")
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_create_table(): def test_create_table():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
create_table(conn) create_table(conn)
expected = { expected = {
"Table": { "Table": {
"CreationDateTime": 1326499200.0, "CreationDateTime": 1326499200.0,
"ItemCount": 0, "ItemCount": 0,
"KeySchema": { "KeySchema": {
"HashKeyElement": {"AttributeName": "forum_name", "AttributeType": "S"}, "HashKeyElement": {"AttributeName": "forum_name", "AttributeType": "S"},
"RangeKeyElement": {"AttributeName": "subject", "AttributeType": "S"}, "RangeKeyElement": {"AttributeName": "subject", "AttributeType": "S"},
}, },
"ProvisionedThroughput": { "ProvisionedThroughput": {
"ReadCapacityUnits": 10, "ReadCapacityUnits": 10,
"WriteCapacityUnits": 10, "WriteCapacityUnits": 10,
}, },
"TableName": "messages", "TableName": "messages",
"TableSizeBytes": 0, "TableSizeBytes": 0,
"TableStatus": "ACTIVE", "TableStatus": "ACTIVE",
} }
} }
conn.describe_table("messages").should.equal(expected) conn.describe_table("messages").should.equal(expected)
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_delete_table(): def test_delete_table():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
create_table(conn) create_table(conn)
conn.list_tables().should.have.length_of(1) conn.list_tables().should.have.length_of(1)
conn.layer1.delete_table("messages") conn.layer1.delete_table("messages")
conn.list_tables().should.have.length_of(0) conn.list_tables().should.have.length_of(0)
conn.layer1.delete_table.when.called_with("messages").should.throw( conn.layer1.delete_table.when.called_with("messages").should.throw(
DynamoDBResponseError DynamoDBResponseError
) )
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_update_table_throughput(): def test_update_table_throughput():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
table = create_table(conn) table = create_table(conn)
table.read_units.should.equal(10) table.read_units.should.equal(10)
table.write_units.should.equal(10) table.write_units.should.equal(10)
table.update_throughput(5, 6) table.update_throughput(5, 6)
table.refresh() table.refresh()
table.read_units.should.equal(5) table.read_units.should.equal(5)
table.write_units.should.equal(6) table.write_units.should.equal(6)
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_item_add_and_describe_and_update(): def test_item_add_and_describe_and_update():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
table = create_table(conn) table = create_table(conn)
item_data = { item_data = {
"Body": "http://url_to_lolcat.gif", "Body": "http://url_to_lolcat.gif",
"SentBy": "User A", "SentBy": "User A",
"ReceivedTime": "12/9/2011 11:36:03 PM", "ReceivedTime": "12/9/2011 11:36:03 PM",
} }
item = table.new_item( item = table.new_item(
hash_key="LOLCat Forum", range_key="Check this out!", attrs=item_data hash_key="LOLCat Forum", range_key="Check this out!", attrs=item_data
) )
item.put() item.put()
table.has_item("LOLCat Forum", "Check this out!").should.equal(True) table.has_item("LOLCat Forum", "Check this out!").should.equal(True)
returned_item = table.get_item( returned_item = table.get_item(
hash_key="LOLCat Forum", hash_key="LOLCat Forum",
range_key="Check this out!", range_key="Check this out!",
attributes_to_get=["Body", "SentBy"], attributes_to_get=["Body", "SentBy"],
) )
dict(returned_item).should.equal( dict(returned_item).should.equal(
{ {
"forum_name": "LOLCat Forum", "forum_name": "LOLCat Forum",
"subject": "Check this out!", "subject": "Check this out!",
"Body": "http://url_to_lolcat.gif", "Body": "http://url_to_lolcat.gif",
"SentBy": "User A", "SentBy": "User A",
} }
) )
item["SentBy"] = "User B" item["SentBy"] = "User B"
item.put() item.put()
returned_item = table.get_item( returned_item = table.get_item(
hash_key="LOLCat Forum", hash_key="LOLCat Forum",
range_key="Check this out!", range_key="Check this out!",
attributes_to_get=["Body", "SentBy"], attributes_to_get=["Body", "SentBy"],
) )
dict(returned_item).should.equal( dict(returned_item).should.equal(
{ {
"forum_name": "LOLCat Forum", "forum_name": "LOLCat Forum",
"subject": "Check this out!", "subject": "Check this out!",
"Body": "http://url_to_lolcat.gif", "Body": "http://url_to_lolcat.gif",
"SentBy": "User B", "SentBy": "User B",
} }
) )
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_item_put_without_table(): def test_item_put_without_table():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
conn.layer1.put_item.when.called_with( conn.layer1.put_item.when.called_with(
table_name="undeclared-table", table_name="undeclared-table",
item=dict(hash_key="LOLCat Forum", range_key="Check this out!"), item=dict(hash_key="LOLCat Forum", range_key="Check this out!"),
).should.throw(DynamoDBResponseError) ).should.throw(DynamoDBResponseError)
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_get_missing_item(): def test_get_missing_item():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
table = create_table(conn) table = create_table(conn)
table.get_item.when.called_with(hash_key="tester", range_key="other").should.throw( table.get_item.when.called_with(hash_key="tester", range_key="other").should.throw(
DynamoDBKeyNotFoundError DynamoDBKeyNotFoundError
) )
table.has_item("foobar", "more").should.equal(False) table.has_item("foobar", "more").should.equal(False)
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_get_item_with_undeclared_table(): def test_get_item_with_undeclared_table():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
conn.layer1.get_item.when.called_with( conn.layer1.get_item.when.called_with(
table_name="undeclared-table", table_name="undeclared-table",
key={"HashKeyElement": {"S": "tester"}, "RangeKeyElement": {"S": "test-range"}}, key={"HashKeyElement": {"S": "tester"}, "RangeKeyElement": {"S": "test-range"}},
).should.throw(DynamoDBKeyNotFoundError) ).should.throw(DynamoDBKeyNotFoundError)
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_get_item_without_range_key(): def test_get_item_without_range_key():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
message_table_schema = conn.create_schema( message_table_schema = conn.create_schema(
hash_key_name="test_hash", hash_key_name="test_hash",
hash_key_proto_value=int, hash_key_proto_value=int,
range_key_name="test_range", range_key_name="test_range",
range_key_proto_value=int, range_key_proto_value=int,
) )
table = conn.create_table( table = conn.create_table(
name="messages", schema=message_table_schema, read_units=10, write_units=10 name="messages", schema=message_table_schema, read_units=10, write_units=10
) )
hash_key = 3241526475 hash_key = 3241526475
range_key = 1234567890987 range_key = 1234567890987
new_item = table.new_item(hash_key=hash_key, range_key=range_key) new_item = table.new_item(hash_key=hash_key, range_key=range_key)
new_item.put() new_item.put()
table.get_item.when.called_with(hash_key=hash_key).should.throw( table.get_item.when.called_with(hash_key=hash_key).should.throw(
DynamoDBValidationError DynamoDBValidationError
) )
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_delete_item(): def test_delete_item():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
table = create_table(conn) table = create_table(conn)
item_data = { item_data = {
"Body": "http://url_to_lolcat.gif", "Body": "http://url_to_lolcat.gif",
"SentBy": "User A", "SentBy": "User A",
"ReceivedTime": "12/9/2011 11:36:03 PM", "ReceivedTime": "12/9/2011 11:36:03 PM",
} }
item = table.new_item( item = table.new_item(
hash_key="LOLCat Forum", range_key="Check this out!", attrs=item_data hash_key="LOLCat Forum", range_key="Check this out!", attrs=item_data
) )
item.put() item.put()
table.refresh() table.refresh()
table.item_count.should.equal(1) table.item_count.should.equal(1)
response = item.delete() response = item.delete()
response.should.equal({"Attributes": [], "ConsumedCapacityUnits": 0.5}) response.should.equal({"Attributes": [], "ConsumedCapacityUnits": 0.5})
table.refresh() table.refresh()
table.item_count.should.equal(0) table.item_count.should.equal(0)
item.delete.when.called_with().should.throw(DynamoDBResponseError) item.delete.when.called_with().should.throw(DynamoDBResponseError)
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_delete_item_with_attribute_response(): def test_delete_item_with_attribute_response():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
table = create_table(conn) table = create_table(conn)
item_data = { item_data = {
"Body": "http://url_to_lolcat.gif", "Body": "http://url_to_lolcat.gif",
"SentBy": "User A", "SentBy": "User A",
"ReceivedTime": "12/9/2011 11:36:03 PM", "ReceivedTime": "12/9/2011 11:36:03 PM",
} }
item = table.new_item( item = table.new_item(
hash_key="LOLCat Forum", range_key="Check this out!", attrs=item_data hash_key="LOLCat Forum", range_key="Check this out!", attrs=item_data
) )
item.put() item.put()
table.refresh() table.refresh()
table.item_count.should.equal(1) table.item_count.should.equal(1)
response = item.delete(return_values="ALL_OLD") response = item.delete(return_values="ALL_OLD")
response.should.equal( response.should.equal(
{ {
"Attributes": { "Attributes": {
"Body": "http://url_to_lolcat.gif", "Body": "http://url_to_lolcat.gif",
"forum_name": "LOLCat Forum", "forum_name": "LOLCat Forum",
"ReceivedTime": "12/9/2011 11:36:03 PM", "ReceivedTime": "12/9/2011 11:36:03 PM",
"SentBy": "User A", "SentBy": "User A",
"subject": "Check this out!", "subject": "Check this out!",
}, },
"ConsumedCapacityUnits": 0.5, "ConsumedCapacityUnits": 0.5,
} }
) )
table.refresh() table.refresh()
table.item_count.should.equal(0) table.item_count.should.equal(0)
item.delete.when.called_with().should.throw(DynamoDBResponseError) item.delete.when.called_with().should.throw(DynamoDBResponseError)
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_delete_item_with_undeclared_table(): def test_delete_item_with_undeclared_table():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
conn.layer1.delete_item.when.called_with( conn.layer1.delete_item.when.called_with(
table_name="undeclared-table", table_name="undeclared-table",
key={"HashKeyElement": {"S": "tester"}, "RangeKeyElement": {"S": "test-range"}}, key={"HashKeyElement": {"S": "tester"}, "RangeKeyElement": {"S": "test-range"}},
).should.throw(DynamoDBResponseError) ).should.throw(DynamoDBResponseError)
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_query(): def test_query():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
table = create_table(conn) table = create_table(conn)
item_data = { item_data = {
"Body": "http://url_to_lolcat.gif", "Body": "http://url_to_lolcat.gif",
"SentBy": "User A", "SentBy": "User A",
"ReceivedTime": "12/9/2011 11:36:03 PM", "ReceivedTime": "12/9/2011 11:36:03 PM",
} }
item = table.new_item(hash_key="the-key", range_key="456", attrs=item_data) item = table.new_item(hash_key="the-key", range_key="456", attrs=item_data)
item.put() item.put()
item = table.new_item(hash_key="the-key", range_key="123", attrs=item_data) item = table.new_item(hash_key="the-key", range_key="123", attrs=item_data)
item.put() item.put()
item = table.new_item(hash_key="the-key", range_key="789", attrs=item_data) item = table.new_item(hash_key="the-key", range_key="789", attrs=item_data)
item.put() item.put()
results = table.query(hash_key="the-key", range_key_condition=condition.GT("1")) results = table.query(hash_key="the-key", range_key_condition=condition.GT("1"))
results.response["Items"].should.have.length_of(3) results.response["Items"].should.have.length_of(3)
results = table.query(hash_key="the-key", range_key_condition=condition.GT("234")) results = table.query(hash_key="the-key", range_key_condition=condition.GT("234"))
results.response["Items"].should.have.length_of(2) results.response["Items"].should.have.length_of(2)
results = table.query(hash_key="the-key", range_key_condition=condition.GT("9999")) results = table.query(hash_key="the-key", range_key_condition=condition.GT("9999"))
results.response["Items"].should.have.length_of(0) results.response["Items"].should.have.length_of(0)
results = table.query( results = table.query(
hash_key="the-key", range_key_condition=condition.CONTAINS("12") hash_key="the-key", range_key_condition=condition.CONTAINS("12")
) )
results.response["Items"].should.have.length_of(1) results.response["Items"].should.have.length_of(1)
results = table.query( results = table.query(
hash_key="the-key", range_key_condition=condition.BEGINS_WITH("7") hash_key="the-key", range_key_condition=condition.BEGINS_WITH("7")
) )
results.response["Items"].should.have.length_of(1) results.response["Items"].should.have.length_of(1)
results = table.query( results = table.query(
hash_key="the-key", range_key_condition=condition.BETWEEN("567", "890") hash_key="the-key", range_key_condition=condition.BETWEEN("567", "890")
) )
results.response["Items"].should.have.length_of(1) results.response["Items"].should.have.length_of(1)
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_query_with_undeclared_table(): def test_query_with_undeclared_table():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
conn.layer1.query.when.called_with( conn.layer1.query.when.called_with(
table_name="undeclared-table", table_name="undeclared-table",
hash_key_value={"S": "the-key"}, hash_key_value={"S": "the-key"},
range_key_conditions={ range_key_conditions={
"AttributeValueList": [{"S": "User B"}], "AttributeValueList": [{"S": "User B"}],
"ComparisonOperator": "EQ", "ComparisonOperator": "EQ",
}, },
).should.throw(DynamoDBResponseError) ).should.throw(DynamoDBResponseError)
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_scan(): def test_scan():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
table = create_table(conn) table = create_table(conn)
item_data = { item_data = {
"Body": "http://url_to_lolcat.gif", "Body": "http://url_to_lolcat.gif",
"SentBy": "User A", "SentBy": "User A",
"ReceivedTime": "12/9/2011 11:36:03 PM", "ReceivedTime": "12/9/2011 11:36:03 PM",
} }
item = table.new_item(hash_key="the-key", range_key="456", attrs=item_data) item = table.new_item(hash_key="the-key", range_key="456", attrs=item_data)
item.put() item.put()
item = table.new_item(hash_key="the-key", range_key="123", attrs=item_data) item = table.new_item(hash_key="the-key", range_key="123", attrs=item_data)
item.put() item.put()
item_data = { item_data = {
"Body": "http://url_to_lolcat.gif", "Body": "http://url_to_lolcat.gif",
"SentBy": "User B", "SentBy": "User B",
"ReceivedTime": "12/9/2011 11:36:03 PM", "ReceivedTime": "12/9/2011 11:36:03 PM",
"Ids": set([1, 2, 3]), "Ids": set([1, 2, 3]),
"PK": 7, "PK": 7,
} }
item = table.new_item(hash_key="the-key", range_key="789", attrs=item_data) item = table.new_item(hash_key="the-key", range_key="789", attrs=item_data)
item.put() item.put()
results = table.scan() results = table.scan()
results.response["Items"].should.have.length_of(3) results.response["Items"].should.have.length_of(3)
results = table.scan(scan_filter={"SentBy": condition.EQ("User B")}) results = table.scan(scan_filter={"SentBy": condition.EQ("User B")})
results.response["Items"].should.have.length_of(1) results.response["Items"].should.have.length_of(1)
results = table.scan(scan_filter={"Body": condition.BEGINS_WITH("http")}) results = table.scan(scan_filter={"Body": condition.BEGINS_WITH("http")})
results.response["Items"].should.have.length_of(3) results.response["Items"].should.have.length_of(3)
results = table.scan(scan_filter={"Ids": condition.CONTAINS(2)}) results = table.scan(scan_filter={"Ids": condition.CONTAINS(2)})
results.response["Items"].should.have.length_of(1) results.response["Items"].should.have.length_of(1)
results = table.scan(scan_filter={"Ids": condition.NOT_NULL()}) results = table.scan(scan_filter={"Ids": condition.NOT_NULL()})
results.response["Items"].should.have.length_of(1) results.response["Items"].should.have.length_of(1)
results = table.scan(scan_filter={"Ids": condition.NULL()}) results = table.scan(scan_filter={"Ids": condition.NULL()})
results.response["Items"].should.have.length_of(2) results.response["Items"].should.have.length_of(2)
results = table.scan(scan_filter={"PK": condition.BETWEEN(8, 9)}) results = table.scan(scan_filter={"PK": condition.BETWEEN(8, 9)})
results.response["Items"].should.have.length_of(0) results.response["Items"].should.have.length_of(0)
results = table.scan(scan_filter={"PK": condition.BETWEEN(5, 8)}) results = table.scan(scan_filter={"PK": condition.BETWEEN(5, 8)})
results.response["Items"].should.have.length_of(1) results.response["Items"].should.have.length_of(1)
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_scan_with_undeclared_table(): def test_scan_with_undeclared_table():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
conn.layer1.scan.when.called_with( conn.layer1.scan.when.called_with(
table_name="undeclared-table", table_name="undeclared-table",
scan_filter={ scan_filter={
"SentBy": { "SentBy": {
"AttributeValueList": [{"S": "User B"}], "AttributeValueList": [{"S": "User B"}],
"ComparisonOperator": "EQ", "ComparisonOperator": "EQ",
} }
}, },
).should.throw(DynamoDBResponseError) ).should.throw(DynamoDBResponseError)
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_scan_after_has_item(): def test_scan_after_has_item():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
table = create_table(conn) table = create_table(conn)
list(table.scan()).should.equal([]) list(table.scan()).should.equal([])
table.has_item(hash_key="the-key", range_key="123") table.has_item(hash_key="the-key", range_key="123")
list(table.scan()).should.equal([]) list(table.scan()).should.equal([])
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_write_batch(): def test_write_batch():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
table = create_table(conn) table = create_table(conn)
batch_list = conn.new_batch_write_list() batch_list = conn.new_batch_write_list()
items = [] items = []
items.append( items.append(
table.new_item( table.new_item(
hash_key="the-key", hash_key="the-key",
range_key="123", range_key="123",
attrs={ attrs={
"Body": "http://url_to_lolcat.gif", "Body": "http://url_to_lolcat.gif",
"SentBy": "User A", "SentBy": "User A",
"ReceivedTime": "12/9/2011 11:36:03 PM", "ReceivedTime": "12/9/2011 11:36:03 PM",
}, },
) )
) )
items.append( items.append(
table.new_item( table.new_item(
hash_key="the-key", hash_key="the-key",
range_key="789", range_key="789",
attrs={ attrs={
"Body": "http://url_to_lolcat.gif", "Body": "http://url_to_lolcat.gif",
"SentBy": "User B", "SentBy": "User B",
"ReceivedTime": "12/9/2011 11:36:03 PM", "ReceivedTime": "12/9/2011 11:36:03 PM",
"Ids": set([1, 2, 3]), "Ids": set([1, 2, 3]),
"PK": 7, "PK": 7,
}, },
) )
) )
batch_list.add_batch(table, puts=items) batch_list.add_batch(table, puts=items)
conn.batch_write_item(batch_list) conn.batch_write_item(batch_list)
table.refresh() table.refresh()
table.item_count.should.equal(2) table.item_count.should.equal(2)
batch_list = conn.new_batch_write_list() batch_list = conn.new_batch_write_list()
batch_list.add_batch(table, deletes=[("the-key", "789")]) batch_list.add_batch(table, deletes=[("the-key", "789")])
conn.batch_write_item(batch_list) conn.batch_write_item(batch_list)
table.refresh() table.refresh()
table.item_count.should.equal(1) table.item_count.should.equal(1)
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_batch_read(): def test_batch_read():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
table = create_table(conn) table = create_table(conn)
item_data = { item_data = {
"Body": "http://url_to_lolcat.gif", "Body": "http://url_to_lolcat.gif",
"SentBy": "User A", "SentBy": "User A",
"ReceivedTime": "12/9/2011 11:36:03 PM", "ReceivedTime": "12/9/2011 11:36:03 PM",
} }
item = table.new_item(hash_key="the-key", range_key="456", attrs=item_data) item = table.new_item(hash_key="the-key", range_key="456", attrs=item_data)
item.put() item.put()
item = table.new_item(hash_key="the-key", range_key="123", attrs=item_data) item = table.new_item(hash_key="the-key", range_key="123", attrs=item_data)
item.put() item.put()
item_data = { item_data = {
"Body": "http://url_to_lolcat.gif", "Body": "http://url_to_lolcat.gif",
"SentBy": "User B", "SentBy": "User B",
"ReceivedTime": "12/9/2011 11:36:03 PM", "ReceivedTime": "12/9/2011 11:36:03 PM",
"Ids": set([1, 2, 3]), "Ids": set([1, 2, 3]),
"PK": 7, "PK": 7,
} }
item = table.new_item(hash_key="another-key", range_key="789", attrs=item_data) item = table.new_item(hash_key="another-key", range_key="789", attrs=item_data)
item.put() item.put()
items = table.batch_get_item([("the-key", "123"), ("another-key", "789")]) items = table.batch_get_item([("the-key", "123"), ("another-key", "789")])
# Iterate through so that batch_item gets called # Iterate through so that batch_item gets called
count = len([x for x in items]) count = len([x for x in items])
count.should.equal(2) count.should.equal(2)

View File

@ -1,390 +1,390 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import boto import boto
import sure # noqa import sure # noqa
from freezegun import freeze_time from freezegun import freeze_time
from moto import mock_dynamodb_deprecated from moto import mock_dynamodb_deprecated
from boto.dynamodb import condition from boto.dynamodb import condition
from boto.dynamodb.exceptions import DynamoDBKeyNotFoundError from boto.dynamodb.exceptions import DynamoDBKeyNotFoundError
from boto.exception import DynamoDBResponseError from boto.exception import DynamoDBResponseError
def create_table(conn): def create_table(conn):
message_table_schema = conn.create_schema( message_table_schema = conn.create_schema(
hash_key_name="forum_name", hash_key_proto_value=str hash_key_name="forum_name", hash_key_proto_value=str
) )
table = conn.create_table( table = conn.create_table(
name="messages", schema=message_table_schema, read_units=10, write_units=10 name="messages", schema=message_table_schema, read_units=10, write_units=10
) )
return table return table
@freeze_time("2012-01-14") @freeze_time("2012-01-14")
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_create_table(): def test_create_table():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
create_table(conn) create_table(conn)
expected = { expected = {
"Table": { "Table": {
"CreationDateTime": 1326499200.0, "CreationDateTime": 1326499200.0,
"ItemCount": 0, "ItemCount": 0,
"KeySchema": { "KeySchema": {
"HashKeyElement": {"AttributeName": "forum_name", "AttributeType": "S"} "HashKeyElement": {"AttributeName": "forum_name", "AttributeType": "S"}
}, },
"ProvisionedThroughput": { "ProvisionedThroughput": {
"ReadCapacityUnits": 10, "ReadCapacityUnits": 10,
"WriteCapacityUnits": 10, "WriteCapacityUnits": 10,
}, },
"TableName": "messages", "TableName": "messages",
"TableSizeBytes": 0, "TableSizeBytes": 0,
"TableStatus": "ACTIVE", "TableStatus": "ACTIVE",
} }
} }
conn.describe_table("messages").should.equal(expected) conn.describe_table("messages").should.equal(expected)
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_delete_table(): def test_delete_table():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
create_table(conn) create_table(conn)
conn.list_tables().should.have.length_of(1) conn.list_tables().should.have.length_of(1)
conn.layer1.delete_table("messages") conn.layer1.delete_table("messages")
conn.list_tables().should.have.length_of(0) conn.list_tables().should.have.length_of(0)
conn.layer1.delete_table.when.called_with("messages").should.throw( conn.layer1.delete_table.when.called_with("messages").should.throw(
DynamoDBResponseError DynamoDBResponseError
) )
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_update_table_throughput(): def test_update_table_throughput():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
table = create_table(conn) table = create_table(conn)
table.read_units.should.equal(10) table.read_units.should.equal(10)
table.write_units.should.equal(10) table.write_units.should.equal(10)
table.update_throughput(5, 6) table.update_throughput(5, 6)
table.refresh() table.refresh()
table.read_units.should.equal(5) table.read_units.should.equal(5)
table.write_units.should.equal(6) table.write_units.should.equal(6)
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_item_add_and_describe_and_update(): def test_item_add_and_describe_and_update():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
table = create_table(conn) table = create_table(conn)
item_data = { item_data = {
"Body": "http://url_to_lolcat.gif", "Body": "http://url_to_lolcat.gif",
"SentBy": "User A", "SentBy": "User A",
"ReceivedTime": "12/9/2011 11:36:03 PM", "ReceivedTime": "12/9/2011 11:36:03 PM",
} }
item = table.new_item(hash_key="LOLCat Forum", attrs=item_data) item = table.new_item(hash_key="LOLCat Forum", attrs=item_data)
item.put() item.put()
returned_item = table.get_item( returned_item = table.get_item(
hash_key="LOLCat Forum", attributes_to_get=["Body", "SentBy"] hash_key="LOLCat Forum", attributes_to_get=["Body", "SentBy"]
) )
dict(returned_item).should.equal( dict(returned_item).should.equal(
{ {
"forum_name": "LOLCat Forum", "forum_name": "LOLCat Forum",
"Body": "http://url_to_lolcat.gif", "Body": "http://url_to_lolcat.gif",
"SentBy": "User A", "SentBy": "User A",
} }
) )
item["SentBy"] = "User B" item["SentBy"] = "User B"
item.put() item.put()
returned_item = table.get_item( returned_item = table.get_item(
hash_key="LOLCat Forum", attributes_to_get=["Body", "SentBy"] hash_key="LOLCat Forum", attributes_to_get=["Body", "SentBy"]
) )
dict(returned_item).should.equal( dict(returned_item).should.equal(
{ {
"forum_name": "LOLCat Forum", "forum_name": "LOLCat Forum",
"Body": "http://url_to_lolcat.gif", "Body": "http://url_to_lolcat.gif",
"SentBy": "User B", "SentBy": "User B",
} }
) )
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_item_put_without_table(): def test_item_put_without_table():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
conn.layer1.put_item.when.called_with( conn.layer1.put_item.when.called_with(
table_name="undeclared-table", item=dict(hash_key="LOLCat Forum") table_name="undeclared-table", item=dict(hash_key="LOLCat Forum")
).should.throw(DynamoDBResponseError) ).should.throw(DynamoDBResponseError)
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_get_missing_item(): def test_get_missing_item():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
table = create_table(conn) table = create_table(conn)
table.get_item.when.called_with(hash_key="tester").should.throw( table.get_item.when.called_with(hash_key="tester").should.throw(
DynamoDBKeyNotFoundError DynamoDBKeyNotFoundError
) )
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_get_item_with_undeclared_table(): def test_get_item_with_undeclared_table():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
conn.layer1.get_item.when.called_with( conn.layer1.get_item.when.called_with(
table_name="undeclared-table", key={"HashKeyElement": {"S": "tester"}} table_name="undeclared-table", key={"HashKeyElement": {"S": "tester"}}
).should.throw(DynamoDBKeyNotFoundError) ).should.throw(DynamoDBKeyNotFoundError)
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_delete_item(): def test_delete_item():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
table = create_table(conn) table = create_table(conn)
item_data = { item_data = {
"Body": "http://url_to_lolcat.gif", "Body": "http://url_to_lolcat.gif",
"SentBy": "User A", "SentBy": "User A",
"ReceivedTime": "12/9/2011 11:36:03 PM", "ReceivedTime": "12/9/2011 11:36:03 PM",
} }
item = table.new_item(hash_key="LOLCat Forum", attrs=item_data) item = table.new_item(hash_key="LOLCat Forum", attrs=item_data)
item.put() item.put()
table.refresh() table.refresh()
table.item_count.should.equal(1) table.item_count.should.equal(1)
response = item.delete() response = item.delete()
response.should.equal({"Attributes": [], "ConsumedCapacityUnits": 0.5}) response.should.equal({"Attributes": [], "ConsumedCapacityUnits": 0.5})
table.refresh() table.refresh()
table.item_count.should.equal(0) table.item_count.should.equal(0)
item.delete.when.called_with().should.throw(DynamoDBResponseError) item.delete.when.called_with().should.throw(DynamoDBResponseError)
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_delete_item_with_attribute_response(): def test_delete_item_with_attribute_response():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
table = create_table(conn) table = create_table(conn)
item_data = { item_data = {
"Body": "http://url_to_lolcat.gif", "Body": "http://url_to_lolcat.gif",
"SentBy": "User A", "SentBy": "User A",
"ReceivedTime": "12/9/2011 11:36:03 PM", "ReceivedTime": "12/9/2011 11:36:03 PM",
} }
item = table.new_item(hash_key="LOLCat Forum", attrs=item_data) item = table.new_item(hash_key="LOLCat Forum", attrs=item_data)
item.put() item.put()
table.refresh() table.refresh()
table.item_count.should.equal(1) table.item_count.should.equal(1)
response = item.delete(return_values="ALL_OLD") response = item.delete(return_values="ALL_OLD")
response.should.equal( response.should.equal(
{ {
"Attributes": { "Attributes": {
"Body": "http://url_to_lolcat.gif", "Body": "http://url_to_lolcat.gif",
"forum_name": "LOLCat Forum", "forum_name": "LOLCat Forum",
"ReceivedTime": "12/9/2011 11:36:03 PM", "ReceivedTime": "12/9/2011 11:36:03 PM",
"SentBy": "User A", "SentBy": "User A",
}, },
"ConsumedCapacityUnits": 0.5, "ConsumedCapacityUnits": 0.5,
} }
) )
table.refresh() table.refresh()
table.item_count.should.equal(0) table.item_count.should.equal(0)
item.delete.when.called_with().should.throw(DynamoDBResponseError) item.delete.when.called_with().should.throw(DynamoDBResponseError)
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_delete_item_with_undeclared_table(): def test_delete_item_with_undeclared_table():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
conn.layer1.delete_item.when.called_with( conn.layer1.delete_item.when.called_with(
table_name="undeclared-table", key={"HashKeyElement": {"S": "tester"}} table_name="undeclared-table", key={"HashKeyElement": {"S": "tester"}}
).should.throw(DynamoDBResponseError) ).should.throw(DynamoDBResponseError)
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_query(): def test_query():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
table = create_table(conn) table = create_table(conn)
item_data = { item_data = {
"Body": "http://url_to_lolcat.gif", "Body": "http://url_to_lolcat.gif",
"SentBy": "User A", "SentBy": "User A",
"ReceivedTime": "12/9/2011 11:36:03 PM", "ReceivedTime": "12/9/2011 11:36:03 PM",
} }
item = table.new_item(hash_key="the-key", attrs=item_data) item = table.new_item(hash_key="the-key", attrs=item_data)
item.put() item.put()
results = table.query(hash_key="the-key") results = table.query(hash_key="the-key")
results.response["Items"].should.have.length_of(1) results.response["Items"].should.have.length_of(1)
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_query_with_undeclared_table(): def test_query_with_undeclared_table():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
conn.layer1.query.when.called_with( conn.layer1.query.when.called_with(
table_name="undeclared-table", hash_key_value={"S": "the-key"} table_name="undeclared-table", hash_key_value={"S": "the-key"}
).should.throw(DynamoDBResponseError) ).should.throw(DynamoDBResponseError)
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_scan(): def test_scan():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
table = create_table(conn) table = create_table(conn)
item_data = { item_data = {
"Body": "http://url_to_lolcat.gif", "Body": "http://url_to_lolcat.gif",
"SentBy": "User A", "SentBy": "User A",
"ReceivedTime": "12/9/2011 11:36:03 PM", "ReceivedTime": "12/9/2011 11:36:03 PM",
} }
item = table.new_item(hash_key="the-key", attrs=item_data) item = table.new_item(hash_key="the-key", attrs=item_data)
item.put() item.put()
item = table.new_item(hash_key="the-key2", attrs=item_data) item = table.new_item(hash_key="the-key2", attrs=item_data)
item.put() item.put()
item_data = { item_data = {
"Body": "http://url_to_lolcat.gif", "Body": "http://url_to_lolcat.gif",
"SentBy": "User B", "SentBy": "User B",
"ReceivedTime": "12/9/2011 11:36:03 PM", "ReceivedTime": "12/9/2011 11:36:03 PM",
"Ids": set([1, 2, 3]), "Ids": set([1, 2, 3]),
"PK": 7, "PK": 7,
} }
item = table.new_item(hash_key="the-key3", attrs=item_data) item = table.new_item(hash_key="the-key3", attrs=item_data)
item.put() item.put()
results = table.scan() results = table.scan()
results.response["Items"].should.have.length_of(3) results.response["Items"].should.have.length_of(3)
results = table.scan(scan_filter={"SentBy": condition.EQ("User B")}) results = table.scan(scan_filter={"SentBy": condition.EQ("User B")})
results.response["Items"].should.have.length_of(1) results.response["Items"].should.have.length_of(1)
results = table.scan(scan_filter={"Body": condition.BEGINS_WITH("http")}) results = table.scan(scan_filter={"Body": condition.BEGINS_WITH("http")})
results.response["Items"].should.have.length_of(3) results.response["Items"].should.have.length_of(3)
results = table.scan(scan_filter={"Ids": condition.CONTAINS(2)}) results = table.scan(scan_filter={"Ids": condition.CONTAINS(2)})
results.response["Items"].should.have.length_of(1) results.response["Items"].should.have.length_of(1)
results = table.scan(scan_filter={"Ids": condition.NOT_NULL()}) results = table.scan(scan_filter={"Ids": condition.NOT_NULL()})
results.response["Items"].should.have.length_of(1) results.response["Items"].should.have.length_of(1)
results = table.scan(scan_filter={"Ids": condition.NULL()}) results = table.scan(scan_filter={"Ids": condition.NULL()})
results.response["Items"].should.have.length_of(2) results.response["Items"].should.have.length_of(2)
results = table.scan(scan_filter={"PK": condition.BETWEEN(8, 9)}) results = table.scan(scan_filter={"PK": condition.BETWEEN(8, 9)})
results.response["Items"].should.have.length_of(0) results.response["Items"].should.have.length_of(0)
results = table.scan(scan_filter={"PK": condition.BETWEEN(5, 8)}) results = table.scan(scan_filter={"PK": condition.BETWEEN(5, 8)})
results.response["Items"].should.have.length_of(1) results.response["Items"].should.have.length_of(1)
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_scan_with_undeclared_table(): def test_scan_with_undeclared_table():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
conn.layer1.scan.when.called_with( conn.layer1.scan.when.called_with(
table_name="undeclared-table", table_name="undeclared-table",
scan_filter={ scan_filter={
"SentBy": { "SentBy": {
"AttributeValueList": [{"S": "User B"}], "AttributeValueList": [{"S": "User B"}],
"ComparisonOperator": "EQ", "ComparisonOperator": "EQ",
} }
}, },
).should.throw(DynamoDBResponseError) ).should.throw(DynamoDBResponseError)
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_scan_after_has_item(): def test_scan_after_has_item():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
table = create_table(conn) table = create_table(conn)
list(table.scan()).should.equal([]) list(table.scan()).should.equal([])
table.has_item("the-key") table.has_item("the-key")
list(table.scan()).should.equal([]) list(table.scan()).should.equal([])
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_write_batch(): def test_write_batch():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
table = create_table(conn) table = create_table(conn)
batch_list = conn.new_batch_write_list() batch_list = conn.new_batch_write_list()
items = [] items = []
items.append( items.append(
table.new_item( table.new_item(
hash_key="the-key", hash_key="the-key",
attrs={ attrs={
"Body": "http://url_to_lolcat.gif", "Body": "http://url_to_lolcat.gif",
"SentBy": "User A", "SentBy": "User A",
"ReceivedTime": "12/9/2011 11:36:03 PM", "ReceivedTime": "12/9/2011 11:36:03 PM",
}, },
) )
) )
items.append( items.append(
table.new_item( table.new_item(
hash_key="the-key2", hash_key="the-key2",
attrs={ attrs={
"Body": "http://url_to_lolcat.gif", "Body": "http://url_to_lolcat.gif",
"SentBy": "User B", "SentBy": "User B",
"ReceivedTime": "12/9/2011 11:36:03 PM", "ReceivedTime": "12/9/2011 11:36:03 PM",
"Ids": set([1, 2, 3]), "Ids": set([1, 2, 3]),
"PK": 7, "PK": 7,
}, },
) )
) )
batch_list.add_batch(table, puts=items) batch_list.add_batch(table, puts=items)
conn.batch_write_item(batch_list) conn.batch_write_item(batch_list)
table.refresh() table.refresh()
table.item_count.should.equal(2) table.item_count.should.equal(2)
batch_list = conn.new_batch_write_list() batch_list = conn.new_batch_write_list()
batch_list.add_batch(table, deletes=[("the-key")]) batch_list.add_batch(table, deletes=[("the-key")])
conn.batch_write_item(batch_list) conn.batch_write_item(batch_list)
table.refresh() table.refresh()
table.item_count.should.equal(1) table.item_count.should.equal(1)
@mock_dynamodb_deprecated @mock_dynamodb_deprecated
def test_batch_read(): def test_batch_read():
conn = boto.connect_dynamodb() conn = boto.connect_dynamodb()
table = create_table(conn) table = create_table(conn)
item_data = { item_data = {
"Body": "http://url_to_lolcat.gif", "Body": "http://url_to_lolcat.gif",
"SentBy": "User A", "SentBy": "User A",
"ReceivedTime": "12/9/2011 11:36:03 PM", "ReceivedTime": "12/9/2011 11:36:03 PM",
} }
item = table.new_item(hash_key="the-key1", attrs=item_data) item = table.new_item(hash_key="the-key1", attrs=item_data)
item.put() item.put()
item = table.new_item(hash_key="the-key2", attrs=item_data) item = table.new_item(hash_key="the-key2", attrs=item_data)
item.put() item.put()
item_data = { item_data = {
"Body": "http://url_to_lolcat.gif", "Body": "http://url_to_lolcat.gif",
"SentBy": "User B", "SentBy": "User B",
"ReceivedTime": "12/9/2011 11:36:03 PM", "ReceivedTime": "12/9/2011 11:36:03 PM",
"Ids": set([1, 2, 3]), "Ids": set([1, 2, 3]),
"PK": 7, "PK": 7,
} }
item = table.new_item(hash_key="another-key", attrs=item_data) item = table.new_item(hash_key="another-key", attrs=item_data)
item.put() item.put()
items = table.batch_get_item([("the-key1"), ("another-key")]) items = table.batch_get_item([("the-key1"), ("another-key")])
# Iterate through so that batch_item gets called # Iterate through so that batch_item gets called
count = len([x for x in items]) count = len([x for x in items])
count.should.have.equal(2) count.should.have.equal(2)

View File

@ -1719,6 +1719,32 @@ def test_scan_filter4():
assert response["Count"] == 0 assert response["Count"] == 0
@mock_dynamodb2
def test_scan_filter_should_not_return_non_existing_attributes():
table_name = "my-table"
item = {"partitionKey": "pk-2", "my-attr": 42}
# Create table
res = boto3.resource("dynamodb", region_name="us-east-1")
res.create_table(
TableName=table_name,
KeySchema=[{"AttributeName": "partitionKey", "KeyType": "HASH"}],
AttributeDefinitions=[{"AttributeName": "partitionKey", "AttributeType": "S"}],
BillingMode="PAY_PER_REQUEST",
)
table = res.Table(table_name)
# Insert items
table.put_item(Item={"partitionKey": "pk-1"})
table.put_item(Item=item)
# Verify a few operations
# Assert we only find the item that has this attribute
table.scan(FilterExpression=Attr("my-attr").lt(43))["Items"].should.equal([item])
table.scan(FilterExpression=Attr("my-attr").lte(42))["Items"].should.equal([item])
table.scan(FilterExpression=Attr("my-attr").gte(42))["Items"].should.equal([item])
table.scan(FilterExpression=Attr("my-attr").gt(41))["Items"].should.equal([item])
# Sanity check that we can't find the item if the FE is wrong
table.scan(FilterExpression=Attr("my-attr").gt(43))["Items"].should.equal([])
@mock_dynamodb2 @mock_dynamodb2
def test_bad_scan_filter(): def test_bad_scan_filter():
client = boto3.client("dynamodb", region_name="us-east-1") client = boto3.client("dynamodb", region_name="us-east-1")
@ -2505,6 +2531,48 @@ def test_condition_expressions():
) )
@mock_dynamodb2
def test_condition_expression_numerical_attribute():
dynamodb = boto3.resource("dynamodb", region_name="us-east-1")
dynamodb.create_table(
TableName="my-table",
KeySchema=[{"AttributeName": "partitionKey", "KeyType": "HASH"}],
AttributeDefinitions=[{"AttributeName": "partitionKey", "AttributeType": "S"}],
)
table = dynamodb.Table("my-table")
table.put_item(Item={"partitionKey": "pk-pos", "myAttr": 5})
table.put_item(Item={"partitionKey": "pk-neg", "myAttr": -5})
# try to update the item we put in the table using numerical condition expression
# Specifically, verify that we can compare with a zero-value
# First verify that > and >= work on positive numbers
update_numerical_con_expr(
key="pk-pos", con_expr="myAttr > :zero", res="6", table=table
)
update_numerical_con_expr(
key="pk-pos", con_expr="myAttr >= :zero", res="7", table=table
)
# Second verify that < and <= work on negative numbers
update_numerical_con_expr(
key="pk-neg", con_expr="myAttr < :zero", res="-4", table=table
)
update_numerical_con_expr(
key="pk-neg", con_expr="myAttr <= :zero", res="-3", table=table
)
def update_numerical_con_expr(key, con_expr, res, table):
table.update_item(
Key={"partitionKey": key},
UpdateExpression="ADD myAttr :one",
ExpressionAttributeValues={":zero": 0, ":one": 1},
ConditionExpression=con_expr,
)
table.get_item(Key={"partitionKey": key})["Item"]["myAttr"].should.equal(
Decimal(res)
)
@mock_dynamodb2 @mock_dynamodb2
def test_condition_expression__attr_doesnt_exist(): def test_condition_expression__attr_doesnt_exist():
client = boto3.client("dynamodb", region_name="us-east-1") client = boto3.client("dynamodb", region_name="us-east-1")
@ -3489,6 +3557,83 @@ def test_update_supports_nested_list_append_onto_another_list():
) )
@mock_dynamodb2
def test_update_supports_list_append_maps():
client = boto3.client("dynamodb", region_name="us-west-1")
client.create_table(
AttributeDefinitions=[
{"AttributeName": "id", "AttributeType": "S"},
{"AttributeName": "rid", "AttributeType": "S"},
],
TableName="TestTable",
KeySchema=[
{"AttributeName": "id", "KeyType": "HASH"},
{"AttributeName": "rid", "KeyType": "RANGE"},
],
ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5},
)
client.put_item(
TableName="TestTable",
Item={
"id": {"S": "nested_list_append"},
"rid": {"S": "range_key"},
"a": {"L": [{"M": {"b": {"S": "bar1"}}}]},
},
)
# Update item using list_append expression
client.update_item(
TableName="TestTable",
Key={"id": {"S": "nested_list_append"}, "rid": {"S": "range_key"}},
UpdateExpression="SET a = list_append(a, :i)",
ExpressionAttributeValues={":i": {"L": [{"M": {"b": {"S": "bar2"}}}]}},
)
# Verify item is appended to the existing list
result = client.query(
TableName="TestTable",
KeyConditionExpression="id = :i AND begins_with(rid, :r)",
ExpressionAttributeValues={
":i": {"S": "nested_list_append"},
":r": {"S": "range_key"},
},
)["Items"]
result.should.equal(
[
{
"a": {"L": [{"M": {"b": {"S": "bar1"}}}, {"M": {"b": {"S": "bar2"}}}]},
"rid": {"S": "range_key"},
"id": {"S": "nested_list_append"},
}
]
)
@mock_dynamodb2
def test_update_supports_list_append_with_nested_if_not_exists_operation():
dynamo = boto3.resource("dynamodb", region_name="us-west-1")
table_name = "test"
dynamo.create_table(
TableName=table_name,
AttributeDefinitions=[{"AttributeName": "Id", "AttributeType": "S"}],
KeySchema=[{"AttributeName": "Id", "KeyType": "HASH"}],
ProvisionedThroughput={"ReadCapacityUnits": 20, "WriteCapacityUnits": 20},
)
table = dynamo.Table(table_name)
table.put_item(Item={"Id": "item-id", "nest1": {"nest2": {}}})
table.update_item(
Key={"Id": "item-id"},
UpdateExpression="SET nest1.nest2.event_history = list_append(if_not_exists(nest1.nest2.event_history, :empty_list), :new_value)",
ExpressionAttributeValues={":empty_list": [], ":new_value": ["some_value"]},
)
table.get_item(Key={"Id": "item-id"})["Item"].should.equal(
{"Id": "item-id", "nest1": {"nest2": {"event_history": ["some_value"]}}}
)
@mock_dynamodb2 @mock_dynamodb2
def test_update_catches_invalid_list_append_operation(): def test_update_catches_invalid_list_append_operation():
client = boto3.client("dynamodb", region_name="us-east-1") client = boto3.client("dynamodb", region_name="us-east-1")
@ -3601,3 +3746,24 @@ def test_allow_update_to_item_with_different_type():
table.get_item(Key={"job_id": "b"})["Item"]["job_details"][ table.get_item(Key={"job_id": "b"})["Item"]["job_details"][
"job_name" "job_name"
].should.be.equal({"nested": "yes"}) ].should.be.equal({"nested": "yes"})
@mock_dynamodb2
def test_query_catches_when_no_filters():
dynamo = boto3.resource("dynamodb", region_name="eu-central-1")
dynamo.create_table(
AttributeDefinitions=[{"AttributeName": "job_id", "AttributeType": "S"}],
TableName="origin-rbu-dev",
KeySchema=[{"AttributeName": "job_id", "KeyType": "HASH"}],
ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1},
)
table = dynamo.Table("origin-rbu-dev")
with assert_raises(ClientError) as ex:
table.query(TableName="original-rbu-dev")
ex.exception.response["Error"]["Code"].should.equal("ValidationException")
ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400)
ex.exception.response["Error"]["Message"].should.equal(
"Either KeyConditions or QueryFilter should be present"
)

View File

@ -1,37 +1,37 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import boto3 import boto3
from moto import mock_ec2 from moto import mock_ec2
import sure # noqa import sure # noqa
@mock_ec2 @mock_ec2
def test_describe_account_attributes(): def test_describe_account_attributes():
conn = boto3.client("ec2", region_name="us-east-1") conn = boto3.client("ec2", region_name="us-east-1")
response = conn.describe_account_attributes() response = conn.describe_account_attributes()
expected_attribute_values = [ expected_attribute_values = [
{ {
"AttributeValues": [{"AttributeValue": "5"}], "AttributeValues": [{"AttributeValue": "5"}],
"AttributeName": "vpc-max-security-groups-per-interface", "AttributeName": "vpc-max-security-groups-per-interface",
}, },
{ {
"AttributeValues": [{"AttributeValue": "20"}], "AttributeValues": [{"AttributeValue": "20"}],
"AttributeName": "max-instances", "AttributeName": "max-instances",
}, },
{ {
"AttributeValues": [{"AttributeValue": "EC2"}, {"AttributeValue": "VPC"}], "AttributeValues": [{"AttributeValue": "EC2"}, {"AttributeValue": "VPC"}],
"AttributeName": "supported-platforms", "AttributeName": "supported-platforms",
}, },
{ {
"AttributeValues": [{"AttributeValue": "none"}], "AttributeValues": [{"AttributeValue": "none"}],
"AttributeName": "default-vpc", "AttributeName": "default-vpc",
}, },
{ {
"AttributeValues": [{"AttributeValue": "5"}], "AttributeValues": [{"AttributeValue": "5"}],
"AttributeName": "max-elastic-ips", "AttributeName": "max-elastic-ips",
}, },
{ {
"AttributeValues": [{"AttributeValue": "5"}], "AttributeValues": [{"AttributeValue": "5"}],
"AttributeName": "vpc-max-elastic-ips", "AttributeName": "vpc-max-elastic-ips",
}, },
] ]
response["AccountAttributes"].should.equal(expected_attribute_values) response["AccountAttributes"].should.equal(expected_attribute_values)

View File

@ -1,10 +1,10 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import boto import boto
import sure # noqa import sure # noqa
from moto import mock_ec2 from moto import mock_ec2
@mock_ec2 @mock_ec2
def test_amazon_dev_pay(): def test_amazon_dev_pay():
pass pass

View File

@ -12,6 +12,7 @@ import sure # noqa
from moto import mock_ec2_deprecated, mock_ec2 from moto import mock_ec2_deprecated, mock_ec2
from moto.ec2.models import AMIS, OWNER_ID from moto.ec2.models import AMIS, OWNER_ID
from moto.iam.models import ACCOUNT_ID
from tests.helpers import requires_boto_gte from tests.helpers import requires_boto_gte
@ -251,6 +252,19 @@ def test_ami_pulls_attributes_from_instance():
image.kernel_id.should.equal("test-kernel") image.kernel_id.should.equal("test-kernel")
@mock_ec2_deprecated
def test_ami_uses_account_id_if_valid_access_key_is_supplied():
access_key = "AKIAXXXXXXXXXXXXXXXX"
conn = boto.connect_ec2(access_key, "the_secret")
reservation = conn.run_instances("ami-1234abcd")
instance = reservation.instances[0]
instance.modify_attribute("kernel", "test-kernel")
image_id = conn.create_image(instance.id, "test-ami", "this is a test ami")
images = conn.get_all_images(owners=["self"])
[(ami.id, ami.owner_id) for ami in images].should.equal([(image_id, ACCOUNT_ID)])
@mock_ec2_deprecated @mock_ec2_deprecated
def test_ami_filters(): def test_ami_filters():
conn = boto.connect_ec2("the_key", "the_secret") conn = boto.connect_ec2("the_key", "the_secret")
@ -773,7 +787,7 @@ def test_ami_filter_wildcard():
instance.create_image(Name="not-matching-image") instance.create_image(Name="not-matching-image")
my_images = ec2_client.describe_images( my_images = ec2_client.describe_images(
Owners=["111122223333"], Filters=[{"Name": "name", "Values": ["test*"]}] Owners=[ACCOUNT_ID], Filters=[{"Name": "name", "Values": ["test*"]}]
)["Images"] )["Images"]
my_images.should.have.length_of(1) my_images.should.have.length_of(1)

View File

@ -1 +1 @@
from __future__ import unicode_literals from __future__ import unicode_literals

View File

@ -1,10 +1,10 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import boto import boto
import sure # noqa import sure # noqa
from moto import mock_ec2 from moto import mock_ec2
@mock_ec2 @mock_ec2
def test_ip_addresses(): def test_ip_addresses():
pass pass

View File

@ -1,10 +1,10 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import boto import boto
import sure # noqa import sure # noqa
from moto import mock_ec2 from moto import mock_ec2
@mock_ec2 @mock_ec2
def test_monitoring(): def test_monitoring():
pass pass

View File

@ -1,10 +1,10 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import boto import boto
import sure # noqa import sure # noqa
from moto import mock_ec2 from moto import mock_ec2
@mock_ec2 @mock_ec2
def test_placement_groups(): def test_placement_groups():
pass pass

View File

@ -1,10 +1,10 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import boto import boto
import sure # noqa import sure # noqa
from moto import mock_ec2 from moto import mock_ec2
@mock_ec2 @mock_ec2
def test_reserved_instances(): def test_reserved_instances():
pass pass

View File

@ -236,8 +236,8 @@ def test_route_table_associations():
@mock_ec2_deprecated @mock_ec2_deprecated
def test_route_table_replace_route_table_association(): def test_route_table_replace_route_table_association():
""" """
Note: Boto has deprecated replace_route_table_assocation (which returns status) Note: Boto has deprecated replace_route_table_association (which returns status)
and now uses replace_route_table_assocation_with_assoc (which returns association ID). and now uses replace_route_table_association_with_assoc (which returns association ID).
""" """
conn = boto.connect_vpc("the_key", "the_secret") conn = boto.connect_vpc("the_key", "the_secret")
vpc = conn.create_vpc("10.0.0.0/16") vpc = conn.create_vpc("10.0.0.0/16")

View File

@ -1,96 +1,96 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import boto import boto
import sure # noqa import sure # noqa
from moto import mock_ec2_deprecated from moto import mock_ec2_deprecated
@mock_ec2_deprecated @mock_ec2_deprecated
def test_virtual_private_gateways(): def test_virtual_private_gateways():
conn = boto.connect_vpc("the_key", "the_secret") conn = boto.connect_vpc("the_key", "the_secret")
vpn_gateway = conn.create_vpn_gateway("ipsec.1", "us-east-1a") vpn_gateway = conn.create_vpn_gateway("ipsec.1", "us-east-1a")
vpn_gateway.should_not.be.none vpn_gateway.should_not.be.none
vpn_gateway.id.should.match(r"vgw-\w+") vpn_gateway.id.should.match(r"vgw-\w+")
vpn_gateway.type.should.equal("ipsec.1") vpn_gateway.type.should.equal("ipsec.1")
vpn_gateway.state.should.equal("available") vpn_gateway.state.should.equal("available")
vpn_gateway.availability_zone.should.equal("us-east-1a") vpn_gateway.availability_zone.should.equal("us-east-1a")
@mock_ec2_deprecated @mock_ec2_deprecated
def test_describe_vpn_gateway(): def test_describe_vpn_gateway():
conn = boto.connect_vpc("the_key", "the_secret") conn = boto.connect_vpc("the_key", "the_secret")
vpn_gateway = conn.create_vpn_gateway("ipsec.1", "us-east-1a") vpn_gateway = conn.create_vpn_gateway("ipsec.1", "us-east-1a")
vgws = conn.get_all_vpn_gateways() vgws = conn.get_all_vpn_gateways()
vgws.should.have.length_of(1) vgws.should.have.length_of(1)
gateway = vgws[0] gateway = vgws[0]
gateway.id.should.match(r"vgw-\w+") gateway.id.should.match(r"vgw-\w+")
gateway.id.should.equal(vpn_gateway.id) gateway.id.should.equal(vpn_gateway.id)
vpn_gateway.type.should.equal("ipsec.1") vpn_gateway.type.should.equal("ipsec.1")
vpn_gateway.state.should.equal("available") vpn_gateway.state.should.equal("available")
vpn_gateway.availability_zone.should.equal("us-east-1a") vpn_gateway.availability_zone.should.equal("us-east-1a")
@mock_ec2_deprecated @mock_ec2_deprecated
def test_vpn_gateway_vpc_attachment(): def test_vpn_gateway_vpc_attachment():
conn = boto.connect_vpc("the_key", "the_secret") conn = boto.connect_vpc("the_key", "the_secret")
vpc = conn.create_vpc("10.0.0.0/16") vpc = conn.create_vpc("10.0.0.0/16")
vpn_gateway = conn.create_vpn_gateway("ipsec.1", "us-east-1a") vpn_gateway = conn.create_vpn_gateway("ipsec.1", "us-east-1a")
conn.attach_vpn_gateway(vpn_gateway_id=vpn_gateway.id, vpc_id=vpc.id) conn.attach_vpn_gateway(vpn_gateway_id=vpn_gateway.id, vpc_id=vpc.id)
gateway = conn.get_all_vpn_gateways()[0] gateway = conn.get_all_vpn_gateways()[0]
attachments = gateway.attachments attachments = gateway.attachments
attachments.should.have.length_of(1) attachments.should.have.length_of(1)
attachments[0].vpc_id.should.equal(vpc.id) attachments[0].vpc_id.should.equal(vpc.id)
attachments[0].state.should.equal("attached") attachments[0].state.should.equal("attached")
@mock_ec2_deprecated @mock_ec2_deprecated
def test_delete_vpn_gateway(): def test_delete_vpn_gateway():
conn = boto.connect_vpc("the_key", "the_secret") conn = boto.connect_vpc("the_key", "the_secret")
vpn_gateway = conn.create_vpn_gateway("ipsec.1", "us-east-1a") vpn_gateway = conn.create_vpn_gateway("ipsec.1", "us-east-1a")
conn.delete_vpn_gateway(vpn_gateway.id) conn.delete_vpn_gateway(vpn_gateway.id)
vgws = conn.get_all_vpn_gateways() vgws = conn.get_all_vpn_gateways()
vgws.should.have.length_of(0) vgws.should.have.length_of(0)
@mock_ec2_deprecated @mock_ec2_deprecated
def test_vpn_gateway_tagging(): def test_vpn_gateway_tagging():
conn = boto.connect_vpc("the_key", "the_secret") conn = boto.connect_vpc("the_key", "the_secret")
vpn_gateway = conn.create_vpn_gateway("ipsec.1", "us-east-1a") vpn_gateway = conn.create_vpn_gateway("ipsec.1", "us-east-1a")
vpn_gateway.add_tag("a key", "some value") vpn_gateway.add_tag("a key", "some value")
tag = conn.get_all_tags()[0] tag = conn.get_all_tags()[0]
tag.name.should.equal("a key") tag.name.should.equal("a key")
tag.value.should.equal("some value") tag.value.should.equal("some value")
# Refresh the subnet # Refresh the subnet
vpn_gateway = conn.get_all_vpn_gateways()[0] vpn_gateway = conn.get_all_vpn_gateways()[0]
vpn_gateway.tags.should.have.length_of(1) vpn_gateway.tags.should.have.length_of(1)
vpn_gateway.tags["a key"].should.equal("some value") vpn_gateway.tags["a key"].should.equal("some value")
@mock_ec2_deprecated @mock_ec2_deprecated
def test_detach_vpn_gateway(): def test_detach_vpn_gateway():
conn = boto.connect_vpc("the_key", "the_secret") conn = boto.connect_vpc("the_key", "the_secret")
vpc = conn.create_vpc("10.0.0.0/16") vpc = conn.create_vpc("10.0.0.0/16")
vpn_gateway = conn.create_vpn_gateway("ipsec.1", "us-east-1a") vpn_gateway = conn.create_vpn_gateway("ipsec.1", "us-east-1a")
conn.attach_vpn_gateway(vpn_gateway_id=vpn_gateway.id, vpc_id=vpc.id) conn.attach_vpn_gateway(vpn_gateway_id=vpn_gateway.id, vpc_id=vpc.id)
gateway = conn.get_all_vpn_gateways()[0] gateway = conn.get_all_vpn_gateways()[0]
attachments = gateway.attachments attachments = gateway.attachments
attachments.should.have.length_of(1) attachments.should.have.length_of(1)
attachments[0].vpc_id.should.equal(vpc.id) attachments[0].vpc_id.should.equal(vpc.id)
attachments[0].state.should.equal("attached") attachments[0].state.should.equal("attached")
conn.detach_vpn_gateway(vpn_gateway_id=vpn_gateway.id, vpc_id=vpc.id) conn.detach_vpn_gateway(vpn_gateway_id=vpn_gateway.id, vpc_id=vpc.id)
gateway = conn.get_all_vpn_gateways()[0] gateway = conn.get_all_vpn_gateways()[0]
attachments = gateway.attachments attachments = gateway.attachments
attachments.should.have.length_of(0) attachments.should.have.length_of(0)

View File

@ -1,10 +1,10 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import boto import boto
import sure # noqa import sure # noqa
from moto import mock_ec2 from moto import mock_ec2
@mock_ec2 @mock_ec2
def test_vm_export(): def test_vm_export():
pass pass

View File

@ -1,10 +1,10 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import boto import boto
import sure # noqa import sure # noqa
from moto import mock_ec2 from moto import mock_ec2
@mock_ec2 @mock_ec2
def test_vm_import(): def test_vm_import():
pass pass

View File

@ -1,10 +1,10 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import boto import boto
import sure # noqa import sure # noqa
from moto import mock_ec2 from moto import mock_ec2
@mock_ec2 @mock_ec2
def test_windows(): def test_windows():
pass pass

View File

@ -77,7 +77,7 @@ def test_describe_repositories():
response = client.describe_repositories() response = client.describe_repositories()
len(response["repositories"]).should.equal(2) len(response["repositories"]).should.equal(2)
respository_arns = [ repository_arns = [
"arn:aws:ecr:us-east-1:012345678910:repository/test_repository1", "arn:aws:ecr:us-east-1:012345678910:repository/test_repository1",
"arn:aws:ecr:us-east-1:012345678910:repository/test_repository0", "arn:aws:ecr:us-east-1:012345678910:repository/test_repository0",
] ]
@ -86,9 +86,9 @@ def test_describe_repositories():
response["repositories"][0]["repositoryArn"], response["repositories"][0]["repositoryArn"],
response["repositories"][1]["repositoryArn"], response["repositories"][1]["repositoryArn"],
] ]
).should.equal(set(respository_arns)) ).should.equal(set(repository_arns))
respository_uris = [ repository_uris = [
"012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository1", "012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository1",
"012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository0", "012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository0",
] ]
@ -97,7 +97,7 @@ def test_describe_repositories():
response["repositories"][0]["repositoryUri"], response["repositories"][0]["repositoryUri"],
response["repositories"][1]["repositoryUri"], response["repositories"][1]["repositoryUri"],
] ]
).should.equal(set(respository_uris)) ).should.equal(set(repository_uris))
@mock_ecr @mock_ecr
@ -108,7 +108,7 @@ def test_describe_repositories_1():
response = client.describe_repositories(registryId="012345678910") response = client.describe_repositories(registryId="012345678910")
len(response["repositories"]).should.equal(2) len(response["repositories"]).should.equal(2)
respository_arns = [ repository_arns = [
"arn:aws:ecr:us-east-1:012345678910:repository/test_repository1", "arn:aws:ecr:us-east-1:012345678910:repository/test_repository1",
"arn:aws:ecr:us-east-1:012345678910:repository/test_repository0", "arn:aws:ecr:us-east-1:012345678910:repository/test_repository0",
] ]
@ -117,9 +117,9 @@ def test_describe_repositories_1():
response["repositories"][0]["repositoryArn"], response["repositories"][0]["repositoryArn"],
response["repositories"][1]["repositoryArn"], response["repositories"][1]["repositoryArn"],
] ]
).should.equal(set(respository_arns)) ).should.equal(set(repository_arns))
respository_uris = [ repository_uris = [
"012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository1", "012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository1",
"012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository0", "012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository0",
] ]
@ -128,7 +128,7 @@ def test_describe_repositories_1():
response["repositories"][0]["repositoryUri"], response["repositories"][0]["repositoryUri"],
response["repositories"][1]["repositoryUri"], response["repositories"][1]["repositoryUri"],
] ]
).should.equal(set(respository_uris)) ).should.equal(set(repository_uris))
@mock_ecr @mock_ecr
@ -147,11 +147,11 @@ def test_describe_repositories_3():
_ = client.create_repository(repositoryName="test_repository0") _ = client.create_repository(repositoryName="test_repository0")
response = client.describe_repositories(repositoryNames=["test_repository1"]) response = client.describe_repositories(repositoryNames=["test_repository1"])
len(response["repositories"]).should.equal(1) len(response["repositories"]).should.equal(1)
respository_arn = "arn:aws:ecr:us-east-1:012345678910:repository/test_repository1" repository_arn = "arn:aws:ecr:us-east-1:012345678910:repository/test_repository1"
response["repositories"][0]["repositoryArn"].should.equal(respository_arn) response["repositories"][0]["repositoryArn"].should.equal(repository_arn)
respository_uri = "012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository1" repository_uri = "012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository1"
response["repositories"][0]["repositoryUri"].should.equal(respository_uri) response["repositories"][0]["repositoryUri"].should.equal(repository_uri)
@mock_ecr @mock_ecr

View File

@ -94,6 +94,7 @@ def test_register_task_definition():
"logConfiguration": {"logDriver": "json-file"}, "logConfiguration": {"logDriver": "json-file"},
} }
], ],
networkMode="bridge",
tags=[ tags=[
{"key": "createdBy", "value": "moto-unittest"}, {"key": "createdBy", "value": "moto-unittest"},
{"key": "foo", "value": "bar"}, {"key": "foo", "value": "bar"},
@ -124,6 +125,7 @@ def test_register_task_definition():
response["taskDefinition"]["containerDefinitions"][0]["logConfiguration"][ response["taskDefinition"]["containerDefinitions"][0]["logConfiguration"][
"logDriver" "logDriver"
].should.equal("json-file") ].should.equal("json-file")
response["taskDefinition"]["networkMode"].should.equal("bridge")
@mock_ecs @mock_ecs
@ -724,7 +726,7 @@ def test_delete_service():
@mock_ecs @mock_ecs
def test_update_non_existant_service(): def test_update_non_existent_service():
client = boto3.client("ecs", region_name="us-east-1") client = boto3.client("ecs", region_name="us-east-1")
try: try:
client.update_service( client.update_service(

View File

@ -1391,7 +1391,7 @@ def test_set_security_groups():
len(resp["LoadBalancers"][0]["SecurityGroups"]).should.equal(2) len(resp["LoadBalancers"][0]["SecurityGroups"]).should.equal(2)
with assert_raises(ClientError): with assert_raises(ClientError):
client.set_security_groups(LoadBalancerArn=arn, SecurityGroups=["non_existant"]) client.set_security_groups(LoadBalancerArn=arn, SecurityGroups=["non_existent"])
@mock_elbv2 @mock_elbv2

View File

@ -1,14 +1,18 @@
import random from moto.events.models import EventsBackend
import boto3
import json
import sure # noqa
from moto.events import mock_events from moto.events import mock_events
import json
import random
import unittest
import boto3
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
from moto.core.exceptions import JsonRESTError from moto.core.exceptions import JsonRESTError
from nose.tools import assert_raises from nose.tools import assert_raises
from moto.core import ACCOUNT_ID from moto.core import ACCOUNT_ID
from moto.events.models import EventsBackend << << << < HEAD
== == == =
>>>>>> > 100dbd529f174f18d579a1dcc066d55409f2e38f
RULES = [ RULES = [
{"Name": "test1", "ScheduleExpression": "rate(5 minutes)"}, {"Name": "test1", "ScheduleExpression": "rate(5 minutes)"},
@ -456,6 +460,11 @@ def test_delete_event_bus_errors():
ClientError, "Cannot delete event bus default." ClientError, "Cannot delete event bus default."
) )
<< << << < HEAD
== == == =
>>>>>> > 100dbd529f174f18d579a1dcc066d55409f2e38f
@mock_events @mock_events
def test_rule_tagging_happy(): def test_rule_tagging_happy():
client = generate_environment() client = generate_environment()
@ -466,7 +475,12 @@ def test_rule_tagging_happy():
client.tag_resource(ResourceARN=rule_arn, Tags=tags) client.tag_resource(ResourceARN=rule_arn, Tags=tags)
actual = client.list_tags_for_resource(ResourceARN=rule_arn).get("Tags") actual = client.list_tags_for_resource(ResourceARN=rule_arn).get("Tags")
assert tags == actual tc = unittest.TestCase("__init__")
expected = [{"Value": "value1", "Key": "key1"}, {"Value": "value2", "Key": "key2"}]
tc.assertTrue(
(expected[0] == actual[0] and expected[1] == actual[1])
or (expected[1] == actual[0] and expected[0] == actual[1])
)
client.untag_resource(ResourceARN=rule_arn, TagKeys=["key1"]) client.untag_resource(ResourceARN=rule_arn, TagKeys=["key1"])
@ -474,24 +488,25 @@ def test_rule_tagging_happy():
expected = [{"Key": "key2", "Value": "value2"}] expected = [{"Key": "key2", "Value": "value2"}]
assert expected == actual assert expected == actual
@mock_events @mock_events
def test_rule_tagging_sad(): def test_rule_tagging_sad():
b = EventsBackend("us-west-2") back_end = EventsBackend("us-west-2")
try: try:
b.tag_resource('unknown', []) back_end.tag_resource("unknown", [])
raise 'tag_resource should fail if ResourceARN is not known' raise "tag_resource should fail if ResourceARN is not known"
except JsonRESTError: except JsonRESTError:
pass pass
try: try:
b.untag_resource('unknown', []) back_end.untag_resource("unknown", [])
raise 'untag_resource should fail if ResourceARN is not known' raise "untag_resource should fail if ResourceARN is not known"
except JsonRESTError: except JsonRESTError:
pass pass
try: try:
b.list_tags_for_resource('unknown') back_end.list_tags_for_resource("unknown")
raise 'list_tags_for_resource should fail if ResourceARN is not known' raise "list_tags_for_resource should fail if ResourceARN is not known"
except JsonRESTError: except JsonRESTError:
pass pass

View File

@ -1,21 +1,21 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
import boto.glacier import boto.glacier
import sure # noqa import sure # noqa
from moto import mock_glacier_deprecated from moto import mock_glacier_deprecated
@mock_glacier_deprecated @mock_glacier_deprecated
def test_create_and_delete_archive(): def test_create_and_delete_archive():
the_file = NamedTemporaryFile(delete=False) the_file = NamedTemporaryFile(delete=False)
the_file.write(b"some stuff") the_file.write(b"some stuff")
the_file.close() the_file.close()
conn = boto.glacier.connect_to_region("us-west-2") conn = boto.glacier.connect_to_region("us-west-2")
vault = conn.create_vault("my_vault") vault = conn.create_vault("my_vault")
archive_id = vault.upload_archive(the_file.name) archive_id = vault.upload_archive(the_file.name)
vault.delete_archive(archive_id) vault.delete_archive(archive_id)

View File

@ -1,31 +1,31 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import boto.glacier import boto.glacier
import sure # noqa import sure # noqa
from moto import mock_glacier_deprecated from moto import mock_glacier_deprecated
@mock_glacier_deprecated @mock_glacier_deprecated
def test_create_vault(): def test_create_vault():
conn = boto.glacier.connect_to_region("us-west-2") conn = boto.glacier.connect_to_region("us-west-2")
conn.create_vault("my_vault") conn.create_vault("my_vault")
vaults = conn.list_vaults() vaults = conn.list_vaults()
vaults.should.have.length_of(1) vaults.should.have.length_of(1)
vaults[0].name.should.equal("my_vault") vaults[0].name.should.equal("my_vault")
@mock_glacier_deprecated @mock_glacier_deprecated
def test_delete_vault(): def test_delete_vault():
conn = boto.glacier.connect_to_region("us-west-2") conn = boto.glacier.connect_to_region("us-west-2")
conn.create_vault("my_vault") conn.create_vault("my_vault")
vaults = conn.list_vaults() vaults = conn.list_vaults()
vaults.should.have.length_of(1) vaults.should.have.length_of(1)
conn.delete_vault("my_vault") conn.delete_vault("my_vault")
vaults = conn.list_vaults() vaults = conn.list_vaults()
vaults.should.have.length_of(0) vaults.should.have.length_of(0)

View File

@ -1 +1 @@
from __future__ import unicode_literals from __future__ import unicode_literals

View File

@ -1 +1 @@
from __future__ import unicode_literals from __future__ import unicode_literals

View File

@ -1,97 +1,97 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import copy import copy
from .fixtures.datacatalog import TABLE_INPUT, PARTITION_INPUT from .fixtures.datacatalog import TABLE_INPUT, PARTITION_INPUT
def create_database(client, database_name): def create_database(client, database_name):
return client.create_database(DatabaseInput={"Name": database_name}) return client.create_database(DatabaseInput={"Name": database_name})
def get_database(client, database_name): def get_database(client, database_name):
return client.get_database(Name=database_name) return client.get_database(Name=database_name)
def create_table_input(database_name, table_name, columns=[], partition_keys=[]): def create_table_input(database_name, table_name, columns=[], partition_keys=[]):
table_input = copy.deepcopy(TABLE_INPUT) table_input = copy.deepcopy(TABLE_INPUT)
table_input["Name"] = table_name table_input["Name"] = table_name
table_input["PartitionKeys"] = partition_keys table_input["PartitionKeys"] = partition_keys
table_input["StorageDescriptor"]["Columns"] = columns table_input["StorageDescriptor"]["Columns"] = columns
table_input["StorageDescriptor"][ table_input["StorageDescriptor"][
"Location" "Location"
] = "s3://my-bucket/{database_name}/{table_name}".format( ] = "s3://my-bucket/{database_name}/{table_name}".format(
database_name=database_name, table_name=table_name database_name=database_name, table_name=table_name
) )
return table_input return table_input
def create_table(client, database_name, table_name, table_input=None, **kwargs): def create_table(client, database_name, table_name, table_input=None, **kwargs):
if table_input is None: if table_input is None:
table_input = create_table_input(database_name, table_name, **kwargs) table_input = create_table_input(database_name, table_name, **kwargs)
return client.create_table(DatabaseName=database_name, TableInput=table_input) return client.create_table(DatabaseName=database_name, TableInput=table_input)
def update_table(client, database_name, table_name, table_input=None, **kwargs): def update_table(client, database_name, table_name, table_input=None, **kwargs):
if table_input is None: if table_input is None:
table_input = create_table_input(database_name, table_name, **kwargs) table_input = create_table_input(database_name, table_name, **kwargs)
return client.update_table(DatabaseName=database_name, TableInput=table_input) return client.update_table(DatabaseName=database_name, TableInput=table_input)
def get_table(client, database_name, table_name): def get_table(client, database_name, table_name):
return client.get_table(DatabaseName=database_name, Name=table_name) return client.get_table(DatabaseName=database_name, Name=table_name)
def get_tables(client, database_name): def get_tables(client, database_name):
return client.get_tables(DatabaseName=database_name) return client.get_tables(DatabaseName=database_name)
def get_table_versions(client, database_name, table_name): def get_table_versions(client, database_name, table_name):
return client.get_table_versions(DatabaseName=database_name, TableName=table_name) return client.get_table_versions(DatabaseName=database_name, TableName=table_name)
def get_table_version(client, database_name, table_name, version_id): def get_table_version(client, database_name, table_name, version_id):
return client.get_table_version( return client.get_table_version(
DatabaseName=database_name, TableName=table_name, VersionId=version_id DatabaseName=database_name, TableName=table_name, VersionId=version_id
) )
def create_partition_input(database_name, table_name, values=[], columns=[]): def create_partition_input(database_name, table_name, values=[], columns=[]):
root_path = "s3://my-bucket/{database_name}/{table_name}".format( root_path = "s3://my-bucket/{database_name}/{table_name}".format(
database_name=database_name, table_name=table_name database_name=database_name, table_name=table_name
) )
part_input = copy.deepcopy(PARTITION_INPUT) part_input = copy.deepcopy(PARTITION_INPUT)
part_input["Values"] = values part_input["Values"] = values
part_input["StorageDescriptor"]["Columns"] = columns part_input["StorageDescriptor"]["Columns"] = columns
part_input["StorageDescriptor"]["SerdeInfo"]["Parameters"]["path"] = root_path part_input["StorageDescriptor"]["SerdeInfo"]["Parameters"]["path"] = root_path
return part_input return part_input
def create_partition(client, database_name, table_name, partiton_input=None, **kwargs): def create_partition(client, database_name, table_name, partiton_input=None, **kwargs):
if partiton_input is None: if partiton_input is None:
partiton_input = create_partition_input(database_name, table_name, **kwargs) partiton_input = create_partition_input(database_name, table_name, **kwargs)
return client.create_partition( return client.create_partition(
DatabaseName=database_name, TableName=table_name, PartitionInput=partiton_input DatabaseName=database_name, TableName=table_name, PartitionInput=partiton_input
) )
def update_partition( def update_partition(
client, database_name, table_name, old_values=[], partiton_input=None, **kwargs client, database_name, table_name, old_values=[], partiton_input=None, **kwargs
): ):
if partiton_input is None: if partiton_input is None:
partiton_input = create_partition_input(database_name, table_name, **kwargs) partiton_input = create_partition_input(database_name, table_name, **kwargs)
return client.update_partition( return client.update_partition(
DatabaseName=database_name, DatabaseName=database_name,
TableName=table_name, TableName=table_name,
PartitionInput=partiton_input, PartitionInput=partiton_input,
PartitionValueList=old_values, PartitionValueList=old_values,
) )
def get_partition(client, database_name, table_name, values): def get_partition(client, database_name, table_name, values):
return client.get_partition( return client.get_partition(
DatabaseName=database_name, TableName=table_name, PartitionValues=values DatabaseName=database_name, TableName=table_name, PartitionValues=values
) )

View File

@ -132,7 +132,7 @@ def test_get_table_versions():
helpers.update_table(client, database_name, table_name, table_input) helpers.update_table(client, database_name, table_name, table_input)
version_inputs["2"] = table_input version_inputs["2"] = table_input
# Updateing with an indentical input should still create a new version # Updateing with an identical input should still create a new version
helpers.update_table(client, database_name, table_name, table_input) helpers.update_table(client, database_name, table_name, table_input)
version_inputs["3"] = table_input version_inputs["3"] = table_input

View File

@ -785,7 +785,7 @@ def test_delete_login_profile():
conn.delete_login_profile("my-user") conn.delete_login_profile("my-user")
@mock_iam() @mock_iam
def test_create_access_key(): def test_create_access_key():
conn = boto3.client("iam", region_name="us-east-1") conn = boto3.client("iam", region_name="us-east-1")
with assert_raises(ClientError): with assert_raises(ClientError):
@ -798,6 +798,19 @@ def test_create_access_key():
access_key["AccessKeyId"].should.have.length_of(20) access_key["AccessKeyId"].should.have.length_of(20)
access_key["SecretAccessKey"].should.have.length_of(40) access_key["SecretAccessKey"].should.have.length_of(40)
assert access_key["AccessKeyId"].startswith("AKIA") assert access_key["AccessKeyId"].startswith("AKIA")
conn = boto3.client(
"iam",
region_name="us-east-1",
aws_access_key_id=access_key["AccessKeyId"],
aws_secret_access_key=access_key["SecretAccessKey"],
)
access_key = conn.create_access_key()["AccessKey"]
(
datetime.utcnow() - access_key["CreateDate"].replace(tzinfo=None)
).seconds.should.be.within(0, 10)
access_key["AccessKeyId"].should.have.length_of(20)
access_key["SecretAccessKey"].should.have.length_of(40)
assert access_key["AccessKeyId"].startswith("AKIA")
@mock_iam_deprecated() @mock_iam_deprecated()
@ -825,8 +838,35 @@ def test_get_all_access_keys():
) )
@mock_iam
def test_list_access_keys():
conn = boto3.client("iam", region_name="us-east-1")
conn.create_user(UserName="my-user")
response = conn.list_access_keys(UserName="my-user")
assert_equals(
response["AccessKeyMetadata"], [],
)
access_key = conn.create_access_key(UserName="my-user")["AccessKey"]
response = conn.list_access_keys(UserName="my-user")
assert_equals(
sorted(response["AccessKeyMetadata"][0].keys()),
sorted(["Status", "CreateDate", "UserName", "AccessKeyId"]),
)
conn = boto3.client(
"iam",
region_name="us-east-1",
aws_access_key_id=access_key["AccessKeyId"],
aws_secret_access_key=access_key["SecretAccessKey"],
)
response = conn.list_access_keys()
assert_equals(
sorted(response["AccessKeyMetadata"][0].keys()),
sorted(["Status", "CreateDate", "UserName", "AccessKeyId"]),
)
@mock_iam_deprecated() @mock_iam_deprecated()
def test_delete_access_key(): def test_delete_access_key_deprecated():
conn = boto.connect_iam() conn = boto.connect_iam()
conn.create_user("my-user") conn.create_user("my-user")
access_key_id = conn.create_access_key("my-user")["create_access_key_response"][ access_key_id = conn.create_access_key("my-user")["create_access_key_response"][
@ -835,6 +875,16 @@ def test_delete_access_key():
conn.delete_access_key(access_key_id, "my-user") conn.delete_access_key(access_key_id, "my-user")
@mock_iam
def test_delete_access_key():
conn = boto3.client("iam", region_name="us-east-1")
conn.create_user(UserName="my-user")
key = conn.create_access_key(UserName="my-user")["AccessKey"]
conn.delete_access_key(AccessKeyId=key["AccessKeyId"], UserName="my-user")
key = conn.create_access_key(UserName="my-user")["AccessKey"]
conn.delete_access_key(AccessKeyId=key["AccessKeyId"])
@mock_iam() @mock_iam()
def test_mfa_devices(): def test_mfa_devices():
# Test enable device # Test enable device
@ -1326,6 +1376,9 @@ def test_update_access_key():
) )
resp = client.list_access_keys(UserName=username) resp = client.list_access_keys(UserName=username)
resp["AccessKeyMetadata"][0]["Status"].should.equal("Inactive") resp["AccessKeyMetadata"][0]["Status"].should.equal("Inactive")
client.update_access_key(AccessKeyId=key["AccessKeyId"], Status="Active")
resp = client.list_access_keys(UserName=username)
resp["AccessKeyMetadata"][0]["Status"].should.equal("Active")
@mock_iam @mock_iam

View File

@ -9,6 +9,173 @@ from botocore.exceptions import ClientError
from nose.tools import assert_raises from nose.tools import assert_raises
@mock_iot
def test_attach_policy():
client = boto3.client("iot", region_name="ap-northeast-1")
policy_name = "my-policy"
doc = "{}"
cert = client.create_keys_and_certificate(setAsActive=True)
cert_arn = cert["certificateArn"]
client.create_policy(policyName=policy_name, policyDocument=doc)
client.attach_policy(policyName=policy_name, target=cert_arn)
res = client.list_attached_policies(target=cert_arn)
res.should.have.key("policies").which.should.have.length_of(1)
res["policies"][0]["policyName"].should.equal("my-policy")
@mock_iot
def test_detach_policy():
client = boto3.client("iot", region_name="ap-northeast-1")
policy_name = "my-policy"
doc = "{}"
cert = client.create_keys_and_certificate(setAsActive=True)
cert_arn = cert["certificateArn"]
client.create_policy(policyName=policy_name, policyDocument=doc)
client.attach_policy(policyName=policy_name, target=cert_arn)
res = client.list_attached_policies(target=cert_arn)
res.should.have.key("policies").which.should.have.length_of(1)
res["policies"][0]["policyName"].should.equal("my-policy")
client.detach_policy(policyName=policy_name, target=cert_arn)
res = client.list_attached_policies(target=cert_arn)
res.should.have.key("policies").which.should.be.empty
@mock_iot
def test_list_attached_policies():
client = boto3.client("iot", region_name="ap-northeast-1")
cert = client.create_keys_and_certificate(setAsActive=True)
policies = client.list_attached_policies(target=cert["certificateArn"])
policies["policies"].should.be.empty
@mock_iot
def test_policy_versions():
client = boto3.client("iot", region_name="ap-northeast-1")
policy_name = "my-policy"
doc = "{}"
policy = client.create_policy(policyName=policy_name, policyDocument=doc)
policy.should.have.key("policyName").which.should.equal(policy_name)
policy.should.have.key("policyArn").which.should_not.be.none
policy.should.have.key("policyDocument").which.should.equal(json.dumps({}))
policy.should.have.key("policyVersionId").which.should.equal("1")
policy = client.get_policy(policyName=policy_name)
policy.should.have.key("policyName").which.should.equal(policy_name)
policy.should.have.key("policyArn").which.should_not.be.none
policy.should.have.key("policyDocument").which.should.equal(json.dumps({}))
policy.should.have.key("defaultVersionId").which.should.equal(
policy["defaultVersionId"]
)
policy1 = client.create_policy_version(
policyName=policy_name,
policyDocument=json.dumps({"version": "version_1"}),
setAsDefault=True,
)
policy1.should.have.key("policyArn").which.should_not.be.none
policy1.should.have.key("policyDocument").which.should.equal(
json.dumps({"version": "version_1"})
)
policy1.should.have.key("policyVersionId").which.should.equal("2")
policy1.should.have.key("isDefaultVersion").which.should.equal(True)
policy2 = client.create_policy_version(
policyName=policy_name,
policyDocument=json.dumps({"version": "version_2"}),
setAsDefault=False,
)
policy2.should.have.key("policyArn").which.should_not.be.none
policy2.should.have.key("policyDocument").which.should.equal(
json.dumps({"version": "version_2"})
)
policy2.should.have.key("policyVersionId").which.should.equal("3")
policy2.should.have.key("isDefaultVersion").which.should.equal(False)
policy = client.get_policy(policyName=policy_name)
policy.should.have.key("policyName").which.should.equal(policy_name)
policy.should.have.key("policyArn").which.should_not.be.none
policy.should.have.key("policyDocument").which.should.equal(
json.dumps({"version": "version_1"})
)
policy.should.have.key("defaultVersionId").which.should.equal(
policy1["policyVersionId"]
)
policy_versions = client.list_policy_versions(policyName=policy_name)
policy_versions.should.have.key("policyVersions").which.should.have.length_of(3)
list(
map(lambda item: item["isDefaultVersion"], policy_versions["policyVersions"])
).count(True).should.equal(1)
default_policy = list(
filter(lambda item: item["isDefaultVersion"], policy_versions["policyVersions"])
)
default_policy[0].should.have.key("versionId").should.equal(
policy1["policyVersionId"]
)
policy = client.get_policy(policyName=policy_name)
policy.should.have.key("policyName").which.should.equal(policy_name)
policy.should.have.key("policyArn").which.should_not.be.none
policy.should.have.key("policyDocument").which.should.equal(
json.dumps({"version": "version_1"})
)
policy.should.have.key("defaultVersionId").which.should.equal(
policy1["policyVersionId"]
)
client.set_default_policy_version(
policyName=policy_name, policyVersionId=policy2["policyVersionId"]
)
policy_versions = client.list_policy_versions(policyName=policy_name)
policy_versions.should.have.key("policyVersions").which.should.have.length_of(3)
list(
map(lambda item: item["isDefaultVersion"], policy_versions["policyVersions"])
).count(True).should.equal(1)
default_policy = list(
filter(lambda item: item["isDefaultVersion"], policy_versions["policyVersions"])
)
default_policy[0].should.have.key("versionId").should.equal(
policy2["policyVersionId"]
)
policy = client.get_policy(policyName=policy_name)
policy.should.have.key("policyName").which.should.equal(policy_name)
policy.should.have.key("policyArn").which.should_not.be.none
policy.should.have.key("policyDocument").which.should.equal(
json.dumps({"version": "version_2"})
)
policy.should.have.key("defaultVersionId").which.should.equal(
policy2["policyVersionId"]
)
client.delete_policy_version(policyName=policy_name, policyVersionId="1")
policy_versions = client.list_policy_versions(policyName=policy_name)
policy_versions.should.have.key("policyVersions").which.should.have.length_of(2)
client.delete_policy_version(
policyName=policy_name, policyVersionId=policy1["policyVersionId"]
)
policy_versions = client.list_policy_versions(policyName=policy_name)
policy_versions.should.have.key("policyVersions").which.should.have.length_of(1)
# should fail as it"s the default policy. Should use delete_policy instead
try:
client.delete_policy_version(
policyName=policy_name, policyVersionId=policy2["policyVersionId"]
)
assert False, "Should have failed in previous call"
except Exception as exception:
exception.response["Error"]["Message"].should.equal(
"Cannot delete the default version of a policy"
)
@mock_iot @mock_iot
def test_things(): def test_things():
client = boto3.client("iot", region_name="ap-northeast-1") client = boto3.client("iot", region_name="ap-northeast-1")
@ -994,7 +1161,10 @@ def test_create_job():
client = boto3.client("iot", region_name="eu-west-1") client = boto3.client("iot", region_name="eu-west-1")
name = "my-thing" name = "my-thing"
job_id = "TestJob" job_id = "TestJob"
# thing # thing# job document
# job_document = {
# "field": "value"
# }
thing = client.create_thing(thingName=name) thing = client.create_thing(thingName=name)
thing.should.have.key("thingName").which.should.equal(name) thing.should.have.key("thingName").which.should.equal(name)
thing.should.have.key("thingArn") thing.should.have.key("thingArn")
@ -1020,6 +1190,63 @@ def test_create_job():
job.should.have.key("description") job.should.have.key("description")
@mock_iot
def test_list_jobs():
client = boto3.client("iot", region_name="eu-west-1")
name = "my-thing"
job_id = "TestJob"
# thing# job document
# job_document = {
# "field": "value"
# }
thing = client.create_thing(thingName=name)
thing.should.have.key("thingName").which.should.equal(name)
thing.should.have.key("thingArn")
# job document
job_document = {"field": "value"}
job1 = client.create_job(
jobId=job_id,
targets=[thing["thingArn"]],
document=json.dumps(job_document),
description="Description",
presignedUrlConfig={
"roleArn": "arn:aws:iam::1:role/service-role/iot_job_role",
"expiresInSec": 123,
},
targetSelection="CONTINUOUS",
jobExecutionsRolloutConfig={"maximumPerMinute": 10},
)
job1.should.have.key("jobId").which.should.equal(job_id)
job1.should.have.key("jobArn")
job1.should.have.key("description")
job2 = client.create_job(
jobId=job_id + "1",
targets=[thing["thingArn"]],
document=json.dumps(job_document),
description="Description",
presignedUrlConfig={
"roleArn": "arn:aws:iam::1:role/service-role/iot_job_role",
"expiresInSec": 123,
},
targetSelection="CONTINUOUS",
jobExecutionsRolloutConfig={"maximumPerMinute": 10},
)
job2.should.have.key("jobId").which.should.equal(job_id + "1")
job2.should.have.key("jobArn")
job2.should.have.key("description")
jobs = client.list_jobs()
jobs.should.have.key("jobs")
jobs.should_not.have.key("nextToken")
jobs["jobs"][0].should.have.key("jobId").which.should.equal(job_id)
jobs["jobs"][1].should.have.key("jobId").which.should.equal(job_id + "1")
@mock_iot @mock_iot
def test_describe_job(): def test_describe_job():
client = boto3.client("iot", region_name="eu-west-1") client = boto3.client("iot", region_name="eu-west-1")
@ -1124,3 +1351,387 @@ def test_describe_job_1():
job.should.have.key("job").which.should.have.key( job.should.have.key("job").which.should.have.key(
"jobExecutionsRolloutConfig" "jobExecutionsRolloutConfig"
).which.should.have.key("maximumPerMinute").which.should.equal(10) ).which.should.have.key("maximumPerMinute").which.should.equal(10)
@mock_iot
def test_delete_job():
client = boto3.client("iot", region_name="eu-west-1")
name = "my-thing"
job_id = "TestJob"
# thing
thing = client.create_thing(thingName=name)
thing.should.have.key("thingName").which.should.equal(name)
thing.should.have.key("thingArn")
job = client.create_job(
jobId=job_id,
targets=[thing["thingArn"]],
documentSource="https://s3-eu-west-1.amazonaws.com/bucket-name/job_document.json",
presignedUrlConfig={
"roleArn": "arn:aws:iam::1:role/service-role/iot_job_role",
"expiresInSec": 123,
},
targetSelection="CONTINUOUS",
jobExecutionsRolloutConfig={"maximumPerMinute": 10},
)
job.should.have.key("jobId").which.should.equal(job_id)
job.should.have.key("jobArn")
job = client.describe_job(jobId=job_id)
job.should.have.key("job")
job.should.have.key("job").which.should.have.key("jobId").which.should.equal(job_id)
client.delete_job(jobId=job_id)
client.list_jobs()["jobs"].should.have.length_of(0)
@mock_iot
def test_cancel_job():
client = boto3.client("iot", region_name="eu-west-1")
name = "my-thing"
job_id = "TestJob"
# thing
thing = client.create_thing(thingName=name)
thing.should.have.key("thingName").which.should.equal(name)
thing.should.have.key("thingArn")
job = client.create_job(
jobId=job_id,
targets=[thing["thingArn"]],
documentSource="https://s3-eu-west-1.amazonaws.com/bucket-name/job_document.json",
presignedUrlConfig={
"roleArn": "arn:aws:iam::1:role/service-role/iot_job_role",
"expiresInSec": 123,
},
targetSelection="CONTINUOUS",
jobExecutionsRolloutConfig={"maximumPerMinute": 10},
)
job.should.have.key("jobId").which.should.equal(job_id)
job.should.have.key("jobArn")
job = client.describe_job(jobId=job_id)
job.should.have.key("job")
job.should.have.key("job").which.should.have.key("jobId").which.should.equal(job_id)
job = client.cancel_job(jobId=job_id, reasonCode="Because", comment="You are")
job.should.have.key("jobId").which.should.equal(job_id)
job.should.have.key("jobArn")
job = client.describe_job(jobId=job_id)
job.should.have.key("job")
job.should.have.key("job").which.should.have.key("jobId").which.should.equal(job_id)
job.should.have.key("job").which.should.have.key("status").which.should.equal(
"CANCELED"
)
job.should.have.key("job").which.should.have.key(
"forceCanceled"
).which.should.equal(False)
job.should.have.key("job").which.should.have.key("reasonCode").which.should.equal(
"Because"
)
job.should.have.key("job").which.should.have.key("comment").which.should.equal(
"You are"
)
@mock_iot
def test_get_job_document_with_document_source():
client = boto3.client("iot", region_name="eu-west-1")
name = "my-thing"
job_id = "TestJob"
# thing
thing = client.create_thing(thingName=name)
thing.should.have.key("thingName").which.should.equal(name)
thing.should.have.key("thingArn")
job = client.create_job(
jobId=job_id,
targets=[thing["thingArn"]],
documentSource="https://s3-eu-west-1.amazonaws.com/bucket-name/job_document.json",
presignedUrlConfig={
"roleArn": "arn:aws:iam::1:role/service-role/iot_job_role",
"expiresInSec": 123,
},
targetSelection="CONTINUOUS",
jobExecutionsRolloutConfig={"maximumPerMinute": 10},
)
job.should.have.key("jobId").which.should.equal(job_id)
job.should.have.key("jobArn")
job_document = client.get_job_document(jobId=job_id)
job_document.should.have.key("document").which.should.equal("")
@mock_iot
def test_get_job_document_with_document():
client = boto3.client("iot", region_name="eu-west-1")
name = "my-thing"
job_id = "TestJob"
# thing
thing = client.create_thing(thingName=name)
thing.should.have.key("thingName").which.should.equal(name)
thing.should.have.key("thingArn")
# job document
job_document = {"field": "value"}
job = client.create_job(
jobId=job_id,
targets=[thing["thingArn"]],
document=json.dumps(job_document),
presignedUrlConfig={
"roleArn": "arn:aws:iam::1:role/service-role/iot_job_role",
"expiresInSec": 123,
},
targetSelection="CONTINUOUS",
jobExecutionsRolloutConfig={"maximumPerMinute": 10},
)
job.should.have.key("jobId").which.should.equal(job_id)
job.should.have.key("jobArn")
job_document = client.get_job_document(jobId=job_id)
job_document.should.have.key("document").which.should.equal('{"field": "value"}')
@mock_iot
def test_describe_job_execution():
client = boto3.client("iot", region_name="eu-west-1")
name = "my-thing"
job_id = "TestJob"
# thing
thing = client.create_thing(thingName=name)
thing.should.have.key("thingName").which.should.equal(name)
thing.should.have.key("thingArn")
# job document
job_document = {"field": "value"}
job = client.create_job(
jobId=job_id,
targets=[thing["thingArn"]],
document=json.dumps(job_document),
description="Description",
presignedUrlConfig={
"roleArn": "arn:aws:iam::1:role/service-role/iot_job_role",
"expiresInSec": 123,
},
targetSelection="CONTINUOUS",
jobExecutionsRolloutConfig={"maximumPerMinute": 10},
)
job.should.have.key("jobId").which.should.equal(job_id)
job.should.have.key("jobArn")
job.should.have.key("description")
job_execution = client.describe_job_execution(jobId=job_id, thingName=name)
job_execution.should.have.key("execution")
job_execution["execution"].should.have.key("jobId").which.should.equal(job_id)
job_execution["execution"].should.have.key("status").which.should.equal("QUEUED")
job_execution["execution"].should.have.key("forceCanceled").which.should.equal(
False
)
job_execution["execution"].should.have.key("statusDetails").which.should.equal(
{"detailsMap": {}}
)
job_execution["execution"].should.have.key("thingArn").which.should.equal(
thing["thingArn"]
)
job_execution["execution"].should.have.key("queuedAt")
job_execution["execution"].should.have.key("startedAt")
job_execution["execution"].should.have.key("lastUpdatedAt")
job_execution["execution"].should.have.key("executionNumber").which.should.equal(
123
)
job_execution["execution"].should.have.key("versionNumber").which.should.equal(123)
job_execution["execution"].should.have.key(
"approximateSecondsBeforeTimedOut"
).which.should.equal(123)
job_execution = client.describe_job_execution(
jobId=job_id, thingName=name, executionNumber=123
)
job_execution.should.have.key("execution")
job_execution["execution"].should.have.key("jobId").which.should.equal(job_id)
job_execution["execution"].should.have.key("status").which.should.equal("QUEUED")
job_execution["execution"].should.have.key("forceCanceled").which.should.equal(
False
)
job_execution["execution"].should.have.key("statusDetails").which.should.equal(
{"detailsMap": {}}
)
job_execution["execution"].should.have.key("thingArn").which.should.equal(
thing["thingArn"]
)
job_execution["execution"].should.have.key("queuedAt")
job_execution["execution"].should.have.key("startedAt")
job_execution["execution"].should.have.key("lastUpdatedAt")
job_execution["execution"].should.have.key("executionNumber").which.should.equal(
123
)
job_execution["execution"].should.have.key("versionNumber").which.should.equal(123)
job_execution["execution"].should.have.key(
"approximateSecondsBeforeTimedOut"
).which.should.equal(123)
try:
client.describe_job_execution(jobId=job_id, thingName=name, executionNumber=456)
except ClientError as exc:
error_code = exc.response["Error"]["Code"]
error_code.should.equal("ResourceNotFoundException")
else:
raise Exception("Should have raised error")
@mock_iot
def test_cancel_job_execution():
client = boto3.client("iot", region_name="eu-west-1")
name = "my-thing"
job_id = "TestJob"
# thing
thing = client.create_thing(thingName=name)
thing.should.have.key("thingName").which.should.equal(name)
thing.should.have.key("thingArn")
# job document
job_document = {"field": "value"}
job = client.create_job(
jobId=job_id,
targets=[thing["thingArn"]],
document=json.dumps(job_document),
description="Description",
presignedUrlConfig={
"roleArn": "arn:aws:iam::1:role/service-role/iot_job_role",
"expiresInSec": 123,
},
targetSelection="CONTINUOUS",
jobExecutionsRolloutConfig={"maximumPerMinute": 10},
)
job.should.have.key("jobId").which.should.equal(job_id)
job.should.have.key("jobArn")
job.should.have.key("description")
client.cancel_job_execution(jobId=job_id, thingName=name)
job_execution = client.describe_job_execution(jobId=job_id, thingName=name)
job_execution.should.have.key("execution")
job_execution["execution"].should.have.key("status").which.should.equal("CANCELED")
@mock_iot
def test_delete_job_execution():
client = boto3.client("iot", region_name="eu-west-1")
name = "my-thing"
job_id = "TestJob"
# thing
thing = client.create_thing(thingName=name)
thing.should.have.key("thingName").which.should.equal(name)
thing.should.have.key("thingArn")
# job document
job_document = {"field": "value"}
job = client.create_job(
jobId=job_id,
targets=[thing["thingArn"]],
document=json.dumps(job_document),
description="Description",
presignedUrlConfig={
"roleArn": "arn:aws:iam::1:role/service-role/iot_job_role",
"expiresInSec": 123,
},
targetSelection="CONTINUOUS",
jobExecutionsRolloutConfig={"maximumPerMinute": 10},
)
job.should.have.key("jobId").which.should.equal(job_id)
job.should.have.key("jobArn")
job.should.have.key("description")
client.delete_job_execution(jobId=job_id, thingName=name, executionNumber=123)
try:
client.describe_job_execution(jobId=job_id, thingName=name, executionNumber=123)
except ClientError as exc:
error_code = exc.response["Error"]["Code"]
error_code.should.equal("ResourceNotFoundException")
else:
raise Exception("Should have raised error")
@mock_iot
def test_list_job_executions_for_job():
client = boto3.client("iot", region_name="eu-west-1")
name = "my-thing"
job_id = "TestJob"
# thing
thing = client.create_thing(thingName=name)
thing.should.have.key("thingName").which.should.equal(name)
thing.should.have.key("thingArn")
# job document
job_document = {"field": "value"}
job = client.create_job(
jobId=job_id,
targets=[thing["thingArn"]],
document=json.dumps(job_document),
description="Description",
presignedUrlConfig={
"roleArn": "arn:aws:iam::1:role/service-role/iot_job_role",
"expiresInSec": 123,
},
targetSelection="CONTINUOUS",
jobExecutionsRolloutConfig={"maximumPerMinute": 10},
)
job.should.have.key("jobId").which.should.equal(job_id)
job.should.have.key("jobArn")
job.should.have.key("description")
job_execution = client.list_job_executions_for_job(jobId=job_id)
job_execution.should.have.key("executionSummaries")
job_execution["executionSummaries"][0].should.have.key(
"thingArn"
).which.should.equal(thing["thingArn"])
@mock_iot
def test_list_job_executions_for_thing():
client = boto3.client("iot", region_name="eu-west-1")
name = "my-thing"
job_id = "TestJob"
# thing
thing = client.create_thing(thingName=name)
thing.should.have.key("thingName").which.should.equal(name)
thing.should.have.key("thingArn")
# job document
job_document = {"field": "value"}
job = client.create_job(
jobId=job_id,
targets=[thing["thingArn"]],
document=json.dumps(job_document),
description="Description",
presignedUrlConfig={
"roleArn": "arn:aws:iam::1:role/service-role/iot_job_role",
"expiresInSec": 123,
},
targetSelection="CONTINUOUS",
jobExecutionsRolloutConfig={"maximumPerMinute": 10},
)
job.should.have.key("jobId").which.should.equal(job_id)
job.should.have.key("jobArn")
job.should.have.key("description")
job_execution = client.list_job_executions_for_thing(thingName=name)
job_execution.should.have.key("executionSummaries")
job_execution["executionSummaries"][0].should.have.key("jobId").which.should.equal(
job_id
)

View File

@ -223,7 +223,7 @@ def test_create_stream_without_redshift():
@mock_kinesis @mock_kinesis
def test_deescribe_non_existant_stream(): def test_deescribe_non_existent_stream():
client = boto3.client("firehose", region_name="us-east-1") client = boto3.client("firehose", region_name="us-east-1")
client.describe_delivery_stream.when.called_with( client.describe_delivery_stream.when.called_with(

View File

@ -32,7 +32,7 @@ def test_create_cluster():
@mock_kinesis_deprecated @mock_kinesis_deprecated
def test_describe_non_existant_stream(): def test_describe_non_existent_stream():
conn = boto.kinesis.connect_to_region("us-east-1") conn = boto.kinesis.connect_to_region("us-east-1")
conn.describe_stream.when.called_with("not-a-stream").should.throw( conn.describe_stream.when.called_with("not-a-stream").should.throw(
ResourceNotFoundException ResourceNotFoundException

View File

@ -1,26 +1,19 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import unicode_literals from __future__ import unicode_literals
from datetime import date
from datetime import datetime
from dateutil.tz import tzutc
import base64 import base64
import os
import re import re
import boto3
import boto.kms import boto.kms
import botocore.exceptions
import six import six
import sure # noqa import sure # noqa
from boto.exception import JSONResponseError from boto.exception import JSONResponseError
from boto.kms.exceptions import AlreadyExistsException, NotFoundException from boto.kms.exceptions import AlreadyExistsException, NotFoundException
from freezegun import freeze_time
from nose.tools import assert_raises from nose.tools import assert_raises
from parameterized import parameterized from parameterized import parameterized
from moto.core.exceptions import JsonRESTError from moto.core.exceptions import JsonRESTError
from moto.kms.models import KmsBackend from moto.kms.models import KmsBackend
from moto.kms.exceptions import NotFoundException as MotoNotFoundException from moto.kms.exceptions import NotFoundException as MotoNotFoundException
from moto import mock_kms, mock_kms_deprecated from moto import mock_kms_deprecated
PLAINTEXT_VECTORS = ( PLAINTEXT_VECTORS = (
(b"some encodeable plaintext",), (b"some encodeable plaintext",),
@ -36,23 +29,6 @@ def _get_encoded_value(plaintext):
return plaintext.encode("utf-8") return plaintext.encode("utf-8")
@mock_kms
def test_create_key():
conn = boto3.client("kms", region_name="us-east-1")
with freeze_time("2015-01-01 00:00:00"):
key = conn.create_key(
Policy="my policy",
Description="my key",
KeyUsage="ENCRYPT_DECRYPT",
Tags=[{"TagKey": "project", "TagValue": "moto"}],
)
key["KeyMetadata"]["Description"].should.equal("my key")
key["KeyMetadata"]["KeyUsage"].should.equal("ENCRYPT_DECRYPT")
key["KeyMetadata"]["Enabled"].should.equal(True)
key["KeyMetadata"]["CreationDate"].should.be.a(date)
@mock_kms_deprecated @mock_kms_deprecated
def test_describe_key(): def test_describe_key():
conn = boto.kms.connect_to_region("us-west-2") conn = boto.kms.connect_to_region("us-west-2")
@ -97,22 +73,6 @@ def test_describe_key_via_alias_not_found():
) )
@parameterized(
(
("alias/does-not-exist",),
("arn:aws:kms:us-east-1:012345678912:alias/does-not-exist",),
("invalid",),
)
)
@mock_kms
def test_describe_key_via_alias_invalid_alias(key_id):
client = boto3.client("kms", region_name="us-east-1")
client.create_key(Description="key")
with assert_raises(client.exceptions.NotFoundException):
client.describe_key(KeyId=key_id)
@mock_kms_deprecated @mock_kms_deprecated
def test_describe_key_via_arn(): def test_describe_key_via_arn():
conn = boto.kms.connect_to_region("us-west-2") conn = boto.kms.connect_to_region("us-west-2")
@ -240,71 +200,6 @@ def test_generate_data_key():
response["KeyId"].should.equal(key_arn) response["KeyId"].should.equal(key_arn)
@mock_kms
def test_boto3_generate_data_key():
kms = boto3.client("kms", region_name="us-west-2")
key = kms.create_key()
key_id = key["KeyMetadata"]["KeyId"]
key_arn = key["KeyMetadata"]["Arn"]
response = kms.generate_data_key(KeyId=key_id, NumberOfBytes=32)
# CiphertextBlob must NOT be base64-encoded
with assert_raises(Exception):
base64.b64decode(response["CiphertextBlob"], validate=True)
# Plaintext must NOT be base64-encoded
with assert_raises(Exception):
base64.b64decode(response["Plaintext"], validate=True)
response["KeyId"].should.equal(key_arn)
@parameterized(PLAINTEXT_VECTORS)
@mock_kms
def test_encrypt(plaintext):
client = boto3.client("kms", region_name="us-west-2")
key = client.create_key(Description="key")
key_id = key["KeyMetadata"]["KeyId"]
key_arn = key["KeyMetadata"]["Arn"]
response = client.encrypt(KeyId=key_id, Plaintext=plaintext)
response["CiphertextBlob"].should_not.equal(plaintext)
# CiphertextBlob must NOT be base64-encoded
with assert_raises(Exception):
base64.b64decode(response["CiphertextBlob"], validate=True)
response["KeyId"].should.equal(key_arn)
@parameterized(PLAINTEXT_VECTORS)
@mock_kms
def test_decrypt(plaintext):
client = boto3.client("kms", region_name="us-west-2")
key = client.create_key(Description="key")
key_id = key["KeyMetadata"]["KeyId"]
key_arn = key["KeyMetadata"]["Arn"]
encrypt_response = client.encrypt(KeyId=key_id, Plaintext=plaintext)
client.create_key(Description="key")
# CiphertextBlob must NOT be base64-encoded
with assert_raises(Exception):
base64.b64decode(encrypt_response["CiphertextBlob"], validate=True)
decrypt_response = client.decrypt(CiphertextBlob=encrypt_response["CiphertextBlob"])
# Plaintext must NOT be base64-encoded
with assert_raises(Exception):
base64.b64decode(decrypt_response["Plaintext"], validate=True)
decrypt_response["Plaintext"].should.equal(_get_encoded_value(plaintext))
decrypt_response["KeyId"].should.equal(key_arn)
@mock_kms_deprecated @mock_kms_deprecated
def test_disable_key_rotation_with_missing_key(): def test_disable_key_rotation_with_missing_key():
conn = boto.kms.connect_to_region("us-west-2") conn = boto.kms.connect_to_region("us-west-2")
@ -775,25 +670,6 @@ def test__list_aliases():
len(aliases).should.equal(7) len(aliases).should.equal(7)
@parameterized(
(
("not-a-uuid",),
("alias/DoesNotExist",),
("arn:aws:kms:us-east-1:012345678912:alias/DoesNotExist",),
("d25652e4-d2d2-49f7-929a-671ccda580c6",),
(
"arn:aws:kms:us-east-1:012345678912:key/d25652e4-d2d2-49f7-929a-671ccda580c6",
),
)
)
@mock_kms
def test_invalid_key_ids(key_id):
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.generate_data_key(KeyId=key_id, NumberOfBytes=5)
@mock_kms_deprecated @mock_kms_deprecated
def test__assert_default_policy(): def test__assert_default_policy():
from moto.kms.responses import _assert_default_policy from moto.kms.responses import _assert_default_policy
@ -804,431 +680,3 @@ def test__assert_default_policy():
_assert_default_policy.when.called_with("default").should_not.throw( _assert_default_policy.when.called_with("default").should_not.throw(
MotoNotFoundException MotoNotFoundException
) )
@parameterized(PLAINTEXT_VECTORS)
@mock_kms
def test_kms_encrypt_boto3(plaintext):
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="key")
response = client.encrypt(KeyId=key["KeyMetadata"]["KeyId"], Plaintext=plaintext)
response = client.decrypt(CiphertextBlob=response["CiphertextBlob"])
response["Plaintext"].should.equal(_get_encoded_value(plaintext))
@mock_kms
def test_disable_key():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="disable-key")
client.disable_key(KeyId=key["KeyMetadata"]["KeyId"])
result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"])
assert result["KeyMetadata"]["Enabled"] == False
assert result["KeyMetadata"]["KeyState"] == "Disabled"
@mock_kms
def test_enable_key():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="enable-key")
client.disable_key(KeyId=key["KeyMetadata"]["KeyId"])
client.enable_key(KeyId=key["KeyMetadata"]["KeyId"])
result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"])
assert result["KeyMetadata"]["Enabled"] == True
assert result["KeyMetadata"]["KeyState"] == "Enabled"
@mock_kms
def test_schedule_key_deletion():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="schedule-key-deletion")
if os.environ.get("TEST_SERVER_MODE", "false").lower() == "false":
with freeze_time("2015-01-01 12:00:00"):
response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"])
assert response["KeyId"] == key["KeyMetadata"]["KeyId"]
assert response["DeletionDate"] == datetime(
2015, 1, 31, 12, 0, tzinfo=tzutc()
)
else:
# Can't manipulate time in server mode
response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"])
assert response["KeyId"] == key["KeyMetadata"]["KeyId"]
result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"])
assert result["KeyMetadata"]["Enabled"] == False
assert result["KeyMetadata"]["KeyState"] == "PendingDeletion"
assert "DeletionDate" in result["KeyMetadata"]
@mock_kms
def test_schedule_key_deletion_custom():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="schedule-key-deletion")
if os.environ.get("TEST_SERVER_MODE", "false").lower() == "false":
with freeze_time("2015-01-01 12:00:00"):
response = client.schedule_key_deletion(
KeyId=key["KeyMetadata"]["KeyId"], PendingWindowInDays=7
)
assert response["KeyId"] == key["KeyMetadata"]["KeyId"]
assert response["DeletionDate"] == datetime(
2015, 1, 8, 12, 0, tzinfo=tzutc()
)
else:
# Can't manipulate time in server mode
response = client.schedule_key_deletion(
KeyId=key["KeyMetadata"]["KeyId"], PendingWindowInDays=7
)
assert response["KeyId"] == key["KeyMetadata"]["KeyId"]
result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"])
assert result["KeyMetadata"]["Enabled"] == False
assert result["KeyMetadata"]["KeyState"] == "PendingDeletion"
assert "DeletionDate" in result["KeyMetadata"]
@mock_kms
def test_cancel_key_deletion():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="cancel-key-deletion")
client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"])
response = client.cancel_key_deletion(KeyId=key["KeyMetadata"]["KeyId"])
assert response["KeyId"] == key["KeyMetadata"]["KeyId"]
result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"])
assert result["KeyMetadata"]["Enabled"] == False
assert result["KeyMetadata"]["KeyState"] == "Disabled"
assert "DeletionDate" not in result["KeyMetadata"]
@mock_kms
def test_update_key_description():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="old_description")
key_id = key["KeyMetadata"]["KeyId"]
result = client.update_key_description(KeyId=key_id, Description="new_description")
assert "ResponseMetadata" in result
@mock_kms
def test_key_tagging_happy():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="test-key-tagging")
key_id = key["KeyMetadata"]["KeyId"]
tags = [{"TagKey": "key1", "TagValue": "value1"}, {"TagKey": "key2", "TagValue": "value2"}]
client.tag_resource(KeyId=key_id, Tags=tags)
result = client.list_resource_tags(KeyId=key_id)
actual = result.get("Tags", [])
assert tags == actual
client.untag_resource(KeyId=key_id, TagKeys=["key1"])
actual = client.list_resource_tags(KeyId=key_id).get("Tags", [])
expected = [{"TagKey": "key2", "TagValue": "value2"}]
assert expected == actual
@mock_kms
def test_key_tagging_sad():
b = KmsBackend()
try:
b.tag_resource('unknown', [])
raise 'tag_resource should fail if KeyId is not known'
except JsonRESTError:
pass
try:
b.untag_resource('unknown', [])
raise 'untag_resource should fail if KeyId is not known'
except JsonRESTError:
pass
try:
b.list_resource_tags('unknown')
raise 'list_resource_tags should fail if KeyId is not known'
except JsonRESTError:
pass
@parameterized(
(
(dict(KeySpec="AES_256"), 32),
(dict(KeySpec="AES_128"), 16),
(dict(NumberOfBytes=64), 64),
(dict(NumberOfBytes=1), 1),
(dict(NumberOfBytes=1024), 1024),
)
)
@mock_kms
def test_generate_data_key_sizes(kwargs, expected_key_length):
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="generate-data-key-size")
response = client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"], **kwargs)
assert len(response["Plaintext"]) == expected_key_length
@mock_kms
def test_generate_data_key_decrypt():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="generate-data-key-decrypt")
resp1 = client.generate_data_key(
KeyId=key["KeyMetadata"]["KeyId"], KeySpec="AES_256"
)
resp2 = client.decrypt(CiphertextBlob=resp1["CiphertextBlob"])
assert resp1["Plaintext"] == resp2["Plaintext"]
@parameterized(
(
(dict(KeySpec="AES_257"),),
(dict(KeySpec="AES_128", NumberOfBytes=16),),
(dict(NumberOfBytes=2048),),
(dict(NumberOfBytes=0),),
(dict(),),
)
)
@mock_kms
def test_generate_data_key_invalid_size_params(kwargs):
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="generate-data-key-size")
with assert_raises(
(botocore.exceptions.ClientError, botocore.exceptions.ParamValidationError)
) as err:
client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"], **kwargs)
@parameterized(
(
("alias/DoesNotExist",),
("arn:aws:kms:us-east-1:012345678912:alias/DoesNotExist",),
("d25652e4-d2d2-49f7-929a-671ccda580c6",),
(
"arn:aws:kms:us-east-1:012345678912:key/d25652e4-d2d2-49f7-929a-671ccda580c6",
),
)
)
@mock_kms
def test_generate_data_key_invalid_key(key_id):
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.generate_data_key(KeyId=key_id, KeySpec="AES_256")
@parameterized(
(
("alias/DoesExist", False),
("arn:aws:kms:us-east-1:012345678912:alias/DoesExist", False),
("", True),
("arn:aws:kms:us-east-1:012345678912:key/", True),
)
)
@mock_kms
def test_generate_data_key_all_valid_key_ids(prefix, append_key_id):
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key()
key_id = key["KeyMetadata"]["KeyId"]
client.create_alias(AliasName="alias/DoesExist", TargetKeyId=key_id)
target_id = prefix
if append_key_id:
target_id += key_id
client.generate_data_key(KeyId=key_id, NumberOfBytes=32)
@mock_kms
def test_generate_data_key_without_plaintext_decrypt():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="generate-data-key-decrypt")
resp1 = client.generate_data_key_without_plaintext(
KeyId=key["KeyMetadata"]["KeyId"], KeySpec="AES_256"
)
assert "Plaintext" not in resp1
@parameterized(PLAINTEXT_VECTORS)
@mock_kms
def test_re_encrypt_decrypt(plaintext):
client = boto3.client("kms", region_name="us-west-2")
key_1 = client.create_key(Description="key 1")
key_1_id = key_1["KeyMetadata"]["KeyId"]
key_1_arn = key_1["KeyMetadata"]["Arn"]
key_2 = client.create_key(Description="key 2")
key_2_id = key_2["KeyMetadata"]["KeyId"]
key_2_arn = key_2["KeyMetadata"]["Arn"]
encrypt_response = client.encrypt(
KeyId=key_1_id, Plaintext=plaintext, EncryptionContext={"encryption": "context"}
)
re_encrypt_response = client.re_encrypt(
CiphertextBlob=encrypt_response["CiphertextBlob"],
SourceEncryptionContext={"encryption": "context"},
DestinationKeyId=key_2_id,
DestinationEncryptionContext={"another": "context"},
)
# CiphertextBlob must NOT be base64-encoded
with assert_raises(Exception):
base64.b64decode(re_encrypt_response["CiphertextBlob"], validate=True)
re_encrypt_response["SourceKeyId"].should.equal(key_1_arn)
re_encrypt_response["KeyId"].should.equal(key_2_arn)
decrypt_response_1 = client.decrypt(
CiphertextBlob=encrypt_response["CiphertextBlob"],
EncryptionContext={"encryption": "context"},
)
decrypt_response_1["Plaintext"].should.equal(_get_encoded_value(plaintext))
decrypt_response_1["KeyId"].should.equal(key_1_arn)
decrypt_response_2 = client.decrypt(
CiphertextBlob=re_encrypt_response["CiphertextBlob"],
EncryptionContext={"another": "context"},
)
decrypt_response_2["Plaintext"].should.equal(_get_encoded_value(plaintext))
decrypt_response_2["KeyId"].should.equal(key_2_arn)
decrypt_response_1["Plaintext"].should.equal(decrypt_response_2["Plaintext"])
@mock_kms
def test_re_encrypt_to_invalid_destination():
client = boto3.client("kms", region_name="us-west-2")
key = client.create_key(Description="key 1")
key_id = key["KeyMetadata"]["KeyId"]
encrypt_response = client.encrypt(KeyId=key_id, Plaintext=b"some plaintext")
with assert_raises(client.exceptions.NotFoundException):
client.re_encrypt(
CiphertextBlob=encrypt_response["CiphertextBlob"],
DestinationKeyId="alias/DoesNotExist",
)
@parameterized(((12,), (44,), (91,), (1,), (1024,)))
@mock_kms
def test_generate_random(number_of_bytes):
client = boto3.client("kms", region_name="us-west-2")
response = client.generate_random(NumberOfBytes=number_of_bytes)
response["Plaintext"].should.be.a(bytes)
len(response["Plaintext"]).should.equal(number_of_bytes)
@parameterized(
(
(2048, botocore.exceptions.ClientError),
(1025, botocore.exceptions.ClientError),
(0, botocore.exceptions.ParamValidationError),
(-1, botocore.exceptions.ParamValidationError),
(-1024, botocore.exceptions.ParamValidationError),
)
)
@mock_kms
def test_generate_random_invalid_number_of_bytes(number_of_bytes, error_type):
client = boto3.client("kms", region_name="us-west-2")
with assert_raises(error_type):
client.generate_random(NumberOfBytes=number_of_bytes)
@mock_kms
def test_enable_key_rotation_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.enable_key_rotation(KeyId="12366f9b-1230-123d-123e-123e6ae60c02")
@mock_kms
def test_disable_key_rotation_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.disable_key_rotation(KeyId="12366f9b-1230-123d-123e-123e6ae60c02")
@mock_kms
def test_enable_key_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.enable_key(KeyId="12366f9b-1230-123d-123e-123e6ae60c02")
@mock_kms
def test_disable_key_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.disable_key(KeyId="12366f9b-1230-123d-123e-123e6ae60c02")
@mock_kms
def test_cancel_key_deletion_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.cancel_key_deletion(KeyId="12366f9b-1230-123d-123e-123e6ae60c02")
@mock_kms
def test_schedule_key_deletion_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.schedule_key_deletion(KeyId="12366f9b-1230-123d-123e-123e6ae60c02")
@mock_kms
def test_get_key_rotation_status_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.get_key_rotation_status(KeyId="12366f9b-1230-123d-123e-123e6ae60c02")
@mock_kms
def test_get_key_policy_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.get_key_policy(
KeyId="12366f9b-1230-123d-123e-123e6ae60c02", PolicyName="default"
)
@mock_kms
def test_list_key_policies_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.list_key_policies(KeyId="12366f9b-1230-123d-123e-123e6ae60c02")
@mock_kms
def test_put_key_policy_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.put_key_policy(
KeyId="00000000-0000-0000-0000-000000000000",
PolicyName="default",
Policy="new policy",
)

View File

@ -0,0 +1,638 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from datetime import datetime
from dateutil.tz import tzutc
import base64
import os
import boto3
import botocore.exceptions
import six
import sure # noqa
from freezegun import freeze_time
from nose.tools import assert_raises
from parameterized import parameterized
from moto import mock_kms
PLAINTEXT_VECTORS = (
(b"some encodeable plaintext",),
(b"some unencodeable plaintext \xec\x8a\xcf\xb6r\xe9\xb5\xeb\xff\xa23\x16",),
("some unicode characters ø˚∆øˆˆ∆ßçøˆˆçßøˆ¨¥",),
)
def _get_encoded_value(plaintext):
if isinstance(plaintext, six.binary_type):
return plaintext
return plaintext.encode("utf-8")
@mock_kms
def test_create_key():
conn = boto3.client("kms", region_name="us-east-1")
key = conn.create_key(
Policy="my policy",
Description="my key",
KeyUsage="ENCRYPT_DECRYPT",
Tags=[{"TagKey": "project", "TagValue": "moto"}],
)
key["KeyMetadata"]["Arn"].should.equal(
"arn:aws:kms:us-east-1:123456789012:key/{}".format(key["KeyMetadata"]["KeyId"])
)
key["KeyMetadata"]["AWSAccountId"].should.equal("123456789012")
key["KeyMetadata"]["CreationDate"].should.be.a(datetime)
key["KeyMetadata"]["CustomerMasterKeySpec"].should.equal("SYMMETRIC_DEFAULT")
key["KeyMetadata"]["Description"].should.equal("my key")
key["KeyMetadata"]["Enabled"].should.be.ok
key["KeyMetadata"]["EncryptionAlgorithms"].should.equal(["SYMMETRIC_DEFAULT"])
key["KeyMetadata"]["KeyId"].should_not.be.empty
key["KeyMetadata"]["KeyManager"].should.equal("CUSTOMER")
key["KeyMetadata"]["KeyState"].should.equal("Enabled")
key["KeyMetadata"]["KeyUsage"].should.equal("ENCRYPT_DECRYPT")
key["KeyMetadata"]["Origin"].should.equal("AWS_KMS")
key["KeyMetadata"].should_not.have.key("SigningAlgorithms")
key = conn.create_key(KeyUsage="ENCRYPT_DECRYPT", CustomerMasterKeySpec="RSA_2048",)
sorted(key["KeyMetadata"]["EncryptionAlgorithms"]).should.equal(
["RSAES_OAEP_SHA_1", "RSAES_OAEP_SHA_256"]
)
key["KeyMetadata"].should_not.have.key("SigningAlgorithms")
key = conn.create_key(KeyUsage="SIGN_VERIFY", CustomerMasterKeySpec="RSA_2048",)
key["KeyMetadata"].should_not.have.key("EncryptionAlgorithms")
sorted(key["KeyMetadata"]["SigningAlgorithms"]).should.equal(
[
"RSASSA_PKCS1_V1_5_SHA_256",
"RSASSA_PKCS1_V1_5_SHA_384",
"RSASSA_PKCS1_V1_5_SHA_512",
"RSASSA_PSS_SHA_256",
"RSASSA_PSS_SHA_384",
"RSASSA_PSS_SHA_512",
]
)
key = conn.create_key(
KeyUsage="SIGN_VERIFY", CustomerMasterKeySpec="ECC_SECG_P256K1",
)
key["KeyMetadata"].should_not.have.key("EncryptionAlgorithms")
key["KeyMetadata"]["SigningAlgorithms"].should.equal(["ECDSA_SHA_256"])
key = conn.create_key(
KeyUsage="SIGN_VERIFY", CustomerMasterKeySpec="ECC_NIST_P384",
)
key["KeyMetadata"].should_not.have.key("EncryptionAlgorithms")
key["KeyMetadata"]["SigningAlgorithms"].should.equal(["ECDSA_SHA_384"])
key = conn.create_key(
KeyUsage="SIGN_VERIFY", CustomerMasterKeySpec="ECC_NIST_P521",
)
key["KeyMetadata"].should_not.have.key("EncryptionAlgorithms")
key["KeyMetadata"]["SigningAlgorithms"].should.equal(["ECDSA_SHA_512"])
@mock_kms
def test_describe_key():
client = boto3.client("kms", region_name="us-east-1")
response = client.create_key(Description="my key", KeyUsage="ENCRYPT_DECRYPT",)
key_id = response["KeyMetadata"]["KeyId"]
response = client.describe_key(KeyId=key_id)
response["KeyMetadata"]["AWSAccountId"].should.equal("123456789012")
response["KeyMetadata"]["CreationDate"].should.be.a(datetime)
response["KeyMetadata"]["CustomerMasterKeySpec"].should.equal("SYMMETRIC_DEFAULT")
response["KeyMetadata"]["Description"].should.equal("my key")
response["KeyMetadata"]["Enabled"].should.be.ok
response["KeyMetadata"]["EncryptionAlgorithms"].should.equal(["SYMMETRIC_DEFAULT"])
response["KeyMetadata"]["KeyId"].should_not.be.empty
response["KeyMetadata"]["KeyManager"].should.equal("CUSTOMER")
response["KeyMetadata"]["KeyState"].should.equal("Enabled")
response["KeyMetadata"]["KeyUsage"].should.equal("ENCRYPT_DECRYPT")
response["KeyMetadata"]["Origin"].should.equal("AWS_KMS")
response["KeyMetadata"].should_not.have.key("SigningAlgorithms")
@parameterized(
(
("alias/does-not-exist",),
("arn:aws:kms:us-east-1:012345678912:alias/does-not-exist",),
("invalid",),
)
)
@mock_kms
def test_describe_key_via_alias_invalid_alias(key_id):
client = boto3.client("kms", region_name="us-east-1")
client.create_key(Description="key")
with assert_raises(client.exceptions.NotFoundException):
client.describe_key(KeyId=key_id)
@mock_kms
def test_generate_data_key():
kms = boto3.client("kms", region_name="us-west-2")
key = kms.create_key()
key_id = key["KeyMetadata"]["KeyId"]
key_arn = key["KeyMetadata"]["Arn"]
response = kms.generate_data_key(KeyId=key_id, NumberOfBytes=32)
# CiphertextBlob must NOT be base64-encoded
with assert_raises(Exception):
base64.b64decode(response["CiphertextBlob"], validate=True)
# Plaintext must NOT be base64-encoded
with assert_raises(Exception):
base64.b64decode(response["Plaintext"], validate=True)
response["KeyId"].should.equal(key_arn)
@parameterized(PLAINTEXT_VECTORS)
@mock_kms
def test_encrypt(plaintext):
client = boto3.client("kms", region_name="us-west-2")
key = client.create_key(Description="key")
key_id = key["KeyMetadata"]["KeyId"]
key_arn = key["KeyMetadata"]["Arn"]
response = client.encrypt(KeyId=key_id, Plaintext=plaintext)
response["CiphertextBlob"].should_not.equal(plaintext)
# CiphertextBlob must NOT be base64-encoded
with assert_raises(Exception):
base64.b64decode(response["CiphertextBlob"], validate=True)
response["KeyId"].should.equal(key_arn)
@parameterized(PLAINTEXT_VECTORS)
@mock_kms
def test_decrypt(plaintext):
client = boto3.client("kms", region_name="us-west-2")
key = client.create_key(Description="key")
key_id = key["KeyMetadata"]["KeyId"]
key_arn = key["KeyMetadata"]["Arn"]
encrypt_response = client.encrypt(KeyId=key_id, Plaintext=plaintext)
client.create_key(Description="key")
# CiphertextBlob must NOT be base64-encoded
with assert_raises(Exception):
base64.b64decode(encrypt_response["CiphertextBlob"], validate=True)
decrypt_response = client.decrypt(CiphertextBlob=encrypt_response["CiphertextBlob"])
# Plaintext must NOT be base64-encoded
with assert_raises(Exception):
base64.b64decode(decrypt_response["Plaintext"], validate=True)
decrypt_response["Plaintext"].should.equal(_get_encoded_value(plaintext))
decrypt_response["KeyId"].should.equal(key_arn)
@parameterized(
(
("not-a-uuid",),
("alias/DoesNotExist",),
("arn:aws:kms:us-east-1:012345678912:alias/DoesNotExist",),
("d25652e4-d2d2-49f7-929a-671ccda580c6",),
(
"arn:aws:kms:us-east-1:012345678912:key/d25652e4-d2d2-49f7-929a-671ccda580c6",
),
)
)
@mock_kms
def test_invalid_key_ids(key_id):
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.generate_data_key(KeyId=key_id, NumberOfBytes=5)
@parameterized(PLAINTEXT_VECTORS)
@mock_kms
def test_kms_encrypt(plaintext):
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="key")
response = client.encrypt(KeyId=key["KeyMetadata"]["KeyId"], Plaintext=plaintext)
response = client.decrypt(CiphertextBlob=response["CiphertextBlob"])
response["Plaintext"].should.equal(_get_encoded_value(plaintext))
@mock_kms
def test_disable_key():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="disable-key")
client.disable_key(KeyId=key["KeyMetadata"]["KeyId"])
result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"])
assert result["KeyMetadata"]["Enabled"] == False
assert result["KeyMetadata"]["KeyState"] == "Disabled"
@mock_kms
def test_enable_key():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="enable-key")
client.disable_key(KeyId=key["KeyMetadata"]["KeyId"])
client.enable_key(KeyId=key["KeyMetadata"]["KeyId"])
result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"])
assert result["KeyMetadata"]["Enabled"] == True
assert result["KeyMetadata"]["KeyState"] == "Enabled"
@mock_kms
def test_schedule_key_deletion():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="schedule-key-deletion")
if os.environ.get("TEST_SERVER_MODE", "false").lower() == "false":
with freeze_time("2015-01-01 12:00:00"):
response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"])
assert response["KeyId"] == key["KeyMetadata"]["KeyId"]
assert response["DeletionDate"] == datetime(
2015, 1, 31, 12, 0, tzinfo=tzutc()
)
else:
# Can't manipulate time in server mode
response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"])
assert response["KeyId"] == key["KeyMetadata"]["KeyId"]
result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"])
assert result["KeyMetadata"]["Enabled"] == False
assert result["KeyMetadata"]["KeyState"] == "PendingDeletion"
assert "DeletionDate" in result["KeyMetadata"]
@mock_kms
def test_schedule_key_deletion_custom():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="schedule-key-deletion")
if os.environ.get("TEST_SERVER_MODE", "false").lower() == "false":
with freeze_time("2015-01-01 12:00:00"):
response = client.schedule_key_deletion(
KeyId=key["KeyMetadata"]["KeyId"], PendingWindowInDays=7
)
assert response["KeyId"] == key["KeyMetadata"]["KeyId"]
assert response["DeletionDate"] == datetime(
2015, 1, 8, 12, 0, tzinfo=tzutc()
)
else:
# Can't manipulate time in server mode
response = client.schedule_key_deletion(
KeyId=key["KeyMetadata"]["KeyId"], PendingWindowInDays=7
)
assert response["KeyId"] == key["KeyMetadata"]["KeyId"]
result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"])
assert result["KeyMetadata"]["Enabled"] == False
assert result["KeyMetadata"]["KeyState"] == "PendingDeletion"
assert "DeletionDate" in result["KeyMetadata"]
@mock_kms
def test_cancel_key_deletion():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="cancel-key-deletion")
client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"])
response = client.cancel_key_deletion(KeyId=key["KeyMetadata"]["KeyId"])
assert response["KeyId"] == key["KeyMetadata"]["KeyId"]
result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"])
assert result["KeyMetadata"]["Enabled"] == False
assert result["KeyMetadata"]["KeyState"] == "Disabled"
assert "DeletionDate" not in result["KeyMetadata"]
@mock_kms
def test_update_key_description():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="old_description")
key_id = key["KeyMetadata"]["KeyId"]
result = client.update_key_description(KeyId=key_id, Description="new_description")
assert "ResponseMetadata" in result
@mock_kms
def test_tag_resource():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="cancel-key-deletion")
response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"])
keyid = response["KeyId"]
response = client.tag_resource(
KeyId=keyid, Tags=[{"TagKey": "string", "TagValue": "string"}]
)
# Shouldn't have any data, just header
assert len(response.keys()) == 1
@mock_kms
def test_list_resource_tags():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="cancel-key-deletion")
response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"])
keyid = response["KeyId"]
response = client.tag_resource(
KeyId=keyid, Tags=[{"TagKey": "string", "TagValue": "string"}]
)
response = client.list_resource_tags(KeyId=keyid)
assert response["Tags"][0]["TagKey"] == "string"
assert response["Tags"][0]["TagValue"] == "string"
@parameterized(
(
(dict(KeySpec="AES_256"), 32),
(dict(KeySpec="AES_128"), 16),
(dict(NumberOfBytes=64), 64),
(dict(NumberOfBytes=1), 1),
(dict(NumberOfBytes=1024), 1024),
)
)
@mock_kms
def test_generate_data_key_sizes(kwargs, expected_key_length):
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="generate-data-key-size")
response = client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"], **kwargs)
assert len(response["Plaintext"]) == expected_key_length
@mock_kms
def test_generate_data_key_decrypt():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="generate-data-key-decrypt")
resp1 = client.generate_data_key(
KeyId=key["KeyMetadata"]["KeyId"], KeySpec="AES_256"
)
resp2 = client.decrypt(CiphertextBlob=resp1["CiphertextBlob"])
assert resp1["Plaintext"] == resp2["Plaintext"]
@parameterized(
(
(dict(KeySpec="AES_257"),),
(dict(KeySpec="AES_128", NumberOfBytes=16),),
(dict(NumberOfBytes=2048),),
(dict(NumberOfBytes=0),),
(dict(),),
)
)
@mock_kms
def test_generate_data_key_invalid_size_params(kwargs):
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="generate-data-key-size")
with assert_raises(
(botocore.exceptions.ClientError, botocore.exceptions.ParamValidationError)
) as err:
client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"], **kwargs)
@parameterized(
(
("alias/DoesNotExist",),
("arn:aws:kms:us-east-1:012345678912:alias/DoesNotExist",),
("d25652e4-d2d2-49f7-929a-671ccda580c6",),
(
"arn:aws:kms:us-east-1:012345678912:key/d25652e4-d2d2-49f7-929a-671ccda580c6",
),
)
)
@mock_kms
def test_generate_data_key_invalid_key(key_id):
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.generate_data_key(KeyId=key_id, KeySpec="AES_256")
@parameterized(
(
("alias/DoesExist", False),
("arn:aws:kms:us-east-1:012345678912:alias/DoesExist", False),
("", True),
("arn:aws:kms:us-east-1:012345678912:key/", True),
)
)
@mock_kms
def test_generate_data_key_all_valid_key_ids(prefix, append_key_id):
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key()
key_id = key["KeyMetadata"]["KeyId"]
client.create_alias(AliasName="alias/DoesExist", TargetKeyId=key_id)
target_id = prefix
if append_key_id:
target_id += key_id
client.generate_data_key(KeyId=key_id, NumberOfBytes=32)
@mock_kms
def test_generate_data_key_without_plaintext_decrypt():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="generate-data-key-decrypt")
resp1 = client.generate_data_key_without_plaintext(
KeyId=key["KeyMetadata"]["KeyId"], KeySpec="AES_256"
)
assert "Plaintext" not in resp1
@parameterized(PLAINTEXT_VECTORS)
@mock_kms
def test_re_encrypt_decrypt(plaintext):
client = boto3.client("kms", region_name="us-west-2")
key_1 = client.create_key(Description="key 1")
key_1_id = key_1["KeyMetadata"]["KeyId"]
key_1_arn = key_1["KeyMetadata"]["Arn"]
key_2 = client.create_key(Description="key 2")
key_2_id = key_2["KeyMetadata"]["KeyId"]
key_2_arn = key_2["KeyMetadata"]["Arn"]
encrypt_response = client.encrypt(
KeyId=key_1_id, Plaintext=plaintext, EncryptionContext={"encryption": "context"}
)
re_encrypt_response = client.re_encrypt(
CiphertextBlob=encrypt_response["CiphertextBlob"],
SourceEncryptionContext={"encryption": "context"},
DestinationKeyId=key_2_id,
DestinationEncryptionContext={"another": "context"},
)
# CiphertextBlob must NOT be base64-encoded
with assert_raises(Exception):
base64.b64decode(re_encrypt_response["CiphertextBlob"], validate=True)
re_encrypt_response["SourceKeyId"].should.equal(key_1_arn)
re_encrypt_response["KeyId"].should.equal(key_2_arn)
decrypt_response_1 = client.decrypt(
CiphertextBlob=encrypt_response["CiphertextBlob"],
EncryptionContext={"encryption": "context"},
)
decrypt_response_1["Plaintext"].should.equal(_get_encoded_value(plaintext))
decrypt_response_1["KeyId"].should.equal(key_1_arn)
decrypt_response_2 = client.decrypt(
CiphertextBlob=re_encrypt_response["CiphertextBlob"],
EncryptionContext={"another": "context"},
)
decrypt_response_2["Plaintext"].should.equal(_get_encoded_value(plaintext))
decrypt_response_2["KeyId"].should.equal(key_2_arn)
decrypt_response_1["Plaintext"].should.equal(decrypt_response_2["Plaintext"])
@mock_kms
def test_re_encrypt_to_invalid_destination():
client = boto3.client("kms", region_name="us-west-2")
key = client.create_key(Description="key 1")
key_id = key["KeyMetadata"]["KeyId"]
encrypt_response = client.encrypt(KeyId=key_id, Plaintext=b"some plaintext")
with assert_raises(client.exceptions.NotFoundException):
client.re_encrypt(
CiphertextBlob=encrypt_response["CiphertextBlob"],
DestinationKeyId="alias/DoesNotExist",
)
@parameterized(((12,), (44,), (91,), (1,), (1024,)))
@mock_kms
def test_generate_random(number_of_bytes):
client = boto3.client("kms", region_name="us-west-2")
response = client.generate_random(NumberOfBytes=number_of_bytes)
response["Plaintext"].should.be.a(bytes)
len(response["Plaintext"]).should.equal(number_of_bytes)
@parameterized(
(
(2048, botocore.exceptions.ClientError),
(1025, botocore.exceptions.ClientError),
(0, botocore.exceptions.ParamValidationError),
(-1, botocore.exceptions.ParamValidationError),
(-1024, botocore.exceptions.ParamValidationError),
)
)
@mock_kms
def test_generate_random_invalid_number_of_bytes(number_of_bytes, error_type):
client = boto3.client("kms", region_name="us-west-2")
with assert_raises(error_type):
client.generate_random(NumberOfBytes=number_of_bytes)
@mock_kms
def test_enable_key_rotation_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.enable_key_rotation(KeyId="12366f9b-1230-123d-123e-123e6ae60c02")
@mock_kms
def test_disable_key_rotation_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.disable_key_rotation(KeyId="12366f9b-1230-123d-123e-123e6ae60c02")
@mock_kms
def test_enable_key_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.enable_key(KeyId="12366f9b-1230-123d-123e-123e6ae60c02")
@mock_kms
def test_disable_key_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.disable_key(KeyId="12366f9b-1230-123d-123e-123e6ae60c02")
@mock_kms
def test_cancel_key_deletion_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.cancel_key_deletion(KeyId="12366f9b-1230-123d-123e-123e6ae60c02")
@mock_kms
def test_schedule_key_deletion_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.schedule_key_deletion(KeyId="12366f9b-1230-123d-123e-123e6ae60c02")
@mock_kms
def test_get_key_rotation_status_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.get_key_rotation_status(KeyId="12366f9b-1230-123d-123e-123e6ae60c02")
@mock_kms
def test_get_key_policy_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.get_key_policy(
KeyId="12366f9b-1230-123d-123e-123e6ae60c02", PolicyName="default"
)
@mock_kms
def test_list_key_policies_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.list_key_policies(KeyId="12366f9b-1230-123d-123e-123e6ae60c02")
@mock_kms
def test_put_key_policy_key_not_found():
client = boto3.client("kms", region_name="us-east-1")
with assert_raises(client.exceptions.NotFoundException):
client.put_key_policy(
KeyId="00000000-0000-0000-0000-000000000000",
PolicyName="default",
Policy="new policy",
)

View File

@ -102,7 +102,7 @@ def test_deserialize_ciphertext_blob(raw, serialized):
@parameterized(((ec[0],) for ec in ENCRYPTION_CONTEXT_VECTORS)) @parameterized(((ec[0],) for ec in ENCRYPTION_CONTEXT_VECTORS))
def test_encrypt_decrypt_cycle(encryption_context): def test_encrypt_decrypt_cycle(encryption_context):
plaintext = b"some secret plaintext" plaintext = b"some secret plaintext"
master_key = Key("nop", "nop", "nop", "nop") master_key = Key("nop", "nop", "nop", "nop", [], "nop")
master_key_map = {master_key.id: master_key} master_key_map = {master_key.id: master_key}
ciphertext_blob = encrypt( ciphertext_blob = encrypt(
@ -133,7 +133,7 @@ def test_encrypt_unknown_key_id():
def test_decrypt_invalid_ciphertext_format(): def test_decrypt_invalid_ciphertext_format():
master_key = Key("nop", "nop", "nop", "nop") master_key = Key("nop", "nop", "nop", "nop", [], "nop")
master_key_map = {master_key.id: master_key} master_key_map = {master_key.id: master_key}
with assert_raises(InvalidCiphertextException): with assert_raises(InvalidCiphertextException):
@ -153,7 +153,7 @@ def test_decrypt_unknwown_key_id():
def test_decrypt_invalid_ciphertext(): def test_decrypt_invalid_ciphertext():
master_key = Key("nop", "nop", "nop", "nop") master_key = Key("nop", "nop", "nop", "nop", [], "nop")
master_key_map = {master_key.id: master_key} master_key_map = {master_key.id: master_key}
ciphertext_blob = ( ciphertext_blob = (
master_key.id.encode("utf-8") + b"123456789012" master_key.id.encode("utf-8") + b"123456789012"
@ -171,7 +171,7 @@ def test_decrypt_invalid_ciphertext():
def test_decrypt_invalid_encryption_context(): def test_decrypt_invalid_encryption_context():
plaintext = b"some secret plaintext" plaintext = b"some secret plaintext"
master_key = Key("nop", "nop", "nop", "nop") master_key = Key("nop", "nop", "nop", "nop", [], "nop")
master_key_map = {master_key.id: master_key} master_key_map = {master_key.id: master_key}
ciphertext_blob = encrypt( ciphertext_blob = encrypt(

Some files were not shown because too many files have changed in this diff Show More