Merge remote-tracking branch 'upstream/master'

This commit is contained in:
Stephan Huber 2019-08-30 14:39:29 +02:00
commit 5a9c921d97
132 changed files with 28345 additions and 4011 deletions

1
.gitignore vendored
View File

@ -15,6 +15,7 @@ python_env
.ropeproject/ .ropeproject/
.pytest_cache/ .pytest_cache/
venv/ venv/
env/
.python-version .python-version
.vscode/ .vscode/
tests/file.tmp tests/file.tmp

View File

@ -2,36 +2,56 @@ dist: xenial
language: python language: python
sudo: false sudo: false
services: services:
- docker - docker
python: python:
- 2.7 - 2.7
- 3.6 - 3.6
- 3.7 - 3.7
env: env:
- TEST_SERVER_MODE=false - TEST_SERVER_MODE=false
- TEST_SERVER_MODE=true - TEST_SERVER_MODE=true
before_install: before_install:
- export BOTO_CONFIG=/dev/null - export BOTO_CONFIG=/dev/null
install: install:
# We build moto first so the docker container doesn't try to compile it as well, also note we don't use - |
# -d for docker run so the logs show up in travis python setup.py sdist
# Python images come from here: https://hub.docker.com/_/python/
- |
python setup.py sdist
if [ "$TEST_SERVER_MODE" = "true" ]; then if [ "$TEST_SERVER_MODE" = "true" ]; then
docker run --rm -t --name motoserver -e TEST_SERVER_MODE=true -e AWS_SECRET_ACCESS_KEY=server_secret -e AWS_ACCESS_KEY_ID=server_key -v `pwd`:/moto -p 5000:5000 -v /var/run/docker.sock:/var/run/docker.sock python:${TRAVIS_PYTHON_VERSION}-stretch /moto/travis_moto_server.sh & docker run --rm -t --name motoserver -e TEST_SERVER_MODE=true -e AWS_SECRET_ACCESS_KEY=server_secret -e AWS_ACCESS_KEY_ID=server_key -v `pwd`:/moto -p 5000:5000 -v /var/run/docker.sock:/var/run/docker.sock python:${TRAVIS_PYTHON_VERSION}-stretch /moto/travis_moto_server.sh &
fi fi
travis_retry pip install boto==2.45.0 travis_retry pip install boto==2.45.0
travis_retry pip install boto3 travis_retry pip install boto3
travis_retry pip install dist/moto*.gz travis_retry pip install dist/moto*.gz
travis_retry pip install coveralls==1.1 travis_retry pip install coveralls==1.1
travis_retry pip install -r requirements-dev.txt travis_retry pip install -r requirements-dev.txt
if [ "$TEST_SERVER_MODE" = "true" ]; then if [ "$TEST_SERVER_MODE" = "true" ]; then
python wait_for.py python wait_for.py
fi fi
script: script:
- make test - make test
after_success: after_success:
- coveralls - coveralls
before_deploy:
- git checkout $TRAVIS_BRANCH
- git fetch --unshallow
- python update_version_from_git.py
deploy:
- provider: pypi
distributions: sdist bdist_wheel
user: spulec
password:
secure: NxnPylnTfekJmGyoufCw0lMoYRskSMJzvAIyAlJJVYKwEhmiCPOrdy5qV8i8mRZ1AkUsqU3jBZ/PD56n96clHW0E3d080UleRDj6JpyALVdeLfMqZl9kLmZ8bqakWzYq3VSJKw2zGP/L4tPGf8wTK1SUv9yl/YNDsBdCkjDverw=
on:
branch:
- master
skip_cleanup: true
skip_existing: true
# - provider: pypi
# distributions: sdist bdist_wheel
# user: spulec
# password:
# secure: NxnPylnTfekJmGyoufCw0lMoYRskSMJzvAIyAlJJVYKwEhmiCPOrdy5qV8i8mRZ1AkUsqU3jBZ/PD56n96clHW0E3d080UleRDj6JpyALVdeLfMqZl9kLmZ8bqakWzYq3VSJKw2zGP/L4tPGf8wTK1SUv9yl/YNDsBdCkjDverw=
# on:
# tags: true
# skip_existing: true

View File

@ -54,5 +54,6 @@ Moto is written by Steve Pulec with contributions from:
* [William Richard](https://github.com/william-richard) * [William Richard](https://github.com/william-richard)
* [Alex Casalboni](https://github.com/alexcasalboni) * [Alex Casalboni](https://github.com/alexcasalboni)
* [Jon Beilke](https://github.com/jrbeilke) * [Jon Beilke](https://github.com/jrbeilke)
* [Bendeguz Acs](https://github.com/acsbendi)
* [Craig Anderson](https://github.com/craiga) * [Craig Anderson](https://github.com/craiga)
* [Robert Lewis](https://github.com/ralewis85) * [Robert Lewis](https://github.com/ralewis85)

View File

@ -2,6 +2,10 @@
Moto has a [Code of Conduct](https://github.com/spulec/moto/blob/master/CODE_OF_CONDUCT.md), you can expect to be treated with respect at all times when interacting with this project. Moto has a [Code of Conduct](https://github.com/spulec/moto/blob/master/CODE_OF_CONDUCT.md), you can expect to be treated with respect at all times when interacting with this project.
## Running the tests locally
Moto has a Makefile which has some helpful commands for getting setup. You should be able to run `make init` to install the dependencies and then `make test` to run the tests.
## Is there a missing feature? ## Is there a missing feature?
Moto is easier to contribute to than you probably think. There's [a list of which endpoints have been implemented](https://github.com/spulec/moto/blob/master/IMPLEMENTATION_COVERAGE.md) and we invite you to add new endpoints to existing services or to add new services. Moto is easier to contribute to than you probably think. There's [a list of which endpoints have been implemented](https://github.com/spulec/moto/blob/master/IMPLEMENTATION_COVERAGE.md) and we invite you to add new endpoints to existing services or to add new services.

File diff suppressed because it is too large Load Diff

View File

@ -10,7 +10,7 @@ endif
init: init:
@python setup.py develop @python setup.py develop
@pip install -r requirements.txt @pip install -r requirements-dev.txt
lint: lint:
flake8 moto flake8 moto

321
README.md
View File

@ -5,6 +5,9 @@
[![Build Status](https://travis-ci.org/spulec/moto.svg?branch=master)](https://travis-ci.org/spulec/moto) [![Build Status](https://travis-ci.org/spulec/moto.svg?branch=master)](https://travis-ci.org/spulec/moto)
[![Coverage Status](https://coveralls.io/repos/spulec/moto/badge.svg?branch=master)](https://coveralls.io/r/spulec/moto) [![Coverage Status](https://coveralls.io/repos/spulec/moto/badge.svg?branch=master)](https://coveralls.io/r/spulec/moto)
[![Docs](https://readthedocs.org/projects/pip/badge/?version=stable)](http://docs.getmoto.org) [![Docs](https://readthedocs.org/projects/pip/badge/?version=stable)](http://docs.getmoto.org)
![PyPI](https://img.shields.io/pypi/v/moto.svg)
![PyPI - Python Version](https://img.shields.io/pypi/pyversions/moto.svg)
![PyPI - Downloads](https://img.shields.io/pypi/dw/moto.svg)
# In a nutshell # In a nutshell
@ -55,95 +58,96 @@ With the decorator wrapping the test, all the calls to s3 are automatically mock
It gets even better! Moto isn't just for Python code and it isn't just for S3. Look at the [standalone server mode](https://github.com/spulec/moto#stand-alone-server-mode) for more information about running Moto with other languages. Here's the status of the other AWS services implemented: It gets even better! Moto isn't just for Python code and it isn't just for S3. Look at the [standalone server mode](https://github.com/spulec/moto#stand-alone-server-mode) for more information about running Moto with other languages. Here's the status of the other AWS services implemented:
```gherkin ```gherkin
|------------------------------------------------------------------------------| |-------------------------------------------------------------------------------------|
| Service Name | Decorator | Development Status | | Service Name | Decorator | Development Status |
|------------------------------------------------------------------------------| |-------------------------------------------------------------------------------------|
| ACM | @mock_acm | all endpoints done | | ACM | @mock_acm | all endpoints done |
|------------------------------------------------------------------------------| |-------------------------------------------------------------------------------------|
| API Gateway | @mock_apigateway | core endpoints done | | API Gateway | @mock_apigateway | core endpoints done |
|------------------------------------------------------------------------------| |-------------------------------------------------------------------------------------|
| Autoscaling | @mock_autoscaling| core endpoints done | | Autoscaling | @mock_autoscaling | core endpoints done |
|------------------------------------------------------------------------------| |-------------------------------------------------------------------------------------|
| Cloudformation | @mock_cloudformation| core endpoints done | | Cloudformation | @mock_cloudformation | core endpoints done |
|------------------------------------------------------------------------------| |-------------------------------------------------------------------------------------|
| Cloudwatch | @mock_cloudwatch | basic endpoints done | | Cloudwatch | @mock_cloudwatch | basic endpoints done |
|------------------------------------------------------------------------------| |-------------------------------------------------------------------------------------|
| CloudwatchEvents | @mock_events | all endpoints done | | CloudwatchEvents | @mock_events | all endpoints done |
|------------------------------------------------------------------------------| |-------------------------------------------------------------------------------------|
| Cognito Identity | @mock_cognitoidentity| basic endpoints done | | Cognito Identity | @mock_cognitoidentity | basic endpoints done |
|------------------------------------------------------------------------------| |-------------------------------------------------------------------------------------|
| Cognito Identity Provider | @mock_cognitoidp| basic endpoints done | | Cognito Identity Provider | @mock_cognitoidp | basic endpoints done |
|------------------------------------------------------------------------------| |-------------------------------------------------------------------------------------|
| Config | @mock_config | basic endpoints done | | Config | @mock_config | basic endpoints done |
|------------------------------------------------------------------------------| | | | core endpoints done |
| Data Pipeline | @mock_datapipeline| basic endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | Data Pipeline | @mock_datapipeline | basic endpoints done |
| DynamoDB | @mock_dynamodb | core endpoints done | |-------------------------------------------------------------------------------------|
| DynamoDB2 | @mock_dynamodb2 | all endpoints + partial indexes | | DynamoDB | @mock_dynamodb | core endpoints done |
|------------------------------------------------------------------------------| | DynamoDB2 | @mock_dynamodb2 | all endpoints + partial indexes |
| EC2 | @mock_ec2 | core endpoints done | |-------------------------------------------------------------------------------------|
| - AMI | | core endpoints done | | EC2 | @mock_ec2 | core endpoints done |
| - EBS | | core endpoints done | | - AMI | | core endpoints done |
| - Instances | | all endpoints done | | - EBS | | core endpoints done |
| - Security Groups | | core endpoints done | | - Instances | | all endpoints done |
| - Tags | | all endpoints done | | - Security Groups | | core endpoints done |
|------------------------------------------------------------------------------| | - Tags | | all endpoints done |
| ECR | @mock_ecr | basic endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | ECR | @mock_ecr | basic endpoints done |
| ECS | @mock_ecs | basic endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | ECS | @mock_ecs | basic endpoints done |
| ELB | @mock_elb | core endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | ELB | @mock_elb | core endpoints done |
| ELBv2 | @mock_elbv2 | all endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | ELBv2 | @mock_elbv2 | all endpoints done |
| EMR | @mock_emr | core endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | EMR | @mock_emr | core endpoints done |
| Glacier | @mock_glacier | core endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | Glacier | @mock_glacier | core endpoints done |
| IAM | @mock_iam | core endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | IAM | @mock_iam | core endpoints done |
| IoT | @mock_iot | core endpoints done | |-------------------------------------------------------------------------------------|
| | @mock_iotdata | core endpoints done | | IoT | @mock_iot | core endpoints done |
|------------------------------------------------------------------------------| | | @mock_iotdata | core endpoints done |
| Lambda | @mock_lambda | basic endpoints done, requires | |-------------------------------------------------------------------------------------|
| | | docker | | Kinesis | @mock_kinesis | core endpoints done |
|------------------------------------------------------------------------------| |-------------------------------------------------------------------------------------|
| Logs | @mock_logs | basic endpoints done | | KMS | @mock_kms | basic endpoints done |
|------------------------------------------------------------------------------| |-------------------------------------------------------------------------------------|
| Kinesis | @mock_kinesis | core endpoints done | | Lambda | @mock_lambda | basic endpoints done, requires |
|------------------------------------------------------------------------------| | | | docker |
| KMS | @mock_kms | basic endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | Logs | @mock_logs | basic endpoints done |
| Organizations | @mock_organizations | some core endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | Organizations | @mock_organizations | some core endpoints done |
| Polly | @mock_polly | all endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | Polly | @mock_polly | all endpoints done |
| RDS | @mock_rds | core endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | RDS | @mock_rds | core endpoints done |
| RDS2 | @mock_rds2 | core endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | RDS2 | @mock_rds2 | core endpoints done |
| Redshift | @mock_redshift | core endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | Redshift | @mock_redshift | core endpoints done |
| Route53 | @mock_route53 | core endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | Route53 | @mock_route53 | core endpoints done |
| S3 | @mock_s3 | core endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | S3 | @mock_s3 | core endpoints done |
| SecretsManager | @mock_secretsmanager | basic endpoints done |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | SecretsManager | @mock_secretsmanager | basic endpoints done |
| SES | @mock_ses | all endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | SES | @mock_ses | all endpoints done |
| SNS | @mock_sns | all endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | SNS | @mock_sns | all endpoints done |
| SQS | @mock_sqs | core endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | SQS | @mock_sqs | core endpoints done |
| SSM | @mock_ssm | core endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | SSM | @mock_ssm | core endpoints done |
| STS | @mock_sts | core endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | STS | @mock_sts | core endpoints done |
| SWF | @mock_swf | basic endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | SWF | @mock_swf | basic endpoints done |
| X-Ray | @mock_xray | all endpoints done | |-------------------------------------------------------------------------------------|
|------------------------------------------------------------------------------| | X-Ray | @mock_xray | all endpoints done |
|-------------------------------------------------------------------------------------|
``` ```
For a full list of endpoint [implementation coverage](https://github.com/spulec/moto/blob/master/IMPLEMENTATION_COVERAGE.md) For a full list of endpoint [implementation coverage](https://github.com/spulec/moto/blob/master/IMPLEMENTATION_COVERAGE.md)
@ -252,6 +256,137 @@ def test_my_model_save():
mock.stop() mock.stop()
``` ```
## IAM-like Access Control
Moto also has the ability to authenticate and authorize actions, just like it's done by IAM in AWS. This functionality can be enabled by either setting the `INITIAL_NO_AUTH_ACTION_COUNT` environment variable or using the `set_initial_no_auth_action_count` decorator. Note that the current implementation is very basic, see [this file](https://github.com/spulec/moto/blob/master/moto/core/access_control.py) for more information.
### `INITIAL_NO_AUTH_ACTION_COUNT`
If this environment variable is set, moto will skip performing any authentication as many times as the variable's value, and only starts authenticating requests afterwards. If it is not set, it defaults to infinity, thus moto will never perform any authentication at all.
### `set_initial_no_auth_action_count`
This is a decorator that works similarly to the environment variable, but the settings are only valid in the function's scope. When the function returns, everything is restored.
```python
@set_initial_no_auth_action_count(4)
@mock_ec2
def test_describe_instances_allowed():
policy_document = {
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": "ec2:Describe*",
"Resource": "*"
}
]
}
access_key = ...
# create access key for an IAM user/assumed role that has the policy above.
# this part should call __exactly__ 4 AWS actions, so that authentication and authorization starts exactly after this
client = boto3.client('ec2', region_name='us-east-1',
aws_access_key_id=access_key['AccessKeyId'],
aws_secret_access_key=access_key['SecretAccessKey'])
# if the IAM principal whose access key is used, does not have the permission to describe instances, this will fail
instances = client.describe_instances()['Reservations'][0]['Instances']
assert len(instances) == 0
```
See [the related test suite](https://github.com/spulec/moto/blob/master/tests/test_core/test_auth.py) for more examples.
## Very Important -- Recommended Usage
There are some important caveats to be aware of when using moto:
*Failure to follow these guidelines could result in your tests mutating your __REAL__ infrastructure!*
### How do I avoid tests from mutating my real infrastructure?
You need to ensure that the mocks are actually in place. Changes made to recent versions of `botocore`
have altered some of the mock behavior. In short, you need to ensure that you _always_ do the following:
1. Ensure that your tests have dummy environment variables set up:
export AWS_ACCESS_KEY_ID='testing'
export AWS_SECRET_ACCESS_KEY='testing'
export AWS_SECURITY_TOKEN='testing'
export AWS_SESSION_TOKEN='testing'
1. __VERY IMPORTANT__: ensure that you have your mocks set up __BEFORE__ your `boto3` client is established.
This can typically happen if you import a module that has a `boto3` client instantiated outside of a function.
See the pesky imports section below on how to work around this.
### Example on usage?
If you are a user of [pytest](https://pytest.org/en/latest/), you can leverage [pytest fixtures](https://pytest.org/en/latest/fixture.html#fixture)
to help set up your mocks and other AWS resources that you would need.
Here is an example:
```python
@pytest.fixture(scope='function')
def aws_credentials():
"""Mocked AWS Credentials for moto."""
os.environ['AWS_ACCESS_KEY_ID'] = 'testing'
os.environ['AWS_SECRET_ACCESS_KEY'] = 'testing'
os.environ['AWS_SECURITY_TOKEN'] = 'testing'
os.environ['AWS_SESSION_TOKEN'] = 'testing'
@pytest.fixture(scope='function')
def s3(aws_credentials):
with mock_s3():
yield boto3.client('s3', region_name='us-east-1')
@pytest.fixture(scope='function')
def sts(aws_credentials):
with mock_sts():
yield boto3.client('sts', region_name='us-east-1')
@pytest.fixture(scope='function')
def cloudwatch(aws_credentials):
with mock_cloudwatch():
yield boto3.client('cloudwatch', region_name='us-east-1')
... etc.
```
In the code sample above, all of the AWS/mocked fixtures take in a parameter of `aws_credentials`,
which sets the proper fake environment variables. The fake environment variables are used so that `botocore` doesn't try to locate real
credentials on your system.
Next, once you need to do anything with the mocked AWS environment, do something like:
```python
def test_create_bucket(s3):
# s3 is a fixture defined above that yields a boto3 s3 client.
# Feel free to instantiate another boto3 S3 client -- Keep note of the region though.
s3.create_bucket(Bucket="somebucket")
result = s3.list_buckets()
assert len(result['Buckets']) == 1
assert result['Buckets'][0]['Name'] == 'somebucket'
```
### What about those pesky imports?
Recall earlier, it was mentioned that mocks should be established __BEFORE__ the clients are set up. One way
to avoid import issues is to make use of local Python imports -- i.e. import the module inside of the unit
test you want to run vs. importing at the top of the file.
Example:
```python
def test_something(s3):
from some.package.that.does.something.with.s3 import some_func # <-- Local import for unit test
# ^^ Importing here ensures that the mock has been established.
sume_func() # The mock has been established from the "s3" pytest fixture, so this function that uses
# a package-level S3 client will properly use the mock and not reach out to AWS.
```
### Other caveats
For Tox, Travis CI, and other build systems, you might need to also perform a `touch ~/.aws/credentials`
command before running the tests. As long as that file is present (empty preferably) and the environment
variables above are set, you should be good to go.
## Stand-alone Server Mode ## Stand-alone Server Mode
Moto also has a stand-alone server mode. This allows you to utilize Moto also has a stand-alone server mode. This allows you to utilize
@ -318,3 +453,11 @@ boto3.resource(
```console ```console
$ pip install moto $ pip install moto
``` ```
## Releases
Releases are done from travisci. Fairly closely following this:
https://docs.travis-ci.com/user/deployment/pypi/
- Commits to `master` branch do a dev deploy to pypi.
- Commits to a tag do a real deploy to pypi.

View File

@ -17,66 +17,95 @@ with ``moto`` and its usage.
Currently implemented Services: Currently implemented Services:
------------------------------- -------------------------------
+-----------------------+---------------------+-----------------------------------+ +---------------------------+-----------------------+------------------------------------+
| Service Name | Decorator | Development Status | | Service Name | Decorator | Development Status |
+=======================+=====================+===================================+ +===========================+=======================+====================================+
| API Gateway | @mock_apigateway | core endpoints done | | ACM | @mock_acm | all endpoints done |
+-----------------------+---------------------+-----------------------------------+ +---------------------------+-----------------------+------------------------------------+
| Autoscaling | @mock_autoscaling | core endpoints done | | API Gateway | @mock_apigateway | core endpoints done |
+-----------------------+---------------------+-----------------------------------+ +---------------------------+-----------------------+------------------------------------+
| Cloudformation | @mock_cloudformation| core endpoints done | | Autoscaling | @mock_autoscaling | core endpoints done |
+-----------------------+---------------------+-----------------------------------+ +---------------------------+-----------------------+------------------------------------+
| Cloudwatch | @mock_cloudwatch | basic endpoints done | | Cloudformation | @mock_cloudformation | core endpoints done |
+-----------------------+---------------------+-----------------------------------+ +---------------------------+-----------------------+------------------------------------+
| Data Pipeline | @mock_datapipeline | basic endpoints done | | Cloudwatch | @mock_cloudwatch | basic endpoints done |
+-----------------------+---------------------+-----------------------------------+ +---------------------------+-----------------------+------------------------------------+
| - DynamoDB | - @mock_dynamodb | - core endpoints done | | CloudwatchEvents | @mock_events | all endpoints done |
| - DynamoDB2 | - @mock_dynamodb2 | - core endpoints + partial indexes| +---------------------------+-----------------------+------------------------------------+
+-----------------------+---------------------+-----------------------------------+ | Cognito Identity | @mock_cognitoidentity | all endpoints done |
| EC2 | @mock_ec2 | core endpoints done | +---------------------------+-----------------------+------------------------------------+
| - AMI | | - core endpoints done | | Cognito Identity Provider | @mock_cognitoidp | all endpoints done |
| - EBS | | - core endpoints done | +---------------------------+-----------------------+------------------------------------+
| - Instances | | - all endpoints done | | Config | @mock_config | basic endpoints done |
| - Security Groups | | - core endpoints done | +---------------------------+-----------------------+------------------------------------+
| - Tags | | - all endpoints done | | Data Pipeline | @mock_datapipeline | basic endpoints done |
+-----------------------+---------------------+-----------------------------------+ +---------------------------+-----------------------+------------------------------------+
| ECS | @mock_ecs | basic endpoints done | | DynamoDB | - @mock_dynamodb | - core endpoints done |
+-----------------------+---------------------+-----------------------------------+ | DynamoDB2 | - @mock_dynamodb2 | - core endpoints + partial indexes |
| ELB | @mock_elb | core endpoints done | +---------------------------+-----------------------+------------------------------------+
| | @mock_elbv2 | core endpoints done | | EC2 | @mock_ec2 | core endpoints done |
+-----------------------+---------------------+-----------------------------------+ | - AMI | | - core endpoints done |
| EMR | @mock_emr | core endpoints done | | - EBS | | - core endpoints done |
+-----------------------+---------------------+-----------------------------------+ | - Instances | | - all endpoints done |
| Glacier | @mock_glacier | core endpoints done | | - Security Groups | | - core endpoints done |
+-----------------------+---------------------+-----------------------------------+ | - Tags | | - all endpoints done |
| IAM | @mock_iam | core endpoints done | +---------------------------+-----------------------+------------------------------------+
+-----------------------+---------------------+-----------------------------------+ | ECR | @mock_ecr | basic endpoints done |
| Lambda | @mock_lambda | basic endpoints done | +---------------------------+-----------------------+------------------------------------+
+-----------------------+---------------------+-----------------------------------+ | ECS | @mock_ecs | basic endpoints done |
| Kinesis | @mock_kinesis | core endpoints done | +---------------------------+-----------------------+------------------------------------+
+-----------------------+---------------------+-----------------------------------+ | ELB | @mock_elb | core endpoints done |
| KMS | @mock_kms | basic endpoints done | +---------------------------+-----------------------+------------------------------------+
+-----------------------+---------------------+-----------------------------------+ | ELBv2 | @mock_elbv2 | all endpoints done |
| RDS | @mock_rds | core endpoints done | +---------------------------+-----------------------+------------------------------------+
+-----------------------+---------------------+-----------------------------------+ | EMR | @mock_emr | core endpoints done |
| RDS2 | @mock_rds2 | core endpoints done | +---------------------------+-----------------------+------------------------------------+
+-----------------------+---------------------+-----------------------------------+ | Glacier | @mock_glacier | core endpoints done |
| Redshift | @mock_redshift | core endpoints done | +---------------------------+-----------------------+------------------------------------+
+-----------------------+---------------------+-----------------------------------+ | IAM | @mock_iam | core endpoints done |
| Route53 | @mock_route53 | core endpoints done | +---------------------------+-----------------------+------------------------------------+
+-----------------------+---------------------+-----------------------------------+ | IoT | @mock_iot | core endpoints done |
| S3 | @mock_s3 | core endpoints done | | | @mock_iotdata | core endpoints done |
+-----------------------+---------------------+-----------------------------------+ +---------------------------+-----------------------+------------------------------------+
| SES | @mock_ses | core endpoints done | | Kinesis | @mock_kinesis | core endpoints done |
+-----------------------+---------------------+-----------------------------------+ +---------------------------+-----------------------+------------------------------------+
| SNS | @mock_sns | core endpoints done | | KMS | @mock_kms | basic endpoints done |
+-----------------------+---------------------+-----------------------------------+ +---------------------------+-----------------------+------------------------------------+
| SQS | @mock_sqs | core endpoints done | | Lambda | @mock_lambda | basic endpoints done, |
+-----------------------+---------------------+-----------------------------------+ | | | requires docker |
| STS | @mock_sts | core endpoints done | +---------------------------+-----------------------+------------------------------------+
+-----------------------+---------------------+-----------------------------------+ | Logs | @mock_logs | basic endpoints done |
| SWF | @mock_swf | basic endpoints done | +---------------------------+-----------------------+------------------------------------+
+-----------------------+---------------------+-----------------------------------+ | Organizations | @mock_organizations | some core edpoints done |
+---------------------------+-----------------------+------------------------------------+
| Polly | @mock_polly | all endpoints done |
+---------------------------+-----------------------+------------------------------------+
| RDS | @mock_rds | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| RDS2 | @mock_rds2 | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| Redshift | @mock_redshift | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| Route53 | @mock_route53 | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| S3 | @mock_s3 | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| SecretsManager | @mock_secretsmanager | basic endpoints done |
+---------------------------+-----------------------+------------------------------------+
| SES | @mock_ses | all endpoints done |
+---------------------------+-----------------------+------------------------------------+
| SNS | @mock_sns | all endpoints done |
+---------------------------+-----------------------+------------------------------------+
| SQS | @mock_sqs | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| SSM | @mock_ssm | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| STS | @mock_sts | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| SWF | @mock_swf | basic endpoints done |
+---------------------------+-----------------------+------------------------------------+
| X-Ray | @mock_xray | all endpoints done |
+---------------------------+-----------------------+------------------------------------+

View File

@ -3,7 +3,7 @@ import logging
# logging.getLogger('boto').setLevel(logging.CRITICAL) # logging.getLogger('boto').setLevel(logging.CRITICAL)
__title__ = 'moto' __title__ = 'moto'
__version__ = '1.3.8' __version__ = '1.3.14.dev'
from .acm import mock_acm # flake8: noqa from .acm import mock_acm # flake8: noqa
from .apigateway import mock_apigateway, mock_apigateway_deprecated # flake8: noqa from .apigateway import mock_apigateway, mock_apigateway_deprecated # flake8: noqa

View File

@ -105,7 +105,7 @@ class CertBundle(BaseModel):
self.arn = arn self.arn = arn
@classmethod @classmethod
def generate_cert(cls, domain_name, sans=None): def generate_cert(cls, domain_name, region, sans=None):
if sans is None: if sans is None:
sans = set() sans = set()
else: else:
@ -152,7 +152,7 @@ class CertBundle(BaseModel):
encryption_algorithm=serialization.NoEncryption() encryption_algorithm=serialization.NoEncryption()
) )
return cls(cert_armored, private_key, cert_type='AMAZON_ISSUED', cert_status='PENDING_VALIDATION') return cls(cert_armored, private_key, cert_type='AMAZON_ISSUED', cert_status='PENDING_VALIDATION', region=region)
def validate_pk(self): def validate_pk(self):
try: try:
@ -325,7 +325,7 @@ class AWSCertificateManagerBackend(BaseBackend):
return bundle.arn return bundle.arn
def get_certificates_list(self): def get_certificates_list(self, statuses):
""" """
Get list of certificates Get list of certificates
@ -333,7 +333,9 @@ class AWSCertificateManagerBackend(BaseBackend):
:rtype: list of CertBundle :rtype: list of CertBundle
""" """
for arn in self._certificates.keys(): for arn in self._certificates.keys():
yield self.get_certificate(arn) cert = self.get_certificate(arn)
if not statuses or cert.status in statuses:
yield cert
def get_certificate(self, arn): def get_certificate(self, arn):
if arn not in self._certificates: if arn not in self._certificates:
@ -355,7 +357,7 @@ class AWSCertificateManagerBackend(BaseBackend):
if arn is not None: if arn is not None:
return arn return arn
cert = CertBundle.generate_cert(domain_name, subject_alt_names) cert = CertBundle.generate_cert(domain_name, region=self.region, sans=subject_alt_names)
if idempotency_token is not None: if idempotency_token is not None:
self._set_idempotency_token_arn(idempotency_token, cert.arn) self._set_idempotency_token_arn(idempotency_token, cert.arn)
self._certificates[cert.arn] = cert self._certificates[cert.arn] = cert

View File

@ -132,8 +132,8 @@ class AWSCertificateManagerResponse(BaseResponse):
def list_certificates(self): def list_certificates(self):
certs = [] certs = []
statuses = self._get_param('CertificateStatuses')
for cert_bundle in self.acm_backend.get_certificates_list(): for cert_bundle in self.acm_backend.get_certificates_list(statuses):
certs.append({ certs.append({
'CertificateArn': cert_bundle.arn, 'CertificateArn': cert_bundle.arn,
'DomainName': cert_bundle.common_name 'DomainName': cert_bundle.common_name

View File

@ -309,6 +309,25 @@ class ApiKey(BaseModel, dict):
self['createdDate'] = self['lastUpdatedDate'] = int(time.time()) self['createdDate'] = self['lastUpdatedDate'] = int(time.time())
self['stageKeys'] = stageKeys self['stageKeys'] = stageKeys
def update_operations(self, patch_operations):
for op in patch_operations:
if op['op'] == 'replace':
if '/name' in op['path']:
self['name'] = op['value']
elif '/customerId' in op['path']:
self['customerId'] = op['value']
elif '/description' in op['path']:
self['description'] = op['value']
elif '/enabled' in op['path']:
self['enabled'] = self._str2bool(op['value'])
else:
raise Exception(
'Patch operation "%s" not implemented' % op['op'])
return self
def _str2bool(self, v):
return v.lower() == "true"
class UsagePlan(BaseModel, dict): class UsagePlan(BaseModel, dict):
@ -599,6 +618,10 @@ class APIGatewayBackend(BaseBackend):
def get_apikey(self, api_key_id): def get_apikey(self, api_key_id):
return self.keys[api_key_id] return self.keys[api_key_id]
def update_apikey(self, api_key_id, patch_operations):
key = self.keys[api_key_id]
return key.update_operations(patch_operations)
def delete_apikey(self, api_key_id): def delete_apikey(self, api_key_id):
self.keys.pop(api_key_id) self.keys.pop(api_key_id)
return {} return {}

View File

@ -245,6 +245,9 @@ class APIGatewayResponse(BaseResponse):
if self.method == 'GET': if self.method == 'GET':
apikey_response = self.backend.get_apikey(apikey) apikey_response = self.backend.get_apikey(apikey)
elif self.method == 'PATCH':
patch_operations = self._get_param('patchOperations')
apikey_response = self.backend.update_apikey(apikey, patch_operations)
elif self.method == 'DELETE': elif self.method == 'DELETE':
apikey_response = self.backend.delete_apikey(apikey) apikey_response = self.backend.delete_apikey(apikey)
return 200, {}, json.dumps(apikey_response) return 200, {}, json.dumps(apikey_response)

View File

@ -1,9 +1,10 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import six import six
import random import random
import string
def create_id(): def create_id():
size = 10 size = 10
chars = list(range(10)) + ['A-Z'] chars = list(range(10)) + list(string.ascii_lowercase)
return ''.join(six.text_type(random.choice(chars)) for x in range(size)) return ''.join(six.text_type(random.choice(chars)) for x in range(size))

View File

@ -13,3 +13,12 @@ class ResourceContentionError(RESTError):
super(ResourceContentionError, self).__init__( super(ResourceContentionError, self).__init__(
"ResourceContentionError", "ResourceContentionError",
"You already have a pending update to an Auto Scaling resource (for example, a group, instance, or load balancer).") "You already have a pending update to an Auto Scaling resource (for example, a group, instance, or load balancer).")
class InvalidInstanceError(AutoscalingClientError):
def __init__(self, instance_id):
super(InvalidInstanceError, self).__init__(
"ValidationError",
"Instance [{0}] is invalid."
.format(instance_id))

View File

@ -3,6 +3,8 @@ from __future__ import unicode_literals
import random import random
from boto.ec2.blockdevicemapping import BlockDeviceType, BlockDeviceMapping from boto.ec2.blockdevicemapping import BlockDeviceType, BlockDeviceMapping
from moto.ec2.exceptions import InvalidInstanceIdError
from moto.compat import OrderedDict from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.ec2 import ec2_backends from moto.ec2 import ec2_backends
@ -10,7 +12,7 @@ from moto.elb import elb_backends
from moto.elbv2 import elbv2_backends from moto.elbv2 import elbv2_backends
from moto.elb.exceptions import LoadBalancerNotFoundError from moto.elb.exceptions import LoadBalancerNotFoundError
from .exceptions import ( from .exceptions import (
AutoscalingClientError, ResourceContentionError, AutoscalingClientError, ResourceContentionError, InvalidInstanceError
) )
# http://docs.aws.amazon.com/AutoScaling/latest/DeveloperGuide/AS_Concepts.html#Cooldown # http://docs.aws.amazon.com/AutoScaling/latest/DeveloperGuide/AS_Concepts.html#Cooldown
@ -73,6 +75,26 @@ class FakeLaunchConfiguration(BaseModel):
self.associate_public_ip_address = associate_public_ip_address self.associate_public_ip_address = associate_public_ip_address
self.block_device_mapping_dict = block_device_mapping_dict self.block_device_mapping_dict = block_device_mapping_dict
@classmethod
def create_from_instance(cls, name, instance, backend):
config = backend.create_launch_configuration(
name=name,
image_id=instance.image_id,
kernel_id='',
ramdisk_id='',
key_name=instance.key_name,
security_groups=instance.security_groups,
user_data=instance.user_data,
instance_type=instance.instance_type,
instance_monitoring=False,
instance_profile_name=None,
spot_price=None,
ebs_optimized=instance.ebs_optimized,
associate_public_ip_address=instance.associate_public_ip,
block_device_mappings=instance.block_device_mapping
)
return config
@classmethod @classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
properties = cloudformation_json['Properties'] properties = cloudformation_json['Properties']
@ -279,6 +301,12 @@ class FakeAutoScalingGroup(BaseModel):
if min_size is not None: if min_size is not None:
self.min_size = min_size self.min_size = min_size
if desired_capacity is None:
if min_size is not None and min_size > len(self.instance_states):
desired_capacity = min_size
if max_size is not None and max_size < len(self.instance_states):
desired_capacity = max_size
if launch_config_name: if launch_config_name:
self.launch_config = self.autoscaling_backend.launch_configurations[ self.launch_config = self.autoscaling_backend.launch_configurations[
launch_config_name] launch_config_name]
@ -414,7 +442,8 @@ class AutoScalingBackend(BaseBackend):
health_check_type, load_balancers, health_check_type, load_balancers,
target_group_arns, placement_group, target_group_arns, placement_group,
termination_policies, tags, termination_policies, tags,
new_instances_protected_from_scale_in=False): new_instances_protected_from_scale_in=False,
instance_id=None):
def make_int(value): def make_int(value):
return int(value) if value is not None else value return int(value) if value is not None else value
@ -427,6 +456,13 @@ class AutoScalingBackend(BaseBackend):
health_check_period = 300 health_check_period = 300
else: else:
health_check_period = make_int(health_check_period) health_check_period = make_int(health_check_period)
if launch_config_name is None and instance_id is not None:
try:
instance = self.ec2_backend.get_instance(instance_id)
launch_config_name = name
FakeLaunchConfiguration.create_from_instance(launch_config_name, instance, self)
except InvalidInstanceIdError:
raise InvalidInstanceError(instance_id)
group = FakeAutoScalingGroup( group = FakeAutoScalingGroup(
name=name, name=name,
@ -684,6 +720,18 @@ class AutoScalingBackend(BaseBackend):
for instance in protected_instances: for instance in protected_instances:
instance.protected_from_scale_in = protected_from_scale_in instance.protected_from_scale_in = protected_from_scale_in
def notify_terminate_instances(self, instance_ids):
for autoscaling_group_name, autoscaling_group in self.autoscaling_groups.items():
original_instance_count = len(autoscaling_group.instance_states)
autoscaling_group.instance_states = list(filter(
lambda i_state: i_state.instance.id not in instance_ids,
autoscaling_group.instance_states
))
difference = original_instance_count - len(autoscaling_group.instance_states)
if difference > 0:
autoscaling_group.replace_autoscaling_group_instances(difference, autoscaling_group.get_propagated_tags())
self.update_attached_elbs(autoscaling_group_name)
autoscaling_backends = {} autoscaling_backends = {}
for region, ec2_backend in ec2_backends.items(): for region, ec2_backend in ec2_backends.items():

View File

@ -48,7 +48,7 @@ class AutoScalingResponse(BaseResponse):
start = all_names.index(marker) + 1 start = all_names.index(marker) + 1
else: else:
start = 0 start = 0
max_records = self._get_param('MaxRecords', 50) # the default is 100, but using 50 to make testing easier max_records = self._get_int_param('MaxRecords', 50) # the default is 100, but using 50 to make testing easier
launch_configurations_resp = all_launch_configurations[start:start + max_records] launch_configurations_resp = all_launch_configurations[start:start + max_records]
next_token = None next_token = None
if len(all_launch_configurations) > start + max_records: if len(all_launch_configurations) > start + max_records:
@ -74,6 +74,7 @@ class AutoScalingResponse(BaseResponse):
desired_capacity=self._get_int_param('DesiredCapacity'), desired_capacity=self._get_int_param('DesiredCapacity'),
max_size=self._get_int_param('MaxSize'), max_size=self._get_int_param('MaxSize'),
min_size=self._get_int_param('MinSize'), min_size=self._get_int_param('MinSize'),
instance_id=self._get_param('InstanceId'),
launch_config_name=self._get_param('LaunchConfigurationName'), launch_config_name=self._get_param('LaunchConfigurationName'),
vpc_zone_identifier=self._get_param('VPCZoneIdentifier'), vpc_zone_identifier=self._get_param('VPCZoneIdentifier'),
default_cooldown=self._get_int_param('DefaultCooldown'), default_cooldown=self._get_int_param('DefaultCooldown'),

View File

@ -1,6 +1,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import base64 import base64
import time
from collections import defaultdict from collections import defaultdict
import copy import copy
import datetime import datetime
@ -31,6 +32,7 @@ from moto.logs.models import logs_backends
from moto.s3.exceptions import MissingBucket, MissingKey from moto.s3.exceptions import MissingBucket, MissingKey
from moto import settings from moto import settings
from .utils import make_function_arn, make_function_ver_arn from .utils import make_function_arn, make_function_ver_arn
from moto.sqs import sqs_backends
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -231,6 +233,10 @@ class LambdaFunction(BaseModel):
config.update({"VpcId": "vpc-123abc"}) config.update({"VpcId": "vpc-123abc"})
return config return config
@property
def physical_resource_id(self):
return self.function_name
def __repr__(self): def __repr__(self):
return json.dumps(self.get_configuration()) return json.dumps(self.get_configuration())
@ -425,24 +431,59 @@ class LambdaFunction(BaseModel):
class EventSourceMapping(BaseModel): class EventSourceMapping(BaseModel):
def __init__(self, spec): def __init__(self, spec):
# required # required
self.function_name = spec['FunctionName'] self.function_arn = spec['FunctionArn']
self.event_source_arn = spec['EventSourceArn'] self.event_source_arn = spec['EventSourceArn']
self.starting_position = spec['StartingPosition'] self.uuid = str(uuid.uuid4())
self.last_modified = time.mktime(datetime.datetime.utcnow().timetuple())
# BatchSize service default/max mapping
batch_size_map = {
'kinesis': (100, 10000),
'dynamodb': (100, 1000),
'sqs': (10, 10),
}
source_type = self.event_source_arn.split(":")[2].lower()
batch_size_entry = batch_size_map.get(source_type)
if batch_size_entry:
# Use service default if not provided
batch_size = int(spec.get('BatchSize', batch_size_entry[0]))
if batch_size > batch_size_entry[1]:
raise ValueError("InvalidParameterValueException",
"BatchSize {} exceeds the max of {}".format(batch_size, batch_size_entry[1]))
else:
self.batch_size = batch_size
else:
raise ValueError("InvalidParameterValueException",
"Unsupported event source type")
# optional # optional
self.batch_size = spec.get('BatchSize', 100) self.starting_position = spec.get('StartingPosition', 'TRIM_HORIZON')
self.enabled = spec.get('Enabled', True) self.enabled = spec.get('Enabled', True)
self.starting_position_timestamp = spec.get('StartingPositionTimestamp', self.starting_position_timestamp = spec.get('StartingPositionTimestamp',
None) None)
def get_configuration(self):
return {
'UUID': self.uuid,
'BatchSize': self.batch_size,
'EventSourceArn': self.event_source_arn,
'FunctionArn': self.function_arn,
'LastModified': self.last_modified,
'LastProcessingResult': '',
'State': 'Enabled' if self.enabled else 'Disabled',
'StateTransitionReason': 'User initiated'
}
@classmethod @classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, def create_from_cloudformation_json(cls, resource_name, cloudformation_json,
region_name): region_name):
properties = cloudformation_json['Properties'] properties = cloudformation_json['Properties']
func = lambda_backends[region_name].get_function(properties['FunctionName'])
spec = { spec = {
'FunctionName': properties['FunctionName'], 'FunctionArn': func.function_arn,
'EventSourceArn': properties['EventSourceArn'], 'EventSourceArn': properties['EventSourceArn'],
'StartingPosition': properties['StartingPosition'] 'StartingPosition': properties['StartingPosition'],
'BatchSize': properties.get('BatchSize', 100)
} }
optional_properties = 'BatchSize Enabled StartingPositionTimestamp'.split() optional_properties = 'BatchSize Enabled StartingPositionTimestamp'.split()
for prop in optional_properties: for prop in optional_properties:
@ -462,8 +503,10 @@ class LambdaVersion(BaseModel):
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, def create_from_cloudformation_json(cls, resource_name, cloudformation_json,
region_name): region_name):
properties = cloudformation_json['Properties'] properties = cloudformation_json['Properties']
function_name = properties['FunctionName']
func = lambda_backends[region_name].publish_function(function_name)
spec = { spec = {
'Version': properties.get('Version') 'Version': func.version
} }
return LambdaVersion(spec) return LambdaVersion(spec)
@ -511,6 +554,9 @@ class LambdaStorage(object):
def get_arn(self, arn): def get_arn(self, arn):
return self._arns.get(arn, None) return self._arns.get(arn, None)
def get_function_by_name_or_arn(self, input):
return self.get_function(input) or self.get_arn(input)
def put_function(self, fn): def put_function(self, fn):
""" """
:param fn: Function :param fn: Function
@ -592,6 +638,7 @@ class LambdaStorage(object):
class LambdaBackend(BaseBackend): class LambdaBackend(BaseBackend):
def __init__(self, region_name): def __init__(self, region_name):
self._lambdas = LambdaStorage() self._lambdas = LambdaStorage()
self._event_source_mappings = {}
self.region_name = region_name self.region_name = region_name
def reset(self): def reset(self):
@ -613,6 +660,40 @@ class LambdaBackend(BaseBackend):
fn.version = ver.version fn.version = ver.version
return fn return fn
def create_event_source_mapping(self, spec):
required = [
'EventSourceArn',
'FunctionName',
]
for param in required:
if not spec.get(param):
raise RESTError('InvalidParameterValueException', 'Missing {}'.format(param))
# Validate function name
func = self._lambdas.get_function_by_name_or_arn(spec.pop('FunctionName', ''))
if not func:
raise RESTError('ResourceNotFoundException', 'Invalid FunctionName')
# Validate queue
for queue in sqs_backends[self.region_name].queues.values():
if queue.queue_arn == spec['EventSourceArn']:
if queue.lambda_event_source_mappings.get('func.function_arn'):
# TODO: Correct exception?
raise RESTError('ResourceConflictException', 'The resource already exists.')
if queue.fifo_queue:
raise RESTError('InvalidParameterValueException',
'{} is FIFO'.format(queue.queue_arn))
else:
spec.update({'FunctionArn': func.function_arn})
esm = EventSourceMapping(spec)
self._event_source_mappings[esm.uuid] = esm
# Set backend function on queue
queue.lambda_event_source_mappings[esm.function_arn] = esm
return esm
raise RESTError('ResourceNotFoundException', 'Invalid EventSourceArn')
def publish_function(self, function_name): def publish_function(self, function_name):
return self._lambdas.publish_function(function_name) return self._lambdas.publish_function(function_name)
@ -622,6 +703,33 @@ class LambdaBackend(BaseBackend):
def list_versions_by_function(self, function_name): def list_versions_by_function(self, function_name):
return self._lambdas.list_versions_by_function(function_name) return self._lambdas.list_versions_by_function(function_name)
def get_event_source_mapping(self, uuid):
return self._event_source_mappings.get(uuid)
def delete_event_source_mapping(self, uuid):
return self._event_source_mappings.pop(uuid)
def update_event_source_mapping(self, uuid, spec):
esm = self.get_event_source_mapping(uuid)
if esm:
if spec.get('FunctionName'):
func = self._lambdas.get_function_by_name_or_arn(spec.get('FunctionName'))
esm.function_arn = func.function_arn
if 'BatchSize' in spec:
esm.batch_size = spec['BatchSize']
if 'Enabled' in spec:
esm.enabled = spec['Enabled']
return esm
return False
def list_event_source_mappings(self, event_source_arn, function_name):
esms = list(self._event_source_mappings.values())
if event_source_arn:
esms = list(filter(lambda x: x.event_source_arn == event_source_arn, esms))
if function_name:
esms = list(filter(lambda x: x.function_name == function_name, esms))
return esms
def get_function_by_arn(self, function_arn): def get_function_by_arn(self, function_arn):
return self._lambdas.get_arn(function_arn) return self._lambdas.get_arn(function_arn)
@ -631,7 +739,43 @@ class LambdaBackend(BaseBackend):
def list_functions(self): def list_functions(self):
return self._lambdas.all() return self._lambdas.all()
def send_message(self, function_name, message, subject=None, qualifier=None): def send_sqs_batch(self, function_arn, messages, queue_arn):
success = True
for message in messages:
func = self.get_function_by_arn(function_arn)
result = self._send_sqs_message(func, message, queue_arn)
if not result:
success = False
return success
def _send_sqs_message(self, func, message, queue_arn):
event = {
"Records": [
{
"messageId": message.id,
"receiptHandle": message.receipt_handle,
"body": message.body,
"attributes": {
"ApproximateReceiveCount": "1",
"SentTimestamp": "1545082649183",
"SenderId": "AIDAIENQZJOLO23YVJ4VO",
"ApproximateFirstReceiveTimestamp": "1545082649185"
},
"messageAttributes": {},
"md5OfBody": "098f6bcd4621d373cade4e832627b4f6",
"eventSource": "aws:sqs",
"eventSourceARN": queue_arn,
"awsRegion": self.region_name
}
]
}
request_headers = {}
response_headers = {}
func.invoke(json.dumps(event), request_headers, response_headers)
return 'x-amz-function-error' not in response_headers
def send_sns_message(self, function_name, message, subject=None, qualifier=None):
event = { event = {
"Records": [ "Records": [
{ {

View File

@ -39,6 +39,31 @@ class LambdaResponse(BaseResponse):
else: else:
raise ValueError("Cannot handle request") raise ValueError("Cannot handle request")
def event_source_mappings(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == 'GET':
querystring = self.querystring
event_source_arn = querystring.get('EventSourceArn', [None])[0]
function_name = querystring.get('FunctionName', [None])[0]
return self._list_event_source_mappings(event_source_arn, function_name)
elif request.method == 'POST':
return self._create_event_source_mapping(request, full_url, headers)
else:
raise ValueError("Cannot handle request")
def event_source_mapping(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
path = request.path if hasattr(request, 'path') else path_url(request.url)
uuid = path.split('/')[-1]
if request.method == 'GET':
return self._get_event_source_mapping(uuid)
elif request.method == 'PUT':
return self._update_event_source_mapping(uuid)
elif request.method == 'DELETE':
return self._delete_event_source_mapping(uuid)
else:
raise ValueError("Cannot handle request")
def function(self, request, full_url, headers): def function(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == 'GET': if request.method == 'GET':
@ -177,6 +202,45 @@ class LambdaResponse(BaseResponse):
config = fn.get_configuration() config = fn.get_configuration()
return 201, {}, json.dumps(config) return 201, {}, json.dumps(config)
def _create_event_source_mapping(self, request, full_url, headers):
try:
fn = self.lambda_backend.create_event_source_mapping(self.json_body)
except ValueError as e:
return 400, {}, json.dumps({"Error": {"Code": e.args[0], "Message": e.args[1]}})
else:
config = fn.get_configuration()
return 201, {}, json.dumps(config)
def _list_event_source_mappings(self, event_source_arn, function_name):
esms = self.lambda_backend.list_event_source_mappings(event_source_arn, function_name)
result = {
'EventSourceMappings': [esm.get_configuration() for esm in esms]
}
return 200, {}, json.dumps(result)
def _get_event_source_mapping(self, uuid):
result = self.lambda_backend.get_event_source_mapping(uuid)
if result:
return 200, {}, json.dumps(result.get_configuration())
else:
return 404, {}, "{}"
def _update_event_source_mapping(self, uuid):
result = self.lambda_backend.update_event_source_mapping(uuid, self.json_body)
if result:
return 202, {}, json.dumps(result.get_configuration())
else:
return 404, {}, "{}"
def _delete_event_source_mapping(self, uuid):
esm = self.lambda_backend.delete_event_source_mapping(uuid)
if esm:
json_result = esm.get_configuration()
json_result.update({'State': 'Deleting'})
return 202, {}, json.dumps(json_result)
else:
return 404, {}, "{}"
def _publish_function(self, request, full_url, headers): def _publish_function(self, request, full_url, headers):
function_name = self.path.rsplit('/', 2)[-2] function_name = self.path.rsplit('/', 2)[-2]

View File

@ -11,6 +11,8 @@ url_paths = {
'{0}/(?P<api_version>[^/]+)/functions/?$': response.root, '{0}/(?P<api_version>[^/]+)/functions/?$': response.root,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/?$': response.function, r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/?$': response.function,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/versions/?$': response.versions, r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/versions/?$': response.versions,
r'{0}/(?P<api_version>[^/]+)/event-source-mappings/?$': response.event_source_mappings,
r'{0}/(?P<api_version>[^/]+)/event-source-mappings/(?P<UUID>[\w_-]+)/?$': response.event_source_mapping,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invocations/?$': response.invoke, r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invocations/?$': response.invoke,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invoke-async/?$': response.invoke_async, r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invoke-async/?$': response.invoke_async,
r'{0}/(?P<api_version>[^/]+)/tags/(?P<resource_arn>.+)': response.tag, r'{0}/(?P<api_version>[^/]+)/tags/(?P<resource_arn>.+)': response.tag,

View File

@ -514,10 +514,13 @@ class BatchBackend(BaseBackend):
return self._job_definitions.get(arn) return self._job_definitions.get(arn)
def get_job_definition_by_name(self, name): def get_job_definition_by_name(self, name):
for comp_env in self._job_definitions.values(): latest_revision = -1
if comp_env.name == name: latest_job = None
return comp_env for job_def in self._job_definitions.values():
return None if job_def.name == name and job_def.revision > latest_revision:
latest_job = job_def
latest_revision = job_def.revision
return latest_job
def get_job_definition_by_name_revision(self, name, revision): def get_job_definition_by_name_revision(self, name, revision):
for job_def in self._job_definitions.values(): for job_def in self._job_definitions.values():
@ -534,10 +537,13 @@ class BatchBackend(BaseBackend):
:return: Job definition or None :return: Job definition or None
:rtype: JobDefinition or None :rtype: JobDefinition or None
""" """
env = self.get_job_definition_by_arn(identifier) job_def = self.get_job_definition_by_arn(identifier)
if env is None: if job_def is None:
env = self.get_job_definition_by_name(identifier) if ':' in identifier:
return env job_def = self.get_job_definition_by_name_revision(*identifier.split(':', 1))
else:
job_def = self.get_job_definition_by_name(identifier)
return job_def
def get_job_definitions(self, identifier): def get_job_definitions(self, identifier):
""" """
@ -984,9 +990,7 @@ class BatchBackend(BaseBackend):
# TODO parameters, retries (which is a dict raw from request), job dependancies and container overrides are ignored for now # TODO parameters, retries (which is a dict raw from request), job dependancies and container overrides are ignored for now
# Look for job definition # Look for job definition
job_def = self.get_job_definition_by_arn(job_def_id) job_def = self.get_job_definition(job_def_id)
if job_def is None and ':' in job_def_id:
job_def = self.get_job_definition_by_name_revision(*job_def_id.split(':', 1))
if job_def is None: if job_def is None:
raise ClientException('Job definition {0} does not exist'.format(job_def_id)) raise ClientException('Job definition {0} does not exist'.format(job_def_id))

View File

@ -246,7 +246,8 @@ def resource_name_property_from_type(resource_type):
def generate_resource_name(resource_type, stack_name, logical_id): def generate_resource_name(resource_type, stack_name, logical_id):
if resource_type == "AWS::ElasticLoadBalancingV2::TargetGroup": if resource_type in ["AWS::ElasticLoadBalancingV2::TargetGroup",
"AWS::ElasticLoadBalancingV2::LoadBalancer"]:
# Target group names need to be less than 32 characters, so when cloudformation creates a name for you # Target group names need to be less than 32 characters, so when cloudformation creates a name for you
# it makes sure to stay under that limit # it makes sure to stay under that limit
name_prefix = '{0}-{1}'.format(stack_name, logical_id) name_prefix = '{0}-{1}'.format(stack_name, logical_id)

View File

@ -4,6 +4,7 @@ import six
import random import random
import yaml import yaml
import os import os
import string
from cfnlint import decode, core from cfnlint import decode, core
@ -29,7 +30,7 @@ def generate_stackset_arn(stackset_id, region_name):
def random_suffix(): def random_suffix():
size = 12 size = 12
chars = list(range(10)) + ['A-Z'] chars = list(range(10)) + list(string.ascii_uppercase)
return ''.join(six.text_type(random.choice(chars)) for x in range(size)) return ''.join(six.text_type(random.choice(chars)) for x in range(size))

View File

@ -275,7 +275,7 @@ GET_METRIC_STATISTICS_TEMPLATE = """<GetMetricStatisticsResponse xmlns="http://m
<Label>{{ label }}</Label> <Label>{{ label }}</Label>
<Datapoints> <Datapoints>
{% for datapoint in datapoints %} {% for datapoint in datapoints %}
<Datapoint> <member>
{% if datapoint.sum is not none %} {% if datapoint.sum is not none %}
<Sum>{{ datapoint.sum }}</Sum> <Sum>{{ datapoint.sum }}</Sum>
{% endif %} {% endif %}
@ -302,7 +302,7 @@ GET_METRIC_STATISTICS_TEMPLATE = """<GetMetricStatisticsResponse xmlns="http://m
<Timestamp>{{ datapoint.timestamp }}</Timestamp> <Timestamp>{{ datapoint.timestamp }}</Timestamp>
<Unit>{{ datapoint.unit }}</Unit> <Unit>{{ datapoint.unit }}</Unit>
</Datapoint> </member>
{% endfor %} {% endfor %}
</Datapoints> </Datapoints>
</GetMetricStatisticsResult> </GetMetricStatisticsResult>

View File

@ -95,6 +95,15 @@ class CognitoIdentityBackend(BaseBackend):
}) })
return response return response
def get_open_id_token(self, identity_id):
response = json.dumps(
{
"IdentityId": identity_id,
"Token": get_random_identity_id(self.region)
}
)
return response
cognitoidentity_backends = {} cognitoidentity_backends = {}
for region in boto.cognito.identity.regions(): for region in boto.cognito.identity.regions():

View File

@ -35,3 +35,8 @@ class CognitoIdentityResponse(BaseResponse):
return cognitoidentity_backends[self.region].get_open_id_token_for_developer_identity( return cognitoidentity_backends[self.region].get_open_id_token_for_developer_identity(
self._get_param('IdentityId') or get_random_identity_id(self.region) self._get_param('IdentityId') or get_random_identity_id(self.region)
) )
def get_open_id_token(self):
return cognitoidentity_backends[self.region].get_open_id_token(
self._get_param("IdentityId") or get_random_identity_id(self.region)
)

View File

@ -2,6 +2,7 @@ from __future__ import unicode_literals
import datetime import datetime
import functools import functools
import hashlib
import itertools import itertools
import json import json
import os import os
@ -154,20 +155,37 @@ class CognitoIdpUserPool(BaseModel):
class CognitoIdpUserPoolDomain(BaseModel): class CognitoIdpUserPoolDomain(BaseModel):
def __init__(self, user_pool_id, domain): def __init__(self, user_pool_id, domain, custom_domain_config=None):
self.user_pool_id = user_pool_id self.user_pool_id = user_pool_id
self.domain = domain self.domain = domain
self.custom_domain_config = custom_domain_config or {}
def to_json(self): def _distribution_name(self):
return { if self.custom_domain_config and \
"UserPoolId": self.user_pool_id, 'CertificateArn' in self.custom_domain_config:
"AWSAccountId": str(uuid.uuid4()), hash = hashlib.md5(
"CloudFrontDistribution": None, self.custom_domain_config['CertificateArn'].encode('utf-8')
"Domain": self.domain, ).hexdigest()
"S3Bucket": None, return "{hash}.cloudfront.net".format(hash=hash[:16])
"Status": "ACTIVE", return None
"Version": None,
} def to_json(self, extended=True):
distribution = self._distribution_name()
if extended:
return {
"UserPoolId": self.user_pool_id,
"AWSAccountId": str(uuid.uuid4()),
"CloudFrontDistribution": distribution,
"Domain": self.domain,
"S3Bucket": None,
"Status": "ACTIVE",
"Version": None,
}
elif distribution:
return {
"CloudFrontDomain": distribution,
}
return None
class CognitoIdpUserPoolClient(BaseModel): class CognitoIdpUserPoolClient(BaseModel):
@ -338,11 +356,13 @@ class CognitoIdpBackend(BaseBackend):
del self.user_pools[user_pool_id] del self.user_pools[user_pool_id]
# User pool domain # User pool domain
def create_user_pool_domain(self, user_pool_id, domain): def create_user_pool_domain(self, user_pool_id, domain, custom_domain_config=None):
if user_pool_id not in self.user_pools: if user_pool_id not in self.user_pools:
raise ResourceNotFoundError(user_pool_id) raise ResourceNotFoundError(user_pool_id)
user_pool_domain = CognitoIdpUserPoolDomain(user_pool_id, domain) user_pool_domain = CognitoIdpUserPoolDomain(
user_pool_id, domain, custom_domain_config=custom_domain_config
)
self.user_pool_domains[domain] = user_pool_domain self.user_pool_domains[domain] = user_pool_domain
return user_pool_domain return user_pool_domain
@ -358,6 +378,14 @@ class CognitoIdpBackend(BaseBackend):
del self.user_pool_domains[domain] del self.user_pool_domains[domain]
def update_user_pool_domain(self, domain, custom_domain_config):
if domain not in self.user_pool_domains:
raise ResourceNotFoundError(domain)
user_pool_domain = self.user_pool_domains[domain]
user_pool_domain.custom_domain_config = custom_domain_config
return user_pool_domain
# User pool client # User pool client
def create_user_pool_client(self, user_pool_id, extended_config): def create_user_pool_client(self, user_pool_id, extended_config):
user_pool = self.user_pools.get(user_pool_id) user_pool = self.user_pools.get(user_pool_id)

View File

@ -50,7 +50,13 @@ class CognitoIdpResponse(BaseResponse):
def create_user_pool_domain(self): def create_user_pool_domain(self):
domain = self._get_param("Domain") domain = self._get_param("Domain")
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
cognitoidp_backends[self.region].create_user_pool_domain(user_pool_id, domain) custom_domain_config = self._get_param("CustomDomainConfig")
user_pool_domain = cognitoidp_backends[self.region].create_user_pool_domain(
user_pool_id, domain, custom_domain_config
)
domain_description = user_pool_domain.to_json(extended=False)
if domain_description:
return json.dumps(domain_description)
return "" return ""
def describe_user_pool_domain(self): def describe_user_pool_domain(self):
@ -69,6 +75,17 @@ class CognitoIdpResponse(BaseResponse):
cognitoidp_backends[self.region].delete_user_pool_domain(domain) cognitoidp_backends[self.region].delete_user_pool_domain(domain)
return "" return ""
def update_user_pool_domain(self):
domain = self._get_param("Domain")
custom_domain_config = self._get_param("CustomDomainConfig")
user_pool_domain = cognitoidp_backends[self.region].update_user_pool_domain(
domain, custom_domain_config
)
domain_description = user_pool_domain.to_json(extended=False)
if domain_description:
return json.dumps(domain_description)
return ""
# User pool client # User pool client
def create_user_pool_client(self): def create_user_pool_client(self):
user_pool_id = self.parameters.pop("UserPoolId") user_pool_id = self.parameters.pop("UserPoolId")

View File

@ -52,6 +52,18 @@ class InvalidResourceTypeException(JsonRESTError):
super(InvalidResourceTypeException, self).__init__("ValidationException", message) super(InvalidResourceTypeException, self).__init__("ValidationException", message)
class NoSuchConfigurationAggregatorException(JsonRESTError):
code = 400
def __init__(self, number=1):
if number == 1:
message = 'The configuration aggregator does not exist. Check the configuration aggregator name and try again.'
else:
message = 'At least one of the configuration aggregators does not exist. Check the configuration aggregator' \
' names and try again.'
super(NoSuchConfigurationAggregatorException, self).__init__("NoSuchConfigurationAggregatorException", message)
class NoSuchConfigurationRecorderException(JsonRESTError): class NoSuchConfigurationRecorderException(JsonRESTError):
code = 400 code = 400
@ -78,6 +90,14 @@ class NoSuchBucketException(JsonRESTError):
super(NoSuchBucketException, self).__init__("NoSuchBucketException", message) super(NoSuchBucketException, self).__init__("NoSuchBucketException", message)
class InvalidNextTokenException(JsonRESTError):
code = 400
def __init__(self):
message = 'The nextToken provided is invalid'
super(InvalidNextTokenException, self).__init__("InvalidNextTokenException", message)
class InvalidS3KeyPrefixException(JsonRESTError): class InvalidS3KeyPrefixException(JsonRESTError):
code = 400 code = 400
@ -147,3 +167,66 @@ class LastDeliveryChannelDeleteFailedException(JsonRESTError):
message = 'Failed to delete last specified delivery channel with name \'{name}\', because there, ' \ message = 'Failed to delete last specified delivery channel with name \'{name}\', because there, ' \
'because there is a running configuration recorder.'.format(name=name) 'because there is a running configuration recorder.'.format(name=name)
super(LastDeliveryChannelDeleteFailedException, self).__init__("LastDeliveryChannelDeleteFailedException", message) super(LastDeliveryChannelDeleteFailedException, self).__init__("LastDeliveryChannelDeleteFailedException", message)
class TooManyAccountSources(JsonRESTError):
code = 400
def __init__(self, length):
locations = ['com.amazonaws.xyz'] * length
message = 'Value \'[{locations}]\' at \'accountAggregationSources\' failed to satisfy constraint: ' \
'Member must have length less than or equal to 1'.format(locations=', '.join(locations))
super(TooManyAccountSources, self).__init__("ValidationException", message)
class DuplicateTags(JsonRESTError):
code = 400
def __init__(self):
super(DuplicateTags, self).__init__(
'InvalidInput', 'Duplicate tag keys found. Please note that Tag keys are case insensitive.')
class TagKeyTooBig(JsonRESTError):
code = 400
def __init__(self, tag, param='tags.X.member.key'):
super(TagKeyTooBig, self).__init__(
'ValidationException', "1 validation error detected: Value '{}' at '{}' failed to satisfy "
"constraint: Member must have length less than or equal to 128".format(tag, param))
class TagValueTooBig(JsonRESTError):
code = 400
def __init__(self, tag):
super(TagValueTooBig, self).__init__(
'ValidationException', "1 validation error detected: Value '{}' at 'tags.X.member.value' failed to satisfy "
"constraint: Member must have length less than or equal to 256".format(tag))
class InvalidParameterValueException(JsonRESTError):
code = 400
def __init__(self, message):
super(InvalidParameterValueException, self).__init__('InvalidParameterValueException', message)
class InvalidTagCharacters(JsonRESTError):
code = 400
def __init__(self, tag, param='tags.X.member.key'):
message = "1 validation error detected: Value '{}' at '{}' failed to satisfy ".format(tag, param)
message += 'constraint: Member must satisfy regular expression pattern: [\\\\p{L}\\\\p{Z}\\\\p{N}_.:/=+\\\\-@]+'
super(InvalidTagCharacters, self).__init__('ValidationException', message)
class TooManyTags(JsonRESTError):
code = 400
def __init__(self, tags, param='tags'):
super(TooManyTags, self).__init__(
'ValidationException', "1 validation error detected: Value '{}' at '{}' failed to satisfy "
"constraint: Member must have length less than or equal to 50.".format(tags, param))

View File

@ -1,6 +1,9 @@
import json import json
import re
import time import time
import pkg_resources import pkg_resources
import random
import string
from datetime import datetime from datetime import datetime
@ -12,37 +15,125 @@ from moto.config.exceptions import InvalidResourceTypeException, InvalidDelivery
NoSuchConfigurationRecorderException, NoAvailableConfigurationRecorderException, \ NoSuchConfigurationRecorderException, NoAvailableConfigurationRecorderException, \
InvalidDeliveryChannelNameException, NoSuchBucketException, InvalidS3KeyPrefixException, \ InvalidDeliveryChannelNameException, NoSuchBucketException, InvalidS3KeyPrefixException, \
InvalidSNSTopicARNException, MaxNumberOfDeliveryChannelsExceededException, NoAvailableDeliveryChannelException, \ InvalidSNSTopicARNException, MaxNumberOfDeliveryChannelsExceededException, NoAvailableDeliveryChannelException, \
NoSuchDeliveryChannelException, LastDeliveryChannelDeleteFailedException NoSuchDeliveryChannelException, LastDeliveryChannelDeleteFailedException, TagKeyTooBig, \
TooManyTags, TagValueTooBig, TooManyAccountSources, InvalidParameterValueException, InvalidNextTokenException, \
NoSuchConfigurationAggregatorException, InvalidTagCharacters, DuplicateTags
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
DEFAULT_ACCOUNT_ID = 123456789012 DEFAULT_ACCOUNT_ID = 123456789012
POP_STRINGS = [
'capitalizeStart',
'CapitalizeStart',
'capitalizeArn',
'CapitalizeArn',
'capitalizeARN',
'CapitalizeARN'
]
DEFAULT_PAGE_SIZE = 100
def datetime2int(date): def datetime2int(date):
return int(time.mktime(date.timetuple())) return int(time.mktime(date.timetuple()))
def snake_to_camels(original): def snake_to_camels(original, cap_start, cap_arn):
parts = original.split('_') parts = original.split('_')
camel_cased = parts[0].lower() + ''.join(p.title() for p in parts[1:]) camel_cased = parts[0].lower() + ''.join(p.title() for p in parts[1:])
camel_cased = camel_cased.replace('Arn', 'ARN') # Config uses 'ARN' instead of 'Arn'
if cap_arn:
camel_cased = camel_cased.replace('Arn', 'ARN') # Some config services use 'ARN' instead of 'Arn'
if cap_start:
camel_cased = camel_cased[0].upper() + camel_cased[1::]
return camel_cased return camel_cased
def random_string():
"""Returns a random set of 8 lowercase letters for the Config Aggregator ARN"""
chars = []
for x in range(0, 8):
chars.append(random.choice(string.ascii_lowercase))
return "".join(chars)
def validate_tag_key(tag_key, exception_param='tags.X.member.key'):
"""Validates the tag key.
:param tag_key: The tag key to check against.
:param exception_param: The exception parameter to send over to help format the message. This is to reflect
the difference between the tag and untag APIs.
:return:
"""
# Validate that the key length is correct:
if len(tag_key) > 128:
raise TagKeyTooBig(tag_key, param=exception_param)
# Validate that the tag key fits the proper Regex:
# [\w\s_.:/=+\-@]+ SHOULD be the same as the Java regex on the AWS documentation: [\p{L}\p{Z}\p{N}_.:/=+\-@]+
match = re.findall(r'[\w\s_.:/=+\-@]+', tag_key)
# Kudos if you can come up with a better way of doing a global search :)
if not len(match) or len(match[0]) < len(tag_key):
raise InvalidTagCharacters(tag_key, param=exception_param)
def check_tag_duplicate(all_tags, tag_key):
"""Validates that a tag key is not a duplicate
:param all_tags: Dict to check if there is a duplicate tag.
:param tag_key: The tag key to check against.
:return:
"""
if all_tags.get(tag_key):
raise DuplicateTags()
def validate_tags(tags):
proper_tags = {}
if len(tags) > 50:
raise TooManyTags(tags)
for tag in tags:
# Validate the Key:
validate_tag_key(tag['Key'])
check_tag_duplicate(proper_tags, tag['Key'])
# Validate the Value:
if len(tag['Value']) > 256:
raise TagValueTooBig(tag['Value'])
proper_tags[tag['Key']] = tag['Value']
return proper_tags
class ConfigEmptyDictable(BaseModel): class ConfigEmptyDictable(BaseModel):
"""Base class to make serialization easy. This assumes that the sub-class will NOT return 'None's in the JSON.""" """Base class to make serialization easy. This assumes that the sub-class will NOT return 'None's in the JSON."""
def __init__(self, capitalize_start=False, capitalize_arn=True):
"""Assists with the serialization of the config object
:param capitalize_start: For some Config services, the first letter is lowercase -- for others it's capital
:param capitalize_arn: For some Config services, the API expects 'ARN' and for others, it expects 'Arn'
"""
self.capitalize_start = capitalize_start
self.capitalize_arn = capitalize_arn
def to_dict(self): def to_dict(self):
data = {} data = {}
for item, value in self.__dict__.items(): for item, value in self.__dict__.items():
if value is not None: if value is not None:
if isinstance(value, ConfigEmptyDictable): if isinstance(value, ConfigEmptyDictable):
data[snake_to_camels(item)] = value.to_dict() data[snake_to_camels(item, self.capitalize_start, self.capitalize_arn)] = value.to_dict()
else: else:
data[snake_to_camels(item)] = value data[snake_to_camels(item, self.capitalize_start, self.capitalize_arn)] = value
# Cleanse the extra properties:
for prop in POP_STRINGS:
data.pop(prop, None)
return data return data
@ -50,8 +141,9 @@ class ConfigEmptyDictable(BaseModel):
class ConfigRecorderStatus(ConfigEmptyDictable): class ConfigRecorderStatus(ConfigEmptyDictable):
def __init__(self, name): def __init__(self, name):
self.name = name super(ConfigRecorderStatus, self).__init__()
self.name = name
self.recording = False self.recording = False
self.last_start_time = None self.last_start_time = None
self.last_stop_time = None self.last_stop_time = None
@ -75,12 +167,16 @@ class ConfigRecorderStatus(ConfigEmptyDictable):
class ConfigDeliverySnapshotProperties(ConfigEmptyDictable): class ConfigDeliverySnapshotProperties(ConfigEmptyDictable):
def __init__(self, delivery_frequency): def __init__(self, delivery_frequency):
super(ConfigDeliverySnapshotProperties, self).__init__()
self.delivery_frequency = delivery_frequency self.delivery_frequency = delivery_frequency
class ConfigDeliveryChannel(ConfigEmptyDictable): class ConfigDeliveryChannel(ConfigEmptyDictable):
def __init__(self, name, s3_bucket_name, prefix=None, sns_arn=None, snapshot_properties=None): def __init__(self, name, s3_bucket_name, prefix=None, sns_arn=None, snapshot_properties=None):
super(ConfigDeliveryChannel, self).__init__()
self.name = name self.name = name
self.s3_bucket_name = s3_bucket_name self.s3_bucket_name = s3_bucket_name
self.s3_key_prefix = prefix self.s3_key_prefix = prefix
@ -91,6 +187,8 @@ class ConfigDeliveryChannel(ConfigEmptyDictable):
class RecordingGroup(ConfigEmptyDictable): class RecordingGroup(ConfigEmptyDictable):
def __init__(self, all_supported=True, include_global_resource_types=False, resource_types=None): def __init__(self, all_supported=True, include_global_resource_types=False, resource_types=None):
super(RecordingGroup, self).__init__()
self.all_supported = all_supported self.all_supported = all_supported
self.include_global_resource_types = include_global_resource_types self.include_global_resource_types = include_global_resource_types
self.resource_types = resource_types self.resource_types = resource_types
@ -99,6 +197,8 @@ class RecordingGroup(ConfigEmptyDictable):
class ConfigRecorder(ConfigEmptyDictable): class ConfigRecorder(ConfigEmptyDictable):
def __init__(self, role_arn, recording_group, name='default', status=None): def __init__(self, role_arn, recording_group, name='default', status=None):
super(ConfigRecorder, self).__init__()
self.name = name self.name = name
self.role_arn = role_arn self.role_arn = role_arn
self.recording_group = recording_group self.recording_group = recording_group
@ -109,18 +209,118 @@ class ConfigRecorder(ConfigEmptyDictable):
self.status = status self.status = status
class AccountAggregatorSource(ConfigEmptyDictable):
def __init__(self, account_ids, aws_regions=None, all_aws_regions=None):
super(AccountAggregatorSource, self).__init__(capitalize_start=True)
# Can't have both the regions and all_regions flag present -- also can't have them both missing:
if aws_regions and all_aws_regions:
raise InvalidParameterValueException('Your configuration aggregator contains a list of regions and also specifies '
'the use of all regions. You must choose one of these options.')
if not (aws_regions or all_aws_regions):
raise InvalidParameterValueException('Your request does not specify any regions. Select AWS Config-supported '
'regions and try again.')
self.account_ids = account_ids
self.aws_regions = aws_regions
if not all_aws_regions:
all_aws_regions = False
self.all_aws_regions = all_aws_regions
class OrganizationAggregationSource(ConfigEmptyDictable):
def __init__(self, role_arn, aws_regions=None, all_aws_regions=None):
super(OrganizationAggregationSource, self).__init__(capitalize_start=True, capitalize_arn=False)
# Can't have both the regions and all_regions flag present -- also can't have them both missing:
if aws_regions and all_aws_regions:
raise InvalidParameterValueException('Your configuration aggregator contains a list of regions and also specifies '
'the use of all regions. You must choose one of these options.')
if not (aws_regions or all_aws_regions):
raise InvalidParameterValueException('Your request does not specify any regions. Select AWS Config-supported '
'regions and try again.')
self.role_arn = role_arn
self.aws_regions = aws_regions
if not all_aws_regions:
all_aws_regions = False
self.all_aws_regions = all_aws_regions
class ConfigAggregator(ConfigEmptyDictable):
def __init__(self, name, region, account_sources=None, org_source=None, tags=None):
super(ConfigAggregator, self).__init__(capitalize_start=True, capitalize_arn=False)
self.configuration_aggregator_name = name
self.configuration_aggregator_arn = 'arn:aws:config:{region}:{id}:config-aggregator/config-aggregator-{random}'.format(
region=region,
id=DEFAULT_ACCOUNT_ID,
random=random_string()
)
self.account_aggregation_sources = account_sources
self.organization_aggregation_source = org_source
self.creation_time = datetime2int(datetime.utcnow())
self.last_updated_time = datetime2int(datetime.utcnow())
# Tags are listed in the list_tags_for_resource API call ... not implementing yet -- please feel free to!
self.tags = tags or {}
# Override the to_dict so that we can format the tags properly...
def to_dict(self):
result = super(ConfigAggregator, self).to_dict()
# Override the account aggregation sources if present:
if self.account_aggregation_sources:
result['AccountAggregationSources'] = [a.to_dict() for a in self.account_aggregation_sources]
# Tags are listed in the list_tags_for_resource API call ... not implementing yet -- please feel free to!
# if self.tags:
# result['Tags'] = [{'Key': key, 'Value': value} for key, value in self.tags.items()]
return result
class ConfigAggregationAuthorization(ConfigEmptyDictable):
def __init__(self, current_region, authorized_account_id, authorized_aws_region, tags=None):
super(ConfigAggregationAuthorization, self).__init__(capitalize_start=True, capitalize_arn=False)
self.aggregation_authorization_arn = 'arn:aws:config:{region}:{id}:aggregation-authorization/' \
'{auth_account}/{auth_region}'.format(region=current_region,
id=DEFAULT_ACCOUNT_ID,
auth_account=authorized_account_id,
auth_region=authorized_aws_region)
self.authorized_account_id = authorized_account_id
self.authorized_aws_region = authorized_aws_region
self.creation_time = datetime2int(datetime.utcnow())
# Tags are listed in the list_tags_for_resource API call ... not implementing yet -- please feel free to!
self.tags = tags or {}
class ConfigBackend(BaseBackend): class ConfigBackend(BaseBackend):
def __init__(self): def __init__(self):
self.recorders = {} self.recorders = {}
self.delivery_channels = {} self.delivery_channels = {}
self.config_aggregators = {}
self.aggregation_authorizations = {}
@staticmethod @staticmethod
def _validate_resource_types(resource_list): def _validate_resource_types(resource_list):
# Load the service file: # Load the service file:
resource_package = 'botocore' resource_package = 'botocore'
resource_path = '/'.join(('data', 'config', '2014-11-12', 'service-2.json')) resource_path = '/'.join(('data', 'config', '2014-11-12', 'service-2.json'))
conifg_schema = json.loads(pkg_resources.resource_string(resource_package, resource_path)) config_schema = json.loads(pkg_resources.resource_string(resource_package, resource_path))
# Verify that each entry exists in the supported list: # Verify that each entry exists in the supported list:
bad_list = [] bad_list = []
@ -128,11 +328,11 @@ class ConfigBackend(BaseBackend):
# For PY2: # For PY2:
r_str = str(resource) r_str = str(resource)
if r_str not in conifg_schema['shapes']['ResourceType']['enum']: if r_str not in config_schema['shapes']['ResourceType']['enum']:
bad_list.append(r_str) bad_list.append(r_str)
if bad_list: if bad_list:
raise InvalidResourceTypeException(bad_list, conifg_schema['shapes']['ResourceType']['enum']) raise InvalidResourceTypeException(bad_list, config_schema['shapes']['ResourceType']['enum'])
@staticmethod @staticmethod
def _validate_delivery_snapshot_properties(properties): def _validate_delivery_snapshot_properties(properties):
@ -147,6 +347,158 @@ class ConfigBackend(BaseBackend):
raise InvalidDeliveryFrequency(properties.get('deliveryFrequency', None), raise InvalidDeliveryFrequency(properties.get('deliveryFrequency', None),
conifg_schema['shapes']['MaximumExecutionFrequency']['enum']) conifg_schema['shapes']['MaximumExecutionFrequency']['enum'])
def put_configuration_aggregator(self, config_aggregator, region):
# Validate the name:
if len(config_aggregator['ConfigurationAggregatorName']) > 256:
raise NameTooLongException(config_aggregator['ConfigurationAggregatorName'], 'configurationAggregatorName')
account_sources = None
org_source = None
# Tag validation:
tags = validate_tags(config_aggregator.get('Tags', []))
# Exception if both AccountAggregationSources and OrganizationAggregationSource are supplied:
if config_aggregator.get('AccountAggregationSources') and config_aggregator.get('OrganizationAggregationSource'):
raise InvalidParameterValueException('The configuration aggregator cannot be created because your request contains both the'
' AccountAggregationSource and the OrganizationAggregationSource. Include only '
'one aggregation source and try again.')
# If neither are supplied:
if not config_aggregator.get('AccountAggregationSources') and not config_aggregator.get('OrganizationAggregationSource'):
raise InvalidParameterValueException('The configuration aggregator cannot be created because your request is missing either '
'the AccountAggregationSource or the OrganizationAggregationSource. Include the '
'appropriate aggregation source and try again.')
if config_aggregator.get('AccountAggregationSources'):
# Currently, only 1 account aggregation source can be set:
if len(config_aggregator['AccountAggregationSources']) > 1:
raise TooManyAccountSources(len(config_aggregator['AccountAggregationSources']))
account_sources = []
for a in config_aggregator['AccountAggregationSources']:
account_sources.append(AccountAggregatorSource(a['AccountIds'], aws_regions=a.get('AwsRegions'),
all_aws_regions=a.get('AllAwsRegions')))
else:
org_source = OrganizationAggregationSource(config_aggregator['OrganizationAggregationSource']['RoleArn'],
aws_regions=config_aggregator['OrganizationAggregationSource'].get('AwsRegions'),
all_aws_regions=config_aggregator['OrganizationAggregationSource'].get(
'AllAwsRegions'))
# Grab the existing one if it exists and update it:
if not self.config_aggregators.get(config_aggregator['ConfigurationAggregatorName']):
aggregator = ConfigAggregator(config_aggregator['ConfigurationAggregatorName'], region, account_sources=account_sources,
org_source=org_source, tags=tags)
self.config_aggregators[config_aggregator['ConfigurationAggregatorName']] = aggregator
else:
aggregator = self.config_aggregators[config_aggregator['ConfigurationAggregatorName']]
aggregator.tags = tags
aggregator.account_aggregation_sources = account_sources
aggregator.organization_aggregation_source = org_source
aggregator.last_updated_time = datetime2int(datetime.utcnow())
return aggregator.to_dict()
def describe_configuration_aggregators(self, names, token, limit):
limit = DEFAULT_PAGE_SIZE if not limit or limit < 0 else limit
agg_list = []
result = {'ConfigurationAggregators': []}
if names:
for name in names:
if not self.config_aggregators.get(name):
raise NoSuchConfigurationAggregatorException(number=len(names))
agg_list.append(name)
else:
agg_list = list(self.config_aggregators.keys())
# Empty?
if not agg_list:
return result
# Sort by name:
sorted_aggregators = sorted(agg_list)
# Get the start:
if not token:
start = 0
else:
# Tokens for this moto feature are just the next names of the items in the list:
if not self.config_aggregators.get(token):
raise InvalidNextTokenException()
start = sorted_aggregators.index(token)
# Get the list of items to collect:
agg_list = sorted_aggregators[start:(start + limit)]
result['ConfigurationAggregators'] = [self.config_aggregators[agg].to_dict() for agg in agg_list]
if len(sorted_aggregators) > (start + limit):
result['NextToken'] = sorted_aggregators[start + limit]
return result
def delete_configuration_aggregator(self, config_aggregator):
if not self.config_aggregators.get(config_aggregator):
raise NoSuchConfigurationAggregatorException()
del self.config_aggregators[config_aggregator]
def put_aggregation_authorization(self, current_region, authorized_account, authorized_region, tags):
# Tag validation:
tags = validate_tags(tags or [])
# Does this already exist?
key = '{}/{}'.format(authorized_account, authorized_region)
agg_auth = self.aggregation_authorizations.get(key)
if not agg_auth:
agg_auth = ConfigAggregationAuthorization(current_region, authorized_account, authorized_region, tags=tags)
self.aggregation_authorizations['{}/{}'.format(authorized_account, authorized_region)] = agg_auth
else:
# Only update the tags:
agg_auth.tags = tags
return agg_auth.to_dict()
def describe_aggregation_authorizations(self, token, limit):
limit = DEFAULT_PAGE_SIZE if not limit or limit < 0 else limit
result = {'AggregationAuthorizations': []}
if not self.aggregation_authorizations:
return result
# Sort by name:
sorted_authorizations = sorted(self.aggregation_authorizations.keys())
# Get the start:
if not token:
start = 0
else:
# Tokens for this moto feature are just the next names of the items in the list:
if not self.aggregation_authorizations.get(token):
raise InvalidNextTokenException()
start = sorted_authorizations.index(token)
# Get the list of items to collect:
auth_list = sorted_authorizations[start:(start + limit)]
result['AggregationAuthorizations'] = [self.aggregation_authorizations[auth].to_dict() for auth in auth_list]
if len(sorted_authorizations) > (start + limit):
result['NextToken'] = sorted_authorizations[start + limit]
return result
def delete_aggregation_authorization(self, authorized_account, authorized_region):
# This will always return a 200 -- regardless if there is or isn't an existing
# aggregation authorization.
key = '{}/{}'.format(authorized_account, authorized_region)
self.aggregation_authorizations.pop(key, None)
def put_configuration_recorder(self, config_recorder): def put_configuration_recorder(self, config_recorder):
# Validate the name: # Validate the name:
if not config_recorder.get('name'): if not config_recorder.get('name'):

View File

@ -13,6 +13,39 @@ class ConfigResponse(BaseResponse):
self.config_backend.put_configuration_recorder(self._get_param('ConfigurationRecorder')) self.config_backend.put_configuration_recorder(self._get_param('ConfigurationRecorder'))
return "" return ""
def put_configuration_aggregator(self):
aggregator = self.config_backend.put_configuration_aggregator(json.loads(self.body), self.region)
schema = {'ConfigurationAggregator': aggregator}
return json.dumps(schema)
def describe_configuration_aggregators(self):
aggregators = self.config_backend.describe_configuration_aggregators(self._get_param('ConfigurationAggregatorNames'),
self._get_param('NextToken'),
self._get_param('Limit'))
return json.dumps(aggregators)
def delete_configuration_aggregator(self):
self.config_backend.delete_configuration_aggregator(self._get_param('ConfigurationAggregatorName'))
return ""
def put_aggregation_authorization(self):
agg_auth = self.config_backend.put_aggregation_authorization(self.region,
self._get_param('AuthorizedAccountId'),
self._get_param('AuthorizedAwsRegion'),
self._get_param('Tags'))
schema = {'AggregationAuthorization': agg_auth}
return json.dumps(schema)
def describe_aggregation_authorizations(self):
authorizations = self.config_backend.describe_aggregation_authorizations(self._get_param('NextToken'), self._get_param('Limit'))
return json.dumps(authorizations)
def delete_aggregation_authorization(self):
self.config_backend.delete_aggregation_authorization(self._get_param('AuthorizedAccountId'), self._get_param('AuthorizedAwsRegion'))
return ""
def describe_configuration_recorders(self): def describe_configuration_recorders(self):
recorders = self.config_backend.describe_configuration_recorders(self._get_param('ConfigurationRecorderNames')) recorders = self.config_backend.describe_configuration_recorders(self._get_param('ConfigurationRecorderNames'))
schema = {'ConfigurationRecorders': recorders} schema = {'ConfigurationRecorders': recorders}

View File

@ -1,4 +1,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .models import BaseModel, BaseBackend, moto_api_backend # flake8: noqa from .models import BaseModel, BaseBackend, moto_api_backend # flake8: noqa
from .responses import ActionAuthenticatorMixin
moto_api_backends = {"global": moto_api_backend} moto_api_backends = {"global": moto_api_backend}
set_initial_no_auth_action_count = ActionAuthenticatorMixin.set_initial_no_auth_action_count

View File

@ -65,3 +65,42 @@ class JsonRESTError(RESTError):
def get_body(self, *args, **kwargs): def get_body(self, *args, **kwargs):
return self.description return self.description
class SignatureDoesNotMatchError(RESTError):
code = 403
def __init__(self):
super(SignatureDoesNotMatchError, self).__init__(
'SignatureDoesNotMatch',
"The request signature we calculated does not match the signature you provided. Check your AWS Secret Access Key and signing method. Consult the service documentation for details.")
class InvalidClientTokenIdError(RESTError):
code = 403
def __init__(self):
super(InvalidClientTokenIdError, self).__init__(
'InvalidClientTokenId',
"The security token included in the request is invalid.")
class AccessDeniedError(RESTError):
code = 403
def __init__(self, user_arn, action):
super(AccessDeniedError, self).__init__(
'AccessDenied',
"User: {user_arn} is not authorized to perform: {operation}".format(
user_arn=user_arn,
operation=action
))
class AuthFailureError(RESTError):
code = 401
def __init__(self):
super(AuthFailureError, self).__init__(
'AuthFailure',
"AWS was not able to validate the provided access credentials")

View File

@ -12,6 +12,7 @@ from collections import defaultdict
from botocore.handlers import BUILTIN_HANDLERS from botocore.handlers import BUILTIN_HANDLERS
from botocore.awsrequest import AWSResponse from botocore.awsrequest import AWSResponse
import mock
from moto import settings from moto import settings
import responses import responses
from moto.packages.httpretty import HTTPretty from moto.packages.httpretty import HTTPretty
@ -22,11 +23,6 @@ from .utils import (
) )
# "Mock" the AWS credentials as they can't be mocked in Botocore currently
os.environ.setdefault("AWS_ACCESS_KEY_ID", "foobar_key")
os.environ.setdefault("AWS_SECRET_ACCESS_KEY", "foobar_secret")
class BaseMockAWS(object): class BaseMockAWS(object):
nested_count = 0 nested_count = 0
@ -42,6 +38,10 @@ class BaseMockAWS(object):
self.backends_for_urls.update(self.backends) self.backends_for_urls.update(self.backends)
self.backends_for_urls.update(default_backends) self.backends_for_urls.update(default_backends)
# "Mock" the AWS credentials as they can't be mocked in Botocore currently
FAKE_KEYS = {"AWS_ACCESS_KEY_ID": "foobar_key", "AWS_SECRET_ACCESS_KEY": "foobar_secret"}
self.env_variables_mocks = mock.patch.dict(os.environ, FAKE_KEYS)
if self.__class__.nested_count == 0: if self.__class__.nested_count == 0:
self.reset() self.reset()
@ -52,11 +52,14 @@ class BaseMockAWS(object):
def __enter__(self): def __enter__(self):
self.start() self.start()
return self
def __exit__(self, *args): def __exit__(self, *args):
self.stop() self.stop()
def start(self, reset=True): def start(self, reset=True):
self.env_variables_mocks.start()
self.__class__.nested_count += 1 self.__class__.nested_count += 1
if reset: if reset:
for backend in self.backends.values(): for backend in self.backends.values():
@ -65,6 +68,7 @@ class BaseMockAWS(object):
self.enable_patching() self.enable_patching()
def stop(self): def stop(self):
self.env_variables_mocks.stop()
self.__class__.nested_count -= 1 self.__class__.nested_count -= 1
if self.__class__.nested_count < 0: if self.__class__.nested_count < 0:
@ -465,10 +469,14 @@ class BaseModel(object):
class BaseBackend(object): class BaseBackend(object):
def reset(self): def _reset_model_refs(self):
# Remove all references to the models stored
for service, models in model_data.items(): for service, models in model_data.items():
for model_name, model in models.items(): for model_name, model in models.items():
model.instances = [] model.instances = []
def reset(self):
self._reset_model_refs()
self.__dict__ = {} self.__dict__ = {}
self.__init__() self.__init__()

View File

@ -1,13 +1,17 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import functools
from collections import defaultdict from collections import defaultdict
import datetime import datetime
import json import json
import logging import logging
import re import re
import io import io
import requests
import pytz import pytz
from moto.core.access_control import IAMRequest, S3IAMRequest
from moto.core.exceptions import DryRunClientError from moto.core.exceptions import DryRunClientError
from jinja2 import Environment, DictLoader, TemplateNotFound from jinja2 import Environment, DictLoader, TemplateNotFound
@ -22,7 +26,7 @@ from werkzeug.exceptions import HTTPException
import boto3 import boto3
from moto.compat import OrderedDict from moto.compat import OrderedDict
from moto.core.utils import camelcase_to_underscores, method_names_from_class from moto.core.utils import camelcase_to_underscores, method_names_from_class
from moto import settings
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -103,7 +107,54 @@ class _TemplateEnvironmentMixin(object):
return self.environment.get_template(template_id) return self.environment.get_template(template_id)
class BaseResponse(_TemplateEnvironmentMixin): class ActionAuthenticatorMixin(object):
request_count = 0
def _authenticate_and_authorize_action(self, iam_request_cls):
if ActionAuthenticatorMixin.request_count >= settings.INITIAL_NO_AUTH_ACTION_COUNT:
iam_request = iam_request_cls(method=self.method, path=self.path, data=self.data, headers=self.headers)
iam_request.check_signature()
iam_request.check_action_permitted()
else:
ActionAuthenticatorMixin.request_count += 1
def _authenticate_and_authorize_normal_action(self):
self._authenticate_and_authorize_action(IAMRequest)
def _authenticate_and_authorize_s3_action(self):
self._authenticate_and_authorize_action(S3IAMRequest)
@staticmethod
def set_initial_no_auth_action_count(initial_no_auth_action_count):
def decorator(function):
def wrapper(*args, **kwargs):
if settings.TEST_SERVER_MODE:
response = requests.post("http://localhost:5000/moto-api/reset-auth", data=str(initial_no_auth_action_count).encode())
original_initial_no_auth_action_count = response.json()['PREVIOUS_INITIAL_NO_AUTH_ACTION_COUNT']
else:
original_initial_no_auth_action_count = settings.INITIAL_NO_AUTH_ACTION_COUNT
original_request_count = ActionAuthenticatorMixin.request_count
settings.INITIAL_NO_AUTH_ACTION_COUNT = initial_no_auth_action_count
ActionAuthenticatorMixin.request_count = 0
try:
result = function(*args, **kwargs)
finally:
if settings.TEST_SERVER_MODE:
requests.post("http://localhost:5000/moto-api/reset-auth", data=str(original_initial_no_auth_action_count).encode())
else:
ActionAuthenticatorMixin.request_count = original_request_count
settings.INITIAL_NO_AUTH_ACTION_COUNT = original_initial_no_auth_action_count
return result
functools.update_wrapper(wrapper, function)
wrapper.__wrapped__ = function
return wrapper
return decorator
class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
default_region = 'us-east-1' default_region = 'us-east-1'
# to extract region, use [^.] # to extract region, use [^.]
@ -167,6 +218,7 @@ class BaseResponse(_TemplateEnvironmentMixin):
self.uri = full_url self.uri = full_url
self.path = urlparse(full_url).path self.path = urlparse(full_url).path
self.querystring = querystring self.querystring = querystring
self.data = querystring
self.method = request.method self.method = request.method
self.region = self.get_region_from_url(request, full_url) self.region = self.get_region_from_url(request, full_url)
self.uri_match = None self.uri_match = None
@ -273,6 +325,13 @@ class BaseResponse(_TemplateEnvironmentMixin):
def call_action(self): def call_action(self):
headers = self.response_headers headers = self.response_headers
try:
self._authenticate_and_authorize_normal_action()
except HTTPException as http_error:
response = http_error.description, dict(status=http_error.code)
return self._send_response(headers, response)
action = camelcase_to_underscores(self._get_action()) action = camelcase_to_underscores(self._get_action())
method_names = method_names_from_class(self.__class__) method_names = method_names_from_class(self.__class__)
if action in method_names: if action in method_names:
@ -285,16 +344,7 @@ class BaseResponse(_TemplateEnvironmentMixin):
if isinstance(response, six.string_types): if isinstance(response, six.string_types):
return 200, headers, response return 200, headers, response
else: else:
if len(response) == 2: return self._send_response(headers, response)
body, new_headers = response
else:
status, new_headers, body = response
status = new_headers.get('status', 200)
headers.update(new_headers)
# Cast status to string
if "status" in headers:
headers['status'] = str(headers['status'])
return status, headers, body
if not action: if not action:
return 404, headers, '' return 404, headers, ''
@ -302,6 +352,19 @@ class BaseResponse(_TemplateEnvironmentMixin):
raise NotImplementedError( raise NotImplementedError(
"The {0} action has not been implemented".format(action)) "The {0} action has not been implemented".format(action))
@staticmethod
def _send_response(headers, response):
if len(response) == 2:
body, new_headers = response
else:
status, new_headers, body = response
status = new_headers.get('status', 200)
headers.update(new_headers)
# Cast status to string
if "status" in headers:
headers['status'] = str(headers['status'])
return status, headers, body
def _get_param(self, param_name, if_none=None): def _get_param(self, param_name, if_none=None):
val = self.querystring.get(param_name) val = self.querystring.get(param_name)
if val is not None: if val is not None:
@ -569,6 +632,14 @@ class MotoAPIResponse(BaseResponse):
return 200, {}, json.dumps({"status": "ok"}) return 200, {}, json.dumps({"status": "ok"})
return 400, {}, json.dumps({"Error": "Need to POST to reset Moto"}) return 400, {}, json.dumps({"Error": "Need to POST to reset Moto"})
def reset_auth_response(self, request, full_url, headers):
if request.method == "POST":
previous_initial_no_auth_action_count = settings.INITIAL_NO_AUTH_ACTION_COUNT
settings.INITIAL_NO_AUTH_ACTION_COUNT = float(request.data.decode())
ActionAuthenticatorMixin.request_count = 0
return 200, {}, json.dumps({"status": "ok", "PREVIOUS_INITIAL_NO_AUTH_ACTION_COUNT": str(previous_initial_no_auth_action_count)})
return 400, {}, json.dumps({"Error": "Need to POST to reset Moto Auth"})
def model_data(self, request, full_url, headers): def model_data(self, request, full_url, headers):
from moto.core.models import model_data from moto.core.models import model_data

View File

@ -11,4 +11,5 @@ url_paths = {
'{0}/moto-api/$': response_instance.dashboard, '{0}/moto-api/$': response_instance.dashboard,
'{0}/moto-api/data.json': response_instance.model_data, '{0}/moto-api/data.json': response_instance.model_data,
'{0}/moto-api/reset': response_instance.reset_response, '{0}/moto-api/reset': response_instance.reset_response,
'{0}/moto-api/reset-auth': response_instance.reset_auth_response,
} }

File diff suppressed because it is too large Load Diff

View File

@ -6,13 +6,16 @@ import decimal
import json import json
import re import re
import uuid import uuid
import six
import boto3 import boto3
from moto.compat import OrderedDict from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.utils import unix_time from moto.core.utils import unix_time
from moto.core.exceptions import JsonRESTError from moto.core.exceptions import JsonRESTError
from .comparisons import get_comparison_func, get_filter_expression, Op from .comparisons import get_comparison_func
from .comparisons import get_filter_expression
from .comparisons import get_expected
from .exceptions import InvalidIndexNameError from .exceptions import InvalidIndexNameError
@ -68,10 +71,34 @@ class DynamoType(object):
except ValueError: except ValueError:
return float(self.value) return float(self.value)
elif self.is_set(): elif self.is_set():
return set(self.value) sub_type = self.type[0]
return set([DynamoType({sub_type: v}).cast_value
for v in self.value])
elif self.is_list():
return [DynamoType(v).cast_value for v in self.value]
elif self.is_map():
return dict([
(k, DynamoType(v).cast_value)
for k, v in self.value.items()])
else: else:
return self.value return self.value
def child_attr(self, key):
"""
Get Map or List children by key. str for Map, int for List.
Returns DynamoType or None.
"""
if isinstance(key, six.string_types) and self.is_map() and key in self.value:
return DynamoType(self.value[key])
if isinstance(key, int) and self.is_list():
idx = key
if idx >= 0 and idx < len(self.value):
return DynamoType(self.value[idx])
return None
def to_json(self): def to_json(self):
return {self.type: self.value} return {self.type: self.value}
@ -89,6 +116,12 @@ class DynamoType(object):
def is_set(self): def is_set(self):
return self.type == 'SS' or self.type == 'NS' or self.type == 'BS' return self.type == 'SS' or self.type == 'NS' or self.type == 'BS'
def is_list(self):
return self.type == 'L'
def is_map(self):
return self.type == 'M'
def same_type(self, other): def same_type(self, other):
return self.type == other.type return self.type == other.type
@ -265,7 +298,9 @@ class Item(BaseModel):
new_value = list(update_action['Value'].values())[0] new_value = list(update_action['Value'].values())[0]
if action == 'PUT': if action == 'PUT':
# TODO deal with other types # TODO deal with other types
if isinstance(new_value, list) or isinstance(new_value, set): if isinstance(new_value, list):
self.attrs[attribute_name] = DynamoType({"L": new_value})
elif isinstance(new_value, set):
self.attrs[attribute_name] = DynamoType({"SS": new_value}) self.attrs[attribute_name] = DynamoType({"SS": new_value})
elif isinstance(new_value, dict): elif isinstance(new_value, dict):
self.attrs[attribute_name] = DynamoType({"M": new_value}) self.attrs[attribute_name] = DynamoType({"M": new_value})
@ -504,7 +539,9 @@ class Table(BaseModel):
keys.append(range_key) keys.append(range_key)
return keys return keys
def put_item(self, item_attrs, expected=None, overwrite=False): def put_item(self, item_attrs, expected=None, condition_expression=None,
expression_attribute_names=None,
expression_attribute_values=None, overwrite=False):
hash_value = DynamoType(item_attrs.get(self.hash_key_attr)) hash_value = DynamoType(item_attrs.get(self.hash_key_attr))
if self.has_range_key: if self.has_range_key:
range_value = DynamoType(item_attrs.get(self.range_key_attr)) range_value = DynamoType(item_attrs.get(self.range_key_attr))
@ -527,29 +564,15 @@ class Table(BaseModel):
self.range_key_type, item_attrs) self.range_key_type, item_attrs)
if not overwrite: if not overwrite:
if current is None: if not get_expected(expected).expr(current):
current_attr = {} raise ValueError('The conditional request failed')
elif hasattr(current, 'attrs'): condition_op = get_filter_expression(
current_attr = current.attrs condition_expression,
else: expression_attribute_names,
current_attr = current expression_attribute_values)
if not condition_op.expr(current):
raise ValueError('The conditional request failed')
for key, val in expected.items():
if 'Exists' in val and val['Exists'] is False \
or 'ComparisonOperator' in val and val['ComparisonOperator'] == 'NULL':
if key in current_attr:
raise ValueError("The conditional request failed")
elif key not in current_attr:
raise ValueError("The conditional request failed")
elif 'Value' in val and DynamoType(val['Value']).value != current_attr[key].value:
raise ValueError("The conditional request failed")
elif 'ComparisonOperator' in val:
dynamo_types = [
DynamoType(ele) for ele in
val.get("AttributeValueList", [])
]
if not current_attr[key].compare(val['ComparisonOperator'], dynamo_types):
raise ValueError('The conditional request failed')
if range_value: if range_value:
self.items[hash_value][range_value] = item self.items[hash_value][range_value] = item
else: else:
@ -724,7 +747,7 @@ class Table(BaseModel):
if idx_col_set.issubset(set(hash_set.attrs)): if idx_col_set.issubset(set(hash_set.attrs)):
yield hash_set yield hash_set
def scan(self, filters, limit, exclusive_start_key, filter_expression=None, index_name=None): def scan(self, filters, limit, exclusive_start_key, filter_expression=None, index_name=None, projection_expression=None):
results = [] results = []
scanned_count = 0 scanned_count = 0
all_indexes = self.all_indexes() all_indexes = self.all_indexes()
@ -763,6 +786,14 @@ class Table(BaseModel):
if passes_all_conditions: if passes_all_conditions:
results.append(item) results.append(item)
if projection_expression:
expressions = [x.strip() for x in projection_expression.split(',')]
results = copy.deepcopy(results)
for result in results:
for attr in list(result.attrs):
if attr not in expressions:
result.attrs.pop(attr)
results, last_evaluated_key = self._trim_results(results, limit, results, last_evaluated_key = self._trim_results(results, limit,
exclusive_start_key, index_name) exclusive_start_key, index_name)
return results, scanned_count, last_evaluated_key return results, scanned_count, last_evaluated_key
@ -894,11 +925,15 @@ class DynamoDBBackend(BaseBackend):
table.global_indexes = list(gsis_by_name.values()) table.global_indexes = list(gsis_by_name.values())
return table return table
def put_item(self, table_name, item_attrs, expected=None, overwrite=False): def put_item(self, table_name, item_attrs, expected=None,
condition_expression=None, expression_attribute_names=None,
expression_attribute_values=None, overwrite=False):
table = self.tables.get(table_name) table = self.tables.get(table_name)
if not table: if not table:
return None return None
return table.put_item(item_attrs, expected, overwrite) return table.put_item(item_attrs, expected, condition_expression,
expression_attribute_names,
expression_attribute_values, overwrite)
def get_table_keys_name(self, table_name, keys): def get_table_keys_name(self, table_name, keys):
""" """
@ -954,15 +989,12 @@ class DynamoDBBackend(BaseBackend):
range_values = [DynamoType(range_value) range_values = [DynamoType(range_value)
for range_value in range_value_dicts] for range_value in range_value_dicts]
if filter_expression is not None: filter_expression = get_filter_expression(filter_expression, expr_names, expr_values)
filter_expression = get_filter_expression(filter_expression, expr_names, expr_values)
else:
filter_expression = Op(None, None) # Will always eval to true
return table.query(hash_key, range_comparison, range_values, limit, return table.query(hash_key, range_comparison, range_values, limit,
exclusive_start_key, scan_index_forward, projection_expression, index_name, filter_expression, **filter_kwargs) exclusive_start_key, scan_index_forward, projection_expression, index_name, filter_expression, **filter_kwargs)
def scan(self, table_name, filters, limit, exclusive_start_key, filter_expression, expr_names, expr_values, index_name): def scan(self, table_name, filters, limit, exclusive_start_key, filter_expression, expr_names, expr_values, index_name, projection_expression):
table = self.tables.get(table_name) table = self.tables.get(table_name)
if not table: if not table:
return None, None, None return None, None, None
@ -972,15 +1004,14 @@ class DynamoDBBackend(BaseBackend):
dynamo_types = [DynamoType(value) for value in comparison_values] dynamo_types = [DynamoType(value) for value in comparison_values]
scan_filters[key] = (comparison_operator, dynamo_types) scan_filters[key] = (comparison_operator, dynamo_types)
if filter_expression is not None: filter_expression = get_filter_expression(filter_expression, expr_names, expr_values)
filter_expression = get_filter_expression(filter_expression, expr_names, expr_values)
else:
filter_expression = Op(None, None) # Will always eval to true
return table.scan(scan_filters, limit, exclusive_start_key, filter_expression, index_name) projection_expression = ','.join([expr_names.get(attr, attr) for attr in projection_expression.replace(' ', '').split(',')])
return table.scan(scan_filters, limit, exclusive_start_key, filter_expression, index_name, projection_expression)
def update_item(self, table_name, key, update_expression, attribute_updates, expression_attribute_names, def update_item(self, table_name, key, update_expression, attribute_updates, expression_attribute_names,
expression_attribute_values, expected=None): expression_attribute_values, expected=None, condition_expression=None):
table = self.get_table(table_name) table = self.get_table(table_name)
if all([table.hash_key_attr in key, table.range_key_attr in key]): if all([table.hash_key_attr in key, table.range_key_attr in key]):
@ -999,32 +1030,17 @@ class DynamoDBBackend(BaseBackend):
item = table.get_item(hash_value, range_value) item = table.get_item(hash_value, range_value)
if item is None:
item_attr = {}
elif hasattr(item, 'attrs'):
item_attr = item.attrs
else:
item_attr = item
if not expected: if not expected:
expected = {} expected = {}
for key, val in expected.items(): if not get_expected(expected).expr(item):
if 'Exists' in val and val['Exists'] is False \ raise ValueError('The conditional request failed')
or 'ComparisonOperator' in val and val['ComparisonOperator'] == 'NULL': condition_op = get_filter_expression(
if key in item_attr: condition_expression,
raise ValueError("The conditional request failed") expression_attribute_names,
elif key not in item_attr: expression_attribute_values)
raise ValueError("The conditional request failed") if not condition_op.expr(item):
elif 'Value' in val and DynamoType(val['Value']).value != item_attr[key].value: raise ValueError('The conditional request failed')
raise ValueError("The conditional request failed")
elif 'ComparisonOperator' in val:
dynamo_types = [
DynamoType(ele) for ele in
val.get("AttributeValueList", [])
]
if not item_attr[key].compare(val['ComparisonOperator'], dynamo_types):
raise ValueError('The conditional request failed')
# Update does not fail on new items, so create one # Update does not fail on new items, so create one
if item is None: if item is None:

View File

@ -32,67 +32,6 @@ def get_empty_str_error():
)) ))
def condition_expression_to_expected(condition_expression, expression_attribute_names, expression_attribute_values):
"""
Limited condition expression syntax parsing.
Supports Global Negation ex: NOT(inner expressions).
Supports simple AND conditions ex: cond_a AND cond_b and cond_c.
Atomic expressions supported are attribute_exists(key), attribute_not_exists(key) and #key = :value.
"""
expected = {}
if condition_expression and 'OR' not in condition_expression:
reverse_re = re.compile('^NOT\s*\((.*)\)$')
reverse_m = reverse_re.match(condition_expression.strip())
reverse = False
if reverse_m:
reverse = True
condition_expression = reverse_m.group(1)
cond_items = [c.strip() for c in condition_expression.split('AND')]
if cond_items:
exists_re = re.compile('^attribute_exists\s*\((.*)\)$')
not_exists_re = re.compile(
'^attribute_not_exists\s*\((.*)\)$')
equals_re = re.compile('^(#?\w+)\s*=\s*(\:?\w+)')
for cond in cond_items:
exists_m = exists_re.match(cond)
not_exists_m = not_exists_re.match(cond)
equals_m = equals_re.match(cond)
if exists_m:
attribute_name = expression_attribute_names_lookup(exists_m.group(1), expression_attribute_names)
expected[attribute_name] = {'Exists': True if not reverse else False}
elif not_exists_m:
attribute_name = expression_attribute_names_lookup(not_exists_m.group(1), expression_attribute_names)
expected[attribute_name] = {'Exists': False if not reverse else True}
elif equals_m:
attribute_name = expression_attribute_names_lookup(equals_m.group(1), expression_attribute_names)
attribute_value = expression_attribute_values_lookup(equals_m.group(2), expression_attribute_values)
expected[attribute_name] = {
'AttributeValueList': [attribute_value],
'ComparisonOperator': 'EQ' if not reverse else 'NEQ'}
return expected
def expression_attribute_names_lookup(attribute_name, expression_attribute_names):
if attribute_name.startswith('#') and attribute_name in expression_attribute_names:
return expression_attribute_names[attribute_name]
else:
return attribute_name
def expression_attribute_values_lookup(attribute_value, expression_attribute_values):
if isinstance(attribute_value, six.string_types) and \
attribute_value.startswith(':') and\
attribute_value in expression_attribute_values:
return expression_attribute_values[attribute_value]
else:
return attribute_value
class DynamoHandler(BaseResponse): class DynamoHandler(BaseResponse):
def get_endpoint_name(self, headers): def get_endpoint_name(self, headers):
@ -166,7 +105,7 @@ class DynamoHandler(BaseResponse):
when BillingMode is PAY_PER_REQUEST') when BillingMode is PAY_PER_REQUEST')
throughput = None throughput = None
else: # Provisioned (default billing mode) else: # Provisioned (default billing mode)
throughput = body["ProvisionedThroughput"] throughput = body.get("ProvisionedThroughput")
# getting the schema # getting the schema
key_schema = body['KeySchema'] key_schema = body['KeySchema']
# getting attribute definition # getting attribute definition
@ -288,18 +227,18 @@ class DynamoHandler(BaseResponse):
# Attempt to parse simple ConditionExpressions into an Expected # Attempt to parse simple ConditionExpressions into an Expected
# expression # expression
if not expected: condition_expression = self.body.get('ConditionExpression')
condition_expression = self.body.get('ConditionExpression') expression_attribute_names = self.body.get('ExpressionAttributeNames', {})
expression_attribute_names = self.body.get('ExpressionAttributeNames', {}) expression_attribute_values = self.body.get('ExpressionAttributeValues', {})
expression_attribute_values = self.body.get('ExpressionAttributeValues', {})
expected = condition_expression_to_expected(condition_expression, if condition_expression:
expression_attribute_names, overwrite = False
expression_attribute_values)
if expected:
overwrite = False
try: try:
result = self.dynamodb_backend.put_item(name, item, expected, overwrite) result = self.dynamodb_backend.put_item(
name, item, expected, condition_expression,
expression_attribute_names, expression_attribute_values,
overwrite)
except ValueError: except ValueError:
er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException' er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException'
return self.error(er, 'A condition specified in the operation could not be evaluated.') return self.error(er, 'A condition specified in the operation could not be evaluated.')
@ -379,6 +318,9 @@ class DynamoHandler(BaseResponse):
for table_name, table_request in table_batches.items(): for table_name, table_request in table_batches.items():
keys = table_request['Keys'] keys = table_request['Keys']
if self._contains_duplicates(keys):
er = 'com.amazon.coral.validate#ValidationException'
return self.error(er, 'Provided list of item keys contains duplicates')
attributes_to_get = table_request.get('AttributesToGet') attributes_to_get = table_request.get('AttributesToGet')
results["Responses"][table_name] = [] results["Responses"][table_name] = []
for key in keys: for key in keys:
@ -394,6 +336,15 @@ class DynamoHandler(BaseResponse):
}) })
return dynamo_json_dump(results) return dynamo_json_dump(results)
def _contains_duplicates(self, keys):
unique_keys = []
for k in keys:
if k in unique_keys:
return True
else:
unique_keys.append(k)
return False
def query(self): def query(self):
name = self.body['TableName'] name = self.body['TableName']
# {u'KeyConditionExpression': u'#n0 = :v0', u'ExpressionAttributeValues': {u':v0': {u'S': u'johndoe'}}, u'ExpressionAttributeNames': {u'#n0': u'username'}} # {u'KeyConditionExpression': u'#n0 = :v0', u'ExpressionAttributeValues': {u':v0': {u'S': u'johndoe'}}, u'ExpressionAttributeNames': {u'#n0': u'username'}}
@ -558,7 +509,7 @@ class DynamoHandler(BaseResponse):
filter_expression = self.body.get('FilterExpression') filter_expression = self.body.get('FilterExpression')
expression_attribute_values = self.body.get('ExpressionAttributeValues', {}) expression_attribute_values = self.body.get('ExpressionAttributeValues', {})
expression_attribute_names = self.body.get('ExpressionAttributeNames', {}) expression_attribute_names = self.body.get('ExpressionAttributeNames', {})
projection_expression = self.body.get('ProjectionExpression', '')
exclusive_start_key = self.body.get('ExclusiveStartKey') exclusive_start_key = self.body.get('ExclusiveStartKey')
limit = self.body.get("Limit") limit = self.body.get("Limit")
index_name = self.body.get('IndexName') index_name = self.body.get('IndexName')
@ -570,7 +521,8 @@ class DynamoHandler(BaseResponse):
filter_expression, filter_expression,
expression_attribute_names, expression_attribute_names,
expression_attribute_values, expression_attribute_values,
index_name) index_name,
projection_expression)
except InvalidIndexNameError as err: except InvalidIndexNameError as err:
er = 'com.amazonaws.dynamodb.v20111205#ValidationException' er = 'com.amazonaws.dynamodb.v20111205#ValidationException'
return self.error(er, str(err)) return self.error(er, str(err))
@ -625,7 +577,7 @@ class DynamoHandler(BaseResponse):
name = self.body['TableName'] name = self.body['TableName']
key = self.body['Key'] key = self.body['Key']
return_values = self.body.get('ReturnValues', 'NONE') return_values = self.body.get('ReturnValues', 'NONE')
update_expression = self.body.get('UpdateExpression') update_expression = self.body.get('UpdateExpression', '').strip()
attribute_updates = self.body.get('AttributeUpdates') attribute_updates = self.body.get('AttributeUpdates')
expression_attribute_names = self.body.get( expression_attribute_names = self.body.get(
'ExpressionAttributeNames', {}) 'ExpressionAttributeNames', {})
@ -652,24 +604,20 @@ class DynamoHandler(BaseResponse):
# Attempt to parse simple ConditionExpressions into an Expected # Attempt to parse simple ConditionExpressions into an Expected
# expression # expression
if not expected: condition_expression = self.body.get('ConditionExpression')
condition_expression = self.body.get('ConditionExpression') expression_attribute_names = self.body.get('ExpressionAttributeNames', {})
expression_attribute_names = self.body.get('ExpressionAttributeNames', {}) expression_attribute_values = self.body.get('ExpressionAttributeValues', {})
expression_attribute_values = self.body.get('ExpressionAttributeValues', {})
expected = condition_expression_to_expected(condition_expression,
expression_attribute_names,
expression_attribute_values)
# Support spaces between operators in an update expression # Support spaces between operators in an update expression
# E.g. `a = b + c` -> `a=b+c` # E.g. `a = b + c` -> `a=b+c`
if update_expression: if update_expression:
update_expression = re.sub( update_expression = re.sub(
'\s*([=\+-])\s*', '\\1', update_expression) r'\s*([=\+-])\s*', '\\1', update_expression)
try: try:
item = self.dynamodb_backend.update_item( item = self.dynamodb_backend.update_item(
name, key, update_expression, attribute_updates, expression_attribute_names, name, key, update_expression, attribute_updates, expression_attribute_names,
expression_attribute_values, expected expression_attribute_values, expected, condition_expression
) )
except ValueError: except ValueError:
er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException' er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException'

View File

@ -332,6 +332,15 @@ class InvalidParameterValueErrorTagNull(EC2ClientError):
"Tag value cannot be null. Use empty string instead.") "Tag value cannot be null. Use empty string instead.")
class InvalidParameterValueErrorUnknownAttribute(EC2ClientError):
def __init__(self, parameter_value):
super(InvalidParameterValueErrorUnknownAttribute, self).__init__(
"InvalidParameterValue",
"Value ({0}) for parameter attribute is invalid. Unknown attribute."
.format(parameter_value))
class InvalidInternetGatewayIdError(EC2ClientError): class InvalidInternetGatewayIdError(EC2ClientError):
def __init__(self, internet_gateway_id): def __init__(self, internet_gateway_id):
@ -430,6 +439,16 @@ class OperationNotPermitted(EC2ClientError):
) )
class InvalidAvailabilityZoneError(EC2ClientError):
def __init__(self, availability_zone_value, valid_availability_zones):
super(InvalidAvailabilityZoneError, self).__init__(
"InvalidParameterValue",
"Value ({0}) for parameter availabilityZone is invalid. "
"Subnets can currently only be created in the following availability zones: {1}.".format(availability_zone_value, valid_availability_zones)
)
class NetworkAclEntryAlreadyExistsError(EC2ClientError): class NetworkAclEntryAlreadyExistsError(EC2ClientError):
def __init__(self, rule_number): def __init__(self, rule_number):
@ -504,3 +523,11 @@ class OperationNotPermitted3(EC2ClientError):
pcx_id, pcx_id,
acceptor_region) acceptor_region)
) )
class InvalidLaunchTemplateNameError(EC2ClientError):
def __init__(self):
super(InvalidLaunchTemplateNameError, self).__init__(
"InvalidLaunchTemplateName.AlreadyExistsException",
"Launch template name already in use."
)

View File

@ -20,7 +20,6 @@ from boto.ec2.blockdevicemapping import BlockDeviceMapping, BlockDeviceType
from boto.ec2.spotinstancerequest import SpotInstanceRequest as BotoSpotRequest from boto.ec2.spotinstancerequest import SpotInstanceRequest as BotoSpotRequest
from boto.ec2.launchspecification import LaunchSpecification from boto.ec2.launchspecification import LaunchSpecification
from moto.compat import OrderedDict from moto.compat import OrderedDict
from moto.core import BaseBackend from moto.core import BaseBackend
from moto.core.models import Model, BaseModel from moto.core.models import Model, BaseModel
@ -36,6 +35,7 @@ from .exceptions import (
InvalidAMIIdError, InvalidAMIIdError,
InvalidAMIAttributeItemValueError, InvalidAMIAttributeItemValueError,
InvalidAssociationIdError, InvalidAssociationIdError,
InvalidAvailabilityZoneError,
InvalidCIDRBlockParameterError, InvalidCIDRBlockParameterError,
InvalidCIDRSubnetError, InvalidCIDRSubnetError,
InvalidCustomerGatewayIdError, InvalidCustomerGatewayIdError,
@ -48,11 +48,13 @@ from .exceptions import (
InvalidKeyPairDuplicateError, InvalidKeyPairDuplicateError,
InvalidKeyPairFormatError, InvalidKeyPairFormatError,
InvalidKeyPairNameError, InvalidKeyPairNameError,
InvalidLaunchTemplateNameError,
InvalidNetworkAclIdError, InvalidNetworkAclIdError,
InvalidNetworkAttachmentIdError, InvalidNetworkAttachmentIdError,
InvalidNetworkInterfaceIdError, InvalidNetworkInterfaceIdError,
InvalidParameterValueError, InvalidParameterValueError,
InvalidParameterValueErrorTagNull, InvalidParameterValueErrorTagNull,
InvalidParameterValueErrorUnknownAttribute,
InvalidPermissionNotFoundError, InvalidPermissionNotFoundError,
InvalidPermissionDuplicateError, InvalidPermissionDuplicateError,
InvalidRouteTableIdError, InvalidRouteTableIdError,
@ -96,6 +98,7 @@ from .utils import (
random_internet_gateway_id, random_internet_gateway_id,
random_ip, random_ip,
random_ipv6_cidr, random_ipv6_cidr,
random_launch_template_id,
random_nat_gateway_id, random_nat_gateway_id,
random_key_pair, random_key_pair,
random_private_ip, random_private_ip,
@ -140,6 +143,8 @@ AMIS = json.load(
__name__, 'resources/amis.json'), 'r') __name__, 'resources/amis.json'), 'r')
) )
OWNER_ID = "111122223333"
def utc_date_and_time(): def utc_date_and_time():
return datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%S.000Z') return datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%S.000Z')
@ -199,7 +204,7 @@ class TaggedEC2Resource(BaseModel):
class NetworkInterface(TaggedEC2Resource): class NetworkInterface(TaggedEC2Resource):
def __init__(self, ec2_backend, subnet, private_ip_address, device_index=0, def __init__(self, ec2_backend, subnet, private_ip_address, device_index=0,
public_ip_auto_assign=True, group_ids=None): public_ip_auto_assign=True, group_ids=None, description=None):
self.ec2_backend = ec2_backend self.ec2_backend = ec2_backend
self.id = random_eni_id() self.id = random_eni_id()
self.device_index = device_index self.device_index = device_index
@ -207,6 +212,7 @@ class NetworkInterface(TaggedEC2Resource):
self.subnet = subnet self.subnet = subnet
self.instance = None self.instance = None
self.attachment_id = None self.attachment_id = None
self.description = description
self.public_ip = None self.public_ip = None
self.public_ip_auto_assign = public_ip_auto_assign self.public_ip_auto_assign = public_ip_auto_assign
@ -244,11 +250,13 @@ class NetworkInterface(TaggedEC2Resource):
subnet = None subnet = None
private_ip_address = properties.get('PrivateIpAddress', None) private_ip_address = properties.get('PrivateIpAddress', None)
description = properties.get('Description', None)
network_interface = ec2_backend.create_network_interface( network_interface = ec2_backend.create_network_interface(
subnet, subnet,
private_ip_address, private_ip_address,
group_ids=security_group_ids group_ids=security_group_ids,
description=description
) )
return network_interface return network_interface
@ -296,6 +304,8 @@ class NetworkInterface(TaggedEC2Resource):
return [group.id for group in self._group_set] return [group.id for group in self._group_set]
elif filter_name == 'availability-zone': elif filter_name == 'availability-zone':
return self.subnet.availability_zone return self.subnet.availability_zone
elif filter_name == 'description':
return self.description
else: else:
return super(NetworkInterface, self).get_filter_value( return super(NetworkInterface, self).get_filter_value(
filter_name, 'DescribeNetworkInterfaces') filter_name, 'DescribeNetworkInterfaces')
@ -306,9 +316,9 @@ class NetworkInterfaceBackend(object):
self.enis = {} self.enis = {}
super(NetworkInterfaceBackend, self).__init__() super(NetworkInterfaceBackend, self).__init__()
def create_network_interface(self, subnet, private_ip_address, group_ids=None, **kwargs): def create_network_interface(self, subnet, private_ip_address, group_ids=None, description=None, **kwargs):
eni = NetworkInterface( eni = NetworkInterface(
self, subnet, private_ip_address, group_ids=group_ids, **kwargs) self, subnet, private_ip_address, group_ids=group_ids, description=description, **kwargs)
self.enis[eni.id] = eni self.enis[eni.id] = eni
return eni return eni
@ -341,6 +351,12 @@ class NetworkInterfaceBackend(object):
if group.id in _filter_value: if group.id in _filter_value:
enis.append(eni) enis.append(eni)
break break
elif _filter == 'private-ip-address:':
enis = [eni for eni in enis if eni.private_ip_address in _filter_value]
elif _filter == 'subnet-id':
enis = [eni for eni in enis if eni.subnet.id in _filter_value]
elif _filter == 'description':
enis = [eni for eni in enis if eni.description in _filter_value]
else: else:
self.raise_not_implemented_error( self.raise_not_implemented_error(
"The filter '{0}' for DescribeNetworkInterfaces".format(_filter)) "The filter '{0}' for DescribeNetworkInterfaces".format(_filter))
@ -382,6 +398,10 @@ class NetworkInterfaceBackend(object):
class Instance(TaggedEC2Resource, BotoInstance): class Instance(TaggedEC2Resource, BotoInstance):
VALID_ATTRIBUTES = {'instanceType', 'kernel', 'ramdisk', 'userData', 'disableApiTermination',
'instanceInitiatedShutdownBehavior', 'rootDeviceName', 'blockDeviceMapping',
'productCodes', 'sourceDestCheck', 'groupSet', 'ebsOptimized', 'sriovNetSupport'}
def __init__(self, ec2_backend, image_id, user_data, security_groups, **kwargs): def __init__(self, ec2_backend, image_id, user_data, security_groups, **kwargs):
super(Instance, self).__init__() super(Instance, self).__init__()
self.ec2_backend = ec2_backend self.ec2_backend = ec2_backend
@ -404,11 +424,13 @@ class Instance(TaggedEC2Resource, BotoInstance):
self.launch_time = utc_date_and_time() self.launch_time = utc_date_and_time()
self.ami_launch_index = kwargs.get("ami_launch_index", 0) self.ami_launch_index = kwargs.get("ami_launch_index", 0)
self.disable_api_termination = kwargs.get("disable_api_termination", False) self.disable_api_termination = kwargs.get("disable_api_termination", False)
self.instance_initiated_shutdown_behavior = kwargs.get("instance_initiated_shutdown_behavior", "stop")
self.sriov_net_support = "simple"
self._spot_fleet_id = kwargs.get("spot_fleet_id", None) self._spot_fleet_id = kwargs.get("spot_fleet_id", None)
associate_public_ip = kwargs.get("associate_public_ip", False) self.associate_public_ip = kwargs.get("associate_public_ip", False)
if in_ec2_classic: if in_ec2_classic:
# If we are in EC2-Classic, autoassign a public IP # If we are in EC2-Classic, autoassign a public IP
associate_public_ip = True self.associate_public_ip = True
amis = self.ec2_backend.describe_images(filters={'image-id': image_id}) amis = self.ec2_backend.describe_images(filters={'image-id': image_id})
ami = amis[0] if amis else None ami = amis[0] if amis else None
@ -439,9 +461,9 @@ class Instance(TaggedEC2Resource, BotoInstance):
self.vpc_id = subnet.vpc_id self.vpc_id = subnet.vpc_id
self._placement.zone = subnet.availability_zone self._placement.zone = subnet.availability_zone
if associate_public_ip is None: if self.associate_public_ip is None:
# Mapping public ip hasnt been explicitly enabled or disabled # Mapping public ip hasnt been explicitly enabled or disabled
associate_public_ip = subnet.map_public_ip_on_launch == 'true' self.associate_public_ip = subnet.map_public_ip_on_launch == 'true'
elif placement: elif placement:
self._placement.zone = placement self._placement.zone = placement
else: else:
@ -453,7 +475,7 @@ class Instance(TaggedEC2Resource, BotoInstance):
self.prep_nics( self.prep_nics(
kwargs.get("nics", {}), kwargs.get("nics", {}),
private_ip=kwargs.get("private_ip"), private_ip=kwargs.get("private_ip"),
associate_public_ip=associate_public_ip associate_public_ip=self.associate_public_ip
) )
def __del__(self): def __del__(self):
@ -787,14 +809,22 @@ class InstanceBackend(object):
setattr(instance, key, value) setattr(instance, key, value)
return instance return instance
def modify_instance_security_groups(self, instance_id, new_group_list): def modify_instance_security_groups(self, instance_id, new_group_id_list):
instance = self.get_instance(instance_id) instance = self.get_instance(instance_id)
new_group_list = []
for new_group_id in new_group_id_list:
new_group_list.append(self.get_security_group_from_id(new_group_id))
setattr(instance, 'security_groups', new_group_list) setattr(instance, 'security_groups', new_group_list)
return instance return instance
def describe_instance_attribute(self, instance_id, key): def describe_instance_attribute(self, instance_id, attribute):
if key == 'group_set': if attribute not in Instance.VALID_ATTRIBUTES:
raise InvalidParameterValueErrorUnknownAttribute(attribute)
if attribute == 'groupSet':
key = 'security_groups' key = 'security_groups'
else:
key = camelcase_to_underscores(attribute)
instance = self.get_instance(instance_id) instance = self.get_instance(instance_id)
value = getattr(instance, key) value = getattr(instance, key)
return instance, value return instance, value
@ -1060,7 +1090,7 @@ class TagBackend(object):
class Ami(TaggedEC2Resource): class Ami(TaggedEC2Resource):
def __init__(self, ec2_backend, ami_id, instance=None, source_ami=None, def __init__(self, ec2_backend, ami_id, instance=None, source_ami=None,
name=None, description=None, owner_id=111122223333, name=None, description=None, owner_id=OWNER_ID,
public=False, virtualization_type=None, architecture=None, public=False, virtualization_type=None, architecture=None,
state='available', creation_date=None, platform=None, state='available', creation_date=None, platform=None,
image_type='machine', image_location=None, hypervisor=None, image_type='machine', image_location=None, hypervisor=None,
@ -1173,7 +1203,7 @@ class AmiBackend(object):
ami = Ami(self, ami_id, instance=instance, source_ami=None, ami = Ami(self, ami_id, instance=instance, source_ami=None,
name=name, description=description, name=name, description=description,
owner_id=context.get_current_user() if context else '111122223333') owner_id=context.get_current_user() if context else OWNER_ID)
self.amis[ami_id] = ami self.amis[ami_id] = ami
return ami return ami
@ -1288,17 +1318,107 @@ class Region(object):
class Zone(object): class Zone(object):
def __init__(self, name, region_name): def __init__(self, name, region_name, zone_id):
self.name = name self.name = name
self.region_name = region_name self.region_name = region_name
self.zone_id = zone_id
class RegionsAndZonesBackend(object): class RegionsAndZonesBackend(object):
regions = [Region(ri.name, ri.endpoint) for ri in boto.ec2.regions()] regions = [Region(ri.name, ri.endpoint) for ri in boto.ec2.regions()]
zones = dict( zones = {
(region, [Zone(region + c, region) for c in 'abc']) 'ap-south-1': [
for region in [r.name for r in regions]) Zone(region_name="ap-south-1", name="ap-south-1a", zone_id="aps1-az1"),
Zone(region_name="ap-south-1", name="ap-south-1b", zone_id="aps1-az3")
],
'eu-west-3': [
Zone(region_name="eu-west-3", name="eu-west-3a", zone_id="euw3-az1"),
Zone(region_name="eu-west-3", name="eu-west-3b", zone_id="euw3-az2"),
Zone(region_name="eu-west-3", name="eu-west-3c", zone_id="euw3-az3")
],
'eu-north-1': [
Zone(region_name="eu-north-1", name="eu-north-1a", zone_id="eun1-az1"),
Zone(region_name="eu-north-1", name="eu-north-1b", zone_id="eun1-az2"),
Zone(region_name="eu-north-1", name="eu-north-1c", zone_id="eun1-az3")
],
'eu-west-2': [
Zone(region_name="eu-west-2", name="eu-west-2a", zone_id="euw2-az2"),
Zone(region_name="eu-west-2", name="eu-west-2b", zone_id="euw2-az3"),
Zone(region_name="eu-west-2", name="eu-west-2c", zone_id="euw2-az1")
],
'eu-west-1': [
Zone(region_name="eu-west-1", name="eu-west-1a", zone_id="euw1-az3"),
Zone(region_name="eu-west-1", name="eu-west-1b", zone_id="euw1-az1"),
Zone(region_name="eu-west-1", name="eu-west-1c", zone_id="euw1-az2")
],
'ap-northeast-3': [
Zone(region_name="ap-northeast-3", name="ap-northeast-2a", zone_id="apne3-az1")
],
'ap-northeast-2': [
Zone(region_name="ap-northeast-2", name="ap-northeast-2a", zone_id="apne2-az1"),
Zone(region_name="ap-northeast-2", name="ap-northeast-2c", zone_id="apne2-az3")
],
'ap-northeast-1': [
Zone(region_name="ap-northeast-1", name="ap-northeast-1a", zone_id="apne1-az4"),
Zone(region_name="ap-northeast-1", name="ap-northeast-1c", zone_id="apne1-az1"),
Zone(region_name="ap-northeast-1", name="ap-northeast-1d", zone_id="apne1-az2")
],
'sa-east-1': [
Zone(region_name="sa-east-1", name="sa-east-1a", zone_id="sae1-az1"),
Zone(region_name="sa-east-1", name="sa-east-1c", zone_id="sae1-az3")
],
'ca-central-1': [
Zone(region_name="ca-central-1", name="ca-central-1a", zone_id="cac1-az1"),
Zone(region_name="ca-central-1", name="ca-central-1b", zone_id="cac1-az2")
],
'ap-southeast-1': [
Zone(region_name="ap-southeast-1", name="ap-southeast-1a", zone_id="apse1-az1"),
Zone(region_name="ap-southeast-1", name="ap-southeast-1b", zone_id="apse1-az2"),
Zone(region_name="ap-southeast-1", name="ap-southeast-1c", zone_id="apse1-az3")
],
'ap-southeast-2': [
Zone(region_name="ap-southeast-2", name="ap-southeast-2a", zone_id="apse2-az1"),
Zone(region_name="ap-southeast-2", name="ap-southeast-2b", zone_id="apse2-az3"),
Zone(region_name="ap-southeast-2", name="ap-southeast-2c", zone_id="apse2-az2")
],
'eu-central-1': [
Zone(region_name="eu-central-1", name="eu-central-1a", zone_id="euc1-az2"),
Zone(region_name="eu-central-1", name="eu-central-1b", zone_id="euc1-az3"),
Zone(region_name="eu-central-1", name="eu-central-1c", zone_id="euc1-az1")
],
'us-east-1': [
Zone(region_name="us-east-1", name="us-east-1a", zone_id="use1-az6"),
Zone(region_name="us-east-1", name="us-east-1b", zone_id="use1-az1"),
Zone(region_name="us-east-1", name="us-east-1c", zone_id="use1-az2"),
Zone(region_name="us-east-1", name="us-east-1d", zone_id="use1-az4"),
Zone(region_name="us-east-1", name="us-east-1e", zone_id="use1-az3"),
Zone(region_name="us-east-1", name="us-east-1f", zone_id="use1-az5")
],
'us-east-2': [
Zone(region_name="us-east-2", name="us-east-2a", zone_id="use2-az1"),
Zone(region_name="us-east-2", name="us-east-2b", zone_id="use2-az2"),
Zone(region_name="us-east-2", name="us-east-2c", zone_id="use2-az3")
],
'us-west-1': [
Zone(region_name="us-west-1", name="us-west-1a", zone_id="usw1-az3"),
Zone(region_name="us-west-1", name="us-west-1b", zone_id="usw1-az1")
],
'us-west-2': [
Zone(region_name="us-west-2", name="us-west-2a", zone_id="usw2-az2"),
Zone(region_name="us-west-2", name="us-west-2b", zone_id="usw2-az1"),
Zone(region_name="us-west-2", name="us-west-2c", zone_id="usw2-az3")
],
'cn-north-1': [
Zone(region_name="cn-north-1", name="cn-north-1a", zone_id="cnn1-az1"),
Zone(region_name="cn-north-1", name="cn-north-1b", zone_id="cnn1-az2")
],
'us-gov-west-1': [
Zone(region_name="us-gov-west-1", name="us-gov-west-1a", zone_id="usgw1-az1"),
Zone(region_name="us-gov-west-1", name="us-gov-west-1b", zone_id="usgw1-az2"),
Zone(region_name="us-gov-west-1", name="us-gov-west-1c", zone_id="usgw1-az3")
]
}
def describe_regions(self, region_names=[]): def describe_regions(self, region_names=[]):
if len(region_names) == 0: if len(region_names) == 0:
@ -1351,7 +1471,7 @@ class SecurityGroup(TaggedEC2Resource):
self.egress_rules = [SecurityRule(-1, None, None, ['0.0.0.0/0'], [])] self.egress_rules = [SecurityRule(-1, None, None, ['0.0.0.0/0'], [])]
self.enis = {} self.enis = {}
self.vpc_id = vpc_id self.vpc_id = vpc_id
self.owner_id = "123456789012" self.owner_id = OWNER_ID
@classmethod @classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
@ -1872,7 +1992,7 @@ class Volume(TaggedEC2Resource):
class Snapshot(TaggedEC2Resource): class Snapshot(TaggedEC2Resource):
def __init__(self, ec2_backend, snapshot_id, volume, description, encrypted=False, owner_id='123456789012'): def __init__(self, ec2_backend, snapshot_id, volume, description, encrypted=False, owner_id=OWNER_ID):
self.id = snapshot_id self.id = snapshot_id
self.volume = volume self.volume = volume
self.description = description self.description = description
@ -2374,7 +2494,7 @@ class VPCPeeringConnectionBackend(object):
class Subnet(TaggedEC2Resource): class Subnet(TaggedEC2Resource):
def __init__(self, ec2_backend, subnet_id, vpc_id, cidr_block, availability_zone, default_for_az, def __init__(self, ec2_backend, subnet_id, vpc_id, cidr_block, availability_zone, default_for_az,
map_public_ip_on_launch): map_public_ip_on_launch, owner_id=OWNER_ID, assign_ipv6_address_on_creation=False):
self.ec2_backend = ec2_backend self.ec2_backend = ec2_backend
self.id = subnet_id self.id = subnet_id
self.vpc_id = vpc_id self.vpc_id = vpc_id
@ -2383,6 +2503,9 @@ class Subnet(TaggedEC2Resource):
self._availability_zone = availability_zone self._availability_zone = availability_zone
self.default_for_az = default_for_az self.default_for_az = default_for_az
self.map_public_ip_on_launch = map_public_ip_on_launch self.map_public_ip_on_launch = map_public_ip_on_launch
self.owner_id = owner_id
self.assign_ipv6_address_on_creation = assign_ipv6_address_on_creation
self.ipv6_cidr_block_associations = []
# Theory is we assign ip's as we go (as 16,777,214 usable IPs in a /8) # Theory is we assign ip's as we go (as 16,777,214 usable IPs in a /8)
self._subnet_ip_generator = self.cidr.hosts() self._subnet_ip_generator = self.cidr.hosts()
@ -2412,7 +2535,7 @@ class Subnet(TaggedEC2Resource):
@property @property
def availability_zone(self): def availability_zone(self):
return self._availability_zone return self._availability_zone.name
@property @property
def physical_resource_id(self): def physical_resource_id(self):
@ -2509,7 +2632,7 @@ class SubnetBackend(object):
return subnets[subnet_id] return subnets[subnet_id]
raise InvalidSubnetIdError(subnet_id) raise InvalidSubnetIdError(subnet_id)
def create_subnet(self, vpc_id, cidr_block, availability_zone): def create_subnet(self, vpc_id, cidr_block, availability_zone, context=None):
subnet_id = random_subnet_id() subnet_id = random_subnet_id()
vpc = self.get_vpc(vpc_id) # Validate VPC exists and the supplied CIDR block is a subnet of the VPC's vpc = self.get_vpc(vpc_id) # Validate VPC exists and the supplied CIDR block is a subnet of the VPC's
vpc_cidr_block = ipaddress.IPv4Network(six.text_type(vpc.cidr_block), strict=False) vpc_cidr_block = ipaddress.IPv4Network(six.text_type(vpc.cidr_block), strict=False)
@ -2529,8 +2652,15 @@ class SubnetBackend(object):
# consider it the default # consider it the default
default_for_az = str(availability_zone not in self.subnets).lower() default_for_az = str(availability_zone not in self.subnets).lower()
map_public_ip_on_launch = default_for_az map_public_ip_on_launch = default_for_az
subnet = Subnet(self, subnet_id, vpc_id, cidr_block, availability_zone, if availability_zone is None:
default_for_az, map_public_ip_on_launch) availability_zone = 'us-east-1a'
try:
availability_zone_data = next(zone for zones in RegionsAndZonesBackend.zones.values() for zone in zones if zone.name == availability_zone)
except StopIteration:
raise InvalidAvailabilityZoneError(availability_zone, ", ".join([zone.name for zones in RegionsAndZonesBackend.zones.values() for zone in zones]))
subnet = Subnet(self, subnet_id, vpc_id, cidr_block, availability_zone_data,
default_for_az, map_public_ip_on_launch,
owner_id=context.get_current_user() if context else OWNER_ID, assign_ipv6_address_on_creation=False)
# AWS associates a new subnet with the default Network ACL # AWS associates a new subnet with the default Network ACL
self.associate_default_network_acl_with_subnet(subnet_id, vpc_id) self.associate_default_network_acl_with_subnet(subnet_id, vpc_id)
@ -2558,11 +2688,12 @@ class SubnetBackend(object):
return subnets.pop(subnet_id, None) return subnets.pop(subnet_id, None)
raise InvalidSubnetIdError(subnet_id) raise InvalidSubnetIdError(subnet_id)
def modify_subnet_attribute(self, subnet_id, map_public_ip): def modify_subnet_attribute(self, subnet_id, attr_name, attr_value):
subnet = self.get_subnet(subnet_id) subnet = self.get_subnet(subnet_id)
if map_public_ip not in ('true', 'false'): if attr_name in ('map_public_ip_on_launch', 'assign_ipv6_address_on_creation'):
raise InvalidParameterValueError(map_public_ip) setattr(subnet, attr_name, attr_value)
subnet.map_public_ip_on_launch = map_public_ip else:
raise InvalidParameterValueError(attr_name)
class SubnetRouteTableAssociation(object): class SubnetRouteTableAssociation(object):
@ -3983,6 +4114,92 @@ class NatGatewayBackend(object):
return self.nat_gateways.pop(nat_gateway_id) return self.nat_gateways.pop(nat_gateway_id)
class LaunchTemplateVersion(object):
def __init__(self, template, number, data, description):
self.template = template
self.number = number
self.data = data
self.description = description
self.create_time = utc_date_and_time()
class LaunchTemplate(TaggedEC2Resource):
def __init__(self, backend, name, template_data, version_description):
self.ec2_backend = backend
self.name = name
self.id = random_launch_template_id()
self.create_time = utc_date_and_time()
self.versions = []
self.create_version(template_data, version_description)
self.default_version_number = 1
def create_version(self, data, description):
num = len(self.versions) + 1
version = LaunchTemplateVersion(self, num, data, description)
self.versions.append(version)
return version
def is_default(self, version):
return self.default_version == version.number
def get_version(self, num):
return self.versions[num - 1]
def default_version(self):
return self.versions[self.default_version_number - 1]
def latest_version(self):
return self.versions[-1]
@property
def latest_version_number(self):
return self.latest_version().number
def get_filter_value(self, filter_name):
if filter_name == 'launch-template-name':
return self.name
else:
return super(LaunchTemplate, self).get_filter_value(
filter_name, "DescribeLaunchTemplates")
class LaunchTemplateBackend(object):
def __init__(self):
self.launch_template_name_to_ids = {}
self.launch_templates = OrderedDict()
self.launch_template_insert_order = []
super(LaunchTemplateBackend, self).__init__()
def create_launch_template(self, name, description, template_data):
if name in self.launch_template_name_to_ids:
raise InvalidLaunchTemplateNameError()
template = LaunchTemplate(self, name, template_data, description)
self.launch_templates[template.id] = template
self.launch_template_name_to_ids[template.name] = template.id
self.launch_template_insert_order.append(template.id)
return template
def get_launch_template(self, template_id):
return self.launch_templates[template_id]
def get_launch_template_by_name(self, name):
return self.get_launch_template(self.launch_template_name_to_ids[name])
def get_launch_templates(self, template_names=None, template_ids=None, filters=None):
if template_names and not template_ids:
template_ids = []
for name in template_names:
template_ids.append(self.launch_template_name_to_ids[name])
if template_ids:
templates = [self.launch_templates[tid] for tid in template_ids]
else:
templates = list(self.launch_templates.values())
return generic_filter(filters, templates)
class EC2Backend(BaseBackend, InstanceBackend, TagBackend, EBSBackend, class EC2Backend(BaseBackend, InstanceBackend, TagBackend, EBSBackend,
RegionsAndZonesBackend, SecurityGroupBackend, AmiBackend, RegionsAndZonesBackend, SecurityGroupBackend, AmiBackend,
VPCBackend, SubnetBackend, SubnetRouteTableAssociationBackend, VPCBackend, SubnetBackend, SubnetRouteTableAssociationBackend,
@ -3992,7 +4209,7 @@ class EC2Backend(BaseBackend, InstanceBackend, TagBackend, EBSBackend,
VPCGatewayAttachmentBackend, SpotFleetBackend, VPCGatewayAttachmentBackend, SpotFleetBackend,
SpotRequestBackend, ElasticAddressBackend, KeyPairBackend, SpotRequestBackend, ElasticAddressBackend, KeyPairBackend,
DHCPOptionsSetBackend, NetworkAclBackend, VpnGatewayBackend, DHCPOptionsSetBackend, NetworkAclBackend, VpnGatewayBackend,
CustomerGatewayBackend, NatGatewayBackend): CustomerGatewayBackend, NatGatewayBackend, LaunchTemplateBackend):
def __init__(self, region_name): def __init__(self, region_name):
self.region_name = region_name self.region_name = region_name
super(EC2Backend, self).__init__() super(EC2Backend, self).__init__()
@ -4047,6 +4264,8 @@ class EC2Backend(BaseBackend, InstanceBackend, TagBackend, EBSBackend,
elif resource_prefix == EC2_RESOURCE_TO_PREFIX['internet-gateway']: elif resource_prefix == EC2_RESOURCE_TO_PREFIX['internet-gateway']:
self.describe_internet_gateways( self.describe_internet_gateways(
internet_gateway_ids=[resource_id]) internet_gateway_ids=[resource_id])
elif resource_prefix == EC2_RESOURCE_TO_PREFIX['launch-template']:
self.get_launch_template(resource_id)
elif resource_prefix == EC2_RESOURCE_TO_PREFIX['network-acl']: elif resource_prefix == EC2_RESOURCE_TO_PREFIX['network-acl']:
self.get_all_network_acls() self.get_all_network_acls()
elif resource_prefix == EC2_RESOURCE_TO_PREFIX['network-interface']: elif resource_prefix == EC2_RESOURCE_TO_PREFIX['network-interface']:

View File

@ -14,6 +14,7 @@ from .instances import InstanceResponse
from .internet_gateways import InternetGateways from .internet_gateways import InternetGateways
from .ip_addresses import IPAddresses from .ip_addresses import IPAddresses
from .key_pairs import KeyPairs from .key_pairs import KeyPairs
from .launch_templates import LaunchTemplates
from .monitoring import Monitoring from .monitoring import Monitoring
from .network_acls import NetworkACLs from .network_acls import NetworkACLs
from .placement_groups import PlacementGroups from .placement_groups import PlacementGroups
@ -49,6 +50,7 @@ class EC2Response(
InternetGateways, InternetGateways,
IPAddresses, IPAddresses,
KeyPairs, KeyPairs,
LaunchTemplates,
Monitoring, Monitoring,
NetworkACLs, NetworkACLs,
PlacementGroups, PlacementGroups,

View File

@ -10,9 +10,10 @@ class ElasticNetworkInterfaces(BaseResponse):
private_ip_address = self._get_param('PrivateIpAddress') private_ip_address = self._get_param('PrivateIpAddress')
groups = self._get_multi_param('SecurityGroupId') groups = self._get_multi_param('SecurityGroupId')
subnet = self.ec2_backend.get_subnet(subnet_id) subnet = self.ec2_backend.get_subnet(subnet_id)
description = self._get_param('Description')
if self.is_not_dryrun('CreateNetworkInterface'): if self.is_not_dryrun('CreateNetworkInterface'):
eni = self.ec2_backend.create_network_interface( eni = self.ec2_backend.create_network_interface(
subnet, private_ip_address, groups) subnet, private_ip_address, groups, description)
template = self.response_template( template = self.response_template(
CREATE_NETWORK_INTERFACE_RESPONSE) CREATE_NETWORK_INTERFACE_RESPONSE)
return template.render(eni=eni) return template.render(eni=eni)
@ -78,7 +79,11 @@ CREATE_NETWORK_INTERFACE_RESPONSE = """
<subnetId>{{ eni.subnet.id }}</subnetId> <subnetId>{{ eni.subnet.id }}</subnetId>
<vpcId>{{ eni.subnet.vpc_id }}</vpcId> <vpcId>{{ eni.subnet.vpc_id }}</vpcId>
<availabilityZone>us-west-2a</availabilityZone> <availabilityZone>us-west-2a</availabilityZone>
{% if eni.description %}
<description>{{ eni.description }}</description>
{% else %}
<description/> <description/>
{% endif %}
<ownerId>498654062920</ownerId> <ownerId>498654062920</ownerId>
<requesterManaged>false</requesterManaged> <requesterManaged>false</requesterManaged>
<status>pending</status> <status>pending</status>
@ -121,7 +126,7 @@ DESCRIBE_NETWORK_INTERFACES_RESPONSE = """<DescribeNetworkInterfacesResponse xml
<subnetId>{{ eni.subnet.id }}</subnetId> <subnetId>{{ eni.subnet.id }}</subnetId>
<vpcId>{{ eni.subnet.vpc_id }}</vpcId> <vpcId>{{ eni.subnet.vpc_id }}</vpcId>
<availabilityZone>us-west-2a</availabilityZone> <availabilityZone>us-west-2a</availabilityZone>
<description>Primary network interface</description> <description>{{ eni.description }}</description>
<ownerId>190610284047</ownerId> <ownerId>190610284047</ownerId>
<requesterManaged>false</requesterManaged> <requesterManaged>false</requesterManaged>
{% if eni.attachment_id %} {% if eni.attachment_id %}

View File

@ -1,5 +1,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from boto.ec2.instancetype import InstanceType from boto.ec2.instancetype import InstanceType
from moto.autoscaling import autoscaling_backends
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.core.utils import camelcase_to_underscores from moto.core.utils import camelcase_to_underscores
from moto.ec2.utils import filters_from_querystring, \ from moto.ec2.utils import filters_from_querystring, \
@ -46,6 +48,7 @@ class InstanceResponse(BaseResponse):
associate_public_ip = self._get_param('AssociatePublicIpAddress') associate_public_ip = self._get_param('AssociatePublicIpAddress')
key_name = self._get_param('KeyName') key_name = self._get_param('KeyName')
ebs_optimized = self._get_param('EbsOptimized') ebs_optimized = self._get_param('EbsOptimized')
instance_initiated_shutdown_behavior = self._get_param("InstanceInitiatedShutdownBehavior")
tags = self._parse_tag_specification("TagSpecification") tags = self._parse_tag_specification("TagSpecification")
region_name = self.region region_name = self.region
@ -55,7 +58,7 @@ class InstanceResponse(BaseResponse):
instance_type=instance_type, placement=placement, region_name=region_name, subnet_id=subnet_id, instance_type=instance_type, placement=placement, region_name=region_name, subnet_id=subnet_id,
owner_id=owner_id, key_name=key_name, security_group_ids=security_group_ids, owner_id=owner_id, key_name=key_name, security_group_ids=security_group_ids,
nics=nics, private_ip=private_ip, associate_public_ip=associate_public_ip, nics=nics, private_ip=private_ip, associate_public_ip=associate_public_ip,
tags=tags, ebs_optimized=ebs_optimized) tags=tags, ebs_optimized=ebs_optimized, instance_initiated_shutdown_behavior=instance_initiated_shutdown_behavior)
template = self.response_template(EC2_RUN_INSTANCES) template = self.response_template(EC2_RUN_INSTANCES)
return template.render(reservation=new_reservation) return template.render(reservation=new_reservation)
@ -64,6 +67,7 @@ class InstanceResponse(BaseResponse):
instance_ids = self._get_multi_param('InstanceId') instance_ids = self._get_multi_param('InstanceId')
if self.is_not_dryrun('TerminateInstance'): if self.is_not_dryrun('TerminateInstance'):
instances = self.ec2_backend.terminate_instances(instance_ids) instances = self.ec2_backend.terminate_instances(instance_ids)
autoscaling_backends[self.region].notify_terminate_instances(instance_ids)
template = self.response_template(EC2_TERMINATE_INSTANCES) template = self.response_template(EC2_TERMINATE_INSTANCES)
return template.render(instances=instances) return template.render(instances=instances)
@ -113,12 +117,11 @@ class InstanceResponse(BaseResponse):
# TODO this and modify below should raise IncorrectInstanceState if # TODO this and modify below should raise IncorrectInstanceState if
# instance not in stopped state # instance not in stopped state
attribute = self._get_param('Attribute') attribute = self._get_param('Attribute')
key = camelcase_to_underscores(attribute)
instance_id = self._get_param('InstanceId') instance_id = self._get_param('InstanceId')
instance, value = self.ec2_backend.describe_instance_attribute( instance, value = self.ec2_backend.describe_instance_attribute(
instance_id, key) instance_id, attribute)
if key == "group_set": if attribute == "groupSet":
template = self.response_template( template = self.response_template(
EC2_DESCRIBE_INSTANCE_GROUPSET_ATTRIBUTE) EC2_DESCRIBE_INSTANCE_GROUPSET_ATTRIBUTE)
else: else:
@ -597,7 +600,9 @@ EC2_DESCRIBE_INSTANCE_ATTRIBUTE = """<DescribeInstanceAttributeResponse xmlns="h
<requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId> <requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId>
<instanceId>{{ instance.id }}</instanceId> <instanceId>{{ instance.id }}</instanceId>
<{{ attribute }}> <{{ attribute }}>
{% if value is not none %}
<value>{{ value }}</value> <value>{{ value }}</value>
{% endif %}
</{{ attribute }}> </{{ attribute }}>
</DescribeInstanceAttributeResponse>""" </DescribeInstanceAttributeResponse>"""
@ -605,9 +610,9 @@ EC2_DESCRIBE_INSTANCE_GROUPSET_ATTRIBUTE = """<DescribeInstanceAttributeResponse
<requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId> <requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId>
<instanceId>{{ instance.id }}</instanceId> <instanceId>{{ instance.id }}</instanceId>
<{{ attribute }}> <{{ attribute }}>
{% for sg_id in value %} {% for sg in value %}
<item> <item>
<groupId>{{ sg_id }}</groupId> <groupId>{{ sg.id }}</groupId>
</item> </item>
{% endfor %} {% endfor %}
</{{ attribute }}> </{{ attribute }}>

View File

@ -1,6 +1,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import random import random
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.core.utils import camelcase_to_underscores
from moto.ec2.utils import filters_from_querystring from moto.ec2.utils import filters_from_querystring
@ -16,6 +17,7 @@ class Subnets(BaseResponse):
vpc_id, vpc_id,
cidr_block, cidr_block,
availability_zone, availability_zone,
context=self,
) )
template = self.response_template(CREATE_SUBNET_RESPONSE) template = self.response_template(CREATE_SUBNET_RESPONSE)
return template.render(subnet=subnet) return template.render(subnet=subnet)
@ -35,9 +37,14 @@ class Subnets(BaseResponse):
def modify_subnet_attribute(self): def modify_subnet_attribute(self):
subnet_id = self._get_param('SubnetId') subnet_id = self._get_param('SubnetId')
map_public_ip = self._get_param('MapPublicIpOnLaunch.Value')
self.ec2_backend.modify_subnet_attribute(subnet_id, map_public_ip) for attribute in ('MapPublicIpOnLaunch', 'AssignIpv6AddressOnCreation'):
return MODIFY_SUBNET_ATTRIBUTE_RESPONSE if self.querystring.get('%s.Value' % attribute):
attr_name = camelcase_to_underscores(attribute)
attr_value = self.querystring.get('%s.Value' % attribute)[0]
self.ec2_backend.modify_subnet_attribute(
subnet_id, attr_name, attr_value)
return MODIFY_SUBNET_ATTRIBUTE_RESPONSE
CREATE_SUBNET_RESPONSE = """ CREATE_SUBNET_RESPONSE = """
@ -49,17 +56,14 @@ CREATE_SUBNET_RESPONSE = """
<vpcId>{{ subnet.vpc_id }}</vpcId> <vpcId>{{ subnet.vpc_id }}</vpcId>
<cidrBlock>{{ subnet.cidr_block }}</cidrBlock> <cidrBlock>{{ subnet.cidr_block }}</cidrBlock>
<availableIpAddressCount>251</availableIpAddressCount> <availableIpAddressCount>251</availableIpAddressCount>
<availabilityZone>{{ subnet.availability_zone }}</availabilityZone> <availabilityZone>{{ subnet._availability_zone.name }}</availabilityZone>
<tagSet> <availabilityZoneId>{{ subnet._availability_zone.zone_id }}</availabilityZoneId>
{% for tag in subnet.get_tags() %} <defaultForAz>{{ subnet.default_for_az }}</defaultForAz>
<item> <mapPublicIpOnLaunch>{{ subnet.map_public_ip_on_launch }}</mapPublicIpOnLaunch>
<resourceId>{{ tag.resource_id }}</resourceId> <ownerId>{{ subnet.owner_id }}</ownerId>
<resourceType>{{ tag.resource_type }}</resourceType> <assignIpv6AddressOnCreation>{{ subnet.assign_ipv6_address_on_creation }}</assignIpv6AddressOnCreation>
<key>{{ tag.key }}</key> <ipv6CidrBlockAssociationSet>{{ subnet.ipv6_cidr_block_associations }}</ipv6CidrBlockAssociationSet>
<value>{{ tag.value }}</value> <subnetArn>arn:aws:ec2:{{ subnet._availability_zone.name[0:-1] }}:{{ subnet.owner_id }}:subnet/{{ subnet.id }}</subnetArn>
</item>
{% endfor %}
</tagSet>
</subnet> </subnet>
</CreateSubnetResponse>""" </CreateSubnetResponse>"""
@ -80,19 +84,26 @@ DESCRIBE_SUBNETS_RESPONSE = """
<vpcId>{{ subnet.vpc_id }}</vpcId> <vpcId>{{ subnet.vpc_id }}</vpcId>
<cidrBlock>{{ subnet.cidr_block }}</cidrBlock> <cidrBlock>{{ subnet.cidr_block }}</cidrBlock>
<availableIpAddressCount>251</availableIpAddressCount> <availableIpAddressCount>251</availableIpAddressCount>
<availabilityZone>{{ subnet.availability_zone }}</availabilityZone> <availabilityZone>{{ subnet._availability_zone.name }}</availabilityZone>
<availabilityZoneId>{{ subnet._availability_zone.zone_id }}</availabilityZoneId>
<defaultForAz>{{ subnet.default_for_az }}</defaultForAz> <defaultForAz>{{ subnet.default_for_az }}</defaultForAz>
<mapPublicIpOnLaunch>{{ subnet.map_public_ip_on_launch }}</mapPublicIpOnLaunch> <mapPublicIpOnLaunch>{{ subnet.map_public_ip_on_launch }}</mapPublicIpOnLaunch>
<tagSet> <ownerId>{{ subnet.owner_id }}</ownerId>
{% for tag in subnet.get_tags() %} <assignIpv6AddressOnCreation>{{ subnet.assign_ipv6_address_on_creation }}</assignIpv6AddressOnCreation>
<item> <ipv6CidrBlockAssociationSet>{{ subnet.ipv6_cidr_block_associations }}</ipv6CidrBlockAssociationSet>
<resourceId>{{ tag.resource_id }}</resourceId> <subnetArn>arn:aws:ec2:{{ subnet._availability_zone.name[0:-1] }}:{{ subnet.owner_id }}:subnet/{{ subnet.id }}</subnetArn>
<resourceType>{{ tag.resource_type }}</resourceType> {% if subnet.get_tags() %}
<key>{{ tag.key }}</key> <tagSet>
<value>{{ tag.value }}</value> {% for tag in subnet.get_tags() %}
</item> <item>
{% endfor %} <resourceId>{{ tag.resource_id }}</resourceId>
</tagSet> <resourceType>{{ tag.resource_type }}</resourceType>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
{% endif %}
</item> </item>
{% endfor %} {% endfor %}
</subnetSet> </subnetSet>

View File

@ -20,6 +20,7 @@ EC2_RESOURCE_TO_PREFIX = {
'image': 'ami', 'image': 'ami',
'instance': 'i', 'instance': 'i',
'internet-gateway': 'igw', 'internet-gateway': 'igw',
'launch-template': 'lt',
'nat-gateway': 'nat', 'nat-gateway': 'nat',
'network-acl': 'acl', 'network-acl': 'acl',
'network-acl-subnet-assoc': 'aclassoc', 'network-acl-subnet-assoc': 'aclassoc',
@ -161,6 +162,10 @@ def random_nat_gateway_id():
return random_id(prefix=EC2_RESOURCE_TO_PREFIX['nat-gateway'], size=17) return random_id(prefix=EC2_RESOURCE_TO_PREFIX['nat-gateway'], size=17)
def random_launch_template_id():
return random_id(prefix=EC2_RESOURCE_TO_PREFIX['launch-template'], size=17)
def random_public_ip(): def random_public_ip():
return '54.214.{0}.{1}'.format(random.choice(range(255)), return '54.214.{0}.{1}'.format(random.choice(range(255)),
random.choice(range(255))) random.choice(range(255)))

View File

@ -1,7 +1,9 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import hashlib import hashlib
import re
from copy import copy from copy import copy
from datetime import datetime
from random import random from random import random
from botocore.exceptions import ParamValidationError from botocore.exceptions import ParamValidationError
@ -105,7 +107,7 @@ class Image(BaseObject):
self.repository = repository self.repository = repository
self.registry_id = registry_id self.registry_id = registry_id
self.image_digest = digest self.image_digest = digest
self.image_pushed_at = None self.image_pushed_at = str(datetime.utcnow().isoformat())
def _create_digest(self): def _create_digest(self):
image_contents = 'docker_image{0}'.format(int(random() * 10 ** 6)) image_contents = 'docker_image{0}'.format(int(random() * 10 ** 6))
@ -119,6 +121,12 @@ class Image(BaseObject):
def get_image_manifest(self): def get_image_manifest(self):
return self.image_manifest return self.image_manifest
def remove_tag(self, tag):
if tag is not None and tag in self.image_tags:
self.image_tags.remove(tag)
if self.image_tags:
self.image_tag = self.image_tags[-1]
def update_tag(self, tag): def update_tag(self, tag):
self.image_tag = tag self.image_tag = tag
if tag not in self.image_tags and tag is not None: if tag not in self.image_tags and tag is not None:
@ -151,7 +159,7 @@ class Image(BaseObject):
response_object['repositoryName'] = self.repository response_object['repositoryName'] = self.repository
response_object['registryId'] = self.registry_id response_object['registryId'] = self.registry_id
response_object['imageSizeInBytes'] = self.image_size_in_bytes response_object['imageSizeInBytes'] = self.image_size_in_bytes
response_object['imagePushedAt'] = '2017-05-09' response_object['imagePushedAt'] = self.image_pushed_at
return {k: v for k, v in response_object.items() if v is not None and v != []} return {k: v for k, v in response_object.items() if v is not None and v != []}
@property @property
@ -165,6 +173,13 @@ class Image(BaseObject):
response_object['registryId'] = self.registry_id response_object['registryId'] = self.registry_id
return {k: v for k, v in response_object.items() if v is not None and v != [None]} return {k: v for k, v in response_object.items() if v is not None and v != [None]}
@property
def response_batch_delete_image(self):
response_object = {}
response_object['imageDigest'] = self.get_image_digest()
response_object['imageTag'] = self.image_tag
return {k: v for k, v in response_object.items() if v is not None and v != [None]}
class ECRBackend(BaseBackend): class ECRBackend(BaseBackend):
@ -310,6 +325,106 @@ class ECRBackend(BaseBackend):
return response return response
def batch_delete_image(self, repository_name, registry_id=None, image_ids=None):
if repository_name in self.repositories:
repository = self.repositories[repository_name]
else:
raise RepositoryNotFoundException(
repository_name, registry_id or DEFAULT_REGISTRY_ID
)
if not image_ids:
raise ParamValidationError(
msg='Missing required parameter in input: "imageIds"'
)
response = {
"imageIds": [],
"failures": []
}
for image_id in image_ids:
image_found = False
# Is request missing both digest and tag?
if "imageDigest" not in image_id and "imageTag" not in image_id:
response["failures"].append(
{
"imageId": {},
"failureCode": "MissingDigestAndTag",
"failureReason": "Invalid request parameters: both tag and digest cannot be null",
}
)
continue
# If we have a digest, is it valid?
if "imageDigest" in image_id:
pattern = re.compile("^[0-9a-zA-Z_+\.-]+:[0-9a-fA-F]{64}")
if not pattern.match(image_id.get("imageDigest")):
response["failures"].append(
{
"imageId": {
"imageDigest": image_id.get("imageDigest", "null")
},
"failureCode": "InvalidImageDigest",
"failureReason": "Invalid request parameters: image digest should satisfy the regex '[a-zA-Z0-9-_+.]+:[a-fA-F0-9]+'",
}
)
continue
for num, image in enumerate(repository.images):
# Search by matching both digest and tag
if "imageDigest" in image_id and "imageTag" in image_id:
if (
image_id["imageDigest"] == image.get_image_digest() and
image_id["imageTag"] in image.image_tags
):
image_found = True
for image_tag in reversed(image.image_tags):
repository.images[num].image_tag = image_tag
response["imageIds"].append(
image.response_batch_delete_image
)
repository.images[num].remove_tag(image_tag)
del repository.images[num]
# Search by matching digest
elif "imageDigest" in image_id and image.get_image_digest() == image_id["imageDigest"]:
image_found = True
for image_tag in reversed(image.image_tags):
repository.images[num].image_tag = image_tag
response["imageIds"].append(image.response_batch_delete_image)
repository.images[num].remove_tag(image_tag)
del repository.images[num]
# Search by matching tag
elif "imageTag" in image_id and image_id["imageTag"] in image.image_tags:
image_found = True
repository.images[num].image_tag = image_id["imageTag"]
response["imageIds"].append(image.response_batch_delete_image)
if len(image.image_tags) > 1:
repository.images[num].remove_tag(image_id["imageTag"])
else:
repository.images.remove(image)
if not image_found:
failure_response = {
"imageId": {},
"failureCode": "ImageNotFound",
"failureReason": "Requested image not found",
}
if "imageDigest" in image_id:
failure_response["imageId"]["imageDigest"] = image_id.get("imageDigest", "null")
if "imageTag" in image_id:
failure_response["imageId"]["imageTag"] = image_id.get("imageTag", "null")
response["failures"].append(failure_response)
return response
ecr_backends = {} ecr_backends = {}
for region, ec2_backend in ec2_backends.items(): for region, ec2_backend in ec2_backends.items():

View File

@ -84,9 +84,12 @@ class ECRResponse(BaseResponse):
'ECR.batch_check_layer_availability is not yet implemented') 'ECR.batch_check_layer_availability is not yet implemented')
def batch_delete_image(self): def batch_delete_image(self):
if self.is_not_dryrun('BatchDeleteImage'): repository_str = self._get_param('repositoryName')
raise NotImplementedError( registry_id = self._get_param('registryId')
'ECR.batch_delete_image is not yet implemented') image_ids = self._get_param('imageIds')
response = self.ecr_backend.batch_delete_image(repository_str, registry_id, image_ids)
return json.dumps(response)
def batch_get_image(self): def batch_get_image(self):
repository_str = self._get_param('repositoryName') repository_str = self._get_param('repositoryName')

View File

@ -1,5 +1,5 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from moto.core.exceptions import RESTError from moto.core.exceptions import RESTError, JsonRESTError
class ServiceNotFoundException(RESTError): class ServiceNotFoundException(RESTError):
@ -11,3 +11,13 @@ class ServiceNotFoundException(RESTError):
message="The service {0} does not exist".format(service_name), message="The service {0} does not exist".format(service_name),
template='error_json', template='error_json',
) )
class TaskDefinitionNotFoundException(JsonRESTError):
code = 400
def __init__(self):
super(TaskDefinitionNotFoundException, self).__init__(
error_type="ClientException",
message="The specified task definition does not exist.",
)

View File

@ -1,4 +1,5 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import re
import uuid import uuid
from datetime import datetime from datetime import datetime
from random import random, randint from random import random, randint
@ -7,10 +8,14 @@ import boto3
import pytz import pytz
from moto.core.exceptions import JsonRESTError from moto.core.exceptions import JsonRESTError
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.utils import unix_time
from moto.ec2 import ec2_backends from moto.ec2 import ec2_backends
from copy import copy from copy import copy
from .exceptions import ServiceNotFoundException from .exceptions import (
ServiceNotFoundException,
TaskDefinitionNotFoundException
)
class BaseObject(BaseModel): class BaseObject(BaseModel):
@ -103,12 +108,13 @@ class Cluster(BaseObject):
class TaskDefinition(BaseObject): class TaskDefinition(BaseObject):
def __init__(self, family, revision, container_definitions, volumes=None): def __init__(self, family, revision, container_definitions, volumes=None, tags=None):
self.family = family self.family = family
self.revision = revision self.revision = revision
self.arn = 'arn:aws:ecs:us-east-1:012345678910:task-definition/{0}:{1}'.format( self.arn = 'arn:aws:ecs:us-east-1:012345678910:task-definition/{0}:{1}'.format(
family, revision) family, revision)
self.container_definitions = container_definitions self.container_definitions = container_definitions
self.tags = tags if tags is not None else []
if volumes is None: if volumes is None:
self.volumes = [] self.volumes = []
else: else:
@ -119,6 +125,7 @@ class TaskDefinition(BaseObject):
response_object = self.gen_response_object() response_object = self.gen_response_object()
response_object['taskDefinitionArn'] = response_object['arn'] response_object['taskDefinitionArn'] = response_object['arn']
del response_object['arn'] del response_object['arn']
del response_object['tags']
return response_object return response_object
@property @property
@ -225,9 +232,9 @@ class Service(BaseObject):
for deployment in response_object['deployments']: for deployment in response_object['deployments']:
if isinstance(deployment['createdAt'], datetime): if isinstance(deployment['createdAt'], datetime):
deployment['createdAt'] = deployment['createdAt'].isoformat() deployment['createdAt'] = unix_time(deployment['createdAt'].replace(tzinfo=None))
if isinstance(deployment['updatedAt'], datetime): if isinstance(deployment['updatedAt'], datetime):
deployment['updatedAt'] = deployment['updatedAt'].isoformat() deployment['updatedAt'] = unix_time(deployment['updatedAt'].replace(tzinfo=None))
return response_object return response_object
@ -422,11 +429,9 @@ class EC2ContainerServiceBackend(BaseBackend):
revision = int(revision) revision = int(revision)
else: else:
family = task_definition_name family = task_definition_name
revision = len(self.task_definitions.get(family, [])) revision = self._get_last_task_definition_revision_id(family)
if family in self.task_definitions and 0 < revision <= len(self.task_definitions[family]): if family in self.task_definitions and revision in self.task_definitions[family]:
return self.task_definitions[family][revision - 1]
elif family in self.task_definitions and revision == -1:
return self.task_definitions[family][revision] return self.task_definitions[family][revision]
else: else:
raise Exception( raise Exception(
@ -466,15 +471,16 @@ class EC2ContainerServiceBackend(BaseBackend):
else: else:
raise Exception("{0} is not a cluster".format(cluster_name)) raise Exception("{0} is not a cluster".format(cluster_name))
def register_task_definition(self, family, container_definitions, volumes): def register_task_definition(self, family, container_definitions, volumes, tags=None):
if family in self.task_definitions: if family in self.task_definitions:
revision = len(self.task_definitions[family]) + 1 last_id = self._get_last_task_definition_revision_id(family)
revision = (last_id or 0) + 1
else: else:
self.task_definitions[family] = [] self.task_definitions[family] = {}
revision = 1 revision = 1
task_definition = TaskDefinition( task_definition = TaskDefinition(
family, revision, container_definitions, volumes) family, revision, container_definitions, volumes, tags)
self.task_definitions[family].append(task_definition) self.task_definitions[family][revision] = task_definition
return task_definition return task_definition
@ -484,16 +490,18 @@ class EC2ContainerServiceBackend(BaseBackend):
""" """
task_arns = [] task_arns = []
for task_definition_list in self.task_definitions.values(): for task_definition_list in self.task_definitions.values():
task_arns.extend( task_arns.extend([
[task_definition.arn for task_definition in task_definition_list]) task_definition.arn
for task_definition in task_definition_list.values()
])
return task_arns return task_arns
def deregister_task_definition(self, task_definition_str): def deregister_task_definition(self, task_definition_str):
task_definition_name = task_definition_str.split('/')[-1] task_definition_name = task_definition_str.split('/')[-1]
family, revision = task_definition_name.split(':') family, revision = task_definition_name.split(':')
revision = int(revision) revision = int(revision)
if family in self.task_definitions and 0 < revision <= len(self.task_definitions[family]): if family in self.task_definitions and revision in self.task_definitions[family]:
return self.task_definitions[family].pop(revision - 1) return self.task_definitions[family].pop(revision)
else: else:
raise Exception( raise Exception(
"{0} is not a task_definition".format(task_definition_name)) "{0} is not a task_definition".format(task_definition_name))
@ -950,6 +958,29 @@ class EC2ContainerServiceBackend(BaseBackend):
yield task_fam yield task_fam
def list_tags_for_resource(self, resource_arn):
"""Currently only implemented for task definitions"""
match = re.match(
"^arn:aws:ecs:(?P<region>[^:]+):(?P<account_id>[^:]+):(?P<service>[^:]+)/(?P<id>.*)$",
resource_arn)
if not match:
raise JsonRESTError('InvalidParameterException', 'The ARN provided is invalid.')
service = match.group("service")
if service == "task-definition":
for task_definition in self.task_definitions.values():
for revision in task_definition.values():
if revision.arn == resource_arn:
return revision.tags
else:
raise TaskDefinitionNotFoundException()
raise NotImplementedError()
def _get_last_task_definition_revision_id(self, family):
definitions = self.task_definitions.get(family, {})
if definitions:
return max(definitions.keys())
available_regions = boto3.session.Session().get_available_regions("ecs") available_regions = boto3.session.Session().get_available_regions("ecs")
ecs_backends = {region: EC2ContainerServiceBackend(region) for region in available_regions} ecs_backends = {region: EC2ContainerServiceBackend(region) for region in available_regions}

View File

@ -62,8 +62,9 @@ class EC2ContainerServiceResponse(BaseResponse):
family = self._get_param('family') family = self._get_param('family')
container_definitions = self._get_param('containerDefinitions') container_definitions = self._get_param('containerDefinitions')
volumes = self._get_param('volumes') volumes = self._get_param('volumes')
tags = self._get_param('tags')
task_definition = self.ecs_backend.register_task_definition( task_definition = self.ecs_backend.register_task_definition(
family, container_definitions, volumes) family, container_definitions, volumes, tags)
return json.dumps({ return json.dumps({
'taskDefinition': task_definition.response_object 'taskDefinition': task_definition.response_object
}) })
@ -313,3 +314,8 @@ class EC2ContainerServiceResponse(BaseResponse):
results = self.ecs_backend.list_task_definition_families(family_prefix, status, max_results, next_token) results = self.ecs_backend.list_task_definition_families(family_prefix, status, max_results, next_token)
return json.dumps({'families': list(results)}) return json.dumps({'families': list(results)})
def list_tags_for_resource(self):
resource_arn = self._get_param('resourceArn')
tags = self.ecs_backend.list_tags_for_resource(resource_arn)
return json.dumps({'tags': tags})

View File

@ -2,9 +2,11 @@ from __future__ import unicode_literals
import datetime import datetime
import re import re
from jinja2 import Template
from moto.compat import OrderedDict from moto.compat import OrderedDict
from moto.core.exceptions import RESTError from moto.core.exceptions import RESTError
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.utils import camelcase_to_underscores
from moto.ec2.models import ec2_backends from moto.ec2.models import ec2_backends
from moto.acm.models import acm_backends from moto.acm.models import acm_backends
from .utils import make_arn_for_target_group from .utils import make_arn_for_target_group
@ -35,12 +37,13 @@ from .exceptions import (
class FakeHealthStatus(BaseModel): class FakeHealthStatus(BaseModel):
def __init__(self, instance_id, port, health_port, status, reason=None): def __init__(self, instance_id, port, health_port, status, reason=None, description=None):
self.instance_id = instance_id self.instance_id = instance_id
self.port = port self.port = port
self.health_port = health_port self.health_port = health_port
self.status = status self.status = status
self.reason = reason self.reason = reason
self.description = description
class FakeTargetGroup(BaseModel): class FakeTargetGroup(BaseModel):
@ -69,7 +72,7 @@ class FakeTargetGroup(BaseModel):
self.protocol = protocol self.protocol = protocol
self.port = port self.port = port
self.healthcheck_protocol = healthcheck_protocol or 'HTTP' self.healthcheck_protocol = healthcheck_protocol or 'HTTP'
self.healthcheck_port = healthcheck_port or 'traffic-port' self.healthcheck_port = healthcheck_port or str(self.port)
self.healthcheck_path = healthcheck_path or '/' self.healthcheck_path = healthcheck_path or '/'
self.healthcheck_interval_seconds = healthcheck_interval_seconds or 30 self.healthcheck_interval_seconds = healthcheck_interval_seconds or 30
self.healthcheck_timeout_seconds = healthcheck_timeout_seconds or 5 self.healthcheck_timeout_seconds = healthcheck_timeout_seconds or 5
@ -112,10 +115,14 @@ class FakeTargetGroup(BaseModel):
raise TooManyTagsError() raise TooManyTagsError()
self.tags[key] = value self.tags[key] = value
def health_for(self, target): def health_for(self, target, ec2_backend):
t = self.targets.get(target['id']) t = self.targets.get(target['id'])
if t is None: if t is None:
raise InvalidTargetError() raise InvalidTargetError()
if t['id'].startswith("i-"): # EC2 instance ID
instance = ec2_backend.get_instance_by_id(t['id'])
if instance.state == "stopped":
return FakeHealthStatus(t['id'], t['port'], self.healthcheck_port, 'unused', 'Target.InvalidState', 'Target is in the stopped state')
return FakeHealthStatus(t['id'], t['port'], self.healthcheck_port, 'healthy') return FakeHealthStatus(t['id'], t['port'], self.healthcheck_port, 'healthy')
@classmethod @classmethod
@ -208,13 +215,12 @@ class FakeListener(BaseModel):
action_type = action['Type'] action_type = action['Type']
if action_type == 'forward': if action_type == 'forward':
default_actions.append({'type': action_type, 'target_group_arn': action['TargetGroupArn']}) default_actions.append({'type': action_type, 'target_group_arn': action['TargetGroupArn']})
elif action_type == 'redirect': elif action_type in ['redirect', 'authenticate-cognito']:
redirect_action = {'type': action_type, } redirect_action = {'type': action_type}
for redirect_config_key, redirect_config_value in action['RedirectConfig'].items(): key = 'RedirectConfig' if action_type == 'redirect' else 'AuthenticateCognitoConfig'
for redirect_config_key, redirect_config_value in action[key].items():
# need to match the output of _get_list_prefix # need to match the output of _get_list_prefix
if redirect_config_key == 'StatusCode': redirect_action[camelcase_to_underscores(key) + '._' + camelcase_to_underscores(redirect_config_key)] = redirect_config_value
redirect_config_key = 'status_code'
redirect_action['redirect_config._' + redirect_config_key.lower()] = redirect_config_value
default_actions.append(redirect_action) default_actions.append(redirect_action)
else: else:
raise InvalidActionTypeError(action_type, i + 1) raise InvalidActionTypeError(action_type, i + 1)
@ -226,6 +232,32 @@ class FakeListener(BaseModel):
return listener return listener
class FakeAction(BaseModel):
def __init__(self, data):
self.data = data
self.type = data.get("type")
def to_xml(self):
template = Template("""<Type>{{ action.type }}</Type>
{% if action.type == "forward" %}
<TargetGroupArn>{{ action.data["target_group_arn"] }}</TargetGroupArn>
{% elif action.type == "redirect" %}
<RedirectConfig>
<Protocol>{{ action.data["redirect_config._protocol"] }}</Protocol>
<Port>{{ action.data["redirect_config._port"] }}</Port>
<StatusCode>{{ action.data["redirect_config._status_code"] }}</StatusCode>
</RedirectConfig>
{% elif action.type == "authenticate-cognito" %}
<AuthenticateCognitoConfig>
<UserPoolArn>{{ action.data["authenticate_cognito_config._user_pool_arn"] }}</UserPoolArn>
<UserPoolClientId>{{ action.data["authenticate_cognito_config._user_pool_client_id"] }}</UserPoolClientId>
<UserPoolDomain>{{ action.data["authenticate_cognito_config._user_pool_domain"] }}</UserPoolDomain>
</AuthenticateCognitoConfig>
{% endif %}
""")
return template.render(action=self)
class FakeRule(BaseModel): class FakeRule(BaseModel):
def __init__(self, listener_arn, conditions, priority, actions, is_default): def __init__(self, listener_arn, conditions, priority, actions, is_default):
@ -397,6 +429,7 @@ class ELBv2Backend(BaseBackend):
return new_load_balancer return new_load_balancer
def create_rule(self, listener_arn, conditions, priority, actions): def create_rule(self, listener_arn, conditions, priority, actions):
actions = [FakeAction(action) for action in actions]
listeners = self.describe_listeners(None, [listener_arn]) listeners = self.describe_listeners(None, [listener_arn])
if not listeners: if not listeners:
raise ListenerNotFoundError() raise ListenerNotFoundError()
@ -424,20 +457,7 @@ class ELBv2Backend(BaseBackend):
if rule.priority == priority: if rule.priority == priority:
raise PriorityInUseError() raise PriorityInUseError()
# validate Actions self._validate_actions(actions)
target_group_arns = [target_group.arn for target_group in self.target_groups.values()]
for i, action in enumerate(actions):
index = i + 1
action_type = action['type']
if action_type == 'forward':
action_target_group_arn = action['target_group_arn']
if action_target_group_arn not in target_group_arns:
raise ActionTargetGroupNotFoundError(action_target_group_arn)
elif action_type == 'redirect':
# nothing to do
pass
else:
raise InvalidActionTypeError(action_type, index)
# TODO: check for error 'TooManyRegistrationsForTargetId' # TODO: check for error 'TooManyRegistrationsForTargetId'
# TODO: check for error 'TooManyRules' # TODO: check for error 'TooManyRules'
@ -447,6 +467,21 @@ class ELBv2Backend(BaseBackend):
listener.register(rule) listener.register(rule)
return [rule] return [rule]
def _validate_actions(self, actions):
# validate Actions
target_group_arns = [target_group.arn for target_group in self.target_groups.values()]
for i, action in enumerate(actions):
index = i + 1
action_type = action.type
if action_type == 'forward':
action_target_group_arn = action.data['target_group_arn']
if action_target_group_arn not in target_group_arns:
raise ActionTargetGroupNotFoundError(action_target_group_arn)
elif action_type in ['redirect', 'authenticate-cognito']:
pass
else:
raise InvalidActionTypeError(action_type, index)
def create_target_group(self, name, **kwargs): def create_target_group(self, name, **kwargs):
if len(name) > 32: if len(name) > 32:
raise InvalidTargetGroupNameError( raise InvalidTargetGroupNameError(
@ -490,26 +525,22 @@ class ELBv2Backend(BaseBackend):
return target_group return target_group
def create_listener(self, load_balancer_arn, protocol, port, ssl_policy, certificate, default_actions): def create_listener(self, load_balancer_arn, protocol, port, ssl_policy, certificate, default_actions):
default_actions = [FakeAction(action) for action in default_actions]
balancer = self.load_balancers.get(load_balancer_arn) balancer = self.load_balancers.get(load_balancer_arn)
if balancer is None: if balancer is None:
raise LoadBalancerNotFoundError() raise LoadBalancerNotFoundError()
if port in balancer.listeners: if port in balancer.listeners:
raise DuplicateListenerError() raise DuplicateListenerError()
self._validate_actions(default_actions)
arn = load_balancer_arn.replace(':loadbalancer/', ':listener/') + "/%s%s" % (port, id(self)) arn = load_balancer_arn.replace(':loadbalancer/', ':listener/') + "/%s%s" % (port, id(self))
listener = FakeListener(load_balancer_arn, arn, protocol, port, ssl_policy, certificate, default_actions) listener = FakeListener(load_balancer_arn, arn, protocol, port, ssl_policy, certificate, default_actions)
balancer.listeners[listener.arn] = listener balancer.listeners[listener.arn] = listener
for i, action in enumerate(default_actions): for action in default_actions:
action_type = action['type'] if action.type == 'forward':
if action_type == 'forward': target_group = self.target_groups[action.data['target_group_arn']]
if action['target_group_arn'] in self.target_groups.keys(): target_group.load_balancer_arns.append(load_balancer_arn)
target_group = self.target_groups[action['target_group_arn']]
target_group.load_balancer_arns.append(load_balancer_arn)
elif action_type == 'redirect':
# nothing to do
pass
else:
raise InvalidActionTypeError(action_type, i + 1)
return listener return listener
@ -643,6 +674,7 @@ class ELBv2Backend(BaseBackend):
raise ListenerNotFoundError() raise ListenerNotFoundError()
def modify_rule(self, rule_arn, conditions, actions): def modify_rule(self, rule_arn, conditions, actions):
actions = [FakeAction(action) for action in actions]
# if conditions or actions is empty list, do not update the attributes # if conditions or actions is empty list, do not update the attributes
if not conditions and not actions: if not conditions and not actions:
raise InvalidModifyRuleArgumentsError() raise InvalidModifyRuleArgumentsError()
@ -668,20 +700,7 @@ class ELBv2Backend(BaseBackend):
# TODO: check pattern of value for 'path-pattern' # TODO: check pattern of value for 'path-pattern'
# validate Actions # validate Actions
target_group_arns = [target_group.arn for target_group in self.target_groups.values()] self._validate_actions(actions)
if actions:
for i, action in enumerate(actions):
index = i + 1
action_type = action['type']
if action_type == 'forward':
action_target_group_arn = action['target_group_arn']
if action_target_group_arn not in target_group_arns:
raise ActionTargetGroupNotFoundError(action_target_group_arn)
elif action_type == 'redirect':
# nothing to do
pass
else:
raise InvalidActionTypeError(action_type, index)
# TODO: check for error 'TooManyRegistrationsForTargetId' # TODO: check for error 'TooManyRegistrationsForTargetId'
# TODO: check for error 'TooManyRules' # TODO: check for error 'TooManyRules'
@ -712,7 +731,7 @@ class ELBv2Backend(BaseBackend):
if not targets: if not targets:
targets = target_group.targets.values() targets = target_group.targets.values()
return [target_group.health_for(target) for target in targets] return [target_group.health_for(target, self.ec2_backend) for target in targets]
def set_rule_priorities(self, rule_priorities): def set_rule_priorities(self, rule_priorities):
# validate # validate
@ -846,6 +865,7 @@ class ELBv2Backend(BaseBackend):
return target_group return target_group
def modify_listener(self, arn, port=None, protocol=None, ssl_policy=None, certificates=None, default_actions=None): def modify_listener(self, arn, port=None, protocol=None, ssl_policy=None, certificates=None, default_actions=None):
default_actions = [FakeAction(action) for action in default_actions]
for load_balancer in self.load_balancers.values(): for load_balancer in self.load_balancers.values():
if arn in load_balancer.listeners: if arn in load_balancer.listeners:
break break
@ -912,7 +932,7 @@ class ELBv2Backend(BaseBackend):
for listener in load_balancer.listeners.values(): for listener in load_balancer.listeners.values():
for rule in listener.rules: for rule in listener.rules:
for action in rule.actions: for action in rule.actions:
if action.get('target_group_arn') == target_group_arn: if action.data.get('target_group_arn') == target_group_arn:
return True return True
return False return False

View File

@ -775,16 +775,7 @@ CREATE_LISTENER_TEMPLATE = """<CreateListenerResponse xmlns="http://elasticloadb
<DefaultActions> <DefaultActions>
{% for action in listener.default_actions %} {% for action in listener.default_actions %}
<member> <member>
<Type>{{ action.type }}</Type> {{ action.to_xml() }}
{% if action["type"] == "forward" %}
<TargetGroupArn>{{ action["target_group_arn"] }}</TargetGroupArn>
{% elif action["type"] == "redirect" %}
<RedirectConfig>
<Protocol>{{ action["redirect_config._protocol"] }}</Protocol>
<Port>{{ action["redirect_config._port"] }}</Port>
<StatusCode>{{ action["redirect_config._status_code"] }}</StatusCode>
</RedirectConfig>
{% endif %}
</member> </member>
{% endfor %} {% endfor %}
</DefaultActions> </DefaultActions>
@ -888,16 +879,7 @@ DESCRIBE_RULES_TEMPLATE = """<DescribeRulesResponse xmlns="http://elasticloadbal
<Actions> <Actions>
{% for action in rule.actions %} {% for action in rule.actions %}
<member> <member>
<Type>{{ action["type"] }}</Type> {{ action.to_xml() }}
{% if action["type"] == "forward" %}
<TargetGroupArn>{{ action["target_group_arn"] }}</TargetGroupArn>
{% elif action["type"] == "redirect" %}
<RedirectConfig>
<Protocol>{{ action["redirect_config._protocol"] }}</Protocol>
<Port>{{ action["redirect_config._port"] }}</Port>
<StatusCode>{{ action["redirect_config._status_code"] }}</StatusCode>
</RedirectConfig>
{% endif %}
</member> </member>
{% endfor %} {% endfor %}
</Actions> </Actions>
@ -989,16 +971,7 @@ DESCRIBE_LISTENERS_TEMPLATE = """<DescribeLoadBalancersResponse xmlns="http://el
<DefaultActions> <DefaultActions>
{% for action in listener.default_actions %} {% for action in listener.default_actions %}
<member> <member>
<Type>{{ action.type }}</Type> {{ action.to_xml() }}
{% if action["type"] == "forward" %}
<TargetGroupArn>{{ action["target_group_arn"] }}</TargetGroupArn>m
{% elif action["type"] == "redirect" %}
<RedirectConfig>
<Protocol>{{ action["redirect_config._protocol"] }}</Protocol>
<Port>{{ action["redirect_config._port"] }}</Port>
<StatusCode>{{ action["redirect_config._status_code"] }}</StatusCode>
</RedirectConfig>
{% endif %}
</member> </member>
{% endfor %} {% endfor %}
</DefaultActions> </DefaultActions>
@ -1048,8 +1021,7 @@ MODIFY_RULE_TEMPLATE = """<ModifyRuleResponse xmlns="http://elasticloadbalancing
<Actions> <Actions>
{% for action in rule.actions %} {% for action in rule.actions %}
<member> <member>
<Type>{{ action["type"] }}</Type> {{ action.to_xml() }}
<TargetGroupArn>{{ action["target_group_arn"] }}</TargetGroupArn>
</member> </member>
{% endfor %} {% endfor %}
</Actions> </Actions>
@ -1208,6 +1180,12 @@ DESCRIBE_TARGET_HEALTH_TEMPLATE = """<DescribeTargetHealthResponse xmlns="http:/
<HealthCheckPort>{{ target_health.health_port }}</HealthCheckPort> <HealthCheckPort>{{ target_health.health_port }}</HealthCheckPort>
<TargetHealth> <TargetHealth>
<State>{{ target_health.status }}</State> <State>{{ target_health.status }}</State>
{% if target_health.reason %}
<Reason>{{ target_health.reason }}</Reason>
{% endif %}
{% if target_health.description %}
<Description>{{ target_health.description }}</Description>
{% endif %}
</TargetHealth> </TargetHealth>
<Target> <Target>
<Port>{{ target_health.port }}</Port> <Port>{{ target_health.port }}</Port>
@ -1426,16 +1404,7 @@ MODIFY_LISTENER_TEMPLATE = """<ModifyListenerResponse xmlns="http://elasticloadb
<DefaultActions> <DefaultActions>
{% for action in listener.default_actions %} {% for action in listener.default_actions %}
<member> <member>
<Type>{{ action.type }}</Type> {{ action.to_xml() }}
{% if action["type"] == "forward" %}
<TargetGroupArn>{{ action["target_group_arn"] }}</TargetGroupArn>
{% elif action["type"] == "redirect" %}
<RedirectConfig>
<Protocol>{{ action["redirect_config._protocol"] }}</Protocol>
<Port>{{ action["redirect_config._port"] }}</Port>
<StatusCode>{{ action["redirect_config._status_code"] }}</StatusCode>
</RedirectConfig>
{% endif %}
</member> </member>
{% endfor %} {% endfor %}
</DefaultActions> </DefaultActions>

View File

@ -138,6 +138,12 @@ class FakeTable(BaseModel):
raise PartitionAlreadyExistsException() raise PartitionAlreadyExistsException()
self.partitions[key] = partition self.partitions[key] = partition
def delete_partition(self, values):
try:
del self.partitions[str(values)]
except KeyError:
raise PartitionNotFoundException()
class FakePartition(BaseModel): class FakePartition(BaseModel):
def __init__(self, database_name, table_name, partiton_input): def __init__(self, database_name, table_name, partiton_input):

View File

@ -4,6 +4,11 @@ import json
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import glue_backend from .models import glue_backend
from .exceptions import (
PartitionAlreadyExistsException,
PartitionNotFoundException,
TableNotFoundException
)
class GlueResponse(BaseResponse): class GlueResponse(BaseResponse):
@ -90,6 +95,28 @@ class GlueResponse(BaseResponse):
resp = self.glue_backend.delete_table(database_name, table_name) resp = self.glue_backend.delete_table(database_name, table_name)
return json.dumps(resp) return json.dumps(resp)
def batch_delete_table(self):
database_name = self.parameters.get('DatabaseName')
errors = []
for table_name in self.parameters.get('TablesToDelete'):
try:
self.glue_backend.delete_table(database_name, table_name)
except TableNotFoundException:
errors.append({
"TableName": table_name,
"ErrorDetail": {
"ErrorCode": "EntityNotFoundException",
"ErrorMessage": "Table not found"
}
})
out = {}
if errors:
out["Errors"] = errors
return json.dumps(out)
def get_partitions(self): def get_partitions(self):
database_name = self.parameters.get('DatabaseName') database_name = self.parameters.get('DatabaseName')
table_name = self.parameters.get('TableName') table_name = self.parameters.get('TableName')
@ -114,6 +141,23 @@ class GlueResponse(BaseResponse):
return json.dumps({'Partition': p.as_dict()}) return json.dumps({'Partition': p.as_dict()})
def batch_get_partition(self):
database_name = self.parameters.get('DatabaseName')
table_name = self.parameters.get('TableName')
partitions_to_get = self.parameters.get('PartitionsToGet')
table = self.glue_backend.get_table(database_name, table_name)
partitions = []
for values in partitions_to_get:
try:
p = table.get_partition(values=values["Values"])
partitions.append(p.as_dict())
except PartitionNotFoundException:
continue
return json.dumps({'Partitions': partitions})
def create_partition(self): def create_partition(self):
database_name = self.parameters.get('DatabaseName') database_name = self.parameters.get('DatabaseName')
table_name = self.parameters.get('TableName') table_name = self.parameters.get('TableName')
@ -124,6 +168,30 @@ class GlueResponse(BaseResponse):
return "" return ""
def batch_create_partition(self):
database_name = self.parameters.get('DatabaseName')
table_name = self.parameters.get('TableName')
table = self.glue_backend.get_table(database_name, table_name)
errors_output = []
for part_input in self.parameters.get('PartitionInputList'):
try:
table.create_partition(part_input)
except PartitionAlreadyExistsException:
errors_output.append({
'PartitionValues': part_input['Values'],
'ErrorDetail': {
'ErrorCode': 'AlreadyExistsException',
'ErrorMessage': 'Partition already exists.'
}
})
out = {}
if errors_output:
out["Errors"] = errors_output
return json.dumps(out)
def update_partition(self): def update_partition(self):
database_name = self.parameters.get('DatabaseName') database_name = self.parameters.get('DatabaseName')
table_name = self.parameters.get('TableName') table_name = self.parameters.get('TableName')
@ -134,3 +202,38 @@ class GlueResponse(BaseResponse):
table.update_partition(part_to_update, part_input) table.update_partition(part_to_update, part_input)
return "" return ""
def delete_partition(self):
database_name = self.parameters.get('DatabaseName')
table_name = self.parameters.get('TableName')
part_to_delete = self.parameters.get('PartitionValues')
table = self.glue_backend.get_table(database_name, table_name)
table.delete_partition(part_to_delete)
return ""
def batch_delete_partition(self):
database_name = self.parameters.get('DatabaseName')
table_name = self.parameters.get('TableName')
table = self.glue_backend.get_table(database_name, table_name)
errors_output = []
for part_input in self.parameters.get('PartitionsToDelete'):
values = part_input.get('Values')
try:
table.delete_partition(values)
except PartitionNotFoundException:
errors_output.append({
'PartitionValues': values,
'ErrorDetail': {
'ErrorCode': 'EntityNotFoundException',
'ErrorMessage': 'Partition not found',
}
})
out = {}
if errors_output:
out['Errors'] = errors_output
return json.dumps(out)

File diff suppressed because it is too large Load Diff

View File

@ -26,6 +26,14 @@ class IAMReportNotPresentException(RESTError):
"ReportNotPresent", message) "ReportNotPresent", message)
class IAMLimitExceededException(RESTError):
code = 400
def __init__(self, message):
super(IAMLimitExceededException, self).__init__(
"LimitExceeded", message)
class MalformedCertificate(RESTError): class MalformedCertificate(RESTError):
code = 400 code = 400
@ -34,6 +42,14 @@ class MalformedCertificate(RESTError):
'MalformedCertificate', 'Certificate {cert} is malformed'.format(cert=cert)) 'MalformedCertificate', 'Certificate {cert} is malformed'.format(cert=cert))
class MalformedPolicyDocument(RESTError):
code = 400
def __init__(self, message=""):
super(MalformedPolicyDocument, self).__init__(
'MalformedPolicyDocument', message)
class DuplicateTags(RESTError): class DuplicateTags(RESTError):
code = 400 code = 400

View File

@ -8,14 +8,14 @@ import re
from cryptography import x509 from cryptography import x509
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
import pytz
from moto.core.exceptions import RESTError from moto.core.exceptions import RESTError
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_without_milliseconds from moto.core.utils import iso_8601_datetime_without_milliseconds, iso_8601_datetime_with_milliseconds
from moto.iam.policy_validation import IAMPolicyDocumentValidator
from .aws_managed_policies import aws_managed_policies_data from .aws_managed_policies import aws_managed_policies_data
from .exceptions import IAMNotFoundException, IAMConflictException, IAMReportNotPresentException, MalformedCertificate, \ from .exceptions import IAMNotFoundException, IAMConflictException, IAMReportNotPresentException, IAMLimitExceededException, \
DuplicateTags, TagKeyTooBig, InvalidTagCharacters, TooManyTags, TagValueTooBig MalformedCertificate, DuplicateTags, TagKeyTooBig, InvalidTagCharacters, TooManyTags, TagValueTooBig
from .utils import random_access_key, random_alphanumeric, random_resource_id, random_policy_id from .utils import random_access_key, random_alphanumeric, random_resource_id, random_policy_id
ACCOUNT_ID = 123456789012 ACCOUNT_ID = 123456789012
@ -28,11 +28,15 @@ class MFADevice(object):
serial_number, serial_number,
authentication_code_1, authentication_code_1,
authentication_code_2): authentication_code_2):
self.enable_date = datetime.now(pytz.utc) self.enable_date = datetime.utcnow()
self.serial_number = serial_number self.serial_number = serial_number
self.authentication_code_1 = authentication_code_1 self.authentication_code_1 = authentication_code_1
self.authentication_code_2 = authentication_code_2 self.authentication_code_2 = authentication_code_2
@property
def enabled_iso_8601(self):
return iso_8601_datetime_without_milliseconds(self.enable_date)
class Policy(BaseModel): class Policy(BaseModel):
is_attachable = False is_attachable = False
@ -42,7 +46,9 @@ class Policy(BaseModel):
default_version_id=None, default_version_id=None,
description=None, description=None,
document=None, document=None,
path=None): path=None,
create_date=None,
update_date=None):
self.name = name self.name = name
self.attachment_count = 0 self.attachment_count = 0
@ -56,10 +62,25 @@ class Policy(BaseModel):
else: else:
self.default_version_id = 'v1' self.default_version_id = 'v1'
self.next_version_num = 2 self.next_version_num = 2
self.versions = [PolicyVersion(self.arn, document, True)] self.versions = [PolicyVersion(self.arn, document, True, self.default_version_id, update_date)]
self.create_datetime = datetime.now(pytz.utc) self.create_date = create_date if create_date is not None else datetime.utcnow()
self.update_datetime = datetime.now(pytz.utc) self.update_date = update_date if update_date is not None else datetime.utcnow()
def update_default_version(self, new_default_version_id):
for version in self.versions:
if version.version_id == self.default_version_id:
version.is_default = False
break
self.default_version_id = new_default_version_id
@property
def created_iso_8601(self):
return iso_8601_datetime_with_milliseconds(self.create_date)
@property
def updated_iso_8601(self):
return iso_8601_datetime_with_milliseconds(self.update_date)
class SAMLProvider(BaseModel): class SAMLProvider(BaseModel):
@ -77,13 +98,19 @@ class PolicyVersion(object):
def __init__(self, def __init__(self,
policy_arn, policy_arn,
document, document,
is_default=False): is_default=False,
version_id='v1',
create_date=None):
self.policy_arn = policy_arn self.policy_arn = policy_arn
self.document = document or {} self.document = document or {}
self.is_default = is_default self.is_default = is_default
self.version_id = 'v1' self.version_id = version_id
self.create_datetime = datetime.now(pytz.utc) self.create_date = create_date if create_date is not None else datetime.utcnow()
@property
def created_iso_8601(self):
return iso_8601_datetime_with_milliseconds(self.create_date)
class ManagedPolicy(Policy): class ManagedPolicy(Policy):
@ -112,7 +139,9 @@ class AWSManagedPolicy(ManagedPolicy):
return cls(name, return cls(name,
default_version_id=data.get('DefaultVersionId'), default_version_id=data.get('DefaultVersionId'),
path=data.get('Path'), path=data.get('Path'),
document=data.get('Document')) document=json.dumps(data.get('Document')),
create_date=datetime.strptime(data.get('CreateDate'), "%Y-%m-%dT%H:%M:%S+00:00"),
update_date=datetime.strptime(data.get('UpdateDate'), "%Y-%m-%dT%H:%M:%S+00:00"))
@property @property
def arn(self): def arn(self):
@ -132,18 +161,22 @@ class InlinePolicy(Policy):
class Role(BaseModel): class Role(BaseModel):
def __init__(self, role_id, name, assume_role_policy_document, path, permissions_boundary): def __init__(self, role_id, name, assume_role_policy_document, path, permissions_boundary, description, tags):
self.id = role_id self.id = role_id
self.name = name self.name = name
self.assume_role_policy_document = assume_role_policy_document self.assume_role_policy_document = assume_role_policy_document
self.path = path or '/' self.path = path or '/'
self.policies = {} self.policies = {}
self.managed_policies = {} self.managed_policies = {}
self.create_date = datetime.now(pytz.utc) self.create_date = datetime.utcnow()
self.tags = {} self.tags = tags
self.description = "" self.description = description
self.permissions_boundary = permissions_boundary self.permissions_boundary = permissions_boundary
@property
def created_iso_8601(self):
return iso_8601_datetime_with_milliseconds(self.create_date)
@classmethod @classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
properties = cloudformation_json['Properties'] properties = cloudformation_json['Properties']
@ -152,7 +185,9 @@ class Role(BaseModel):
role_name=resource_name, role_name=resource_name,
assume_role_policy_document=properties['AssumeRolePolicyDocument'], assume_role_policy_document=properties['AssumeRolePolicyDocument'],
path=properties.get('Path', '/'), path=properties.get('Path', '/'),
permissions_boundary=properties.get('PermissionsBoundary', '') permissions_boundary=properties.get('PermissionsBoundary', ''),
description=properties.get('Description', ''),
tags=properties.get('Tags', {})
) )
policies = properties.get('Policies', []) policies = properties.get('Policies', [])
@ -198,7 +233,11 @@ class InstanceProfile(BaseModel):
self.name = name self.name = name
self.path = path or '/' self.path = path or '/'
self.roles = roles if roles else [] self.roles = roles if roles else []
self.create_date = datetime.now(pytz.utc) self.create_date = datetime.utcnow()
@property
def created_iso_8601(self):
return iso_8601_datetime_with_milliseconds(self.create_date)
@classmethod @classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
@ -250,25 +289,31 @@ class SigningCertificate(BaseModel):
self.id = id self.id = id
self.user_name = user_name self.user_name = user_name
self.body = body self.body = body
self.upload_date = datetime.strftime(datetime.utcnow(), "%Y-%m-%d-%H-%M-%S") self.upload_date = datetime.utcnow()
self.status = 'Active' self.status = 'Active'
@property
def uploaded_iso_8601(self):
return iso_8601_datetime_without_milliseconds(self.upload_date)
class AccessKey(BaseModel): class AccessKey(BaseModel):
def __init__(self, user_name): def __init__(self, user_name):
self.user_name = user_name self.user_name = user_name
self.access_key_id = random_access_key() self.access_key_id = "AKIA" + random_access_key()
self.secret_access_key = random_alphanumeric(32) self.secret_access_key = random_alphanumeric(40)
self.status = 'Active' self.status = 'Active'
self.create_date = datetime.strftime( self.create_date = datetime.utcnow()
datetime.utcnow(), self.last_used = datetime.utcnow()
"%Y-%m-%dT%H:%M:%SZ"
) @property
self.last_used = datetime.strftime( def created_iso_8601(self):
datetime.utcnow(), return iso_8601_datetime_without_milliseconds(self.create_date)
"%Y-%m-%dT%H:%M:%SZ"
) @property
def last_used_iso_8601(self):
return iso_8601_datetime_without_milliseconds(self.last_used)
def get_cfn_attribute(self, attribute_name): def get_cfn_attribute(self, attribute_name):
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
@ -283,15 +328,16 @@ class Group(BaseModel):
self.name = name self.name = name
self.id = random_resource_id() self.id = random_resource_id()
self.path = path self.path = path
self.created = datetime.strftime( self.create_date = datetime.utcnow()
datetime.utcnow(),
"%Y-%m-%d-%H-%M-%S"
)
self.users = [] self.users = []
self.managed_policies = {} self.managed_policies = {}
self.policies = {} self.policies = {}
@property
def created_iso_8601(self):
return iso_8601_datetime_with_milliseconds(self.create_date)
def get_cfn_attribute(self, attribute_name): def get_cfn_attribute(self, attribute_name):
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == 'Arn': if attribute_name == 'Arn':
@ -306,10 +352,6 @@ class Group(BaseModel):
else: else:
return "arn:aws:iam::{0}:group/{1}/{2}".format(ACCOUNT_ID, self.path, self.name) return "arn:aws:iam::{0}:group/{1}/{2}".format(ACCOUNT_ID, self.path, self.name)
@property
def create_date(self):
return self.created
def get_policy(self, policy_name): def get_policy(self, policy_name):
try: try:
policy_json = self.policies[policy_name] policy_json = self.policies[policy_name]
@ -335,7 +377,7 @@ class User(BaseModel):
self.name = name self.name = name
self.id = random_resource_id() self.id = random_resource_id()
self.path = path if path else "/" self.path = path if path else "/"
self.created = datetime.utcnow() self.create_date = datetime.utcnow()
self.mfa_devices = {} self.mfa_devices = {}
self.policies = {} self.policies = {}
self.managed_policies = {} self.managed_policies = {}
@ -350,7 +392,7 @@ class User(BaseModel):
@property @property
def created_iso_8601(self): def created_iso_8601(self):
return iso_8601_datetime_without_milliseconds(self.created) return iso_8601_datetime_with_milliseconds(self.create_date)
def get_policy(self, policy_name): def get_policy(self, policy_name):
policy_json = None policy_json = None
@ -421,7 +463,7 @@ class User(BaseModel):
def to_csv(self): def to_csv(self):
date_format = '%Y-%m-%dT%H:%M:%S+00:00' date_format = '%Y-%m-%dT%H:%M:%S+00:00'
date_created = self.created date_created = self.create_date
# aagrawal,arn:aws:iam::509284790694:user/aagrawal,2014-09-01T22:28:48+00:00,true,2014-11-12T23:36:49+00:00,2014-09-03T18:59:00+00:00,N/A,false,true,2014-09-01T22:28:48+00:00,false,N/A,false,N/A,false,N/A # aagrawal,arn:aws:iam::509284790694:user/aagrawal,2014-09-01T22:28:48+00:00,true,2014-11-12T23:36:49+00:00,2014-09-03T18:59:00+00:00,N/A,false,true,2014-09-01T22:28:48+00:00,false,N/A,false,N/A,false,N/A
if not self.password: if not self.password:
password_enabled = 'false' password_enabled = 'false'
@ -478,7 +520,7 @@ class IAMBackend(BaseBackend):
super(IAMBackend, self).__init__() super(IAMBackend, self).__init__()
def _init_managed_policies(self): def _init_managed_policies(self):
return dict((p.name, p) for p in aws_managed_policies) return dict((p.arn, p) for p in aws_managed_policies)
def attach_role_policy(self, policy_arn, role_name): def attach_role_policy(self, policy_arn, role_name):
arns = dict((p.arn, p) for p in self.managed_policies.values()) arns = dict((p.arn, p) for p in self.managed_policies.values())
@ -536,6 +578,9 @@ class IAMBackend(BaseBackend):
policy.detach_from(self.get_user(user_name)) policy.detach_from(self.get_user(user_name))
def create_policy(self, description, path, policy_document, policy_name): def create_policy(self, description, path, policy_document, policy_name):
iam_policy_document_validator = IAMPolicyDocumentValidator(policy_document)
iam_policy_document_validator.validate()
policy = ManagedPolicy( policy = ManagedPolicy(
policy_name, policy_name,
description=description, description=description,
@ -592,12 +637,13 @@ class IAMBackend(BaseBackend):
return policies, marker return policies, marker
def create_role(self, role_name, assume_role_policy_document, path, permissions_boundary): def create_role(self, role_name, assume_role_policy_document, path, permissions_boundary, description, tags):
role_id = random_resource_id() role_id = random_resource_id()
if permissions_boundary and not self.policy_arn_regex.match(permissions_boundary): if permissions_boundary and not self.policy_arn_regex.match(permissions_boundary):
raise RESTError('InvalidParameterValue', 'Value ({}) for parameter PermissionsBoundary is invalid.'.format(permissions_boundary)) raise RESTError('InvalidParameterValue', 'Value ({}) for parameter PermissionsBoundary is invalid.'.format(permissions_boundary))
role = Role(role_id, role_name, assume_role_policy_document, path, permissions_boundary) clean_tags = self._tag_verification(tags)
role = Role(role_id, role_name, assume_role_policy_document, path, permissions_boundary, description, clean_tags)
self.roles[role_id] = role self.roles[role_id] = role
return role return role
@ -628,6 +674,9 @@ class IAMBackend(BaseBackend):
def put_role_policy(self, role_name, policy_name, policy_json): def put_role_policy(self, role_name, policy_name, policy_json):
role = self.get_role(role_name) role = self.get_role(role_name)
iam_policy_document_validator = IAMPolicyDocumentValidator(policy_json)
iam_policy_document_validator.validate()
role.put_policy(policy_name, policy_json) role.put_policy(policy_name, policy_json)
def delete_role_policy(self, role_name, policy_name): def delete_role_policy(self, role_name, policy_name):
@ -639,15 +688,32 @@ class IAMBackend(BaseBackend):
for p, d in role.policies.items(): for p, d in role.policies.items():
if p == policy_name: if p == policy_name:
return p, d return p, d
raise IAMNotFoundException("Policy Document {0} not attached to role {1}".format(policy_name, role_name))
def list_role_policies(self, role_name): def list_role_policies(self, role_name):
role = self.get_role(role_name) role = self.get_role(role_name)
return role.policies.keys() return role.policies.keys()
def _tag_verification(self, tags):
if len(tags) > 50:
raise TooManyTags(tags)
tag_keys = {}
for tag in tags:
# Need to index by the lowercase tag key since the keys are case insensitive, but their case is retained.
ref_key = tag['Key'].lower()
self._check_tag_duplicate(tag_keys, ref_key)
self._validate_tag_key(tag['Key'])
if len(tag['Value']) > 256:
raise TagValueTooBig(tag['Value'])
tag_keys[ref_key] = tag
return tag_keys
def _validate_tag_key(self, tag_key, exception_param='tags.X.member.key'): def _validate_tag_key(self, tag_key, exception_param='tags.X.member.key'):
"""Validates the tag key. """Validates the tag key.
:param all_tags: Dict to check if there is a duplicate tag.
:param tag_key: The tag key to check against. :param tag_key: The tag key to check against.
:param exception_param: The exception parameter to send over to help format the message. This is to reflect :param exception_param: The exception parameter to send over to help format the message. This is to reflect
the difference between the tag and untag APIs. the difference between the tag and untag APIs.
@ -694,23 +760,9 @@ class IAMBackend(BaseBackend):
return tags, marker return tags, marker
def tag_role(self, role_name, tags): def tag_role(self, role_name, tags):
if len(tags) > 50: clean_tags = self._tag_verification(tags)
raise TooManyTags(tags)
role = self.get_role(role_name) role = self.get_role(role_name)
role.tags.update(clean_tags)
tag_keys = {}
for tag in tags:
# Need to index by the lowercase tag key since the keys are case insensitive, but their case is retained.
ref_key = tag['Key'].lower()
self._check_tag_duplicate(tag_keys, ref_key)
self._validate_tag_key(tag['Key'])
if len(tag['Value']) > 256:
raise TagValueTooBig(tag['Value'])
tag_keys[ref_key] = tag
role.tags.update(tag_keys)
def untag_role(self, role_name, tag_keys): def untag_role(self, role_name, tag_keys):
if len(tag_keys) > 50: if len(tag_keys) > 50:
@ -725,15 +777,21 @@ class IAMBackend(BaseBackend):
role.tags.pop(ref_key, None) role.tags.pop(ref_key, None)
def create_policy_version(self, policy_arn, policy_document, set_as_default): def create_policy_version(self, policy_arn, policy_document, set_as_default):
iam_policy_document_validator = IAMPolicyDocumentValidator(policy_document)
iam_policy_document_validator.validate()
policy = self.get_policy(policy_arn) policy = self.get_policy(policy_arn)
if not policy: if not policy:
raise IAMNotFoundException("Policy not found") raise IAMNotFoundException("Policy not found")
if len(policy.versions) >= 5:
raise IAMLimitExceededException("A managed policy can have up to 5 versions. Before you create a new version, you must delete an existing version.")
set_as_default = (set_as_default == "true") # convert it to python bool
version = PolicyVersion(policy_arn, policy_document, set_as_default) version = PolicyVersion(policy_arn, policy_document, set_as_default)
policy.versions.append(version) policy.versions.append(version)
version.version_id = 'v{0}'.format(policy.next_version_num) version.version_id = 'v{0}'.format(policy.next_version_num)
policy.next_version_num += 1 policy.next_version_num += 1
if set_as_default: if set_as_default:
policy.default_version_id = version.version_id policy.update_default_version(version.version_id)
return version return version
def get_policy_version(self, policy_arn, version_id): def get_policy_version(self, policy_arn, version_id):
@ -756,8 +814,8 @@ class IAMBackend(BaseBackend):
if not policy: if not policy:
raise IAMNotFoundException("Policy not found") raise IAMNotFoundException("Policy not found")
if version_id == policy.default_version_id: if version_id == policy.default_version_id:
raise IAMConflictException( raise IAMConflictException(code="DeleteConflict",
"Cannot delete the default version of a policy") message="Cannot delete the default version of a policy.")
for i, v in enumerate(policy.versions): for i, v in enumerate(policy.versions):
if v.version_id == version_id: if v.version_id == version_id:
del policy.versions[i] del policy.versions[i]
@ -869,6 +927,9 @@ class IAMBackend(BaseBackend):
def put_group_policy(self, group_name, policy_name, policy_json): def put_group_policy(self, group_name, policy_name, policy_json):
group = self.get_group(group_name) group = self.get_group(group_name)
iam_policy_document_validator = IAMPolicyDocumentValidator(policy_json)
iam_policy_document_validator.validate()
group.put_policy(policy_name, policy_json) group.put_policy(policy_name, policy_json)
def list_group_policies(self, group_name, marker=None, max_items=None): def list_group_policies(self, group_name, marker=None, max_items=None):
@ -1029,6 +1090,9 @@ class IAMBackend(BaseBackend):
def put_user_policy(self, user_name, policy_name, policy_json): def put_user_policy(self, user_name, policy_name, policy_json):
user = self.get_user(user_name) user = self.get_user(user_name)
iam_policy_document_validator = IAMPolicyDocumentValidator(policy_json)
iam_policy_document_validator.validate()
user.put_policy(policy_name, policy_json) user.put_policy(policy_name, policy_json)
def delete_user_policy(self, user_name, policy_name): def delete_user_policy(self, user_name, policy_name):
@ -1050,7 +1114,7 @@ class IAMBackend(BaseBackend):
if key.access_key_id == access_key_id: if key.access_key_id == access_key_id:
return { return {
'user_name': key.user_name, 'user_name': key.user_name,
'last_used': key.last_used 'last_used': key.last_used_iso_8601,
} }
else: else:
raise IAMNotFoundException( raise IAMNotFoundException(
@ -1189,5 +1253,13 @@ class IAMBackend(BaseBackend):
return saml_provider return saml_provider
raise IAMNotFoundException("SamlProvider {0} not found".format(saml_provider_arn)) raise IAMNotFoundException("SamlProvider {0} not found".format(saml_provider_arn))
def get_user_from_access_key_id(self, access_key_id):
for user_name, user in self.users.items():
access_keys = self.get_all_access_keys(user_name)
for access_key in access_keys:
if access_key.access_key_id == access_key_id:
return user
return None
iam_backend = IAMBackend() iam_backend = IAMBackend()

View File

@ -1,6 +1,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import iam_backend, User from .models import iam_backend, User
@ -177,9 +178,11 @@ class IamResponse(BaseResponse):
'AssumeRolePolicyDocument') 'AssumeRolePolicyDocument')
permissions_boundary = self._get_param( permissions_boundary = self._get_param(
'PermissionsBoundary') 'PermissionsBoundary')
description = self._get_param('Description')
tags = self._get_multi_param('Tags.member')
role = iam_backend.create_role( role = iam_backend.create_role(
role_name, assume_role_policy_document, path, permissions_boundary) role_name, assume_role_policy_document, path, permissions_boundary, description, tags)
template = self.response_template(CREATE_ROLE_TEMPLATE) template = self.response_template(CREATE_ROLE_TEMPLATE)
return template.render(role=role) return template.render(role=role)
@ -425,11 +428,13 @@ class IamResponse(BaseResponse):
def get_user(self): def get_user(self):
user_name = self._get_param('UserName') user_name = self._get_param('UserName')
if user_name: if not user_name:
user = iam_backend.get_user(user_name) access_key_id = self.get_current_user()
user = iam_backend.get_user_from_access_key_id(access_key_id)
if user is None:
user = User("default_user")
else: else:
user = User(name='default_user') user = iam_backend.get_user(user_name)
# If no user is specific, IAM returns the current user
template = self.response_template(USER_TEMPLATE) template = self.response_template(USER_TEMPLATE)
return template.render(action='Get', user=user) return template.render(action='Get', user=user)
@ -457,7 +462,6 @@ class IamResponse(BaseResponse):
def create_login_profile(self): def create_login_profile(self):
user_name = self._get_param('UserName') user_name = self._get_param('UserName')
password = self._get_param('Password') password = self._get_param('Password')
password = self._get_param('Password')
user = iam_backend.create_login_profile(user_name, password) user = iam_backend.create_login_profile(user_name, password)
template = self.response_template(CREATE_LOGIN_PROFILE_TEMPLATE) template = self.response_template(CREATE_LOGIN_PROFILE_TEMPLATE)
@ -818,12 +822,12 @@ CREATE_POLICY_TEMPLATE = """<CreatePolicyResponse>
<Policy> <Policy>
<Arn>{{ policy.arn }}</Arn> <Arn>{{ policy.arn }}</Arn>
<AttachmentCount>{{ policy.attachment_count }}</AttachmentCount> <AttachmentCount>{{ policy.attachment_count }}</AttachmentCount>
<CreateDate>{{ policy.create_datetime.isoformat() }}</CreateDate> <CreateDate>{{ policy.created_iso_8601 }}</CreateDate>
<DefaultVersionId>{{ policy.default_version_id }}</DefaultVersionId> <DefaultVersionId>{{ policy.default_version_id }}</DefaultVersionId>
<Path>{{ policy.path }}</Path> <Path>{{ policy.path }}</Path>
<PolicyId>{{ policy.id }}</PolicyId> <PolicyId>{{ policy.id }}</PolicyId>
<PolicyName>{{ policy.name }}</PolicyName> <PolicyName>{{ policy.name }}</PolicyName>
<UpdateDate>{{ policy.update_datetime.isoformat() }}</UpdateDate> <UpdateDate>{{ policy.updated_iso_8601 }}</UpdateDate>
</Policy> </Policy>
</CreatePolicyResult> </CreatePolicyResult>
<ResponseMetadata> <ResponseMetadata>
@ -841,8 +845,8 @@ GET_POLICY_TEMPLATE = """<GetPolicyResponse>
<Path>{{ policy.path }}</Path> <Path>{{ policy.path }}</Path>
<Arn>{{ policy.arn }}</Arn> <Arn>{{ policy.arn }}</Arn>
<AttachmentCount>{{ policy.attachment_count }}</AttachmentCount> <AttachmentCount>{{ policy.attachment_count }}</AttachmentCount>
<CreateDate>{{ policy.create_datetime.isoformat() }}</CreateDate> <CreateDate>{{ policy.created_iso_8601 }}</CreateDate>
<UpdateDate>{{ policy.update_datetime.isoformat() }}</UpdateDate> <UpdateDate>{{ policy.updated_iso_8601 }}</UpdateDate>
</Policy> </Policy>
</GetPolicyResult> </GetPolicyResult>
<ResponseMetadata> <ResponseMetadata>
@ -929,12 +933,12 @@ LIST_POLICIES_TEMPLATE = """<ListPoliciesResponse>
<member> <member>
<Arn>{{ policy.arn }}</Arn> <Arn>{{ policy.arn }}</Arn>
<AttachmentCount>{{ policy.attachment_count }}</AttachmentCount> <AttachmentCount>{{ policy.attachment_count }}</AttachmentCount>
<CreateDate>{{ policy.create_datetime.isoformat() }}</CreateDate> <CreateDate>{{ policy.created_iso_8601 }}</CreateDate>
<DefaultVersionId>{{ policy.default_version_id }}</DefaultVersionId> <DefaultVersionId>{{ policy.default_version_id }}</DefaultVersionId>
<Path>{{ policy.path }}</Path> <Path>{{ policy.path }}</Path>
<PolicyId>{{ policy.id }}</PolicyId> <PolicyId>{{ policy.id }}</PolicyId>
<PolicyName>{{ policy.name }}</PolicyName> <PolicyName>{{ policy.name }}</PolicyName>
<UpdateDate>{{ policy.update_datetime.isoformat() }}</UpdateDate> <UpdateDate>{{ policy.updated_iso_8601 }}</UpdateDate>
</member> </member>
{% endfor %} {% endfor %}
</Policies> </Policies>
@ -958,7 +962,7 @@ CREATE_INSTANCE_PROFILE_TEMPLATE = """<CreateInstanceProfileResponse xmlns="http
<InstanceProfileName>{{ profile.name }}</InstanceProfileName> <InstanceProfileName>{{ profile.name }}</InstanceProfileName>
<Path>{{ profile.path }}</Path> <Path>{{ profile.path }}</Path>
<Arn>{{ profile.arn }}</Arn> <Arn>{{ profile.arn }}</Arn>
<CreateDate>{{ profile.create_date }}</CreateDate> <CreateDate>{{ profile.created_iso_8601 }}</CreateDate>
</InstanceProfile> </InstanceProfile>
</CreateInstanceProfileResult> </CreateInstanceProfileResult>
<ResponseMetadata> <ResponseMetadata>
@ -977,7 +981,7 @@ GET_INSTANCE_PROFILE_TEMPLATE = """<GetInstanceProfileResponse xmlns="https://ia
<Arn>{{ role.arn }}</Arn> <Arn>{{ role.arn }}</Arn>
<RoleName>{{ role.name }}</RoleName> <RoleName>{{ role.name }}</RoleName>
<AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument> <AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument>
<CreateDate>{{ role.create_date }}</CreateDate> <CreateDate>{{ role.created_iso_8601 }}</CreateDate>
<RoleId>{{ role.id }}</RoleId> <RoleId>{{ role.id }}</RoleId>
</member> </member>
{% endfor %} {% endfor %}
@ -985,7 +989,7 @@ GET_INSTANCE_PROFILE_TEMPLATE = """<GetInstanceProfileResponse xmlns="https://ia
<InstanceProfileName>{{ profile.name }}</InstanceProfileName> <InstanceProfileName>{{ profile.name }}</InstanceProfileName>
<Path>{{ profile.path }}</Path> <Path>{{ profile.path }}</Path>
<Arn>{{ profile.arn }}</Arn> <Arn>{{ profile.arn }}</Arn>
<CreateDate>{{ profile.create_date }}</CreateDate> <CreateDate>{{ profile.created_iso_8601 }}</CreateDate>
</InstanceProfile> </InstanceProfile>
</GetInstanceProfileResult> </GetInstanceProfileResult>
<ResponseMetadata> <ResponseMetadata>
@ -1000,7 +1004,8 @@ CREATE_ROLE_TEMPLATE = """<CreateRoleResponse xmlns="https://iam.amazonaws.com/d
<Arn>{{ role.arn }}</Arn> <Arn>{{ role.arn }}</Arn>
<RoleName>{{ role.name }}</RoleName> <RoleName>{{ role.name }}</RoleName>
<AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument> <AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument>
<CreateDate>{{ role.create_date }}</CreateDate> <Description>{{role.description}}</Description>
<CreateDate>{{ role.created_iso_8601 }}</CreateDate>
<RoleId>{{ role.id }}</RoleId> <RoleId>{{ role.id }}</RoleId>
{% if role.permissions_boundary %} {% if role.permissions_boundary %}
<PermissionsBoundary> <PermissionsBoundary>
@ -1008,6 +1013,16 @@ CREATE_ROLE_TEMPLATE = """<CreateRoleResponse xmlns="https://iam.amazonaws.com/d
<PermissionsBoundaryArn>{{ role.permissions_boundary }}</PermissionsBoundaryArn> <PermissionsBoundaryArn>{{ role.permissions_boundary }}</PermissionsBoundaryArn>
</PermissionsBoundary> </PermissionsBoundary>
{% endif %} {% endif %}
{% if role.tags %}
<Tags>
{% for tag in role.get_tags() %}
<member>
<Key>{{ tag['Key'] }}</Key>
<Value>{{ tag['Value'] }}</Value>
</member>
{% endfor %}
</Tags>
{% endif %}
</Role> </Role>
</CreateRoleResult> </CreateRoleResult>
<ResponseMetadata> <ResponseMetadata>
@ -1041,7 +1056,8 @@ UPDATE_ROLE_DESCRIPTION_TEMPLATE = """<UpdateRoleDescriptionResponse xmlns="http
<Arn>{{ role.arn }}</Arn> <Arn>{{ role.arn }}</Arn>
<RoleName>{{ role.name }}</RoleName> <RoleName>{{ role.name }}</RoleName>
<AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument> <AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument>
<CreateDate>{{ role.create_date.isoformat() }}</CreateDate> <Description>{{role.description}}</Description>
<CreateDate>{{ role.created_iso_8601 }}</CreateDate>
<RoleId>{{ role.id }}</RoleId> <RoleId>{{ role.id }}</RoleId>
{% if role.tags %} {% if role.tags %}
<Tags> <Tags>
@ -1067,7 +1083,8 @@ GET_ROLE_TEMPLATE = """<GetRoleResponse xmlns="https://iam.amazonaws.com/doc/201
<Arn>{{ role.arn }}</Arn> <Arn>{{ role.arn }}</Arn>
<RoleName>{{ role.name }}</RoleName> <RoleName>{{ role.name }}</RoleName>
<AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument> <AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument>
<CreateDate>{{ role.create_date }}</CreateDate> <Description>{{role.description}}</Description>
<CreateDate>{{ role.created_iso_8601 }}</CreateDate>
<RoleId>{{ role.id }}</RoleId> <RoleId>{{ role.id }}</RoleId>
{% if role.tags %} {% if role.tags %}
<Tags> <Tags>
@ -1108,7 +1125,7 @@ LIST_ROLES_TEMPLATE = """<ListRolesResponse xmlns="https://iam.amazonaws.com/doc
<Arn>{{ role.arn }}</Arn> <Arn>{{ role.arn }}</Arn>
<RoleName>{{ role.name }}</RoleName> <RoleName>{{ role.name }}</RoleName>
<AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument> <AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument>
<CreateDate>{{ role.create_date }}</CreateDate> <CreateDate>{{ role.created_iso_8601 }}</CreateDate>
<RoleId>{{ role.id }}</RoleId> <RoleId>{{ role.id }}</RoleId>
{% if role.permissions_boundary %} {% if role.permissions_boundary %}
<PermissionsBoundary> <PermissionsBoundary>
@ -1144,8 +1161,8 @@ CREATE_POLICY_VERSION_TEMPLATE = """<CreatePolicyVersionResponse xmlns="https://
<PolicyVersion> <PolicyVersion>
<Document>{{ policy_version.document }}</Document> <Document>{{ policy_version.document }}</Document>
<VersionId>{{ policy_version.version_id }}</VersionId> <VersionId>{{ policy_version.version_id }}</VersionId>
<IsDefaultVersion>{{ policy_version.is_default }}</IsDefaultVersion> <IsDefaultVersion>{{ policy_version.is_default | lower }}</IsDefaultVersion>
<CreateDate>{{ policy_version.create_datetime }}</CreateDate> <CreateDate>{{ policy_version.created_iso_8601 }}</CreateDate>
</PolicyVersion> </PolicyVersion>
</CreatePolicyVersionResult> </CreatePolicyVersionResult>
<ResponseMetadata> <ResponseMetadata>
@ -1158,8 +1175,8 @@ GET_POLICY_VERSION_TEMPLATE = """<GetPolicyVersionResponse xmlns="https://iam.am
<PolicyVersion> <PolicyVersion>
<Document>{{ policy_version.document }}</Document> <Document>{{ policy_version.document }}</Document>
<VersionId>{{ policy_version.version_id }}</VersionId> <VersionId>{{ policy_version.version_id }}</VersionId>
<IsDefaultVersion>{{ policy_version.is_default }}</IsDefaultVersion> <IsDefaultVersion>{{ policy_version.is_default | lower }}</IsDefaultVersion>
<CreateDate>{{ policy_version.create_datetime }}</CreateDate> <CreateDate>{{ policy_version.created_iso_8601 }}</CreateDate>
</PolicyVersion> </PolicyVersion>
</GetPolicyVersionResult> </GetPolicyVersionResult>
<ResponseMetadata> <ResponseMetadata>
@ -1175,8 +1192,8 @@ LIST_POLICY_VERSIONS_TEMPLATE = """<ListPolicyVersionsResponse xmlns="https://ia
<member> <member>
<Document>{{ policy_version.document }}</Document> <Document>{{ policy_version.document }}</Document>
<VersionId>{{ policy_version.version_id }}</VersionId> <VersionId>{{ policy_version.version_id }}</VersionId>
<IsDefaultVersion>{{ policy_version.is_default }}</IsDefaultVersion> <IsDefaultVersion>{{ policy_version.is_default | lower }}</IsDefaultVersion>
<CreateDate>{{ policy_version.create_datetime }}</CreateDate> <CreateDate>{{ policy_version.created_iso_8601 }}</CreateDate>
</member> </member>
{% endfor %} {% endfor %}
</Versions> </Versions>
@ -1200,7 +1217,7 @@ LIST_INSTANCE_PROFILES_TEMPLATE = """<ListInstanceProfilesResponse xmlns="https:
<Arn>{{ role.arn }}</Arn> <Arn>{{ role.arn }}</Arn>
<RoleName>{{ role.name }}</RoleName> <RoleName>{{ role.name }}</RoleName>
<AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument> <AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument>
<CreateDate>{{ role.create_date }}</CreateDate> <CreateDate>{{ role.created_iso_8601 }}</CreateDate>
<RoleId>{{ role.id }}</RoleId> <RoleId>{{ role.id }}</RoleId>
</member> </member>
{% endfor %} {% endfor %}
@ -1208,7 +1225,7 @@ LIST_INSTANCE_PROFILES_TEMPLATE = """<ListInstanceProfilesResponse xmlns="https:
<InstanceProfileName>{{ instance.name }}</InstanceProfileName> <InstanceProfileName>{{ instance.name }}</InstanceProfileName>
<Path>{{ instance.path }}</Path> <Path>{{ instance.path }}</Path>
<Arn>{{ instance.arn }}</Arn> <Arn>{{ instance.arn }}</Arn>
<CreateDate>{{ instance.create_date }}</CreateDate> <CreateDate>{{ instance.created_iso_8601 }}</CreateDate>
</member> </member>
{% endfor %} {% endfor %}
</InstanceProfiles> </InstanceProfiles>
@ -1287,7 +1304,7 @@ CREATE_GROUP_TEMPLATE = """<CreateGroupResponse>
<GroupName>{{ group.name }}</GroupName> <GroupName>{{ group.name }}</GroupName>
<GroupId>{{ group.id }}</GroupId> <GroupId>{{ group.id }}</GroupId>
<Arn>{{ group.arn }}</Arn> <Arn>{{ group.arn }}</Arn>
<CreateDate>{{ group.create_date }}</CreateDate> <CreateDate>{{ group.created_iso_8601 }}</CreateDate>
</Group> </Group>
</CreateGroupResult> </CreateGroupResult>
<ResponseMetadata> <ResponseMetadata>
@ -1302,7 +1319,7 @@ GET_GROUP_TEMPLATE = """<GetGroupResponse>
<GroupName>{{ group.name }}</GroupName> <GroupName>{{ group.name }}</GroupName>
<GroupId>{{ group.id }}</GroupId> <GroupId>{{ group.id }}</GroupId>
<Arn>{{ group.arn }}</Arn> <Arn>{{ group.arn }}</Arn>
<CreateDate>{{ group.create_date }}</CreateDate> <CreateDate>{{ group.created_iso_8601 }}</CreateDate>
</Group> </Group>
<Users> <Users>
{% for user in group.users %} {% for user in group.users %}
@ -1349,6 +1366,7 @@ LIST_GROUPS_FOR_USER_TEMPLATE = """<ListGroupsForUserResponse>
<GroupName>{{ group.name }}</GroupName> <GroupName>{{ group.name }}</GroupName>
<GroupId>{{ group.id }}</GroupId> <GroupId>{{ group.id }}</GroupId>
<Arn>{{ group.arn }}</Arn> <Arn>{{ group.arn }}</Arn>
<CreateDate>{{ group.created_iso_8601 }}</CreateDate>
</member> </member>
{% endfor %} {% endfor %}
</Groups> </Groups>
@ -1493,6 +1511,7 @@ CREATE_ACCESS_KEY_TEMPLATE = """<CreateAccessKeyResponse>
<AccessKeyId>{{ key.access_key_id }}</AccessKeyId> <AccessKeyId>{{ key.access_key_id }}</AccessKeyId>
<Status>{{ key.status }}</Status> <Status>{{ key.status }}</Status>
<SecretAccessKey>{{ key.secret_access_key }}</SecretAccessKey> <SecretAccessKey>{{ key.secret_access_key }}</SecretAccessKey>
<CreateDate>{{ key.created_iso_8601 }}</CreateDate>
</AccessKey> </AccessKey>
</CreateAccessKeyResult> </CreateAccessKeyResult>
<ResponseMetadata> <ResponseMetadata>
@ -1509,7 +1528,7 @@ LIST_ACCESS_KEYS_TEMPLATE = """<ListAccessKeysResponse>
<UserName>{{ user_name }}</UserName> <UserName>{{ user_name }}</UserName>
<AccessKeyId>{{ key.access_key_id }}</AccessKeyId> <AccessKeyId>{{ key.access_key_id }}</AccessKeyId>
<Status>{{ key.status }}</Status> <Status>{{ key.status }}</Status>
<CreateDate>{{ key.create_date }}</CreateDate> <CreateDate>{{ key.created_iso_8601 }}</CreateDate>
</member> </member>
{% endfor %} {% endfor %}
</AccessKeyMetadata> </AccessKeyMetadata>
@ -1577,7 +1596,7 @@ LIST_INSTANCE_PROFILES_FOR_ROLE_TEMPLATE = """<ListInstanceProfilesForRoleRespon
<Arn>{{ role.arn }}</Arn> <Arn>{{ role.arn }}</Arn>
<RoleName>{{ role.name }}</RoleName> <RoleName>{{ role.name }}</RoleName>
<AssumeRolePolicyDocument>{{ role.assume_policy_document }}</AssumeRolePolicyDocument> <AssumeRolePolicyDocument>{{ role.assume_policy_document }}</AssumeRolePolicyDocument>
<CreateDate>{{ role.create_date }}</CreateDate> <CreateDate>{{ role.created_iso_8601 }}</CreateDate>
<RoleId>{{ role.id }}</RoleId> <RoleId>{{ role.id }}</RoleId>
</member> </member>
{% endfor %} {% endfor %}
@ -1585,7 +1604,7 @@ LIST_INSTANCE_PROFILES_FOR_ROLE_TEMPLATE = """<ListInstanceProfilesForRoleRespon
<InstanceProfileName>{{ profile.name }}</InstanceProfileName> <InstanceProfileName>{{ profile.name }}</InstanceProfileName>
<Path>{{ profile.path }}</Path> <Path>{{ profile.path }}</Path>
<Arn>{{ profile.arn }}</Arn> <Arn>{{ profile.arn }}</Arn>
<CreateDate>{{ profile.create_date }}</CreateDate> <CreateDate>{{ profile.created_iso_8601 }}</CreateDate>
</member> </member>
{% endfor %} {% endfor %}
</InstanceProfiles> </InstanceProfiles>
@ -1651,6 +1670,7 @@ LIST_GROUPS_FOR_USER_TEMPLATE = """<ListGroupsForUserResponse>
<GroupName>{{ group.name }}</GroupName> <GroupName>{{ group.name }}</GroupName>
<GroupId>{{ group.id }}</GroupId> <GroupId>{{ group.id }}</GroupId>
<Arn>{{ group.arn }}</Arn> <Arn>{{ group.arn }}</Arn>
<CreateDate>{{ group.created_iso_8601 }}</CreateDate>
</member> </member>
{% endfor %} {% endfor %}
</Groups> </Groups>
@ -1704,7 +1724,7 @@ GET_ACCOUNT_AUTHORIZATION_DETAILS_TEMPLATE = """<GetAccountAuthorizationDetailsR
<GroupName>{{ group.name }}</GroupName> <GroupName>{{ group.name }}</GroupName>
<Path>{{ group.path }}</Path> <Path>{{ group.path }}</Path>
<Arn>{{ group.arn }}</Arn> <Arn>{{ group.arn }}</Arn>
<CreateDate>{{ group.create_date }}</CreateDate> <CreateDate>{{ group.created_iso_8601 }}</CreateDate>
<GroupPolicyList> <GroupPolicyList>
{% for policy in group.policies %} {% for policy in group.policies %}
<member> <member>
@ -1754,15 +1774,22 @@ GET_ACCOUNT_AUTHORIZATION_DETAILS_TEMPLATE = """<GetAccountAuthorizationDetailsR
<Arn>{{ role.arn }}</Arn> <Arn>{{ role.arn }}</Arn>
<RoleName>{{ role.name }}</RoleName> <RoleName>{{ role.name }}</RoleName>
<AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument> <AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument>
<CreateDate>{{ role.create_date }}</CreateDate> <Description>{{role.description}}</Description>
<CreateDate>{{ role.created_iso_8601 }}</CreateDate>
<RoleId>{{ role.id }}</RoleId> <RoleId>{{ role.id }}</RoleId>
{% if role.permissions_boundary %}
<PermissionsBoundary>
<PermissionsBoundaryType>PermissionsBoundaryPolicy</PermissionsBoundaryType>
<PermissionsBoundaryArn>{{ role.permissions_boundary }}</PermissionsBoundaryArn>
</PermissionsBoundary>
{% endif %}
</member> </member>
{% endfor %} {% endfor %}
</Roles> </Roles>
<InstanceProfileName>{{ profile.name }}</InstanceProfileName> <InstanceProfileName>{{ profile.name }}</InstanceProfileName>
<Path>{{ profile.path }}</Path> <Path>{{ profile.path }}</Path>
<Arn>{{ profile.arn }}</Arn> <Arn>{{ profile.arn }}</Arn>
<CreateDate>{{ profile.create_date }}</CreateDate> <CreateDate>{{ profile.created_iso_8601 }}</CreateDate>
</member> </member>
{% endfor %} {% endfor %}
</InstanceProfileList> </InstanceProfileList>
@ -1770,7 +1797,7 @@ GET_ACCOUNT_AUTHORIZATION_DETAILS_TEMPLATE = """<GetAccountAuthorizationDetailsR
<Arn>{{ role.arn }}</Arn> <Arn>{{ role.arn }}</Arn>
<RoleName>{{ role.name }}</RoleName> <RoleName>{{ role.name }}</RoleName>
<AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument> <AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument>
<CreateDate>{{ role.create_date }}</CreateDate> <CreateDate>{{ role.created_iso_8601 }}</CreateDate>
<RoleId>{{ role.id }}</RoleId> <RoleId>{{ role.id }}</RoleId>
</member> </member>
{% endfor %} {% endfor %}
@ -1786,17 +1813,17 @@ GET_ACCOUNT_AUTHORIZATION_DETAILS_TEMPLATE = """<GetAccountAuthorizationDetailsR
{% for policy_version in policy.versions %} {% for policy_version in policy.versions %}
<member> <member>
<Document>{{ policy_version.document }}</Document> <Document>{{ policy_version.document }}</Document>
<IsDefaultVersion>{{ policy_version.is_default }}</IsDefaultVersion> <IsDefaultVersion>{{ policy_version.is_default | lower }}</IsDefaultVersion>
<VersionId>{{ policy_version.version_id }}</VersionId> <VersionId>{{ policy_version.version_id }}</VersionId>
<CreateDate>{{ policy_version.create_datetime }}</CreateDate> <CreateDate>{{ policy_version.created_iso_8601 }}</CreateDate>
</member> </member>
{% endfor %} {% endfor %}
</PolicyVersionList> </PolicyVersionList>
<Arn>{{ policy.arn }}</Arn> <Arn>{{ policy.arn }}</Arn>
<AttachmentCount>1</AttachmentCount> <AttachmentCount>1</AttachmentCount>
<CreateDate>{{ policy.create_datetime }}</CreateDate> <CreateDate>{{ policy.created_iso_8601 }}</CreateDate>
<IsAttachable>true</IsAttachable> <IsAttachable>true</IsAttachable>
<UpdateDate>{{ policy.update_datetime }}</UpdateDate> <UpdateDate>{{ policy.updated_iso_8601 }}</UpdateDate>
</member> </member>
{% endfor %} {% endfor %}
</Policies> </Policies>

View File

@ -7,7 +7,7 @@ import six
def random_alphanumeric(length): def random_alphanumeric(length):
return ''.join(six.text_type( return ''.join(six.text_type(
random.choice( random.choice(
string.ascii_letters + string.digits string.ascii_letters + string.digits + "+" + "/"
)) for _ in range(length) )) for _ in range(length)
) )

View File

@ -123,17 +123,12 @@ class Stream(BaseModel):
self.tags = {} self.tags = {}
self.status = "ACTIVE" self.status = "ACTIVE"
if six.PY3: step = 2**128 // shard_count
izip_longest = itertools.zip_longest hash_ranges = itertools.chain(map(lambda i: (i, i * step, (i + 1) * step),
else: range(shard_count - 1)),
izip_longest = itertools.izip_longest [(shard_count - 1, (shard_count - 1) * step, 2**128)])
for index, start, end in hash_ranges:
for index, start, end in izip_longest(range(shard_count),
range(0, 2**128, 2 **
128 // shard_count),
range(2**128 // shard_count, 2 **
128, 2**128 // shard_count),
fillvalue=2**128):
shard = Shard(index, start, end) shard = Shard(index, start, end)
self.shards[shard.shard_id] = shard self.shards[shard.shard_id] = shard

View File

@ -3,7 +3,7 @@ from __future__ import unicode_literals
import os import os
import boto.kms import boto.kms
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_without_milliseconds, unix_time from moto.core.utils import iso_8601_datetime_without_milliseconds
from .utils import generate_key_id from .utils import generate_key_id
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -11,7 +11,7 @@ from datetime import datetime, timedelta
class Key(BaseModel): class Key(BaseModel):
def __init__(self, policy, key_usage, description, region): def __init__(self, policy, key_usage, description, tags, region):
self.id = generate_key_id() self.id = generate_key_id()
self.policy = policy self.policy = policy
self.key_usage = key_usage self.key_usage = key_usage
@ -22,7 +22,7 @@ class Key(BaseModel):
self.account_id = "0123456789012" self.account_id = "0123456789012"
self.key_rotation_status = False self.key_rotation_status = False
self.deletion_date = None self.deletion_date = None
self.tags = {} self.tags = tags or {}
@property @property
def physical_resource_id(self): def physical_resource_id(self):
@ -37,7 +37,7 @@ class Key(BaseModel):
"KeyMetadata": { "KeyMetadata": {
"AWSAccountId": self.account_id, "AWSAccountId": self.account_id,
"Arn": self.arn, "Arn": self.arn,
"CreationDate": "%d" % unix_time(), "CreationDate": iso_8601_datetime_without_milliseconds(datetime.now()),
"Description": self.description, "Description": self.description,
"Enabled": self.enabled, "Enabled": self.enabled,
"KeyId": self.id, "KeyId": self.id,
@ -61,6 +61,7 @@ class Key(BaseModel):
policy=properties['KeyPolicy'], policy=properties['KeyPolicy'],
key_usage='ENCRYPT_DECRYPT', key_usage='ENCRYPT_DECRYPT',
description=properties['Description'], description=properties['Description'],
tags=properties.get('Tags'),
region=region_name, region=region_name,
) )
key.key_rotation_status = properties['EnableKeyRotation'] key.key_rotation_status = properties['EnableKeyRotation']
@ -80,8 +81,8 @@ class KmsBackend(BaseBackend):
self.keys = {} self.keys = {}
self.key_to_aliases = defaultdict(set) self.key_to_aliases = defaultdict(set)
def create_key(self, policy, key_usage, description, region): def create_key(self, policy, key_usage, description, tags, region):
key = Key(policy, key_usage, description, region) key = Key(policy, key_usage, description, tags, region)
self.keys[key.id] = key self.keys[key.id] = key
return key return key

View File

@ -31,9 +31,10 @@ class KmsResponse(BaseResponse):
policy = self.parameters.get('Policy') policy = self.parameters.get('Policy')
key_usage = self.parameters.get('KeyUsage') key_usage = self.parameters.get('KeyUsage')
description = self.parameters.get('Description') description = self.parameters.get('Description')
tags = self.parameters.get('Tags')
key = self.kms_backend.create_key( key = self.kms_backend.create_key(
policy, key_usage, description, self.region) policy, key_usage, description, tags, self.region)
return json.dumps(key.to_dict()) return json.dumps(key.to_dict())
def update_key_description(self): def update_key_description(self):
@ -237,7 +238,7 @@ class KmsResponse(BaseResponse):
value = self.parameters.get("CiphertextBlob") value = self.parameters.get("CiphertextBlob")
try: try:
return json.dumps({"Plaintext": base64.b64decode(value).decode("utf-8")}) return json.dumps({"Plaintext": base64.b64decode(value).decode("utf-8"), 'KeyId': 'key_id'})
except UnicodeDecodeError: except UnicodeDecodeError:
# Generate data key will produce random bytes which when decrypted is still returned as base64 # Generate data key will produce random bytes which when decrypted is still returned as base64
return json.dumps({"Plaintext": value}) return json.dumps({"Plaintext": value})

View File

@ -98,17 +98,29 @@ class LogStream:
return True return True
def get_paging_token_from_index(index, back=False):
if index is not None:
return "b/{:056d}".format(index) if back else "f/{:056d}".format(index)
return 0
def get_index_from_paging_token(token):
if token is not None:
return int(token[2:])
return 0
events = sorted(filter(filter_func, self.events), key=lambda event: event.timestamp, reverse=start_from_head) events = sorted(filter(filter_func, self.events), key=lambda event: event.timestamp, reverse=start_from_head)
back_token = next_token next_index = get_index_from_paging_token(next_token)
if next_token is None: back_index = next_index
next_token = 0
events_page = [event.to_response_dict() for event in events[next_token: next_token + limit]] events_page = [event.to_response_dict() for event in events[next_index: next_index + limit]]
next_token += limit if next_index + limit < len(self.events):
if next_token >= len(self.events): next_index += limit
next_token = None
return events_page, back_token, next_token back_index -= limit
if back_index <= 0:
back_index = 0
return events_page, get_paging_token_from_index(back_index, True), get_paging_token_from_index(next_index)
def filter_log_events(self, log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved): def filter_log_events(self, log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved):
def filter_func(event): def filter_func(event):

View File

@ -2,6 +2,7 @@ from __future__ import unicode_literals
import datetime import datetime
import re import re
import json
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.exceptions import RESTError from moto.core.exceptions import RESTError
@ -151,7 +152,6 @@ class FakeRoot(FakeOrganizationalUnit):
class FakeServiceControlPolicy(BaseModel): class FakeServiceControlPolicy(BaseModel):
def __init__(self, organization, **kwargs): def __init__(self, organization, **kwargs):
self.type = 'POLICY'
self.content = kwargs.get('Content') self.content = kwargs.get('Content')
self.description = kwargs.get('Description') self.description = kwargs.get('Description')
self.name = kwargs.get('Name') self.name = kwargs.get('Name')
@ -197,7 +197,38 @@ class OrganizationsBackend(BaseBackend):
def create_organization(self, **kwargs): def create_organization(self, **kwargs):
self.org = FakeOrganization(kwargs['FeatureSet']) self.org = FakeOrganization(kwargs['FeatureSet'])
self.ou.append(FakeRoot(self.org)) root_ou = FakeRoot(self.org)
self.ou.append(root_ou)
master_account = FakeAccount(
self.org,
AccountName='master',
Email=self.org.master_account_email,
)
master_account.id = self.org.master_account_id
self.accounts.append(master_account)
default_policy = FakeServiceControlPolicy(
self.org,
Name='FullAWSAccess',
Description='Allows access to every operation',
Type='SERVICE_CONTROL_POLICY',
Content=json.dumps(
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": "*",
"Resource": "*"
}
]
}
)
)
default_policy.id = utils.DEFAULT_POLICY_ID
default_policy.aws_managed = True
self.policies.append(default_policy)
self.attach_policy(PolicyId=default_policy.id, TargetId=root_ou.id)
self.attach_policy(PolicyId=default_policy.id, TargetId=master_account.id)
return self.org.describe() return self.org.describe()
def describe_organization(self): def describe_organization(self):
@ -216,6 +247,7 @@ class OrganizationsBackend(BaseBackend):
def create_organizational_unit(self, **kwargs): def create_organizational_unit(self, **kwargs):
new_ou = FakeOrganizationalUnit(self.org, **kwargs) new_ou = FakeOrganizationalUnit(self.org, **kwargs)
self.ou.append(new_ou) self.ou.append(new_ou)
self.attach_policy(PolicyId=utils.DEFAULT_POLICY_ID, TargetId=new_ou.id)
return new_ou.describe() return new_ou.describe()
def get_organizational_unit_by_id(self, ou_id): def get_organizational_unit_by_id(self, ou_id):
@ -258,6 +290,7 @@ class OrganizationsBackend(BaseBackend):
def create_account(self, **kwargs): def create_account(self, **kwargs):
new_account = FakeAccount(self.org, **kwargs) new_account = FakeAccount(self.org, **kwargs)
self.accounts.append(new_account) self.accounts.append(new_account)
self.attach_policy(PolicyId=utils.DEFAULT_POLICY_ID, TargetId=new_account.id)
return new_account.create_account_status return new_account.create_account_status
def get_account_by_id(self, account_id): def get_account_by_id(self, account_id):
@ -358,8 +391,7 @@ class OrganizationsBackend(BaseBackend):
def attach_policy(self, **kwargs): def attach_policy(self, **kwargs):
policy = next((p for p in self.policies if p.id == kwargs['PolicyId']), None) policy = next((p for p in self.policies if p.id == kwargs['PolicyId']), None)
if (re.compile(utils.ROOT_ID_REGEX).match(kwargs['TargetId']) or if (re.compile(utils.ROOT_ID_REGEX).match(kwargs['TargetId']) or re.compile(utils.OU_ID_REGEX).match(kwargs['TargetId'])):
re.compile(utils.OU_ID_REGEX).match(kwargs['TargetId'])):
ou = next((ou for ou in self.ou if ou.id == kwargs['TargetId']), None) ou = next((ou for ou in self.ou if ou.id == kwargs['TargetId']), None)
if ou is not None: if ou is not None:
if ou not in ou.attached_policies: if ou not in ou.attached_policies:

View File

@ -4,7 +4,8 @@ import random
import string import string
MASTER_ACCOUNT_ID = '123456789012' MASTER_ACCOUNT_ID = '123456789012'
MASTER_ACCOUNT_EMAIL = 'fakeorg@moto-example.com' MASTER_ACCOUNT_EMAIL = 'master@example.com'
DEFAULT_POLICY_ID = 'p-FullAWSAccess'
ORGANIZATION_ARN_FORMAT = 'arn:aws:organizations::{0}:organization/{1}' ORGANIZATION_ARN_FORMAT = 'arn:aws:organizations::{0}:organization/{1}'
MASTER_ACCOUNT_ARN_FORMAT = 'arn:aws:organizations::{0}:account/{1}/{0}' MASTER_ACCOUNT_ARN_FORMAT = 'arn:aws:organizations::{0}:account/{1}/{0}'
ACCOUNT_ARN_FORMAT = 'arn:aws:organizations::{0}:account/{1}/{2}' ACCOUNT_ARN_FORMAT = 'arn:aws:organizations::{0}:account/{1}/{2}'
@ -26,7 +27,7 @@ ROOT_ID_REGEX = r'r-[a-z0-9]{%s}' % ROOT_ID_SIZE
OU_ID_REGEX = r'ou-[a-z0-9]{%s}-[a-z0-9]{%s}' % (ROOT_ID_SIZE, OU_ID_SUFFIX_SIZE) OU_ID_REGEX = r'ou-[a-z0-9]{%s}-[a-z0-9]{%s}' % (ROOT_ID_SIZE, OU_ID_SUFFIX_SIZE)
ACCOUNT_ID_REGEX = r'[0-9]{%s}' % ACCOUNT_ID_SIZE ACCOUNT_ID_REGEX = r'[0-9]{%s}' % ACCOUNT_ID_SIZE
CREATE_ACCOUNT_STATUS_ID_REGEX = r'car-[a-z0-9]{%s}' % CREATE_ACCOUNT_STATUS_ID_SIZE CREATE_ACCOUNT_STATUS_ID_REGEX = r'car-[a-z0-9]{%s}' % CREATE_ACCOUNT_STATUS_ID_SIZE
SCP_ID_REGEX = r'p-[a-z0-9]{%s}' % SCP_ID_SIZE SCP_ID_REGEX = r'%s|p-[a-z0-9]{%s}' % (DEFAULT_POLICY_ID, SCP_ID_SIZE)
def make_random_org_id(): def make_random_org_id():

View File

@ -268,10 +268,26 @@ class fakesock(object):
_sent_data = [] _sent_data = []
def __init__(self, family=socket.AF_INET, type=socket.SOCK_STREAM, def __init__(self, family=socket.AF_INET, type=socket.SOCK_STREAM,
protocol=0): proto=0, fileno=None, _sock=None):
self.truesock = (old_socket(family, type, protocol) """
if httpretty.allow_net_connect Matches both the Python 2 API:
else None) def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, _sock=None):
https://github.com/python/cpython/blob/2.7/Lib/socket.py
and the Python 3 API:
def __init__(self, family=-1, type=-1, proto=-1, fileno=None):
https://github.com/python/cpython/blob/3.5/Lib/socket.py
"""
if httpretty.allow_net_connect:
if PY3:
self.truesock = old_socket(family, type, proto, fileno)
else:
# If Python 2, if parameters are passed as arguments, instead of kwargs,
# the 4th argument `_sock` will be interpreted as the `fileno`.
# Check if _sock is none, and if so, pass fileno.
self.truesock = old_socket(family, type, proto, fileno or _sock)
else:
self.truesock = None
self._closed = True self._closed = True
self.fd = FakeSockFile() self.fd = FakeSockFile()
self.fd.socket = self self.fd.socket = self

View File

@ -95,7 +95,7 @@ class RDSResponse(BaseResponse):
start = all_ids.index(marker) + 1 start = all_ids.index(marker) + 1
else: else:
start = 0 start = 0
page_size = self._get_param('MaxRecords', 50) # the default is 100, but using 50 to make testing easier page_size = self._get_int_param('MaxRecords', 50) # the default is 100, but using 50 to make testing easier
instances_resp = all_instances[start:start + page_size] instances_resp = all_instances[start:start + page_size]
next_marker = None next_marker = None
if len(all_instances) > start + page_size: if len(all_instances) > start + page_size:

View File

@ -60,6 +60,15 @@ class DBParameterGroupNotFoundError(RDSClientError):
'DB Parameter Group {0} not found.'.format(db_parameter_group_name)) 'DB Parameter Group {0} not found.'.format(db_parameter_group_name))
class OptionGroupNotFoundFaultError(RDSClientError):
def __init__(self, option_group_name):
super(OptionGroupNotFoundFaultError, self).__init__(
'OptionGroupNotFoundFault',
'Specified OptionGroupName: {0} not found.'.format(option_group_name)
)
class InvalidDBClusterStateFaultError(RDSClientError): class InvalidDBClusterStateFaultError(RDSClientError):
def __init__(self, database_identifier): def __init__(self, database_identifier):

View File

@ -20,6 +20,7 @@ from .exceptions import (RDSClientError,
DBSecurityGroupNotFoundError, DBSecurityGroupNotFoundError,
DBSubnetGroupNotFoundError, DBSubnetGroupNotFoundError,
DBParameterGroupNotFoundError, DBParameterGroupNotFoundError,
OptionGroupNotFoundFaultError,
InvalidDBClusterStateFaultError, InvalidDBClusterStateFaultError,
InvalidDBInstanceStateError, InvalidDBInstanceStateError,
SnapshotQuotaExceededError, SnapshotQuotaExceededError,
@ -70,6 +71,7 @@ class Database(BaseModel):
self.port = Database.default_port(self.engine) self.port = Database.default_port(self.engine)
self.db_instance_identifier = kwargs.get('db_instance_identifier') self.db_instance_identifier = kwargs.get('db_instance_identifier')
self.db_name = kwargs.get("db_name") self.db_name = kwargs.get("db_name")
self.instance_create_time = iso_8601_datetime_with_milliseconds(datetime.datetime.now())
self.publicly_accessible = kwargs.get("publicly_accessible") self.publicly_accessible = kwargs.get("publicly_accessible")
if self.publicly_accessible is None: if self.publicly_accessible is None:
self.publicly_accessible = True self.publicly_accessible = True
@ -99,6 +101,8 @@ class Database(BaseModel):
'preferred_backup_window', '13:14-13:44') 'preferred_backup_window', '13:14-13:44')
self.license_model = kwargs.get('license_model', 'general-public-license') self.license_model = kwargs.get('license_model', 'general-public-license')
self.option_group_name = kwargs.get('option_group_name', None) self.option_group_name = kwargs.get('option_group_name', None)
if self.option_group_name and self.option_group_name not in rds2_backends[self.region].option_groups:
raise OptionGroupNotFoundFaultError(self.option_group_name)
self.default_option_groups = {"MySQL": "default.mysql5.6", self.default_option_groups = {"MySQL": "default.mysql5.6",
"mysql": "default.mysql5.6", "mysql": "default.mysql5.6",
"postgres": "default.postgres9.3" "postgres": "default.postgres9.3"
@ -145,9 +149,17 @@ class Database(BaseModel):
<DBInstanceStatus>{{ database.status }}</DBInstanceStatus> <DBInstanceStatus>{{ database.status }}</DBInstanceStatus>
{% if database.db_name %}<DBName>{{ database.db_name }}</DBName>{% endif %} {% if database.db_name %}<DBName>{{ database.db_name }}</DBName>{% endif %}
<MultiAZ>{{ database.multi_az }}</MultiAZ> <MultiAZ>{{ database.multi_az }}</MultiAZ>
<VpcSecurityGroups/> <VpcSecurityGroups>
{% for vpc_security_group_id in database.vpc_security_group_ids %}
<VpcSecurityGroupMembership>
<Status>active</Status>
<VpcSecurityGroupId>{{ vpc_security_group_id }}</VpcSecurityGroupId>
</VpcSecurityGroupMembership>
{% endfor %}
</VpcSecurityGroups>
<DBInstanceIdentifier>{{ database.db_instance_identifier }}</DBInstanceIdentifier> <DBInstanceIdentifier>{{ database.db_instance_identifier }}</DBInstanceIdentifier>
<DbiResourceId>{{ database.dbi_resource_id }}</DbiResourceId> <DbiResourceId>{{ database.dbi_resource_id }}</DbiResourceId>
<InstanceCreateTime>{{ database.instance_create_time }}</InstanceCreateTime>
<PreferredBackupWindow>03:50-04:20</PreferredBackupWindow> <PreferredBackupWindow>03:50-04:20</PreferredBackupWindow>
<PreferredMaintenanceWindow>wed:06:38-wed:07:08</PreferredMaintenanceWindow> <PreferredMaintenanceWindow>wed:06:38-wed:07:08</PreferredMaintenanceWindow>
<ReadReplicaDBInstanceIdentifiers> <ReadReplicaDBInstanceIdentifiers>
@ -173,6 +185,10 @@ class Database(BaseModel):
<LicenseModel>{{ database.license_model }}</LicenseModel> <LicenseModel>{{ database.license_model }}</LicenseModel>
<EngineVersion>{{ database.engine_version }}</EngineVersion> <EngineVersion>{{ database.engine_version }}</EngineVersion>
<OptionGroupMemberships> <OptionGroupMemberships>
<OptionGroupMembership>
<OptionGroupName>{{ database.option_group_name }}</OptionGroupName>
<Status>in-sync</Status>
</OptionGroupMembership>
</OptionGroupMemberships> </OptionGroupMemberships>
<DBParameterGroups> <DBParameterGroups>
{% for db_parameter_group in database.db_parameter_groups() %} {% for db_parameter_group in database.db_parameter_groups() %}
@ -314,6 +330,7 @@ class Database(BaseModel):
"storage_encrypted": properties.get("StorageEncrypted"), "storage_encrypted": properties.get("StorageEncrypted"),
"storage_type": properties.get("StorageType"), "storage_type": properties.get("StorageType"),
"tags": properties.get("Tags"), "tags": properties.get("Tags"),
"vpc_security_group_ids": properties.get('VpcSecurityGroupIds', []),
} }
rds2_backend = rds2_backends[region_name] rds2_backend = rds2_backends[region_name]
@ -373,7 +390,7 @@ class Database(BaseModel):
"Address": "{{ database.address }}", "Address": "{{ database.address }}",
"Port": "{{ database.port }}" "Port": "{{ database.port }}"
}, },
"InstanceCreateTime": null, "InstanceCreateTime": "{{ database.instance_create_time }}",
"Iops": null, "Iops": null,
"ReadReplicaDBInstanceIdentifiers": [{%- for replica in database.replicas -%} "ReadReplicaDBInstanceIdentifiers": [{%- for replica in database.replicas -%}
{%- if not loop.first -%},{%- endif -%} {%- if not loop.first -%},{%- endif -%}
@ -388,10 +405,12 @@ class Database(BaseModel):
"SecondaryAvailabilityZone": null, "SecondaryAvailabilityZone": null,
"StatusInfos": null, "StatusInfos": null,
"VpcSecurityGroups": [ "VpcSecurityGroups": [
{% for vpc_security_group_id in database.vpc_security_group_ids %}
{ {
"Status": "active", "Status": "active",
"VpcSecurityGroupId": "sg-123456" "VpcSecurityGroupId": "{{ vpc_security_group_id }}"
} }
{% endfor %}
], ],
"DBInstanceArn": "{{ database.db_instance_arn }}" "DBInstanceArn": "{{ database.db_instance_arn }}"
}""") }""")
@ -873,13 +892,16 @@ class RDS2Backend(BaseBackend):
def create_option_group(self, option_group_kwargs): def create_option_group(self, option_group_kwargs):
option_group_id = option_group_kwargs['name'] option_group_id = option_group_kwargs['name']
valid_option_group_engines = {'mysql': ['5.6'], valid_option_group_engines = {'mariadb': ['10.0', '10.1', '10.2', '10.3'],
'oracle-se1': ['11.2'], 'mysql': ['5.5', '5.6', '5.7', '8.0'],
'oracle-se': ['11.2'], 'oracle-se2': ['11.2', '12.1', '12.2'],
'oracle-ee': ['11.2'], 'oracle-se1': ['11.2', '12.1', '12.2'],
'oracle-se': ['11.2', '12.1', '12.2'],
'oracle-ee': ['11.2', '12.1', '12.2'],
'sqlserver-se': ['10.50', '11.00'], 'sqlserver-se': ['10.50', '11.00'],
'sqlserver-ee': ['10.50', '11.00'] 'sqlserver-ee': ['10.50', '11.00'],
} 'sqlserver-ex': ['10.50', '11.00'],
'sqlserver-web': ['10.50', '11.00']}
if option_group_kwargs['name'] in self.option_groups: if option_group_kwargs['name'] in self.option_groups:
raise RDSClientError('OptionGroupAlreadyExistsFault', raise RDSClientError('OptionGroupAlreadyExistsFault',
'An option group named {0} already exists.'.format(option_group_kwargs['name'])) 'An option group named {0} already exists.'.format(option_group_kwargs['name']))
@ -905,8 +927,7 @@ class RDS2Backend(BaseBackend):
if option_group_name in self.option_groups: if option_group_name in self.option_groups:
return self.option_groups.pop(option_group_name) return self.option_groups.pop(option_group_name)
else: else:
raise RDSClientError( raise OptionGroupNotFoundFaultError(option_group_name)
'OptionGroupNotFoundFault', 'Specified OptionGroupName: {0} not found.'.format(option_group_name))
def describe_option_groups(self, option_group_kwargs): def describe_option_groups(self, option_group_kwargs):
option_group_list = [] option_group_list = []
@ -935,8 +956,7 @@ class RDS2Backend(BaseBackend):
else: else:
option_group_list.append(option_group) option_group_list.append(option_group)
if not len(option_group_list): if not len(option_group_list):
raise RDSClientError('OptionGroupNotFoundFault', raise OptionGroupNotFoundFaultError(option_group_kwargs['name'])
'Specified OptionGroupName: {0} not found.'.format(option_group_kwargs['name']))
return option_group_list[marker:max_records + marker] return option_group_list[marker:max_records + marker]
@staticmethod @staticmethod
@ -965,8 +985,7 @@ class RDS2Backend(BaseBackend):
def modify_option_group(self, option_group_name, options_to_include=None, options_to_remove=None, apply_immediately=None): def modify_option_group(self, option_group_name, options_to_include=None, options_to_remove=None, apply_immediately=None):
if option_group_name not in self.option_groups: if option_group_name not in self.option_groups:
raise RDSClientError('OptionGroupNotFoundFault', raise OptionGroupNotFoundFaultError(option_group_name)
'Specified OptionGroupName: {0} not found.'.format(option_group_name))
if not options_to_include and not options_to_remove: if not options_to_include and not options_to_remove:
raise RDSClientError('InvalidParameterValue', raise RDSClientError('InvalidParameterValue',
'At least one option must be added, modified, or removed.') 'At least one option must be added, modified, or removed.')

View File

@ -34,7 +34,7 @@ class RDS2Response(BaseResponse):
"master_user_password": self._get_param('MasterUserPassword'), "master_user_password": self._get_param('MasterUserPassword'),
"master_username": self._get_param('MasterUsername'), "master_username": self._get_param('MasterUsername'),
"multi_az": self._get_bool_param("MultiAZ"), "multi_az": self._get_bool_param("MultiAZ"),
# OptionGroupName "option_group_name": self._get_param("OptionGroupName"),
"port": self._get_param('Port'), "port": self._get_param('Port'),
# PreferredBackupWindow # PreferredBackupWindow
# PreferredMaintenanceWindow # PreferredMaintenanceWindow
@ -43,7 +43,7 @@ class RDS2Response(BaseResponse):
"security_groups": self._get_multi_param('DBSecurityGroups.DBSecurityGroupName'), "security_groups": self._get_multi_param('DBSecurityGroups.DBSecurityGroupName'),
"storage_encrypted": self._get_param("StorageEncrypted"), "storage_encrypted": self._get_param("StorageEncrypted"),
"storage_type": self._get_param("StorageType", 'standard'), "storage_type": self._get_param("StorageType", 'standard'),
# VpcSecurityGroupIds.member.N "vpc_security_group_ids": self._get_multi_param("VpcSecurityGroupIds.VpcSecurityGroupId"),
"tags": list(), "tags": list(),
} }
args['tags'] = self.unpack_complex_list_params( args['tags'] = self.unpack_complex_list_params(
@ -280,7 +280,7 @@ class RDS2Response(BaseResponse):
def describe_option_groups(self): def describe_option_groups(self):
kwargs = self._get_option_group_kwargs() kwargs = self._get_option_group_kwargs()
kwargs['max_records'] = self._get_param('MaxRecords') kwargs['max_records'] = self._get_int_param('MaxRecords')
kwargs['marker'] = self._get_param('Marker') kwargs['marker'] = self._get_param('Marker')
option_groups = self.backend.describe_option_groups(kwargs) option_groups = self.backend.describe_option_groups(kwargs)
template = self.response_template(DESCRIBE_OPTION_GROUP_TEMPLATE) template = self.response_template(DESCRIBE_OPTION_GROUP_TEMPLATE)
@ -329,7 +329,7 @@ class RDS2Response(BaseResponse):
def describe_db_parameter_groups(self): def describe_db_parameter_groups(self):
kwargs = self._get_db_parameter_group_kwargs() kwargs = self._get_db_parameter_group_kwargs()
kwargs['max_records'] = self._get_param('MaxRecords') kwargs['max_records'] = self._get_int_param('MaxRecords')
kwargs['marker'] = self._get_param('Marker') kwargs['marker'] = self._get_param('Marker')
db_parameter_groups = self.backend.describe_db_parameter_groups(kwargs) db_parameter_groups = self.backend.describe_db_parameter_groups(kwargs)
template = self.response_template( template = self.response_template(

View File

@ -78,7 +78,7 @@ class Cluster(TaggableResourceMixin, BaseModel):
super(Cluster, self).__init__(region_name, tags) super(Cluster, self).__init__(region_name, tags)
self.redshift_backend = redshift_backend self.redshift_backend = redshift_backend
self.cluster_identifier = cluster_identifier self.cluster_identifier = cluster_identifier
self.create_time = iso_8601_datetime_with_milliseconds(datetime.datetime.now()) self.create_time = iso_8601_datetime_with_milliseconds(datetime.datetime.utcnow())
self.status = 'available' self.status = 'available'
self.node_type = node_type self.node_type = node_type
self.master_username = master_username self.master_username = master_username

View File

@ -10,6 +10,7 @@ from moto.ec2 import ec2_backends
from moto.elb import elb_backends from moto.elb import elb_backends
from moto.elbv2 import elbv2_backends from moto.elbv2 import elbv2_backends
from moto.kinesis import kinesis_backends from moto.kinesis import kinesis_backends
from moto.kms import kms_backends
from moto.rds2 import rds2_backends from moto.rds2 import rds2_backends
from moto.glacier import glacier_backends from moto.glacier import glacier_backends
from moto.redshift import redshift_backends from moto.redshift import redshift_backends
@ -71,6 +72,13 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend):
""" """
return kinesis_backends[self.region_name] return kinesis_backends[self.region_name]
@property
def kms_backend(self):
"""
:rtype: moto.kms.models.KmsBackend
"""
return kms_backends[self.region_name]
@property @property
def rds_backend(self): def rds_backend(self):
""" """
@ -221,9 +229,6 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend):
if not resource_type_filters or 'elasticloadbalancer' in resource_type_filters or 'elasticloadbalancer:loadbalancer' in resource_type_filters: if not resource_type_filters or 'elasticloadbalancer' in resource_type_filters or 'elasticloadbalancer:loadbalancer' in resource_type_filters:
for elb in self.elbv2_backend.load_balancers.values(): for elb in self.elbv2_backend.load_balancers.values():
tags = get_elbv2_tags(elb.arn) tags = get_elbv2_tags(elb.arn)
# if 'elasticloadbalancer:loadbalancer' in resource_type_filters:
# from IPython import embed
# embed()
if not tag_filter(tags): # Skip if no tags, or invalid filter if not tag_filter(tags): # Skip if no tags, or invalid filter
continue continue
@ -235,6 +240,21 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend):
# Kinesis # Kinesis
# KMS
def get_kms_tags(kms_key_id):
result = []
for tag in self.kms_backend.list_resource_tags(kms_key_id):
result.append({'Key': tag['TagKey'], 'Value': tag['TagValue']})
return result
if not resource_type_filters or 'kms' in resource_type_filters:
for kms_key in self.kms_backend.list_keys():
tags = get_kms_tags(kms_key.id)
if not tag_filter(tags): # Skip if no tags, or invalid filter
continue
yield {'ResourceARN': '{0}'.format(kms_key.arn), 'Tags': tags}
# RDS Instance # RDS Instance
# RDS Reserved Database Instance # RDS Reserved Database Instance
# RDS Option Group # RDS Option Group
@ -370,7 +390,7 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend):
def get_resources(self, pagination_token=None, def get_resources(self, pagination_token=None,
resources_per_page=50, tags_per_page=100, resources_per_page=50, tags_per_page=100,
tag_filters=None, resource_type_filters=None): tag_filters=None, resource_type_filters=None):
# Simple range checning # Simple range checking
if 100 >= tags_per_page >= 500: if 100 >= tags_per_page >= 500:
raise RESTError('InvalidParameterException', 'TagsPerPage must be between 100 and 500') raise RESTError('InvalidParameterException', 'TagsPerPage must be between 100 and 500')
if 1 >= resources_per_page >= 50: if 1 >= resources_per_page >= 50:

View File

@ -85,6 +85,7 @@ class RecordSet(BaseModel):
self.health_check = kwargs.get('HealthCheckId') self.health_check = kwargs.get('HealthCheckId')
self.hosted_zone_name = kwargs.get('HostedZoneName') self.hosted_zone_name = kwargs.get('HostedZoneName')
self.hosted_zone_id = kwargs.get('HostedZoneId') self.hosted_zone_id = kwargs.get('HostedZoneId')
self.alias_target = kwargs.get('AliasTarget')
@classmethod @classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
@ -119,7 +120,7 @@ class RecordSet(BaseModel):
properties["HostedZoneId"]) properties["HostedZoneId"])
try: try:
hosted_zone.delete_rrset_by_name(resource_name) hosted_zone.delete_rrset({'Name': resource_name})
except KeyError: except KeyError:
pass pass
@ -143,6 +144,13 @@ class RecordSet(BaseModel):
{% if record_set.ttl %} {% if record_set.ttl %}
<TTL>{{ record_set.ttl }}</TTL> <TTL>{{ record_set.ttl }}</TTL>
{% endif %} {% endif %}
{% if record_set.alias_target %}
<AliasTarget>
<HostedZoneId>{{ record_set.alias_target['HostedZoneId'] }}</HostedZoneId>
<DNSName>{{ record_set.alias_target['DNSName'] }}</DNSName>
<EvaluateTargetHealth>{{ record_set.alias_target['EvaluateTargetHealth'] }}</EvaluateTargetHealth>
</AliasTarget>
{% else %}
<ResourceRecords> <ResourceRecords>
{% for record in record_set.records %} {% for record in record_set.records %}
<ResourceRecord> <ResourceRecord>
@ -150,6 +158,7 @@ class RecordSet(BaseModel):
</ResourceRecord> </ResourceRecord>
{% endfor %} {% endfor %}
</ResourceRecords> </ResourceRecords>
{% endif %}
{% if record_set.health_check %} {% if record_set.health_check %}
<HealthCheckId>{{ record_set.health_check }}</HealthCheckId> <HealthCheckId>{{ record_set.health_check }}</HealthCheckId>
{% endif %} {% endif %}
@ -162,7 +171,13 @@ class RecordSet(BaseModel):
self.hosted_zone_name) self.hosted_zone_name)
if not hosted_zone: if not hosted_zone:
hosted_zone = route53_backend.get_hosted_zone(self.hosted_zone_id) hosted_zone = route53_backend.get_hosted_zone(self.hosted_zone_id)
hosted_zone.delete_rrset_by_name(self.name) hosted_zone.delete_rrset({'Name': self.name, 'Type': self.type_})
def reverse_domain_name(domain_name):
if domain_name.endswith('.'): # normalize without trailing dot
domain_name = domain_name[:-1]
return '.'.join(reversed(domain_name.split('.')))
class FakeZone(BaseModel): class FakeZone(BaseModel):
@ -183,16 +198,20 @@ class FakeZone(BaseModel):
def upsert_rrset(self, record_set): def upsert_rrset(self, record_set):
new_rrset = RecordSet(record_set) new_rrset = RecordSet(record_set)
for i, rrset in enumerate(self.rrsets): for i, rrset in enumerate(self.rrsets):
if rrset.name == new_rrset.name and rrset.type_ == new_rrset.type_: if rrset.name == new_rrset.name and rrset.type_ == new_rrset.type_ and rrset.set_identifier == new_rrset.set_identifier:
self.rrsets[i] = new_rrset self.rrsets[i] = new_rrset
break break
else: else:
self.rrsets.append(new_rrset) self.rrsets.append(new_rrset)
return new_rrset return new_rrset
def delete_rrset_by_name(self, name): def delete_rrset(self, rrset):
self.rrsets = [ self.rrsets = [
record_set for record_set in self.rrsets if record_set.name != name] record_set
for record_set in self.rrsets
if record_set.name != rrset['Name'] or
(rrset.get('Type') is not None and record_set.type_ != rrset['Type'])
]
def delete_rrset_by_id(self, set_identifier): def delete_rrset_by_id(self, set_identifier):
self.rrsets = [ self.rrsets = [
@ -200,12 +219,15 @@ class FakeZone(BaseModel):
def get_record_sets(self, start_type, start_name): def get_record_sets(self, start_type, start_name):
record_sets = list(self.rrsets) # Copy the list record_sets = list(self.rrsets) # Copy the list
if start_name:
record_sets = [
record_set
for record_set in record_sets
if reverse_domain_name(record_set.name) >= reverse_domain_name(start_name)
]
if start_type: if start_type:
record_sets = [ record_sets = [
record_set for record_set in record_sets if record_set.type_ >= start_type] record_set for record_set in record_sets if record_set.type_ >= start_type]
if start_name:
record_sets = [
record_set for record_set in record_sets if record_set.name >= start_name]
return record_sets return record_sets

View File

@ -134,10 +134,7 @@ class Route53(BaseResponse):
# Depending on how many records there are, this may # Depending on how many records there are, this may
# or may not be a list # or may not be a list
resource_records = [resource_records] resource_records = [resource_records]
record_values = [x['Value'] for x in resource_records] record_set['ResourceRecords'] = [x['Value'] for x in resource_records]
elif 'AliasTarget' in record_set:
record_values = [record_set['AliasTarget']['DNSName']]
record_set['ResourceRecords'] = record_values
if action == 'CREATE': if action == 'CREATE':
the_zone.add_rrset(record_set) the_zone.add_rrset(record_set)
else: else:
@ -147,7 +144,7 @@ class Route53(BaseResponse):
the_zone.delete_rrset_by_id( the_zone.delete_rrset_by_id(
record_set["SetIdentifier"]) record_set["SetIdentifier"])
else: else:
the_zone.delete_rrset_by_name(record_set["Name"]) the_zone.delete_rrset(record_set)
return 200, headers, CHANGE_RRSET_RESPONSE return 200, headers, CHANGE_RRSET_RESPONSE

View File

@ -60,6 +60,17 @@ class MissingKey(S3ClientError):
) )
class ObjectNotInActiveTierError(S3ClientError):
code = 403
def __init__(self, key_name):
super(ObjectNotInActiveTierError, self).__init__(
"ObjectNotInActiveTierError",
"The source object of the COPY operation is not in the active tier and is only stored in Amazon Glacier.",
Key=key_name,
)
class InvalidPartOrder(S3ClientError): class InvalidPartOrder(S3ClientError):
code = 400 code = 400
@ -199,3 +210,67 @@ class DuplicateTagKeys(S3ClientError):
"InvalidTag", "InvalidTag",
"Cannot provide multiple Tags with the same key", "Cannot provide multiple Tags with the same key",
*args, **kwargs) *args, **kwargs)
class S3AccessDeniedError(S3ClientError):
code = 403
def __init__(self, *args, **kwargs):
super(S3AccessDeniedError, self).__init__('AccessDenied', 'Access Denied', *args, **kwargs)
class BucketAccessDeniedError(BucketError):
code = 403
def __init__(self, *args, **kwargs):
super(BucketAccessDeniedError, self).__init__('AccessDenied', 'Access Denied', *args, **kwargs)
class S3InvalidTokenError(S3ClientError):
code = 400
def __init__(self, *args, **kwargs):
super(S3InvalidTokenError, self).__init__('InvalidToken', 'The provided token is malformed or otherwise invalid.', *args, **kwargs)
class BucketInvalidTokenError(BucketError):
code = 400
def __init__(self, *args, **kwargs):
super(BucketInvalidTokenError, self).__init__('InvalidToken', 'The provided token is malformed or otherwise invalid.', *args, **kwargs)
class S3InvalidAccessKeyIdError(S3ClientError):
code = 403
def __init__(self, *args, **kwargs):
super(S3InvalidAccessKeyIdError, self).__init__(
'InvalidAccessKeyId',
"The AWS Access Key Id you provided does not exist in our records.", *args, **kwargs)
class BucketInvalidAccessKeyIdError(S3ClientError):
code = 403
def __init__(self, *args, **kwargs):
super(BucketInvalidAccessKeyIdError, self).__init__(
'InvalidAccessKeyId',
"The AWS Access Key Id you provided does not exist in our records.", *args, **kwargs)
class S3SignatureDoesNotMatchError(S3ClientError):
code = 403
def __init__(self, *args, **kwargs):
super(S3SignatureDoesNotMatchError, self).__init__(
'SignatureDoesNotMatch',
"The request signature we calculated does not match the signature you provided. Check your key and signing method.", *args, **kwargs)
class BucketSignatureDoesNotMatchError(S3ClientError):
code = 403
def __init__(self, *args, **kwargs):
super(BucketSignatureDoesNotMatchError, self).__init__(
'SignatureDoesNotMatch',
"The request signature we calculated does not match the signature you provided. Check your key and signing method.", *args, **kwargs)

View File

@ -28,7 +28,8 @@ MAX_BUCKET_NAME_LENGTH = 63
MIN_BUCKET_NAME_LENGTH = 3 MIN_BUCKET_NAME_LENGTH = 3
UPLOAD_ID_BYTES = 43 UPLOAD_ID_BYTES = 43
UPLOAD_PART_MIN_SIZE = 5242880 UPLOAD_PART_MIN_SIZE = 5242880
STORAGE_CLASS = ["STANDARD", "REDUCED_REDUNDANCY", "STANDARD_IA", "ONEZONE_IA"] STORAGE_CLASS = ["STANDARD", "REDUCED_REDUNDANCY", "STANDARD_IA", "ONEZONE_IA",
"INTELLIGENT_TIERING", "GLACIER", "DEEP_ARCHIVE"]
DEFAULT_KEY_BUFFER_SIZE = 16 * 1024 * 1024 DEFAULT_KEY_BUFFER_SIZE = 16 * 1024 * 1024
DEFAULT_TEXT_ENCODING = sys.getdefaultencoding() DEFAULT_TEXT_ENCODING = sys.getdefaultencoding()
@ -52,8 +53,17 @@ class FakeDeleteMarker(BaseModel):
class FakeKey(BaseModel): class FakeKey(BaseModel):
def __init__(self, name, value, storage="STANDARD", etag=None, is_versioned=False, version_id=0, def __init__(
max_buffer_size=DEFAULT_KEY_BUFFER_SIZE): self,
name,
value,
storage="STANDARD",
etag=None,
is_versioned=False,
version_id=0,
max_buffer_size=DEFAULT_KEY_BUFFER_SIZE,
multipart=None
):
self.name = name self.name = name
self.last_modified = datetime.datetime.utcnow() self.last_modified = datetime.datetime.utcnow()
self.acl = get_canned_acl('private') self.acl = get_canned_acl('private')
@ -65,6 +75,7 @@ class FakeKey(BaseModel):
self._version_id = version_id self._version_id = version_id
self._is_versioned = is_versioned self._is_versioned = is_versioned
self._tagging = FakeTagging() self._tagging = FakeTagging()
self.multipart = multipart
self._value_buffer = tempfile.SpooledTemporaryFile(max_size=max_buffer_size) self._value_buffer = tempfile.SpooledTemporaryFile(max_size=max_buffer_size)
self._max_buffer_size = max_buffer_size self._max_buffer_size = max_buffer_size
@ -754,7 +765,7 @@ class S3Backend(BaseBackend):
prefix=''): prefix=''):
bucket = self.get_bucket(bucket_name) bucket = self.get_bucket(bucket_name)
if any((delimiter, encoding_type, key_marker, version_id_marker)): if any((delimiter, key_marker, version_id_marker)):
raise NotImplementedError( raise NotImplementedError(
"Called get_bucket_versions with some of delimiter, encoding_type, key_marker, version_id_marker") "Called get_bucket_versions with some of delimiter, encoding_type, key_marker, version_id_marker")
@ -782,7 +793,15 @@ class S3Backend(BaseBackend):
bucket = self.get_bucket(bucket_name) bucket = self.get_bucket(bucket_name)
return bucket.website_configuration return bucket.website_configuration
def set_key(self, bucket_name, key_name, value, storage=None, etag=None): def set_key(
self,
bucket_name,
key_name,
value,
storage=None,
etag=None,
multipart=None,
):
key_name = clean_key_name(key_name) key_name = clean_key_name(key_name)
if storage is not None and storage not in STORAGE_CLASS: if storage is not None and storage not in STORAGE_CLASS:
raise InvalidStorageClass(storage=storage) raise InvalidStorageClass(storage=storage)
@ -795,7 +814,9 @@ class S3Backend(BaseBackend):
storage=storage, storage=storage,
etag=etag, etag=etag,
is_versioned=bucket.is_versioned, is_versioned=bucket.is_versioned,
version_id=str(uuid.uuid4()) if bucket.is_versioned else None) version_id=str(uuid.uuid4()) if bucket.is_versioned else None,
multipart=multipart,
)
keys = [ keys = [
key for key in bucket.keys.getlist(key_name, []) key for key in bucket.keys.getlist(key_name, [])
@ -812,7 +833,7 @@ class S3Backend(BaseBackend):
key.append_to_value(value) key.append_to_value(value)
return key return key
def get_key(self, bucket_name, key_name, version_id=None): def get_key(self, bucket_name, key_name, version_id=None, part_number=None):
key_name = clean_key_name(key_name) key_name = clean_key_name(key_name)
bucket = self.get_bucket(bucket_name) bucket = self.get_bucket(bucket_name)
key = None key = None
@ -827,6 +848,9 @@ class S3Backend(BaseBackend):
key = key_version key = key_version
break break
if part_number and key.multipart:
key = key.multipart.parts[part_number]
if isinstance(key, FakeKey): if isinstance(key, FakeKey):
return key return key
else: else:
@ -890,7 +914,12 @@ class S3Backend(BaseBackend):
return return
del bucket.multiparts[multipart_id] del bucket.multiparts[multipart_id]
key = self.set_key(bucket_name, multipart.key_name, value, etag=etag) key = self.set_key(
bucket_name,
multipart.key_name,
value, etag=etag,
multipart=multipart
)
key.set_metadata(multipart.metadata) key.set_metadata(multipart.metadata)
return key return key

View File

@ -3,20 +3,21 @@ from __future__ import unicode_literals
import re import re
import six import six
from moto.core.utils import str_to_rfc_1123_datetime from moto.core.utils import str_to_rfc_1123_datetime
from six.moves.urllib.parse import parse_qs, urlparse, unquote from six.moves.urllib.parse import parse_qs, urlparse, unquote
import xmltodict import xmltodict
from moto.packages.httpretty.core import HTTPrettyRequest from moto.packages.httpretty.core import HTTPrettyRequest
from moto.core.responses import _TemplateEnvironmentMixin from moto.core.responses import _TemplateEnvironmentMixin, ActionAuthenticatorMixin
from moto.core.utils import path_url from moto.core.utils import path_url
from moto.s3bucket_path.utils import bucket_name_from_url as bucketpath_bucket_name_from_url, \ from moto.s3bucket_path.utils import bucket_name_from_url as bucketpath_bucket_name_from_url, \
parse_key_name as bucketpath_parse_key_name, is_delete_keys as bucketpath_is_delete_keys parse_key_name as bucketpath_parse_key_name, is_delete_keys as bucketpath_is_delete_keys
from .exceptions import BucketAlreadyExists, S3ClientError, MissingBucket, MissingKey, InvalidPartOrder, MalformedXML, \ from .exceptions import BucketAlreadyExists, S3ClientError, MissingBucket, MissingKey, InvalidPartOrder, MalformedXML, \
MalformedACLError, InvalidNotificationARN, InvalidNotificationEvent MalformedACLError, InvalidNotificationARN, InvalidNotificationEvent, ObjectNotInActiveTierError
from .models import s3_backend, get_canned_acl, FakeGrantee, FakeGrant, FakeAcl, FakeKey, FakeTagging, FakeTagSet, \ from .models import s3_backend, get_canned_acl, FakeGrantee, FakeGrant, FakeAcl, FakeKey, FakeTagging, FakeTagSet, \
FakeTag FakeTag
from .utils import bucket_name_from_url, clean_key_name, metadata_from_headers, parse_region_from_url from .utils import bucket_name_from_url, clean_key_name, metadata_from_headers, parse_region_from_url
@ -25,6 +26,72 @@ from xml.dom import minidom
DEFAULT_REGION_NAME = 'us-east-1' DEFAULT_REGION_NAME = 'us-east-1'
ACTION_MAP = {
"BUCKET": {
"GET": {
"uploads": "ListBucketMultipartUploads",
"location": "GetBucketLocation",
"lifecycle": "GetLifecycleConfiguration",
"versioning": "GetBucketVersioning",
"policy": "GetBucketPolicy",
"website": "GetBucketWebsite",
"acl": "GetBucketAcl",
"tagging": "GetBucketTagging",
"logging": "GetBucketLogging",
"cors": "GetBucketCORS",
"notification": "GetBucketNotification",
"accelerate": "GetAccelerateConfiguration",
"versions": "ListBucketVersions",
"DEFAULT": "ListBucket"
},
"PUT": {
"lifecycle": "PutLifecycleConfiguration",
"versioning": "PutBucketVersioning",
"policy": "PutBucketPolicy",
"website": "PutBucketWebsite",
"acl": "PutBucketAcl",
"tagging": "PutBucketTagging",
"logging": "PutBucketLogging",
"cors": "PutBucketCORS",
"notification": "PutBucketNotification",
"accelerate": "PutAccelerateConfiguration",
"DEFAULT": "CreateBucket"
},
"DELETE": {
"lifecycle": "PutLifecycleConfiguration",
"policy": "DeleteBucketPolicy",
"tagging": "PutBucketTagging",
"cors": "PutBucketCORS",
"DEFAULT": "DeleteBucket"
}
},
"KEY": {
"GET": {
"uploadId": "ListMultipartUploadParts",
"acl": "GetObjectAcl",
"tagging": "GetObjectTagging",
"versionId": "GetObjectVersion",
"DEFAULT": "GetObject"
},
"PUT": {
"acl": "PutObjectAcl",
"tagging": "PutObjectTagging",
"DEFAULT": "PutObject"
},
"DELETE": {
"uploadId": "AbortMultipartUpload",
"versionId": "DeleteObjectVersion",
"DEFAULT": " DeleteObject"
},
"POST": {
"uploads": "PutObject",
"restore": "RestoreObject",
"uploadId": "PutObject"
}
}
}
def parse_key_name(pth): def parse_key_name(pth):
return pth.lstrip("/") return pth.lstrip("/")
@ -37,17 +104,24 @@ def is_delete_keys(request, path, bucket_name):
) )
class ResponseObject(_TemplateEnvironmentMixin): class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
def __init__(self, backend): def __init__(self, backend):
super(ResponseObject, self).__init__() super(ResponseObject, self).__init__()
self.backend = backend self.backend = backend
self.method = ""
self.path = ""
self.data = {}
self.headers = {}
@property @property
def should_autoescape(self): def should_autoescape(self):
return True return True
def all_buckets(self): def all_buckets(self):
self.data["Action"] = "ListAllMyBuckets"
self._authenticate_and_authorize_s3_action()
# No bucket specified. Listing all buckets # No bucket specified. Listing all buckets
all_buckets = self.backend.get_all_buckets() all_buckets = self.backend.get_all_buckets()
template = self.response_template(S3_ALL_BUCKETS) template = self.response_template(S3_ALL_BUCKETS)
@ -112,11 +186,20 @@ class ResponseObject(_TemplateEnvironmentMixin):
return self.bucket_response(request, full_url, headers) return self.bucket_response(request, full_url, headers)
def bucket_response(self, request, full_url, headers): def bucket_response(self, request, full_url, headers):
self.method = request.method
self.path = self._get_path(request)
self.headers = request.headers
if 'host' not in self.headers:
self.headers['host'] = urlparse(full_url).netloc
try: try:
response = self._bucket_response(request, full_url, headers) response = self._bucket_response(request, full_url, headers)
except S3ClientError as s3error: except S3ClientError as s3error:
response = s3error.code, {}, s3error.description response = s3error.code, {}, s3error.description
return self._send_response(response)
@staticmethod
def _send_response(response):
if isinstance(response, six.string_types): if isinstance(response, six.string_types):
return 200, {}, response.encode("utf-8") return 200, {}, response.encode("utf-8")
else: else:
@ -127,8 +210,7 @@ class ResponseObject(_TemplateEnvironmentMixin):
return status_code, headers, response_content return status_code, headers, response_content
def _bucket_response(self, request, full_url, headers): def _bucket_response(self, request, full_url, headers):
parsed_url = urlparse(full_url) querystring = self._get_querystring(full_url)
querystring = parse_qs(parsed_url.query, keep_blank_values=True)
method = request.method method = request.method
region_name = parse_region_from_url(full_url) region_name = parse_region_from_url(full_url)
@ -137,6 +219,8 @@ class ResponseObject(_TemplateEnvironmentMixin):
# If no bucket specified, list all buckets # If no bucket specified, list all buckets
return self.all_buckets() return self.all_buckets()
self.data["BucketName"] = bucket_name
if hasattr(request, 'body'): if hasattr(request, 'body'):
# Boto # Boto
body = request.body body = request.body
@ -150,20 +234,26 @@ class ResponseObject(_TemplateEnvironmentMixin):
body = u'{0}'.format(body).encode('utf-8') body = u'{0}'.format(body).encode('utf-8')
if method == 'HEAD': if method == 'HEAD':
return self._bucket_response_head(bucket_name, headers) return self._bucket_response_head(bucket_name)
elif method == 'GET': elif method == 'GET':
return self._bucket_response_get(bucket_name, querystring, headers) return self._bucket_response_get(bucket_name, querystring)
elif method == 'PUT': elif method == 'PUT':
return self._bucket_response_put(request, body, region_name, bucket_name, querystring, headers) return self._bucket_response_put(request, body, region_name, bucket_name, querystring)
elif method == 'DELETE': elif method == 'DELETE':
return self._bucket_response_delete(body, bucket_name, querystring, headers) return self._bucket_response_delete(body, bucket_name, querystring)
elif method == 'POST': elif method == 'POST':
return self._bucket_response_post(request, body, bucket_name, headers) return self._bucket_response_post(request, body, bucket_name)
else: else:
raise NotImplementedError( raise NotImplementedError(
"Method {0} has not been impelemented in the S3 backend yet".format(method)) "Method {0} has not been impelemented in the S3 backend yet".format(method))
def _bucket_response_head(self, bucket_name, headers): @staticmethod
def _get_querystring(full_url):
parsed_url = urlparse(full_url)
querystring = parse_qs(parsed_url.query, keep_blank_values=True)
return querystring
def _bucket_response_head(self, bucket_name):
try: try:
self.backend.get_bucket(bucket_name) self.backend.get_bucket(bucket_name)
except MissingBucket: except MissingBucket:
@ -174,7 +264,10 @@ class ResponseObject(_TemplateEnvironmentMixin):
return 404, {}, "" return 404, {}, ""
return 200, {}, "" return 200, {}, ""
def _bucket_response_get(self, bucket_name, querystring, headers): def _bucket_response_get(self, bucket_name, querystring):
self._set_action("BUCKET", "GET", querystring)
self._authenticate_and_authorize_s3_action()
if 'uploads' in querystring: if 'uploads' in querystring:
for unsup in ('delimiter', 'max-uploads'): for unsup in ('delimiter', 'max-uploads'):
if unsup in querystring: if unsup in querystring:
@ -333,6 +426,15 @@ class ResponseObject(_TemplateEnvironmentMixin):
max_keys=max_keys max_keys=max_keys
) )
def _set_action(self, action_resource_type, method, querystring):
action_set = False
for action_in_querystring, action in ACTION_MAP[action_resource_type][method].items():
if action_in_querystring in querystring:
self.data["Action"] = action
action_set = True
if not action_set:
self.data["Action"] = ACTION_MAP[action_resource_type][method]["DEFAULT"]
def _handle_list_objects_v2(self, bucket_name, querystring): def _handle_list_objects_v2(self, bucket_name, querystring):
template = self.response_template(S3_BUCKET_GET_RESPONSE_V2) template = self.response_template(S3_BUCKET_GET_RESPONSE_V2)
bucket = self.backend.get_bucket(bucket_name) bucket = self.backend.get_bucket(bucket_name)
@ -361,10 +463,13 @@ class ResponseObject(_TemplateEnvironmentMixin):
else: else:
result_folders, is_truncated, next_continuation_token = self._truncate_result(result_folders, max_keys) result_folders, is_truncated, next_continuation_token = self._truncate_result(result_folders, max_keys)
key_count = len(result_keys) + len(result_folders)
return template.render( return template.render(
bucket=bucket, bucket=bucket,
prefix=prefix or '', prefix=prefix or '',
delimiter=delimiter, delimiter=delimiter,
key_count=key_count,
result_keys=result_keys, result_keys=result_keys,
result_folders=result_folders, result_folders=result_folders,
fetch_owner=fetch_owner, fetch_owner=fetch_owner,
@ -393,9 +498,13 @@ class ResponseObject(_TemplateEnvironmentMixin):
next_continuation_token = None next_continuation_token = None
return result_keys, is_truncated, next_continuation_token return result_keys, is_truncated, next_continuation_token
def _bucket_response_put(self, request, body, region_name, bucket_name, querystring, headers): def _bucket_response_put(self, request, body, region_name, bucket_name, querystring):
if not request.headers.get('Content-Length'): if not request.headers.get('Content-Length'):
return 411, {}, "Content-Length required" return 411, {}, "Content-Length required"
self._set_action("BUCKET", "PUT", querystring)
self._authenticate_and_authorize_s3_action()
if 'versioning' in querystring: if 'versioning' in querystring:
ver = re.search('<Status>([A-Za-z]+)</Status>', body.decode()) ver = re.search('<Status>([A-Za-z]+)</Status>', body.decode())
if ver: if ver:
@ -494,7 +603,10 @@ class ResponseObject(_TemplateEnvironmentMixin):
template = self.response_template(S3_BUCKET_CREATE_RESPONSE) template = self.response_template(S3_BUCKET_CREATE_RESPONSE)
return 200, {}, template.render(bucket=new_bucket) return 200, {}, template.render(bucket=new_bucket)
def _bucket_response_delete(self, body, bucket_name, querystring, headers): def _bucket_response_delete(self, body, bucket_name, querystring):
self._set_action("BUCKET", "DELETE", querystring)
self._authenticate_and_authorize_s3_action()
if 'policy' in querystring: if 'policy' in querystring:
self.backend.delete_bucket_policy(bucket_name, body) self.backend.delete_bucket_policy(bucket_name, body)
return 204, {}, "" return 204, {}, ""
@ -521,17 +633,20 @@ class ResponseObject(_TemplateEnvironmentMixin):
S3_DELETE_BUCKET_WITH_ITEMS_ERROR) S3_DELETE_BUCKET_WITH_ITEMS_ERROR)
return 409, {}, template.render(bucket=removed_bucket) return 409, {}, template.render(bucket=removed_bucket)
def _bucket_response_post(self, request, body, bucket_name, headers): def _bucket_response_post(self, request, body, bucket_name):
if not request.headers.get('Content-Length'): if not request.headers.get('Content-Length'):
return 411, {}, "Content-Length required" return 411, {}, "Content-Length required"
if isinstance(request, HTTPrettyRequest): path = self._get_path(request)
path = request.path
else:
path = request.full_path if hasattr(request, 'full_path') else path_url(request.url)
if self.is_delete_keys(request, path, bucket_name): if self.is_delete_keys(request, path, bucket_name):
return self._bucket_response_delete_keys(request, body, bucket_name, headers) self.data["Action"] = "DeleteObject"
self._authenticate_and_authorize_s3_action()
return self._bucket_response_delete_keys(request, body, bucket_name)
self.data["Action"] = "PutObject"
self._authenticate_and_authorize_s3_action()
# POST to bucket-url should create file from form # POST to bucket-url should create file from form
if hasattr(request, 'form'): if hasattr(request, 'form'):
@ -560,12 +675,22 @@ class ResponseObject(_TemplateEnvironmentMixin):
return 200, {}, "" return 200, {}, ""
def _bucket_response_delete_keys(self, request, body, bucket_name, headers): @staticmethod
def _get_path(request):
if isinstance(request, HTTPrettyRequest):
path = request.path
else:
path = request.full_path if hasattr(request, 'full_path') else path_url(request.url)
return path
def _bucket_response_delete_keys(self, request, body, bucket_name):
template = self.response_template(S3_DELETE_KEYS_RESPONSE) template = self.response_template(S3_DELETE_KEYS_RESPONSE)
keys = minidom.parseString(body).getElementsByTagName('Key') keys = minidom.parseString(body).getElementsByTagName('Key')
deleted_names = [] deleted_names = []
error_names = [] error_names = []
if len(keys) == 0:
raise MalformedXML()
for k in keys: for k in keys:
key_name = k.firstChild.nodeValue key_name = k.firstChild.nodeValue
@ -604,6 +729,11 @@ class ResponseObject(_TemplateEnvironmentMixin):
return 206, response_headers, response_content[begin:end + 1] return 206, response_headers, response_content[begin:end + 1]
def key_response(self, request, full_url, headers): def key_response(self, request, full_url, headers):
self.method = request.method
self.path = self._get_path(request)
self.headers = request.headers
if 'host' not in self.headers:
self.headers['host'] = urlparse(full_url).netloc
response_headers = {} response_headers = {}
try: try:
response = self._key_response(request, full_url, headers) response = self._key_response(request, full_url, headers)
@ -657,20 +787,23 @@ class ResponseObject(_TemplateEnvironmentMixin):
body = b'' body = b''
if method == 'GET': if method == 'GET':
return self._key_response_get(bucket_name, query, key_name, headers) return self._key_response_get(bucket_name, query, key_name, headers=request.headers)
elif method == 'PUT': elif method == 'PUT':
return self._key_response_put(request, body, bucket_name, query, key_name, headers) return self._key_response_put(request, body, bucket_name, query, key_name, headers)
elif method == 'HEAD': elif method == 'HEAD':
return self._key_response_head(bucket_name, query, key_name, headers=request.headers) return self._key_response_head(bucket_name, query, key_name, headers=request.headers)
elif method == 'DELETE': elif method == 'DELETE':
return self._key_response_delete(bucket_name, query, key_name, headers) return self._key_response_delete(bucket_name, query, key_name)
elif method == 'POST': elif method == 'POST':
return self._key_response_post(request, body, bucket_name, query, key_name, headers) return self._key_response_post(request, body, bucket_name, query, key_name)
else: else:
raise NotImplementedError( raise NotImplementedError(
"Method {0} has not been implemented in the S3 backend yet".format(method)) "Method {0} has not been implemented in the S3 backend yet".format(method))
def _key_response_get(self, bucket_name, query, key_name, headers): def _key_response_get(self, bucket_name, query, key_name, headers):
self._set_action("KEY", "GET", query)
self._authenticate_and_authorize_s3_action()
response_headers = {} response_headers = {}
if query.get('uploadId'): if query.get('uploadId'):
upload_id = query['uploadId'][0] upload_id = query['uploadId'][0]
@ -684,10 +817,15 @@ class ResponseObject(_TemplateEnvironmentMixin):
parts=parts parts=parts
) )
version_id = query.get('versionId', [None])[0] version_id = query.get('versionId', [None])[0]
if_modified_since = headers.get('If-Modified-Since', None)
key = self.backend.get_key( key = self.backend.get_key(
bucket_name, key_name, version_id=version_id) bucket_name, key_name, version_id=version_id)
if key is None: if key is None:
raise MissingKey(key_name) raise MissingKey(key_name)
if if_modified_since:
if_modified_since = str_to_rfc_1123_datetime(if_modified_since)
if if_modified_since and key.last_modified < if_modified_since:
return 304, response_headers, 'Not Modified'
if 'acl' in query: if 'acl' in query:
template = self.response_template(S3_OBJECT_ACL_RESPONSE) template = self.response_template(S3_OBJECT_ACL_RESPONSE)
return 200, response_headers, template.render(obj=key) return 200, response_headers, template.render(obj=key)
@ -700,6 +838,9 @@ class ResponseObject(_TemplateEnvironmentMixin):
return 200, response_headers, key.value return 200, response_headers, key.value
def _key_response_put(self, request, body, bucket_name, query, key_name, headers): def _key_response_put(self, request, body, bucket_name, query, key_name, headers):
self._set_action("KEY", "PUT", query)
self._authenticate_and_authorize_s3_action()
response_headers = {} response_headers = {}
if query.get('uploadId') and query.get('partNumber'): if query.get('uploadId') and query.get('partNumber'):
upload_id = query['uploadId'][0] upload_id = query['uploadId'][0]
@ -764,7 +905,11 @@ class ResponseObject(_TemplateEnvironmentMixin):
src_version_id = parse_qs(src_key_parsed.query).get( src_version_id = parse_qs(src_key_parsed.query).get(
'versionId', [None])[0] 'versionId', [None])[0]
if self.backend.get_key(src_bucket, src_key, version_id=src_version_id): key = self.backend.get_key(src_bucket, src_key, version_id=src_version_id)
if key is not None:
if key.storage_class in ["GLACIER", "DEEP_ARCHIVE"]:
raise ObjectNotInActiveTierError(key)
self.backend.copy_key(src_bucket, src_key, bucket_name, key_name, self.backend.copy_key(src_bucket, src_key, bucket_name, key_name,
storage=storage_class, acl=acl, src_version_id=src_version_id) storage=storage_class, acl=acl, src_version_id=src_version_id)
else: else:
@ -804,13 +949,20 @@ class ResponseObject(_TemplateEnvironmentMixin):
def _key_response_head(self, bucket_name, query, key_name, headers): def _key_response_head(self, bucket_name, query, key_name, headers):
response_headers = {} response_headers = {}
version_id = query.get('versionId', [None])[0] version_id = query.get('versionId', [None])[0]
part_number = query.get('partNumber', [None])[0]
if part_number:
part_number = int(part_number)
if_modified_since = headers.get('If-Modified-Since', None) if_modified_since = headers.get('If-Modified-Since', None)
if if_modified_since: if if_modified_since:
if_modified_since = str_to_rfc_1123_datetime(if_modified_since) if_modified_since = str_to_rfc_1123_datetime(if_modified_since)
key = self.backend.get_key( key = self.backend.get_key(
bucket_name, key_name, version_id=version_id) bucket_name,
key_name,
version_id=version_id,
part_number=part_number
)
if key: if key:
response_headers.update(key.metadata) response_headers.update(key.metadata)
response_headers.update(key.response_dict) response_headers.update(key.response_dict)
@ -1066,7 +1218,10 @@ class ResponseObject(_TemplateEnvironmentMixin):
config = parsed_xml['AccelerateConfiguration'] config = parsed_xml['AccelerateConfiguration']
return config['Status'] return config['Status']
def _key_response_delete(self, bucket_name, query, key_name, headers): def _key_response_delete(self, bucket_name, query, key_name):
self._set_action("KEY", "DELETE", query)
self._authenticate_and_authorize_s3_action()
if query.get('uploadId'): if query.get('uploadId'):
upload_id = query['uploadId'][0] upload_id = query['uploadId'][0]
self.backend.cancel_multipart(bucket_name, upload_id) self.backend.cancel_multipart(bucket_name, upload_id)
@ -1086,7 +1241,10 @@ class ResponseObject(_TemplateEnvironmentMixin):
raise InvalidPartOrder() raise InvalidPartOrder()
yield (pn, p.getElementsByTagName('ETag')[0].firstChild.wholeText) yield (pn, p.getElementsByTagName('ETag')[0].firstChild.wholeText)
def _key_response_post(self, request, body, bucket_name, query, key_name, headers): def _key_response_post(self, request, body, bucket_name, query, key_name):
self._set_action("KEY", "POST", query)
self._authenticate_and_authorize_s3_action()
if body == b'' and 'uploads' in query: if body == b'' and 'uploads' in query:
metadata = metadata_from_headers(request.headers) metadata = metadata_from_headers(request.headers)
multipart = self.backend.initiate_multipart( multipart = self.backend.initiate_multipart(
@ -1175,7 +1333,7 @@ S3_BUCKET_GET_RESPONSE_V2 = """<?xml version="1.0" encoding="UTF-8"?>
<Name>{{ bucket.name }}</Name> <Name>{{ bucket.name }}</Name>
<Prefix>{{ prefix }}</Prefix> <Prefix>{{ prefix }}</Prefix>
<MaxKeys>{{ max_keys }}</MaxKeys> <MaxKeys>{{ max_keys }}</MaxKeys>
<KeyCount>{{ result_keys | length }}</KeyCount> <KeyCount>{{ key_count }}</KeyCount>
{% if delimiter %} {% if delimiter %}
<Delimiter>{{ delimiter }}</Delimiter> <Delimiter>{{ delimiter }}</Delimiter>
{% endif %} {% endif %}

View File

@ -7,15 +7,6 @@ url_bases = [
r"https?://(?P<bucket_name>[a-zA-Z0-9\-_.]*)\.?s3(.*).amazonaws.com" r"https?://(?P<bucket_name>[a-zA-Z0-9\-_.]*)\.?s3(.*).amazonaws.com"
] ]
def ambiguous_response1(*args, **kwargs):
return S3ResponseInstance.ambiguous_response(*args, **kwargs)
def ambiguous_response2(*args, **kwargs):
return S3ResponseInstance.ambiguous_response(*args, **kwargs)
url_paths = { url_paths = {
# subdomain bucket # subdomain bucket
'{0}/$': S3ResponseInstance.bucket_response, '{0}/$': S3ResponseInstance.bucket_response,

View File

@ -70,24 +70,31 @@ class SecretsManagerBackend(BaseBackend):
secret_version = secret['versions'][version_id] secret_version = secret['versions'][version_id]
response = json.dumps({ response_data = {
"ARN": secret_arn(self.region, secret['secret_id']), "ARN": secret_arn(self.region, secret['secret_id']),
"Name": secret['name'], "Name": secret['name'],
"VersionId": secret_version['version_id'], "VersionId": secret_version['version_id'],
"SecretString": secret_version['secret_string'],
"VersionStages": secret_version['version_stages'], "VersionStages": secret_version['version_stages'],
"CreatedDate": secret_version['createdate'], "CreatedDate": secret_version['createdate'],
}) }
if 'secret_string' in secret_version:
response_data["SecretString"] = secret_version['secret_string']
if 'secret_binary' in secret_version:
response_data["SecretBinary"] = secret_version['secret_binary']
response = json.dumps(response_data)
return response return response
def create_secret(self, name, secret_string, tags, **kwargs): def create_secret(self, name, secret_string=None, secret_binary=None, tags=[], **kwargs):
# error if secret exists # error if secret exists
if name in self.secrets.keys(): if name in self.secrets.keys():
raise ResourceExistsException('A resource with the ID you requested already exists.') raise ResourceExistsException('A resource with the ID you requested already exists.')
version_id = self._add_secret(name, secret_string, tags=tags) version_id = self._add_secret(name, secret_string=secret_string, secret_binary=secret_binary, tags=tags)
response = json.dumps({ response = json.dumps({
"ARN": secret_arn(self.region, name), "ARN": secret_arn(self.region, name),
@ -97,7 +104,7 @@ class SecretsManagerBackend(BaseBackend):
return response return response
def _add_secret(self, secret_id, secret_string, tags=[], version_id=None, version_stages=None): def _add_secret(self, secret_id, secret_string=None, secret_binary=None, tags=[], version_id=None, version_stages=None):
if version_stages is None: if version_stages is None:
version_stages = ['AWSCURRENT'] version_stages = ['AWSCURRENT']
@ -106,12 +113,17 @@ class SecretsManagerBackend(BaseBackend):
version_id = str(uuid.uuid4()) version_id = str(uuid.uuid4())
secret_version = { secret_version = {
'secret_string': secret_string,
'createdate': int(time.time()), 'createdate': int(time.time()),
'version_id': version_id, 'version_id': version_id,
'version_stages': version_stages, 'version_stages': version_stages,
} }
if secret_string is not None:
secret_version['secret_string'] = secret_string
if secret_binary is not None:
secret_version['secret_binary'] = secret_binary
if secret_id in self.secrets: if secret_id in self.secrets:
# remove all old AWSPREVIOUS stages # remove all old AWSPREVIOUS stages
for secret_verion_to_look_at in self.secrets[secret_id]['versions'].values(): for secret_verion_to_look_at in self.secrets[secret_id]['versions'].values():

View File

@ -21,10 +21,12 @@ class SecretsManagerResponse(BaseResponse):
def create_secret(self): def create_secret(self):
name = self._get_param('Name') name = self._get_param('Name')
secret_string = self._get_param('SecretString') secret_string = self._get_param('SecretString')
secret_binary = self._get_param('SecretBinary')
tags = self._get_param('Tags', if_none=[]) tags = self._get_param('Tags', if_none=[])
return secretsmanager_backends[self.region].create_secret( return secretsmanager_backends[self.region].create_secret(
name=name, name=name,
secret_string=secret_string, secret_string=secret_string,
secret_binary=secret_binary,
tags=tags tags=tags
) )

View File

@ -21,6 +21,16 @@ from moto.core.utils import convert_flask_to_httpretty_response
HTTP_METHODS = ["GET", "POST", "PUT", "DELETE", "HEAD", "PATCH"] HTTP_METHODS = ["GET", "POST", "PUT", "DELETE", "HEAD", "PATCH"]
DEFAULT_SERVICE_REGION = ('s3', 'us-east-1')
# Map of unsigned calls to service-region as per AWS API docs
# https://docs.aws.amazon.com/cognito/latest/developerguide/resource-permissions.html#amazon-cognito-signed-versus-unsigned-apis
UNSIGNED_REQUESTS = {
'AWSCognitoIdentityService': ('cognito-identity', 'us-east-1'),
'AWSCognitoIdentityProviderService': ('cognito-idp', 'us-east-1'),
}
class DomainDispatcherApplication(object): class DomainDispatcherApplication(object):
""" """
Dispatch requests to different applications based on the "Host:" header Dispatch requests to different applications based on the "Host:" header
@ -48,7 +58,45 @@ class DomainDispatcherApplication(object):
if re.match(url_base, 'http://%s' % host): if re.match(url_base, 'http://%s' % host):
return backend_name return backend_name
raise RuntimeError('Invalid host: "%s"' % host) def infer_service_region_host(self, environ):
auth = environ.get('HTTP_AUTHORIZATION')
if auth:
# Signed request
# Parse auth header to find service assuming a SigV4 request
# https://docs.aws.amazon.com/general/latest/gr/sigv4-signed-request-examples.html
# ['Credential=sdffdsa', '20170220', 'us-east-1', 'sns', 'aws4_request']
try:
credential_scope = auth.split(",")[0].split()[1]
_, _, region, service, _ = credential_scope.split("/")
except ValueError:
# Signature format does not match, this is exceptional and we can't
# infer a service-region. A reduced set of services still use
# the deprecated SigV2, ergo prefer S3 as most likely default.
# https://docs.aws.amazon.com/general/latest/gr/signature-version-2.html
service, region = DEFAULT_SERVICE_REGION
else:
# Unsigned request
target = environ.get('HTTP_X_AMZ_TARGET')
if target:
service, _ = target.split('.', 1)
service, region = UNSIGNED_REQUESTS.get(service, DEFAULT_SERVICE_REGION)
else:
# S3 is the last resort when the target is also unknown
service, region = DEFAULT_SERVICE_REGION
if service == 'dynamodb':
if environ['HTTP_X_AMZ_TARGET'].startswith('DynamoDBStreams'):
host = 'dynamodbstreams'
else:
dynamo_api_version = environ['HTTP_X_AMZ_TARGET'].split("_")[1].split(".")[0]
# If Newer API version, use dynamodb2
if dynamo_api_version > "20111205":
host = "dynamodb2"
else:
host = "{service}.{region}.amazonaws.com".format(
service=service, region=region)
return host
def get_application(self, environ): def get_application(self, environ):
path_info = environ.get('PATH_INFO', '') path_info = environ.get('PATH_INFO', '')
@ -65,34 +113,14 @@ class DomainDispatcherApplication(object):
host = "instance_metadata" host = "instance_metadata"
else: else:
host = environ['HTTP_HOST'].split(':')[0] host = environ['HTTP_HOST'].split(':')[0]
if host in {'localhost', 'motoserver'} or host.startswith("192.168."):
# Fall back to parsing auth header to find service
# ['Credential=sdffdsa', '20170220', 'us-east-1', 'sns', 'aws4_request']
try:
_, _, region, service, _ = environ['HTTP_AUTHORIZATION'].split(",")[0].split()[
1].split("/")
except (KeyError, ValueError):
# Some cognito-idp endpoints (e.g. change password) do not receive an auth header.
if environ.get('HTTP_X_AMZ_TARGET', '').startswith('AWSCognitoIdentityProviderService'):
service = 'cognito-idp'
else:
service = 's3'
region = 'us-east-1'
if service == 'dynamodb':
if environ['HTTP_X_AMZ_TARGET'].startswith('DynamoDBStreams'):
host = 'dynamodbstreams'
else:
dynamo_api_version = environ['HTTP_X_AMZ_TARGET'].split("_")[1].split(".")[0]
# If Newer API version, use dynamodb2
if dynamo_api_version > "20111205":
host = "dynamodb2"
else:
host = "{service}.{region}.amazonaws.com".format(
service=service, region=region)
with self.lock: with self.lock:
backend = self.get_backend_for_host(host) backend = self.get_backend_for_host(host)
if not backend:
# No regular backend found; try parsing other headers
host = self.infer_service_region_host(environ)
backend = self.get_backend_for_host(host)
app = self.app_instances.get(backend, None) app = self.app_instances.get(backend, None)
if app is None: if app is None:
app = self.create_app(backend) app = self.create_app(backend)

View File

@ -4,13 +4,41 @@ import email
from email.utils import parseaddr from email.utils import parseaddr
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.sns.models import sns_backends
from .exceptions import MessageRejectedError from .exceptions import MessageRejectedError
from .utils import get_random_message_id from .utils import get_random_message_id
from .feedback import COMMON_MAIL, BOUNCE, COMPLAINT, DELIVERY
RECIPIENT_LIMIT = 50 RECIPIENT_LIMIT = 50
class SESFeedback(BaseModel):
BOUNCE = "Bounce"
COMPLAINT = "Complaint"
DELIVERY = "Delivery"
SUCCESS_ADDR = "success"
BOUNCE_ADDR = "bounce"
COMPLAINT_ADDR = "complaint"
FEEDBACK_SUCCESS_MSG = {"test": "success"}
FEEDBACK_BOUNCE_MSG = {"test": "bounce"}
FEEDBACK_COMPLAINT_MSG = {"test": "complaint"}
@staticmethod
def generate_message(msg_type):
msg = dict(COMMON_MAIL)
if msg_type == SESFeedback.BOUNCE:
msg["bounce"] = BOUNCE
elif msg_type == SESFeedback.COMPLAINT:
msg["complaint"] = COMPLAINT
elif msg_type == SESFeedback.DELIVERY:
msg["delivery"] = DELIVERY
return msg
class Message(BaseModel): class Message(BaseModel):
def __init__(self, message_id, source, subject, body, destinations): def __init__(self, message_id, source, subject, body, destinations):
@ -48,6 +76,7 @@ class SESBackend(BaseBackend):
self.domains = [] self.domains = []
self.sent_messages = [] self.sent_messages = []
self.sent_message_count = 0 self.sent_message_count = 0
self.sns_topics = {}
def _is_verified_address(self, source): def _is_verified_address(self, source):
_, address = parseaddr(source) _, address = parseaddr(source)
@ -77,7 +106,7 @@ class SESBackend(BaseBackend):
else: else:
self.domains.remove(identity) self.domains.remove(identity)
def send_email(self, source, subject, body, destinations): def send_email(self, source, subject, body, destinations, region):
recipient_count = sum(map(len, destinations.values())) recipient_count = sum(map(len, destinations.values()))
if recipient_count > RECIPIENT_LIMIT: if recipient_count > RECIPIENT_LIMIT:
raise MessageRejectedError('Too many recipients.') raise MessageRejectedError('Too many recipients.')
@ -86,13 +115,46 @@ class SESBackend(BaseBackend):
"Email address not verified %s" % source "Email address not verified %s" % source
) )
self.__process_sns_feedback__(source, destinations, region)
message_id = get_random_message_id() message_id = get_random_message_id()
message = Message(message_id, source, subject, body, destinations) message = Message(message_id, source, subject, body, destinations)
self.sent_messages.append(message) self.sent_messages.append(message)
self.sent_message_count += recipient_count self.sent_message_count += recipient_count
return message return message
def send_raw_email(self, source, destinations, raw_data): def __type_of_message__(self, destinations):
"""Checks the destination for any special address that could indicate delivery, complaint or bounce
like in SES simualtor"""
alladdress = destinations.get("ToAddresses", []) + destinations.get("CcAddresses", []) + destinations.get("BccAddresses", [])
for addr in alladdress:
if SESFeedback.SUCCESS_ADDR in addr:
return SESFeedback.DELIVERY
elif SESFeedback.COMPLAINT_ADDR in addr:
return SESFeedback.COMPLAINT
elif SESFeedback.BOUNCE_ADDR in addr:
return SESFeedback.BOUNCE
return None
def __generate_feedback__(self, msg_type):
"""Generates the SNS message for the feedback"""
return SESFeedback.generate_message(msg_type)
def __process_sns_feedback__(self, source, destinations, region):
domain = str(source)
if "@" in domain:
domain = domain.split("@")[1]
if domain in self.sns_topics:
msg_type = self.__type_of_message__(destinations)
if msg_type is not None:
sns_topic = self.sns_topics[domain].get(msg_type, None)
if sns_topic is not None:
message = self.__generate_feedback__(msg_type)
if message:
sns_backends[region].publish(sns_topic, message)
def send_raw_email(self, source, destinations, raw_data, region):
if source is not None: if source is not None:
_, source_email_address = parseaddr(source) _, source_email_address = parseaddr(source)
if source_email_address not in self.addresses: if source_email_address not in self.addresses:
@ -122,6 +184,8 @@ class SESBackend(BaseBackend):
if recipient_count > RECIPIENT_LIMIT: if recipient_count > RECIPIENT_LIMIT:
raise MessageRejectedError('Too many recipients.') raise MessageRejectedError('Too many recipients.')
self.__process_sns_feedback__(source, destinations, region)
self.sent_message_count += recipient_count self.sent_message_count += recipient_count
message_id = get_random_message_id() message_id = get_random_message_id()
message = RawMessage(message_id, source, destinations, raw_data) message = RawMessage(message_id, source, destinations, raw_data)
@ -131,5 +195,16 @@ class SESBackend(BaseBackend):
def get_send_quota(self): def get_send_quota(self):
return SESQuota(self.sent_message_count) return SESQuota(self.sent_message_count)
def set_identity_notification_topic(self, identity, notification_type, sns_topic):
identity_sns_topics = self.sns_topics.get(identity, {})
if sns_topic is None:
del identity_sns_topics[notification_type]
else:
identity_sns_topics[notification_type] = sns_topic
self.sns_topics[identity] = identity_sns_topics
return {}
ses_backend = SESBackend() ses_backend = SESBackend()

View File

@ -70,7 +70,7 @@ class EmailResponse(BaseResponse):
break break
destinations[dest_type].append(address[0]) destinations[dest_type].append(address[0])
message = ses_backend.send_email(source, subject, body, destinations) message = ses_backend.send_email(source, subject, body, destinations, self.region)
template = self.response_template(SEND_EMAIL_RESPONSE) template = self.response_template(SEND_EMAIL_RESPONSE)
return template.render(message=message) return template.render(message=message)
@ -92,7 +92,7 @@ class EmailResponse(BaseResponse):
break break
destinations.append(address[0]) destinations.append(address[0])
message = ses_backend.send_raw_email(source, destinations, raw_data) message = ses_backend.send_raw_email(source, destinations, raw_data, self.region)
template = self.response_template(SEND_RAW_EMAIL_RESPONSE) template = self.response_template(SEND_RAW_EMAIL_RESPONSE)
return template.render(message=message) return template.render(message=message)
@ -101,6 +101,18 @@ class EmailResponse(BaseResponse):
template = self.response_template(GET_SEND_QUOTA_RESPONSE) template = self.response_template(GET_SEND_QUOTA_RESPONSE)
return template.render(quota=quota) return template.render(quota=quota)
def set_identity_notification_topic(self):
identity = self.querystring.get("Identity")[0]
not_type = self.querystring.get("NotificationType")[0]
sns_topic = self.querystring.get("SnsTopic")
if sns_topic:
sns_topic = sns_topic[0]
ses_backend.set_identity_notification_topic(identity, not_type, sns_topic)
template = self.response_template(SET_IDENTITY_NOTIFICATION_TOPIC_RESPONSE)
return template.render()
VERIFY_EMAIL_IDENTITY = """<VerifyEmailIdentityResponse xmlns="http://ses.amazonaws.com/doc/2010-12-01/"> VERIFY_EMAIL_IDENTITY = """<VerifyEmailIdentityResponse xmlns="http://ses.amazonaws.com/doc/2010-12-01/">
<VerifyEmailIdentityResult/> <VerifyEmailIdentityResult/>
@ -200,3 +212,10 @@ GET_SEND_QUOTA_RESPONSE = """<GetSendQuotaResponse xmlns="http://ses.amazonaws.c
<RequestId>273021c6-c866-11e0-b926-699e21c3af9e</RequestId> <RequestId>273021c6-c866-11e0-b926-699e21c3af9e</RequestId>
</ResponseMetadata> </ResponseMetadata>
</GetSendQuotaResponse>""" </GetSendQuotaResponse>"""
SET_IDENTITY_NOTIFICATION_TOPIC_RESPONSE = """<SetIdentityNotificationTopicResponse xmlns="http://ses.amazonaws.com/doc/2010-12-01/">
<SetIdentityNotificationTopicResult/>
<ResponseMetadata>
<RequestId>47e0ef1a-9bf2-11e1-9279-0100e8cf109a</RequestId>
</ResponseMetadata>
</SetIdentityNotificationTopicResponse>"""

View File

@ -1,3 +1,4 @@
import os import os
TEST_SERVER_MODE = os.environ.get('TEST_SERVER_MODE', '0').lower() == 'true' TEST_SERVER_MODE = os.environ.get('TEST_SERVER_MODE', '0').lower() == 'true'
INITIAL_NO_AUTH_ACTION_COUNT = float(os.environ.get('INITIAL_NO_AUTH_ACTION_COUNT', float('inf')))

View File

@ -12,7 +12,7 @@ from boto3 import Session
from moto.compat import OrderedDict from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_with_milliseconds from moto.core.utils import iso_8601_datetime_with_milliseconds, camelcase_to_underscores
from moto.sqs import sqs_backends from moto.sqs import sqs_backends
from moto.awslambda import lambda_backends from moto.awslambda import lambda_backends
@ -119,7 +119,7 @@ class Subscription(BaseModel):
else: else:
assert False assert False
lambda_backends[region].send_message(function_name, message, subject=subject, qualifier=qualifier) lambda_backends[region].send_sns_message(function_name, message, subject=subject, qualifier=qualifier)
def _matches_filter_policy(self, message_attributes): def _matches_filter_policy(self, message_attributes):
# TODO: support Anything-but matching, prefix matching and # TODO: support Anything-but matching, prefix matching and
@ -243,11 +243,14 @@ class SNSBackend(BaseBackend):
def update_sms_attributes(self, attrs): def update_sms_attributes(self, attrs):
self.sms_attributes.update(attrs) self.sms_attributes.update(attrs)
def create_topic(self, name): def create_topic(self, name, attributes=None):
fails_constraints = not re.match(r'^[a-zA-Z0-9_-]{1,256}$', name) fails_constraints = not re.match(r'^[a-zA-Z0-9_-]{1,256}$', name)
if fails_constraints: if fails_constraints:
raise InvalidParameterValue("Topic names must be made up of only uppercase and lowercase ASCII letters, numbers, underscores, and hyphens, and must be between 1 and 256 characters long.") raise InvalidParameterValue("Topic names must be made up of only uppercase and lowercase ASCII letters, numbers, underscores, and hyphens, and must be between 1 and 256 characters long.")
candidate_topic = Topic(name, self) candidate_topic = Topic(name, self)
if attributes:
for attribute in attributes:
setattr(candidate_topic, camelcase_to_underscores(attribute), attributes[attribute])
if candidate_topic.arn in self.topics: if candidate_topic.arn in self.topics:
return self.topics[candidate_topic.arn] return self.topics[candidate_topic.arn]
else: else:

View File

@ -75,7 +75,8 @@ class SNSResponse(BaseResponse):
def create_topic(self): def create_topic(self):
name = self._get_param('Name') name = self._get_param('Name')
topic = self.backend.create_topic(name) attributes = self._get_attributes()
topic = self.backend.create_topic(name, attributes)
if self.request_json: if self.request_json:
return json.dumps({ return json.dumps({

View File

@ -189,6 +189,8 @@ class Queue(BaseModel):
self.name) self.name)
self.dead_letter_queue = None self.dead_letter_queue = None
self.lambda_event_source_mappings = {}
# default settings for a non fifo queue # default settings for a non fifo queue
defaults = { defaults = {
'ContentBasedDeduplication': 'false', 'ContentBasedDeduplication': 'false',
@ -360,6 +362,33 @@ class Queue(BaseModel):
def add_message(self, message): def add_message(self, message):
self._messages.append(message) self._messages.append(message)
from moto.awslambda import lambda_backends
for arn, esm in self.lambda_event_source_mappings.items():
backend = sqs_backends[self.region]
"""
Lambda polls the queue and invokes your function synchronously with an event
that contains queue messages. Lambda reads messages in batches and invokes
your function once for each batch. When your function successfully processes
a batch, Lambda deletes its messages from the queue.
"""
messages = backend.receive_messages(
self.name,
esm.batch_size,
self.receive_message_wait_time_seconds,
self.visibility_timeout,
)
result = lambda_backends[self.region].send_sqs_batch(
arn,
messages,
self.queue_arn,
)
if result:
[backend.delete_message(self.name, m.receipt_handle) for m in messages]
else:
[backend.change_message_visibility(self.name, m.receipt_handle, 0) for m in messages]
def get_cfn_attribute(self, attribute_name): def get_cfn_attribute(self, attribute_name):
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
@ -379,6 +408,7 @@ class SQSBackend(BaseBackend):
def reset(self): def reset(self):
region_name = self.region_name region_name = self.region_name
self._reset_model_refs()
self.__dict__ = {} self.__dict__ = {}
self.__init__(region_name) self.__init__(region_name)

View File

@ -2,6 +2,8 @@ from __future__ import unicode_literals
import datetime import datetime
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_with_milliseconds from moto.core.utils import iso_8601_datetime_with_milliseconds
from moto.iam.models import ACCOUNT_ID
from moto.sts.utils import random_access_key_id, random_secret_access_key, random_session_token, random_assumed_role_id
class Token(BaseModel): class Token(BaseModel):
@ -21,19 +23,38 @@ class AssumedRole(BaseModel):
def __init__(self, role_session_name, role_arn, policy, duration, external_id): def __init__(self, role_session_name, role_arn, policy, duration, external_id):
self.session_name = role_session_name self.session_name = role_session_name
self.arn = role_arn self.role_arn = role_arn
self.policy = policy self.policy = policy
now = datetime.datetime.utcnow() now = datetime.datetime.utcnow()
self.expiration = now + datetime.timedelta(seconds=duration) self.expiration = now + datetime.timedelta(seconds=duration)
self.external_id = external_id self.external_id = external_id
self.access_key_id = "ASIA" + random_access_key_id()
self.secret_access_key = random_secret_access_key()
self.session_token = random_session_token()
self.assumed_role_id = "AROA" + random_assumed_role_id()
@property @property
def expiration_ISO8601(self): def expiration_ISO8601(self):
return iso_8601_datetime_with_milliseconds(self.expiration) return iso_8601_datetime_with_milliseconds(self.expiration)
@property
def user_id(self):
return self.assumed_role_id + ":" + self.session_name
@property
def arn(self):
return "arn:aws:sts::{account_id}:assumed-role/{role_name}/{session_name}".format(
account_id=ACCOUNT_ID,
role_name=self.role_arn.split("/")[-1],
session_name=self.session_name
)
class STSBackend(BaseBackend): class STSBackend(BaseBackend):
def __init__(self):
self.assumed_roles = []
def get_session_token(self, duration): def get_session_token(self, duration):
token = Token(duration=duration) token = Token(duration=duration)
return token return token
@ -44,7 +65,17 @@ class STSBackend(BaseBackend):
def assume_role(self, **kwargs): def assume_role(self, **kwargs):
role = AssumedRole(**kwargs) role = AssumedRole(**kwargs)
self.assumed_roles.append(role)
return role return role
def get_assumed_role_from_access_key(self, access_key_id):
for assumed_role in self.assumed_roles:
if assumed_role.access_key_id == access_key_id:
return assumed_role
return None
def assume_role_with_web_identity(self, **kwargs):
return self.assume_role(**kwargs)
sts_backend = STSBackend() sts_backend = STSBackend()

View File

@ -1,8 +1,13 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.iam.models import ACCOUNT_ID
from moto.iam import iam_backend
from .exceptions import STSValidationError
from .models import sts_backend from .models import sts_backend
MAX_FEDERATION_TOKEN_POLICY_LENGTH = 2048
class TokenResponse(BaseResponse): class TokenResponse(BaseResponse):
@ -15,11 +20,20 @@ class TokenResponse(BaseResponse):
def get_federation_token(self): def get_federation_token(self):
duration = int(self.querystring.get('DurationSeconds', [43200])[0]) duration = int(self.querystring.get('DurationSeconds', [43200])[0])
policy = self.querystring.get('Policy', [None])[0] policy = self.querystring.get('Policy', [None])[0]
if policy is not None and len(policy) > MAX_FEDERATION_TOKEN_POLICY_LENGTH:
raise STSValidationError(
"1 validation error detected: Value "
"'{\"Version\": \"2012-10-17\", \"Statement\": [...]}' "
"at 'policy' failed to satisfy constraint: Member must have length less than or "
" equal to %s" % MAX_FEDERATION_TOKEN_POLICY_LENGTH
)
name = self.querystring.get('Name')[0] name = self.querystring.get('Name')[0]
token = sts_backend.get_federation_token( token = sts_backend.get_federation_token(
duration=duration, name=name, policy=policy) duration=duration, name=name, policy=policy)
template = self.response_template(GET_FEDERATION_TOKEN_RESPONSE) template = self.response_template(GET_FEDERATION_TOKEN_RESPONSE)
return template.render(token=token) return template.render(token=token, account_id=ACCOUNT_ID)
def assume_role(self): def assume_role(self):
role_session_name = self.querystring.get('RoleSessionName')[0] role_session_name = self.querystring.get('RoleSessionName')[0]
@ -39,9 +53,43 @@ class TokenResponse(BaseResponse):
template = self.response_template(ASSUME_ROLE_RESPONSE) template = self.response_template(ASSUME_ROLE_RESPONSE)
return template.render(role=role) return template.render(role=role)
def assume_role_with_web_identity(self):
role_session_name = self.querystring.get('RoleSessionName')[0]
role_arn = self.querystring.get('RoleArn')[0]
policy = self.querystring.get('Policy', [None])[0]
duration = int(self.querystring.get('DurationSeconds', [3600])[0])
external_id = self.querystring.get('ExternalId', [None])[0]
role = sts_backend.assume_role_with_web_identity(
role_session_name=role_session_name,
role_arn=role_arn,
policy=policy,
duration=duration,
external_id=external_id,
)
template = self.response_template(ASSUME_ROLE_WITH_WEB_IDENTITY_RESPONSE)
return template.render(role=role)
def get_caller_identity(self): def get_caller_identity(self):
template = self.response_template(GET_CALLER_IDENTITY_RESPONSE) template = self.response_template(GET_CALLER_IDENTITY_RESPONSE)
return template.render()
# Default values in case the request does not use valid credentials generated by moto
user_id = "AKIAIOSFODNN7EXAMPLE"
arn = "arn:aws:sts::{account_id}:user/moto".format(account_id=ACCOUNT_ID)
access_key_id = self.get_current_user()
assumed_role = sts_backend.get_assumed_role_from_access_key(access_key_id)
if assumed_role:
user_id = assumed_role.user_id
arn = assumed_role.arn
user = iam_backend.get_user_from_access_key_id(access_key_id)
if user:
user_id = user.id
arn = user.arn
return template.render(account_id=ACCOUNT_ID, user_id=user_id, arn=arn)
GET_SESSION_TOKEN_RESPONSE = """<GetSessionTokenResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/"> GET_SESSION_TOKEN_RESPONSE = """<GetSessionTokenResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
@ -69,8 +117,8 @@ GET_FEDERATION_TOKEN_RESPONSE = """<GetFederationTokenResponse xmlns="https://st
<AccessKeyId>AKIAIOSFODNN7EXAMPLE</AccessKeyId> <AccessKeyId>AKIAIOSFODNN7EXAMPLE</AccessKeyId>
</Credentials> </Credentials>
<FederatedUser> <FederatedUser>
<Arn>arn:aws:sts::123456789012:federated-user/{{ token.name }}</Arn> <Arn>arn:aws:sts::{{ account_id }}:federated-user/{{ token.name }}</Arn>
<FederatedUserId>123456789012:{{ token.name }}</FederatedUserId> <FederatedUserId>{{ account_id }}:{{ token.name }}</FederatedUserId>
</FederatedUser> </FederatedUser>
<PackedPolicySize>6</PackedPolicySize> <PackedPolicySize>6</PackedPolicySize>
</GetFederationTokenResult> </GetFederationTokenResult>
@ -84,14 +132,14 @@ ASSUME_ROLE_RESPONSE = """<AssumeRoleResponse xmlns="https://sts.amazonaws.com/d
2011-06-15/"> 2011-06-15/">
<AssumeRoleResult> <AssumeRoleResult>
<Credentials> <Credentials>
<SessionToken>BQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKwRcOIfrRh3c/LTo6UDdyJwOOvEVPvLXCrrrUtdnniCEXAMPLE/IvU1dYUg2RVAJBanLiHb4IgRmpRV3zrkuWJOgQs8IZZaIv2BXIa2R4OlgkBN9bkUDNCJiBeb/AXlzBBko7b15fjrBs2+cTQtpZ3CYWFXG8C5zqx37wnOE49mRl/+OtkIKGO7fAE</SessionToken> <SessionToken>{{ role.session_token }}</SessionToken>
<SecretAccessKey>aJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY</SecretAccessKey> <SecretAccessKey>{{ role.secret_access_key }}</SecretAccessKey>
<Expiration>{{ role.expiration_ISO8601 }}</Expiration> <Expiration>{{ role.expiration_ISO8601 }}</Expiration>
<AccessKeyId>AKIAIOSFODNN7EXAMPLE</AccessKeyId> <AccessKeyId>{{ role.access_key_id }}</AccessKeyId>
</Credentials> </Credentials>
<AssumedRoleUser> <AssumedRoleUser>
<Arn>{{ role.arn }}</Arn> <Arn>{{ role.arn }}</Arn>
<AssumedRoleId>ARO123EXAMPLE123:{{ role.session_name }}</AssumedRoleId> <AssumedRoleId>{{ role.user_id }}</AssumedRoleId>
</AssumedRoleUser> </AssumedRoleUser>
<PackedPolicySize>6</PackedPolicySize> <PackedPolicySize>6</PackedPolicySize>
</AssumeRoleResult> </AssumeRoleResult>
@ -100,11 +148,32 @@ ASSUME_ROLE_RESPONSE = """<AssumeRoleResponse xmlns="https://sts.amazonaws.com/d
</ResponseMetadata> </ResponseMetadata>
</AssumeRoleResponse>""" </AssumeRoleResponse>"""
ASSUME_ROLE_WITH_WEB_IDENTITY_RESPONSE = """<AssumeRoleWithWebIdentityResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<AssumeRoleWithWebIdentityResult>
<Credentials>
<SessionToken>{{ role.session_token }}</SessionToken>
<SecretAccessKey>{{ role.secret_access_key }}</SecretAccessKey>
<Expiration>{{ role.expiration_ISO8601 }}</Expiration>
<AccessKeyId>{{ role.access_key_id }}</AccessKeyId>
</Credentials>
<AssumedRoleUser>
<Arn>{{ role.arn }}</Arn>
<AssumedRoleId>ARO123EXAMPLE123:{{ role.session_name }}</AssumedRoleId>
</AssumedRoleUser>
<PackedPolicySize>6</PackedPolicySize>
</AssumeRoleWithWebIdentityResult>
<ResponseMetadata>
<RequestId>c6104cbe-af31-11e0-8154-cbc7ccf896c7</RequestId>
</ResponseMetadata>
</AssumeRoleWithWebIdentityResponse>"""
GET_CALLER_IDENTITY_RESPONSE = """<GetCallerIdentityResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/"> GET_CALLER_IDENTITY_RESPONSE = """<GetCallerIdentityResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<GetCallerIdentityResult> <GetCallerIdentityResult>
<Arn>arn:aws:sts::123456789012:user/moto</Arn> <Arn>{{ arn }}</Arn>
<UserId>AKIAIOSFODNN7EXAMPLE</UserId> <UserId>{{ user_id }}</UserId>
<Account>123456789012</Account> <Account>{{ account_id }}</Account>
</GetCallerIdentityResult> </GetCallerIdentityResult>
<ResponseMetadata> <ResponseMetadata>
<RequestId>c6104cbe-af31-11e0-8154-cbc7ccf896c7</RequestId> <RequestId>c6104cbe-af31-11e0-8154-cbc7ccf896c7</RequestId>

View File

@ -61,7 +61,8 @@ def print_implementation_coverage(coverage):
percentage_implemented = 0 percentage_implemented = 0
print("") print("")
print("## {} - {}% implemented".format(service_name, percentage_implemented)) print("## {}\n".format(service_name))
print("{}% implemented\n".format(percentage_implemented))
for op in operations: for op in operations:
if op in implemented: if op in implemented:
print("- [X] {}".format(op)) print("- [X] {}".format(op))
@ -93,7 +94,8 @@ def write_implementation_coverage_to_file(coverage):
percentage_implemented = 0 percentage_implemented = 0
file.write("\n") file.write("\n")
file.write("## {} - {}% implemented\n".format(service_name, percentage_implemented)) file.write("## {}\n".format(service_name))
file.write("{}% implemented\n".format(percentage_implemented))
for op in operations: for op in operations:
if op in implemented: if op in implemented:
file.write("- [X] {}\n".format(op)) file.write("- [X] {}\n".format(op))

View File

@ -48,7 +48,8 @@ for policy_name in policies:
PolicyArn=policies[policy_name]['Arn'], PolicyArn=policies[policy_name]['Arn'],
VersionId=policies[policy_name]['DefaultVersionId']) VersionId=policies[policy_name]['DefaultVersionId'])
for key in response['PolicyVersion']: for key in response['PolicyVersion']:
policies[policy_name][key] = response['PolicyVersion'][key] if key != "CreateDate": # the policy's CreateDate should not be overwritten by its version's CreateDate
policies[policy_name][key] = response['PolicyVersion'][key]
with open(output_file, 'w') as f: with open(output_file, 'w') as f:
triple_quote = '\"\"\"' triple_quote = '\"\"\"'

View File

@ -18,17 +18,26 @@ def read(*parts):
return fp.read() return fp.read()
def get_version():
version_file = read('moto', '__init__.py')
version_match = re.search(r'^__version__ = [\'"]([^\'"]*)[\'"]',
version_file, re.MULTILINE)
if version_match:
return version_match.group(1)
raise RuntimeError('Unable to find version string.')
install_requires = [ install_requires = [
"Jinja2>=2.10.1", "Jinja2>=2.10.1",
"boto>=2.36.0", "boto>=2.36.0",
"boto3>=1.9.86", "boto3>=1.9.201",
"botocore>=1.12.86", "botocore>=1.12.201",
"cryptography>=2.3.0", "cryptography>=2.3.0",
"requests>=2.5", "requests>=2.5",
"xmltodict", "xmltodict",
"six>1.9", "six>1.9",
"werkzeug", "werkzeug",
"PyYAML==3.13", "PyYAML>=5.1",
"pytz", "pytz",
"python-dateutil<3.0.0,>=2.1", "python-dateutil<3.0.0,>=2.1",
"python-jose<4.0.0", "python-jose<4.0.0",
@ -38,7 +47,7 @@ install_requires = [
"aws-xray-sdk!=0.96,>=0.93", "aws-xray-sdk!=0.96,>=0.93",
"responses>=0.9.0", "responses>=0.9.0",
"idna<2.9,>=2.5", "idna<2.9,>=2.5",
"cfn-lint", "cfn-lint>=0.4.0",
"sshpubkeys>=3.1.0,<4.0" "sshpubkeys>=3.1.0,<4.0"
] ]
@ -56,7 +65,7 @@ else:
setup( setup(
name='moto', name='moto',
version='1.3.8', version=get_version(),
description='A library that allows your python tests to easily' description='A library that allows your python tests to easily'
' mock out the boto library', ' mock out the boto library',
long_description=read('README.md'), long_description=read('README.md'),
@ -79,10 +88,9 @@ setup(
"Programming Language :: Python :: 2", "Programming Language :: Python :: 2",
"Programming Language :: Python :: 2.7", "Programming Language :: Python :: 2.7",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.3",
"Programming Language :: Python :: 3.4",
"Programming Language :: Python :: 3.5", "Programming Language :: Python :: 3.5",
"Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"License :: OSI Approved :: Apache Software License", "License :: OSI Approved :: Apache Software License",
"Topic :: Software Development :: Testing", "Topic :: Software Development :: Testing",
], ],

View File

@ -74,6 +74,31 @@ def test_list_certificates():
resp['CertificateSummaryList'][0]['DomainName'].should.equal(SERVER_COMMON_NAME) resp['CertificateSummaryList'][0]['DomainName'].should.equal(SERVER_COMMON_NAME)
@mock_acm
def test_list_certificates_by_status():
client = boto3.client('acm', region_name='eu-central-1')
issued_arn = _import_cert(client)
pending_arn = client.request_certificate(DomainName='google.com')['CertificateArn']
resp = client.list_certificates()
len(resp['CertificateSummaryList']).should.equal(2)
resp = client.list_certificates(CertificateStatuses=['EXPIRED', 'INACTIVE'])
len(resp['CertificateSummaryList']).should.equal(0)
resp = client.list_certificates(CertificateStatuses=['PENDING_VALIDATION'])
len(resp['CertificateSummaryList']).should.equal(1)
resp['CertificateSummaryList'][0]['CertificateArn'].should.equal(pending_arn)
resp = client.list_certificates(CertificateStatuses=['ISSUED'])
len(resp['CertificateSummaryList']).should.equal(1)
resp['CertificateSummaryList'][0]['CertificateArn'].should.equal(issued_arn)
resp = client.list_certificates(CertificateStatuses=['ISSUED', 'PENDING_VALIDATION'])
len(resp['CertificateSummaryList']).should.equal(2)
arns = {cert['CertificateArn'] for cert in resp['CertificateSummaryList']}
arns.should.contain(issued_arn)
arns.should.contain(pending_arn)
@mock_acm @mock_acm
def test_get_invalid_certificate(): def test_get_invalid_certificate():
client = boto3.client('acm', region_name='eu-central-1') client = boto3.client('acm', region_name='eu-central-1')
@ -291,6 +316,7 @@ def test_request_certificate():
) )
resp.should.contain('CertificateArn') resp.should.contain('CertificateArn')
arn = resp['CertificateArn'] arn = resp['CertificateArn']
arn.should.match(r"arn:aws:acm:eu-central-1:\d{12}:certificate/")
resp = client.request_certificate( resp = client.request_certificate(
DomainName='google.com', DomainName='google.com',

View File

@ -988,13 +988,30 @@ def test_api_keys():
apikey['name'].should.equal(apikey_name) apikey['name'].should.equal(apikey_name)
len(apikey['value']).should.equal(40) len(apikey['value']).should.equal(40)
apikey_name = 'TESTKEY3'
payload = {'name': apikey_name }
response = client.create_api_key(**payload)
apikey_id = response['id']
patch_operations = [
{'op': 'replace', 'path': '/name', 'value': 'TESTKEY3_CHANGE'},
{'op': 'replace', 'path': '/customerId', 'value': '12345'},
{'op': 'replace', 'path': '/description', 'value': 'APIKEY UPDATE TEST'},
{'op': 'replace', 'path': '/enabled', 'value': 'false'},
]
response = client.update_api_key(apiKey=apikey_id, patchOperations=patch_operations)
response['name'].should.equal('TESTKEY3_CHANGE')
response['customerId'].should.equal('12345')
response['description'].should.equal('APIKEY UPDATE TEST')
response['enabled'].should.equal(False)
response = client.get_api_keys() response = client.get_api_keys()
len(response['items']).should.equal(2) len(response['items']).should.equal(3)
client.delete_api_key(apiKey=apikey_id) client.delete_api_key(apiKey=apikey_id)
response = client.get_api_keys() response = client.get_api_keys()
len(response['items']).should.equal(1) len(response['items']).should.equal(2)
@mock_apigateway @mock_apigateway
def test_usage_plans(): def test_usage_plans():

View File

@ -7,11 +7,13 @@ from boto.ec2.autoscale.group import AutoScalingGroup
from boto.ec2.autoscale import Tag from boto.ec2.autoscale import Tag
import boto.ec2.elb import boto.ec2.elb
import sure # noqa import sure # noqa
from botocore.exceptions import ClientError
from nose.tools import assert_raises
from moto import mock_autoscaling, mock_ec2_deprecated, mock_elb_deprecated, mock_elb, mock_autoscaling_deprecated, mock_ec2 from moto import mock_autoscaling, mock_ec2_deprecated, mock_elb_deprecated, mock_elb, mock_autoscaling_deprecated, mock_ec2
from tests.helpers import requires_boto_gte from tests.helpers import requires_boto_gte
from utils import setup_networking, setup_networking_deprecated from utils import setup_networking, setup_networking_deprecated, setup_instance_with_networking
@mock_autoscaling_deprecated @mock_autoscaling_deprecated
@ -724,6 +726,67 @@ def test_create_autoscaling_group_boto3():
response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) response['ResponseMetadata']['HTTPStatusCode'].should.equal(200)
@mock_autoscaling
def test_create_autoscaling_group_from_instance():
autoscaling_group_name = 'test_asg'
image_id = 'ami-0cc293023f983ed53'
instance_type = 't2.micro'
mocked_instance_with_networking = setup_instance_with_networking(image_id, instance_type)
client = boto3.client('autoscaling', region_name='us-east-1')
response = client.create_auto_scaling_group(
AutoScalingGroupName=autoscaling_group_name,
InstanceId=mocked_instance_with_networking['instance'],
MinSize=1,
MaxSize=3,
DesiredCapacity=2,
Tags=[
{'ResourceId': 'test_asg',
'ResourceType': 'auto-scaling-group',
'Key': 'propogated-tag-key',
'Value': 'propogate-tag-value',
'PropagateAtLaunch': True
},
{'ResourceId': 'test_asg',
'ResourceType': 'auto-scaling-group',
'Key': 'not-propogated-tag-key',
'Value': 'not-propogate-tag-value',
'PropagateAtLaunch': False
}],
VPCZoneIdentifier=mocked_instance_with_networking['subnet1'],
NewInstancesProtectedFromScaleIn=False,
)
response['ResponseMetadata']['HTTPStatusCode'].should.equal(200)
describe_launch_configurations_response = client.describe_launch_configurations()
describe_launch_configurations_response['LaunchConfigurations'].should.have.length_of(1)
launch_configuration_from_instance = describe_launch_configurations_response['LaunchConfigurations'][0]
launch_configuration_from_instance['LaunchConfigurationName'].should.equal('test_asg')
launch_configuration_from_instance['ImageId'].should.equal(image_id)
launch_configuration_from_instance['InstanceType'].should.equal(instance_type)
@mock_autoscaling
def test_create_autoscaling_group_from_invalid_instance_id():
invalid_instance_id = 'invalid_instance'
mocked_networking = setup_networking()
client = boto3.client('autoscaling', region_name='us-east-1')
with assert_raises(ClientError) as ex:
client.create_auto_scaling_group(
AutoScalingGroupName='test_asg',
InstanceId=invalid_instance_id,
MinSize=9,
MaxSize=15,
DesiredCapacity=12,
VPCZoneIdentifier=mocked_networking['subnet1'],
NewInstancesProtectedFromScaleIn=False,
)
ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400)
ex.exception.response['Error']['Code'].should.equal('ValidationError')
ex.exception.response['Error']['Message'].should.equal('Instance [{0}] is invalid.'.format(invalid_instance_id))
@mock_autoscaling @mock_autoscaling
def test_describe_autoscaling_groups_boto3(): def test_describe_autoscaling_groups_boto3():
mocked_networking = setup_networking() mocked_networking = setup_networking()
@ -823,6 +886,62 @@ def test_update_autoscaling_group_boto3():
group['NewInstancesProtectedFromScaleIn'].should.equal(False) group['NewInstancesProtectedFromScaleIn'].should.equal(False)
@mock_autoscaling
def test_update_autoscaling_group_min_size_desired_capacity_change():
mocked_networking = setup_networking()
client = boto3.client('autoscaling', region_name='us-east-1')
client.create_launch_configuration(
LaunchConfigurationName='test_launch_configuration'
)
client.create_auto_scaling_group(
AutoScalingGroupName='test_asg',
LaunchConfigurationName='test_launch_configuration',
MinSize=2,
MaxSize=20,
DesiredCapacity=3,
VPCZoneIdentifier=mocked_networking['subnet1'],
)
client.update_auto_scaling_group(
AutoScalingGroupName='test_asg',
MinSize=5,
)
response = client.describe_auto_scaling_groups(
AutoScalingGroupNames=['test_asg'])
group = response['AutoScalingGroups'][0]
group['DesiredCapacity'].should.equal(5)
group['MinSize'].should.equal(5)
group['Instances'].should.have.length_of(5)
@mock_autoscaling
def test_update_autoscaling_group_max_size_desired_capacity_change():
mocked_networking = setup_networking()
client = boto3.client('autoscaling', region_name='us-east-1')
client.create_launch_configuration(
LaunchConfigurationName='test_launch_configuration'
)
client.create_auto_scaling_group(
AutoScalingGroupName='test_asg',
LaunchConfigurationName='test_launch_configuration',
MinSize=2,
MaxSize=20,
DesiredCapacity=10,
VPCZoneIdentifier=mocked_networking['subnet1'],
)
client.update_auto_scaling_group(
AutoScalingGroupName='test_asg',
MaxSize=5,
)
response = client.describe_auto_scaling_groups(
AutoScalingGroupNames=['test_asg'])
group = response['AutoScalingGroups'][0]
group['DesiredCapacity'].should.equal(5)
group['MaxSize'].should.equal(5)
group['Instances'].should.have.length_of(5)
@mock_autoscaling @mock_autoscaling
def test_autoscaling_taqs_update_boto3(): def test_autoscaling_taqs_update_boto3():
mocked_networking = setup_networking() mocked_networking = setup_networking()
@ -1269,3 +1388,36 @@ def test_set_desired_capacity_down_boto3():
instance_ids = {instance['InstanceId'] for instance in group['Instances']} instance_ids = {instance['InstanceId'] for instance in group['Instances']}
set(protected).should.equal(instance_ids) set(protected).should.equal(instance_ids)
set(unprotected).should_not.be.within(instance_ids) # only unprotected killed set(unprotected).should_not.be.within(instance_ids) # only unprotected killed
@mock_autoscaling
@mock_ec2
def test_terminate_instance_in_autoscaling_group():
mocked_networking = setup_networking()
client = boto3.client('autoscaling', region_name='us-east-1')
_ = client.create_launch_configuration(
LaunchConfigurationName='test_launch_configuration'
)
_ = client.create_auto_scaling_group(
AutoScalingGroupName='test_asg',
LaunchConfigurationName='test_launch_configuration',
MinSize=1,
MaxSize=20,
VPCZoneIdentifier=mocked_networking['subnet1'],
NewInstancesProtectedFromScaleIn=False
)
response = client.describe_auto_scaling_groups(AutoScalingGroupNames=['test_asg'])
original_instance_id = next(
instance['InstanceId']
for instance in response['AutoScalingGroups'][0]['Instances']
)
ec2_client = boto3.client('ec2', region_name='us-east-1')
ec2_client.terminate_instances(InstanceIds=[original_instance_id])
response = client.describe_auto_scaling_groups(AutoScalingGroupNames=['test_asg'])
replaced_instance_id = next(
instance['InstanceId']
for instance in response['AutoScalingGroups'][0]['Instances']
)
replaced_instance_id.should_not.equal(original_instance_id)

View File

@ -31,3 +31,18 @@ def setup_networking_deprecated():
"10.11.2.0/24", "10.11.2.0/24",
availability_zone='us-east-1b') availability_zone='us-east-1b')
return {'vpc': vpc.id, 'subnet1': subnet1.id, 'subnet2': subnet2.id} return {'vpc': vpc.id, 'subnet1': subnet1.id, 'subnet2': subnet2.id}
@mock_ec2
def setup_instance_with_networking(image_id, instance_type):
mock_data = setup_networking()
ec2 = boto3.resource('ec2', region_name='us-east-1')
instances = ec2.create_instances(
ImageId=image_id,
InstanceType=instance_type,
MaxCount=1,
MinCount=1,
SubnetId=mock_data['subnet1']
)
mock_data['instance'] = instances[0].id
return mock_data

View File

@ -1,6 +1,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import base64 import base64
import uuid
import botocore.client import botocore.client
import boto3 import boto3
import hashlib import hashlib
@ -11,11 +12,12 @@ import zipfile
import sure # noqa import sure # noqa
from freezegun import freeze_time from freezegun import freeze_time
from moto import mock_lambda, mock_s3, mock_ec2, mock_sns, mock_logs, settings from moto import mock_lambda, mock_s3, mock_ec2, mock_sns, mock_logs, settings, mock_sqs
from nose.tools import assert_raises from nose.tools import assert_raises
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
_lambda_region = 'us-west-2' _lambda_region = 'us-west-2'
boto3.setup_default_session(region_name=_lambda_region)
def _process_lambda(func_str): def _process_lambda(func_str):
@ -59,6 +61,13 @@ def lambda_handler(event, context):
""" """
return _process_lambda(pfunc) return _process_lambda(pfunc)
def get_test_zip_file4():
pfunc = """
def lambda_handler(event, context):
raise Exception('I failed!')
"""
return _process_lambda(pfunc)
@mock_lambda @mock_lambda
def test_list_functions(): def test_list_functions():
@ -933,3 +942,306 @@ def test_list_versions_by_function_for_nonexistent_function():
versions = conn.list_versions_by_function(FunctionName='testFunction') versions = conn.list_versions_by_function(FunctionName='testFunction')
assert len(versions['Versions']) == 0 assert len(versions['Versions']) == 0
@mock_logs
@mock_lambda
@mock_sqs
def test_create_event_source_mapping():
sqs = boto3.resource('sqs')
queue = sqs.create_queue(QueueName="test-sqs-queue1")
conn = boto3.client('lambda')
func = conn.create_function(
FunctionName='testFunction',
Runtime='python2.7',
Role='test-iam-role',
Handler='lambda_function.lambda_handler',
Code={
'ZipFile': get_test_zip_file3(),
},
Description='test lambda function',
Timeout=3,
MemorySize=128,
Publish=True,
)
response = conn.create_event_source_mapping(
EventSourceArn=queue.attributes['QueueArn'],
FunctionName=func['FunctionArn'],
)
assert response['EventSourceArn'] == queue.attributes['QueueArn']
assert response['FunctionArn'] == func['FunctionArn']
assert response['State'] == 'Enabled'
@mock_logs
@mock_lambda
@mock_sqs
def test_invoke_function_from_sqs():
logs_conn = boto3.client("logs")
sqs = boto3.resource('sqs')
queue = sqs.create_queue(QueueName="test-sqs-queue1")
conn = boto3.client('lambda')
func = conn.create_function(
FunctionName='testFunction',
Runtime='python2.7',
Role='test-iam-role',
Handler='lambda_function.lambda_handler',
Code={
'ZipFile': get_test_zip_file3(),
},
Description='test lambda function',
Timeout=3,
MemorySize=128,
Publish=True,
)
response = conn.create_event_source_mapping(
EventSourceArn=queue.attributes['QueueArn'],
FunctionName=func['FunctionArn'],
)
assert response['EventSourceArn'] == queue.attributes['QueueArn']
assert response['State'] == 'Enabled'
sqs_client = boto3.client('sqs')
sqs_client.send_message(QueueUrl=queue.url, MessageBody='test')
start = time.time()
while (time.time() - start) < 30:
result = logs_conn.describe_log_streams(logGroupName='/aws/lambda/testFunction')
log_streams = result.get('logStreams')
if not log_streams:
time.sleep(1)
continue
assert len(log_streams) == 1
result = logs_conn.get_log_events(logGroupName='/aws/lambda/testFunction', logStreamName=log_streams[0]['logStreamName'])
for event in result.get('events'):
if event['message'] == 'get_test_zip_file3 success':
return
time.sleep(1)
assert False, "Test Failed"
@mock_logs
@mock_lambda
@mock_sqs
def test_invoke_function_from_sqs_exception():
logs_conn = boto3.client("logs")
sqs = boto3.resource('sqs')
queue = sqs.create_queue(QueueName="test-sqs-queue1")
conn = boto3.client('lambda')
func = conn.create_function(
FunctionName='testFunction',
Runtime='python2.7',
Role='test-iam-role',
Handler='lambda_function.lambda_handler',
Code={
'ZipFile': get_test_zip_file4(),
},
Description='test lambda function',
Timeout=3,
MemorySize=128,
Publish=True,
)
response = conn.create_event_source_mapping(
EventSourceArn=queue.attributes['QueueArn'],
FunctionName=func['FunctionArn'],
)
assert response['EventSourceArn'] == queue.attributes['QueueArn']
assert response['State'] == 'Enabled'
entries = []
for i in range(3):
body = {
"uuid": str(uuid.uuid4()),
"test": "test_{}".format(i),
}
entry = {
'Id': str(i),
'MessageBody': json.dumps(body)
}
entries.append(entry)
queue.send_messages(Entries=entries)
start = time.time()
while (time.time() - start) < 30:
result = logs_conn.describe_log_streams(logGroupName='/aws/lambda/testFunction')
log_streams = result.get('logStreams')
if not log_streams:
time.sleep(1)
continue
assert len(log_streams) >= 1
result = logs_conn.get_log_events(logGroupName='/aws/lambda/testFunction', logStreamName=log_streams[0]['logStreamName'])
for event in result.get('events'):
if 'I failed!' in event['message']:
messages = queue.receive_messages(MaxNumberOfMessages=10)
# Verify messages are still visible and unprocessed
assert len(messages) is 3
return
time.sleep(1)
assert False, "Test Failed"
@mock_logs
@mock_lambda
@mock_sqs
def test_list_event_source_mappings():
sqs = boto3.resource('sqs')
queue = sqs.create_queue(QueueName="test-sqs-queue1")
conn = boto3.client('lambda')
func = conn.create_function(
FunctionName='testFunction',
Runtime='python2.7',
Role='test-iam-role',
Handler='lambda_function.lambda_handler',
Code={
'ZipFile': get_test_zip_file3(),
},
Description='test lambda function',
Timeout=3,
MemorySize=128,
Publish=True,
)
response = conn.create_event_source_mapping(
EventSourceArn=queue.attributes['QueueArn'],
FunctionName=func['FunctionArn'],
)
mappings = conn.list_event_source_mappings(EventSourceArn='123')
assert len(mappings['EventSourceMappings']) == 0
mappings = conn.list_event_source_mappings(EventSourceArn=queue.attributes['QueueArn'])
assert len(mappings['EventSourceMappings']) == 1
assert mappings['EventSourceMappings'][0]['UUID'] == response['UUID']
assert mappings['EventSourceMappings'][0]['FunctionArn'] == func['FunctionArn']
@mock_lambda
@mock_sqs
def test_get_event_source_mapping():
sqs = boto3.resource('sqs')
queue = sqs.create_queue(QueueName="test-sqs-queue1")
conn = boto3.client('lambda')
func = conn.create_function(
FunctionName='testFunction',
Runtime='python2.7',
Role='test-iam-role',
Handler='lambda_function.lambda_handler',
Code={
'ZipFile': get_test_zip_file3(),
},
Description='test lambda function',
Timeout=3,
MemorySize=128,
Publish=True,
)
response = conn.create_event_source_mapping(
EventSourceArn=queue.attributes['QueueArn'],
FunctionName=func['FunctionArn'],
)
mapping = conn.get_event_source_mapping(UUID=response['UUID'])
assert mapping['UUID'] == response['UUID']
assert mapping['FunctionArn'] == func['FunctionArn']
conn.get_event_source_mapping.when.called_with(UUID='1')\
.should.throw(botocore.client.ClientError)
@mock_lambda
@mock_sqs
def test_update_event_source_mapping():
sqs = boto3.resource('sqs')
queue = sqs.create_queue(QueueName="test-sqs-queue1")
conn = boto3.client('lambda')
func1 = conn.create_function(
FunctionName='testFunction',
Runtime='python2.7',
Role='test-iam-role',
Handler='lambda_function.lambda_handler',
Code={
'ZipFile': get_test_zip_file3(),
},
Description='test lambda function',
Timeout=3,
MemorySize=128,
Publish=True,
)
func2 = conn.create_function(
FunctionName='testFunction2',
Runtime='python2.7',
Role='test-iam-role',
Handler='lambda_function.lambda_handler',
Code={
'ZipFile': get_test_zip_file3(),
},
Description='test lambda function',
Timeout=3,
MemorySize=128,
Publish=True,
)
response = conn.create_event_source_mapping(
EventSourceArn=queue.attributes['QueueArn'],
FunctionName=func1['FunctionArn'],
)
assert response['FunctionArn'] == func1['FunctionArn']
assert response['BatchSize'] == 10
assert response['State'] == 'Enabled'
mapping = conn.update_event_source_mapping(
UUID=response['UUID'],
Enabled=False,
BatchSize=15,
FunctionName='testFunction2'
)
assert mapping['UUID'] == response['UUID']
assert mapping['FunctionArn'] == func2['FunctionArn']
assert mapping['State'] == 'Disabled'
@mock_lambda
@mock_sqs
def test_delete_event_source_mapping():
sqs = boto3.resource('sqs')
queue = sqs.create_queue(QueueName="test-sqs-queue1")
conn = boto3.client('lambda')
func1 = conn.create_function(
FunctionName='testFunction',
Runtime='python2.7',
Role='test-iam-role',
Handler='lambda_function.lambda_handler',
Code={
'ZipFile': get_test_zip_file3(),
},
Description='test lambda function',
Timeout=3,
MemorySize=128,
Publish=True,
)
response = conn.create_event_source_mapping(
EventSourceArn=queue.attributes['QueueArn'],
FunctionName=func1['FunctionArn'],
)
assert response['FunctionArn'] == func1['FunctionArn']
assert response['BatchSize'] == 10
assert response['State'] == 'Enabled'
response = conn.delete_event_source_mapping(UUID=response['UUID'])
assert response['State'] == 'Deleting'
conn.get_event_source_mapping.when.called_with(UUID=response['UUID'])\
.should.throw(botocore.client.ClientError)

View File

@ -642,6 +642,87 @@ def test_describe_task_definition():
len(resp['jobDefinitions']).should.equal(3) len(resp['jobDefinitions']).should.equal(3)
@mock_logs
@mock_ec2
@mock_ecs
@mock_iam
@mock_batch
def test_submit_job_by_name():
ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients()
vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client)
compute_name = 'test_compute_env'
resp = batch_client.create_compute_environment(
computeEnvironmentName=compute_name,
type='UNMANAGED',
state='ENABLED',
serviceRole=iam_arn
)
arn = resp['computeEnvironmentArn']
resp = batch_client.create_job_queue(
jobQueueName='test_job_queue',
state='ENABLED',
priority=123,
computeEnvironmentOrder=[
{
'order': 123,
'computeEnvironment': arn
},
]
)
queue_arn = resp['jobQueueArn']
job_definition_name = 'sleep10'
batch_client.register_job_definition(
jobDefinitionName=job_definition_name,
type='container',
containerProperties={
'image': 'busybox',
'vcpus': 1,
'memory': 128,
'command': ['sleep', '10']
}
)
batch_client.register_job_definition(
jobDefinitionName=job_definition_name,
type='container',
containerProperties={
'image': 'busybox',
'vcpus': 1,
'memory': 256,
'command': ['sleep', '10']
}
)
resp = batch_client.register_job_definition(
jobDefinitionName=job_definition_name,
type='container',
containerProperties={
'image': 'busybox',
'vcpus': 1,
'memory': 512,
'command': ['sleep', '10']
}
)
job_definition_arn = resp['jobDefinitionArn']
resp = batch_client.submit_job(
jobName='test1',
jobQueue=queue_arn,
jobDefinition=job_definition_name
)
job_id = resp['jobId']
resp_jobs = batch_client.describe_jobs(jobs=[job_id])
# batch_client.terminate_job(jobId=job_id)
len(resp_jobs['jobs']).should.equal(1)
resp_jobs['jobs'][0]['jobId'].should.equal(job_id)
resp_jobs['jobs'][0]['jobQueue'].should.equal(queue_arn)
resp_jobs['jobs'][0]['jobDefinition'].should.equal(job_definition_arn)
# SLOW TESTS # SLOW TESTS
@expected_failure @expected_failure
@mock_logs @mock_logs

View File

@ -593,9 +593,11 @@ def test_create_stack_lambda_and_dynamodb():
} }
}, },
"func1version": { "func1version": {
"Type": "AWS::Lambda::LambdaVersion", "Type": "AWS::Lambda::Version",
"Properties" : { "Properties": {
"Version": "v1.2.3" "FunctionName": {
"Ref": "func1"
}
} }
}, },
"tab1": { "tab1": {
@ -618,8 +620,10 @@ def test_create_stack_lambda_and_dynamodb():
}, },
"func1mapping": { "func1mapping": {
"Type": "AWS::Lambda::EventSourceMapping", "Type": "AWS::Lambda::EventSourceMapping",
"Properties" : { "Properties": {
"FunctionName": "v1.2.3", "FunctionName": {
"Ref": "func1"
},
"EventSourceArn": "arn:aws:dynamodb:region:XXXXXX:table/tab1/stream/2000T00:00:00.000", "EventSourceArn": "arn:aws:dynamodb:region:XXXXXX:table/tab1/stream/2000T00:00:00.000",
"StartingPosition": "0", "StartingPosition": "0",
"BatchSize": 100, "BatchSize": 100,

Some files were not shown because too many files have changed in this diff Show More