Merge remote-tracking branch 'upstream/master'

This commit is contained in:
Stephan Huber 2019-08-30 14:39:29 +02:00
commit 5a9c921d97
132 changed files with 28345 additions and 4011 deletions

1
.gitignore vendored
View File

@ -15,6 +15,7 @@ python_env
.ropeproject/
.pytest_cache/
venv/
env/
.python-version
.vscode/
tests/file.tmp

View File

@ -2,36 +2,56 @@ dist: xenial
language: python
sudo: false
services:
- docker
- docker
python:
- 2.7
- 3.6
- 3.7
- 2.7
- 3.6
- 3.7
env:
- TEST_SERVER_MODE=false
- TEST_SERVER_MODE=true
- TEST_SERVER_MODE=false
- TEST_SERVER_MODE=true
before_install:
- export BOTO_CONFIG=/dev/null
- export BOTO_CONFIG=/dev/null
install:
# We build moto first so the docker container doesn't try to compile it as well, also note we don't use
# -d for docker run so the logs show up in travis
# Python images come from here: https://hub.docker.com/_/python/
- |
python setup.py sdist
- |
python setup.py sdist
if [ "$TEST_SERVER_MODE" = "true" ]; then
docker run --rm -t --name motoserver -e TEST_SERVER_MODE=true -e AWS_SECRET_ACCESS_KEY=server_secret -e AWS_ACCESS_KEY_ID=server_key -v `pwd`:/moto -p 5000:5000 -v /var/run/docker.sock:/var/run/docker.sock python:${TRAVIS_PYTHON_VERSION}-stretch /moto/travis_moto_server.sh &
fi
travis_retry pip install boto==2.45.0
travis_retry pip install boto3
travis_retry pip install dist/moto*.gz
travis_retry pip install coveralls==1.1
travis_retry pip install -r requirements-dev.txt
if [ "$TEST_SERVER_MODE" = "true" ]; then
docker run --rm -t --name motoserver -e TEST_SERVER_MODE=true -e AWS_SECRET_ACCESS_KEY=server_secret -e AWS_ACCESS_KEY_ID=server_key -v `pwd`:/moto -p 5000:5000 -v /var/run/docker.sock:/var/run/docker.sock python:${TRAVIS_PYTHON_VERSION}-stretch /moto/travis_moto_server.sh &
fi
travis_retry pip install boto==2.45.0
travis_retry pip install boto3
travis_retry pip install dist/moto*.gz
travis_retry pip install coveralls==1.1
travis_retry pip install -r requirements-dev.txt
if [ "$TEST_SERVER_MODE" = "true" ]; then
python wait_for.py
fi
if [ "$TEST_SERVER_MODE" = "true" ]; then
python wait_for.py
fi
script:
- make test
- make test
after_success:
- coveralls
- coveralls
before_deploy:
- git checkout $TRAVIS_BRANCH
- git fetch --unshallow
- python update_version_from_git.py
deploy:
- provider: pypi
distributions: sdist bdist_wheel
user: spulec
password:
secure: NxnPylnTfekJmGyoufCw0lMoYRskSMJzvAIyAlJJVYKwEhmiCPOrdy5qV8i8mRZ1AkUsqU3jBZ/PD56n96clHW0E3d080UleRDj6JpyALVdeLfMqZl9kLmZ8bqakWzYq3VSJKw2zGP/L4tPGf8wTK1SUv9yl/YNDsBdCkjDverw=
on:
branch:
- master
skip_cleanup: true
skip_existing: true
# - provider: pypi
# distributions: sdist bdist_wheel
# user: spulec
# password:
# secure: NxnPylnTfekJmGyoufCw0lMoYRskSMJzvAIyAlJJVYKwEhmiCPOrdy5qV8i8mRZ1AkUsqU3jBZ/PD56n96clHW0E3d080UleRDj6JpyALVdeLfMqZl9kLmZ8bqakWzYq3VSJKw2zGP/L4tPGf8wTK1SUv9yl/YNDsBdCkjDverw=
# on:
# tags: true
# skip_existing: true

View File

@ -54,5 +54,6 @@ Moto is written by Steve Pulec with contributions from:
* [William Richard](https://github.com/william-richard)
* [Alex Casalboni](https://github.com/alexcasalboni)
* [Jon Beilke](https://github.com/jrbeilke)
* [Bendeguz Acs](https://github.com/acsbendi)
* [Craig Anderson](https://github.com/craiga)
* [Robert Lewis](https://github.com/ralewis85)

View File

@ -2,6 +2,10 @@
Moto has a [Code of Conduct](https://github.com/spulec/moto/blob/master/CODE_OF_CONDUCT.md), you can expect to be treated with respect at all times when interacting with this project.
## Running the tests locally
Moto has a Makefile which has some helpful commands for getting setup. You should be able to run `make init` to install the dependencies and then `make test` to run the tests.
## Is there a missing feature?
Moto is easier to contribute to than you probably think. There's [a list of which endpoints have been implemented](https://github.com/spulec/moto/blob/master/IMPLEMENTATION_COVERAGE.md) and we invite you to add new endpoints to existing services or to add new services.

File diff suppressed because it is too large Load Diff

View File

@ -10,7 +10,7 @@ endif
init:
@python setup.py develop
@pip install -r requirements.txt
@pip install -r requirements-dev.txt
lint:
flake8 moto

321
README.md
View File

@ -5,6 +5,9 @@
[![Build Status](https://travis-ci.org/spulec/moto.svg?branch=master)](https://travis-ci.org/spulec/moto)
[![Coverage Status](https://coveralls.io/repos/spulec/moto/badge.svg?branch=master)](https://coveralls.io/r/spulec/moto)
[![Docs](https://readthedocs.org/projects/pip/badge/?version=stable)](http://docs.getmoto.org)
![PyPI](https://img.shields.io/pypi/v/moto.svg)
![PyPI - Python Version](https://img.shields.io/pypi/pyversions/moto.svg)
![PyPI - Downloads](https://img.shields.io/pypi/dw/moto.svg)
# In a nutshell
@ -55,95 +58,96 @@ With the decorator wrapping the test, all the calls to s3 are automatically mock
It gets even better! Moto isn't just for Python code and it isn't just for S3. Look at the [standalone server mode](https://github.com/spulec/moto#stand-alone-server-mode) for more information about running Moto with other languages. Here's the status of the other AWS services implemented:
```gherkin
|------------------------------------------------------------------------------|
| Service Name | Decorator | Development Status |
|------------------------------------------------------------------------------|
| ACM | @mock_acm | all endpoints done |
|------------------------------------------------------------------------------|
| API Gateway | @mock_apigateway | core endpoints done |
|------------------------------------------------------------------------------|
| Autoscaling | @mock_autoscaling| core endpoints done |
|------------------------------------------------------------------------------|
| Cloudformation | @mock_cloudformation| core endpoints done |
|------------------------------------------------------------------------------|
| Cloudwatch | @mock_cloudwatch | basic endpoints done |
|------------------------------------------------------------------------------|
| CloudwatchEvents | @mock_events | all endpoints done |
|------------------------------------------------------------------------------|
| Cognito Identity | @mock_cognitoidentity| basic endpoints done |
|------------------------------------------------------------------------------|
| Cognito Identity Provider | @mock_cognitoidp| basic endpoints done |
|------------------------------------------------------------------------------|
| Config | @mock_config | basic endpoints done |
|------------------------------------------------------------------------------|
| Data Pipeline | @mock_datapipeline| basic endpoints done |
|------------------------------------------------------------------------------|
| DynamoDB | @mock_dynamodb | core endpoints done |
| DynamoDB2 | @mock_dynamodb2 | all endpoints + partial indexes |
|------------------------------------------------------------------------------|
| EC2 | @mock_ec2 | core endpoints done |
| - AMI | | core endpoints done |
| - EBS | | core endpoints done |
| - Instances | | all endpoints done |
| - Security Groups | | core endpoints done |
| - Tags | | all endpoints done |
|------------------------------------------------------------------------------|
| ECR | @mock_ecr | basic endpoints done |
|------------------------------------------------------------------------------|
| ECS | @mock_ecs | basic endpoints done |
|------------------------------------------------------------------------------|
| ELB | @mock_elb | core endpoints done |
|------------------------------------------------------------------------------|
| ELBv2 | @mock_elbv2 | all endpoints done |
|------------------------------------------------------------------------------|
| EMR | @mock_emr | core endpoints done |
|------------------------------------------------------------------------------|
| Glacier | @mock_glacier | core endpoints done |
|------------------------------------------------------------------------------|
| IAM | @mock_iam | core endpoints done |
|------------------------------------------------------------------------------|
| IoT | @mock_iot | core endpoints done |
| | @mock_iotdata | core endpoints done |
|------------------------------------------------------------------------------|
| Lambda | @mock_lambda | basic endpoints done, requires |
| | | docker |
|------------------------------------------------------------------------------|
| Logs | @mock_logs | basic endpoints done |
|------------------------------------------------------------------------------|
| Kinesis | @mock_kinesis | core endpoints done |
|------------------------------------------------------------------------------|
| KMS | @mock_kms | basic endpoints done |
|------------------------------------------------------------------------------|
| Organizations | @mock_organizations | some core endpoints done |
|------------------------------------------------------------------------------|
| Polly | @mock_polly | all endpoints done |
|------------------------------------------------------------------------------|
| RDS | @mock_rds | core endpoints done |
|------------------------------------------------------------------------------|
| RDS2 | @mock_rds2 | core endpoints done |
|------------------------------------------------------------------------------|
| Redshift | @mock_redshift | core endpoints done |
|------------------------------------------------------------------------------|
| Route53 | @mock_route53 | core endpoints done |
|------------------------------------------------------------------------------|
| S3 | @mock_s3 | core endpoints done |
|------------------------------------------------------------------------------|
| SecretsManager | @mock_secretsmanager | basic endpoints done
|------------------------------------------------------------------------------|
| SES | @mock_ses | all endpoints done |
|------------------------------------------------------------------------------|
| SNS | @mock_sns | all endpoints done |
|------------------------------------------------------------------------------|
| SQS | @mock_sqs | core endpoints done |
|------------------------------------------------------------------------------|
| SSM | @mock_ssm | core endpoints done |
|------------------------------------------------------------------------------|
| STS | @mock_sts | core endpoints done |
|------------------------------------------------------------------------------|
| SWF | @mock_swf | basic endpoints done |
|------------------------------------------------------------------------------|
| X-Ray | @mock_xray | all endpoints done |
|------------------------------------------------------------------------------|
|-------------------------------------------------------------------------------------|
| Service Name | Decorator | Development Status |
|-------------------------------------------------------------------------------------|
| ACM | @mock_acm | all endpoints done |
|-------------------------------------------------------------------------------------|
| API Gateway | @mock_apigateway | core endpoints done |
|-------------------------------------------------------------------------------------|
| Autoscaling | @mock_autoscaling | core endpoints done |
|-------------------------------------------------------------------------------------|
| Cloudformation | @mock_cloudformation | core endpoints done |
|-------------------------------------------------------------------------------------|
| Cloudwatch | @mock_cloudwatch | basic endpoints done |
|-------------------------------------------------------------------------------------|
| CloudwatchEvents | @mock_events | all endpoints done |
|-------------------------------------------------------------------------------------|
| Cognito Identity | @mock_cognitoidentity | basic endpoints done |
|-------------------------------------------------------------------------------------|
| Cognito Identity Provider | @mock_cognitoidp | basic endpoints done |
|-------------------------------------------------------------------------------------|
| Config | @mock_config | basic endpoints done |
| | | core endpoints done |
|-------------------------------------------------------------------------------------|
| Data Pipeline | @mock_datapipeline | basic endpoints done |
|-------------------------------------------------------------------------------------|
| DynamoDB | @mock_dynamodb | core endpoints done |
| DynamoDB2 | @mock_dynamodb2 | all endpoints + partial indexes |
|-------------------------------------------------------------------------------------|
| EC2 | @mock_ec2 | core endpoints done |
| - AMI | | core endpoints done |
| - EBS | | core endpoints done |
| - Instances | | all endpoints done |
| - Security Groups | | core endpoints done |
| - Tags | | all endpoints done |
|-------------------------------------------------------------------------------------|
| ECR | @mock_ecr | basic endpoints done |
|-------------------------------------------------------------------------------------|
| ECS | @mock_ecs | basic endpoints done |
|-------------------------------------------------------------------------------------|
| ELB | @mock_elb | core endpoints done |
|-------------------------------------------------------------------------------------|
| ELBv2 | @mock_elbv2 | all endpoints done |
|-------------------------------------------------------------------------------------|
| EMR | @mock_emr | core endpoints done |
|-------------------------------------------------------------------------------------|
| Glacier | @mock_glacier | core endpoints done |
|-------------------------------------------------------------------------------------|
| IAM | @mock_iam | core endpoints done |
|-------------------------------------------------------------------------------------|
| IoT | @mock_iot | core endpoints done |
| | @mock_iotdata | core endpoints done |
|-------------------------------------------------------------------------------------|
| Kinesis | @mock_kinesis | core endpoints done |
|-------------------------------------------------------------------------------------|
| KMS | @mock_kms | basic endpoints done |
|-------------------------------------------------------------------------------------|
| Lambda | @mock_lambda | basic endpoints done, requires |
| | | docker |
|-------------------------------------------------------------------------------------|
| Logs | @mock_logs | basic endpoints done |
|-------------------------------------------------------------------------------------|
| Organizations | @mock_organizations | some core endpoints done |
|-------------------------------------------------------------------------------------|
| Polly | @mock_polly | all endpoints done |
|-------------------------------------------------------------------------------------|
| RDS | @mock_rds | core endpoints done |
|-------------------------------------------------------------------------------------|
| RDS2 | @mock_rds2 | core endpoints done |
|-------------------------------------------------------------------------------------|
| Redshift | @mock_redshift | core endpoints done |
|-------------------------------------------------------------------------------------|
| Route53 | @mock_route53 | core endpoints done |
|-------------------------------------------------------------------------------------|
| S3 | @mock_s3 | core endpoints done |
|-------------------------------------------------------------------------------------|
| SecretsManager | @mock_secretsmanager | basic endpoints done |
|-------------------------------------------------------------------------------------|
| SES | @mock_ses | all endpoints done |
|-------------------------------------------------------------------------------------|
| SNS | @mock_sns | all endpoints done |
|-------------------------------------------------------------------------------------|
| SQS | @mock_sqs | core endpoints done |
|-------------------------------------------------------------------------------------|
| SSM | @mock_ssm | core endpoints done |
|-------------------------------------------------------------------------------------|
| STS | @mock_sts | core endpoints done |
|-------------------------------------------------------------------------------------|
| SWF | @mock_swf | basic endpoints done |
|-------------------------------------------------------------------------------------|
| X-Ray | @mock_xray | all endpoints done |
|-------------------------------------------------------------------------------------|
```
For a full list of endpoint [implementation coverage](https://github.com/spulec/moto/blob/master/IMPLEMENTATION_COVERAGE.md)
@ -252,6 +256,137 @@ def test_my_model_save():
mock.stop()
```
## IAM-like Access Control
Moto also has the ability to authenticate and authorize actions, just like it's done by IAM in AWS. This functionality can be enabled by either setting the `INITIAL_NO_AUTH_ACTION_COUNT` environment variable or using the `set_initial_no_auth_action_count` decorator. Note that the current implementation is very basic, see [this file](https://github.com/spulec/moto/blob/master/moto/core/access_control.py) for more information.
### `INITIAL_NO_AUTH_ACTION_COUNT`
If this environment variable is set, moto will skip performing any authentication as many times as the variable's value, and only starts authenticating requests afterwards. If it is not set, it defaults to infinity, thus moto will never perform any authentication at all.
### `set_initial_no_auth_action_count`
This is a decorator that works similarly to the environment variable, but the settings are only valid in the function's scope. When the function returns, everything is restored.
```python
@set_initial_no_auth_action_count(4)
@mock_ec2
def test_describe_instances_allowed():
policy_document = {
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": "ec2:Describe*",
"Resource": "*"
}
]
}
access_key = ...
# create access key for an IAM user/assumed role that has the policy above.
# this part should call __exactly__ 4 AWS actions, so that authentication and authorization starts exactly after this
client = boto3.client('ec2', region_name='us-east-1',
aws_access_key_id=access_key['AccessKeyId'],
aws_secret_access_key=access_key['SecretAccessKey'])
# if the IAM principal whose access key is used, does not have the permission to describe instances, this will fail
instances = client.describe_instances()['Reservations'][0]['Instances']
assert len(instances) == 0
```
See [the related test suite](https://github.com/spulec/moto/blob/master/tests/test_core/test_auth.py) for more examples.
## Very Important -- Recommended Usage
There are some important caveats to be aware of when using moto:
*Failure to follow these guidelines could result in your tests mutating your __REAL__ infrastructure!*
### How do I avoid tests from mutating my real infrastructure?
You need to ensure that the mocks are actually in place. Changes made to recent versions of `botocore`
have altered some of the mock behavior. In short, you need to ensure that you _always_ do the following:
1. Ensure that your tests have dummy environment variables set up:
export AWS_ACCESS_KEY_ID='testing'
export AWS_SECRET_ACCESS_KEY='testing'
export AWS_SECURITY_TOKEN='testing'
export AWS_SESSION_TOKEN='testing'
1. __VERY IMPORTANT__: ensure that you have your mocks set up __BEFORE__ your `boto3` client is established.
This can typically happen if you import a module that has a `boto3` client instantiated outside of a function.
See the pesky imports section below on how to work around this.
### Example on usage?
If you are a user of [pytest](https://pytest.org/en/latest/), you can leverage [pytest fixtures](https://pytest.org/en/latest/fixture.html#fixture)
to help set up your mocks and other AWS resources that you would need.
Here is an example:
```python
@pytest.fixture(scope='function')
def aws_credentials():
"""Mocked AWS Credentials for moto."""
os.environ['AWS_ACCESS_KEY_ID'] = 'testing'
os.environ['AWS_SECRET_ACCESS_KEY'] = 'testing'
os.environ['AWS_SECURITY_TOKEN'] = 'testing'
os.environ['AWS_SESSION_TOKEN'] = 'testing'
@pytest.fixture(scope='function')
def s3(aws_credentials):
with mock_s3():
yield boto3.client('s3', region_name='us-east-1')
@pytest.fixture(scope='function')
def sts(aws_credentials):
with mock_sts():
yield boto3.client('sts', region_name='us-east-1')
@pytest.fixture(scope='function')
def cloudwatch(aws_credentials):
with mock_cloudwatch():
yield boto3.client('cloudwatch', region_name='us-east-1')
... etc.
```
In the code sample above, all of the AWS/mocked fixtures take in a parameter of `aws_credentials`,
which sets the proper fake environment variables. The fake environment variables are used so that `botocore` doesn't try to locate real
credentials on your system.
Next, once you need to do anything with the mocked AWS environment, do something like:
```python
def test_create_bucket(s3):
# s3 is a fixture defined above that yields a boto3 s3 client.
# Feel free to instantiate another boto3 S3 client -- Keep note of the region though.
s3.create_bucket(Bucket="somebucket")
result = s3.list_buckets()
assert len(result['Buckets']) == 1
assert result['Buckets'][0]['Name'] == 'somebucket'
```
### What about those pesky imports?
Recall earlier, it was mentioned that mocks should be established __BEFORE__ the clients are set up. One way
to avoid import issues is to make use of local Python imports -- i.e. import the module inside of the unit
test you want to run vs. importing at the top of the file.
Example:
```python
def test_something(s3):
from some.package.that.does.something.with.s3 import some_func # <-- Local import for unit test
# ^^ Importing here ensures that the mock has been established.
sume_func() # The mock has been established from the "s3" pytest fixture, so this function that uses
# a package-level S3 client will properly use the mock and not reach out to AWS.
```
### Other caveats
For Tox, Travis CI, and other build systems, you might need to also perform a `touch ~/.aws/credentials`
command before running the tests. As long as that file is present (empty preferably) and the environment
variables above are set, you should be good to go.
## Stand-alone Server Mode
Moto also has a stand-alone server mode. This allows you to utilize
@ -318,3 +453,11 @@ boto3.resource(
```console
$ pip install moto
```
## Releases
Releases are done from travisci. Fairly closely following this:
https://docs.travis-ci.com/user/deployment/pypi/
- Commits to `master` branch do a dev deploy to pypi.
- Commits to a tag do a real deploy to pypi.

View File

@ -17,66 +17,95 @@ with ``moto`` and its usage.
Currently implemented Services:
-------------------------------
+-----------------------+---------------------+-----------------------------------+
| Service Name | Decorator | Development Status |
+=======================+=====================+===================================+
| API Gateway | @mock_apigateway | core endpoints done |
+-----------------------+---------------------+-----------------------------------+
| Autoscaling | @mock_autoscaling | core endpoints done |
+-----------------------+---------------------+-----------------------------------+
| Cloudformation | @mock_cloudformation| core endpoints done |
+-----------------------+---------------------+-----------------------------------+
| Cloudwatch | @mock_cloudwatch | basic endpoints done |
+-----------------------+---------------------+-----------------------------------+
| Data Pipeline | @mock_datapipeline | basic endpoints done |
+-----------------------+---------------------+-----------------------------------+
| - DynamoDB | - @mock_dynamodb | - core endpoints done |
| - DynamoDB2 | - @mock_dynamodb2 | - core endpoints + partial indexes|
+-----------------------+---------------------+-----------------------------------+
| EC2 | @mock_ec2 | core endpoints done |
| - AMI | | - core endpoints done |
| - EBS | | - core endpoints done |
| - Instances | | - all endpoints done |
| - Security Groups | | - core endpoints done |
| - Tags | | - all endpoints done |
+-----------------------+---------------------+-----------------------------------+
| ECS | @mock_ecs | basic endpoints done |
+-----------------------+---------------------+-----------------------------------+
| ELB | @mock_elb | core endpoints done |
| | @mock_elbv2 | core endpoints done |
+-----------------------+---------------------+-----------------------------------+
| EMR | @mock_emr | core endpoints done |
+-----------------------+---------------------+-----------------------------------+
| Glacier | @mock_glacier | core endpoints done |
+-----------------------+---------------------+-----------------------------------+
| IAM | @mock_iam | core endpoints done |
+-----------------------+---------------------+-----------------------------------+
| Lambda | @mock_lambda | basic endpoints done |
+-----------------------+---------------------+-----------------------------------+
| Kinesis | @mock_kinesis | core endpoints done |
+-----------------------+---------------------+-----------------------------------+
| KMS | @mock_kms | basic endpoints done |
+-----------------------+---------------------+-----------------------------------+
| RDS | @mock_rds | core endpoints done |
+-----------------------+---------------------+-----------------------------------+
| RDS2 | @mock_rds2 | core endpoints done |
+-----------------------+---------------------+-----------------------------------+
| Redshift | @mock_redshift | core endpoints done |
+-----------------------+---------------------+-----------------------------------+
| Route53 | @mock_route53 | core endpoints done |
+-----------------------+---------------------+-----------------------------------+
| S3 | @mock_s3 | core endpoints done |
+-----------------------+---------------------+-----------------------------------+
| SES | @mock_ses | core endpoints done |
+-----------------------+---------------------+-----------------------------------+
| SNS | @mock_sns | core endpoints done |
+-----------------------+---------------------+-----------------------------------+
| SQS | @mock_sqs | core endpoints done |
+-----------------------+---------------------+-----------------------------------+
| STS | @mock_sts | core endpoints done |
+-----------------------+---------------------+-----------------------------------+
| SWF | @mock_swf | basic endpoints done |
+-----------------------+---------------------+-----------------------------------+
+---------------------------+-----------------------+------------------------------------+
| Service Name | Decorator | Development Status |
+===========================+=======================+====================================+
| ACM | @mock_acm | all endpoints done |
+---------------------------+-----------------------+------------------------------------+
| API Gateway | @mock_apigateway | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| Autoscaling | @mock_autoscaling | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| Cloudformation | @mock_cloudformation | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| Cloudwatch | @mock_cloudwatch | basic endpoints done |
+---------------------------+-----------------------+------------------------------------+
| CloudwatchEvents | @mock_events | all endpoints done |
+---------------------------+-----------------------+------------------------------------+
| Cognito Identity | @mock_cognitoidentity | all endpoints done |
+---------------------------+-----------------------+------------------------------------+
| Cognito Identity Provider | @mock_cognitoidp | all endpoints done |
+---------------------------+-----------------------+------------------------------------+
| Config | @mock_config | basic endpoints done |
+---------------------------+-----------------------+------------------------------------+
| Data Pipeline | @mock_datapipeline | basic endpoints done |
+---------------------------+-----------------------+------------------------------------+
| DynamoDB | - @mock_dynamodb | - core endpoints done |
| DynamoDB2 | - @mock_dynamodb2 | - core endpoints + partial indexes |
+---------------------------+-----------------------+------------------------------------+
| EC2 | @mock_ec2 | core endpoints done |
| - AMI | | - core endpoints done |
| - EBS | | - core endpoints done |
| - Instances | | - all endpoints done |
| - Security Groups | | - core endpoints done |
| - Tags | | - all endpoints done |
+---------------------------+-----------------------+------------------------------------+
| ECR | @mock_ecr | basic endpoints done |
+---------------------------+-----------------------+------------------------------------+
| ECS | @mock_ecs | basic endpoints done |
+---------------------------+-----------------------+------------------------------------+
| ELB | @mock_elb | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| ELBv2 | @mock_elbv2 | all endpoints done |
+---------------------------+-----------------------+------------------------------------+
| EMR | @mock_emr | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| Glacier | @mock_glacier | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| IAM | @mock_iam | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| IoT | @mock_iot | core endpoints done |
| | @mock_iotdata | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| Kinesis | @mock_kinesis | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| KMS | @mock_kms | basic endpoints done |
+---------------------------+-----------------------+------------------------------------+
| Lambda | @mock_lambda | basic endpoints done, |
| | | requires docker |
+---------------------------+-----------------------+------------------------------------+
| Logs | @mock_logs | basic endpoints done |
+---------------------------+-----------------------+------------------------------------+
| Organizations | @mock_organizations | some core edpoints done |
+---------------------------+-----------------------+------------------------------------+
| Polly | @mock_polly | all endpoints done |
+---------------------------+-----------------------+------------------------------------+
| RDS | @mock_rds | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| RDS2 | @mock_rds2 | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| Redshift | @mock_redshift | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| Route53 | @mock_route53 | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| S3 | @mock_s3 | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| SecretsManager | @mock_secretsmanager | basic endpoints done |
+---------------------------+-----------------------+------------------------------------+
| SES | @mock_ses | all endpoints done |
+---------------------------+-----------------------+------------------------------------+
| SNS | @mock_sns | all endpoints done |
+---------------------------+-----------------------+------------------------------------+
| SQS | @mock_sqs | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| SSM | @mock_ssm | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| STS | @mock_sts | core endpoints done |
+---------------------------+-----------------------+------------------------------------+
| SWF | @mock_swf | basic endpoints done |
+---------------------------+-----------------------+------------------------------------+
| X-Ray | @mock_xray | all endpoints done |
+---------------------------+-----------------------+------------------------------------+

View File

@ -3,7 +3,7 @@ import logging
# logging.getLogger('boto').setLevel(logging.CRITICAL)
__title__ = 'moto'
__version__ = '1.3.8'
__version__ = '1.3.14.dev'
from .acm import mock_acm # flake8: noqa
from .apigateway import mock_apigateway, mock_apigateway_deprecated # flake8: noqa

View File

@ -105,7 +105,7 @@ class CertBundle(BaseModel):
self.arn = arn
@classmethod
def generate_cert(cls, domain_name, sans=None):
def generate_cert(cls, domain_name, region, sans=None):
if sans is None:
sans = set()
else:
@ -152,7 +152,7 @@ class CertBundle(BaseModel):
encryption_algorithm=serialization.NoEncryption()
)
return cls(cert_armored, private_key, cert_type='AMAZON_ISSUED', cert_status='PENDING_VALIDATION')
return cls(cert_armored, private_key, cert_type='AMAZON_ISSUED', cert_status='PENDING_VALIDATION', region=region)
def validate_pk(self):
try:
@ -325,7 +325,7 @@ class AWSCertificateManagerBackend(BaseBackend):
return bundle.arn
def get_certificates_list(self):
def get_certificates_list(self, statuses):
"""
Get list of certificates
@ -333,7 +333,9 @@ class AWSCertificateManagerBackend(BaseBackend):
:rtype: list of CertBundle
"""
for arn in self._certificates.keys():
yield self.get_certificate(arn)
cert = self.get_certificate(arn)
if not statuses or cert.status in statuses:
yield cert
def get_certificate(self, arn):
if arn not in self._certificates:
@ -355,7 +357,7 @@ class AWSCertificateManagerBackend(BaseBackend):
if arn is not None:
return arn
cert = CertBundle.generate_cert(domain_name, subject_alt_names)
cert = CertBundle.generate_cert(domain_name, region=self.region, sans=subject_alt_names)
if idempotency_token is not None:
self._set_idempotency_token_arn(idempotency_token, cert.arn)
self._certificates[cert.arn] = cert

View File

@ -132,8 +132,8 @@ class AWSCertificateManagerResponse(BaseResponse):
def list_certificates(self):
certs = []
for cert_bundle in self.acm_backend.get_certificates_list():
statuses = self._get_param('CertificateStatuses')
for cert_bundle in self.acm_backend.get_certificates_list(statuses):
certs.append({
'CertificateArn': cert_bundle.arn,
'DomainName': cert_bundle.common_name

View File

@ -309,6 +309,25 @@ class ApiKey(BaseModel, dict):
self['createdDate'] = self['lastUpdatedDate'] = int(time.time())
self['stageKeys'] = stageKeys
def update_operations(self, patch_operations):
for op in patch_operations:
if op['op'] == 'replace':
if '/name' in op['path']:
self['name'] = op['value']
elif '/customerId' in op['path']:
self['customerId'] = op['value']
elif '/description' in op['path']:
self['description'] = op['value']
elif '/enabled' in op['path']:
self['enabled'] = self._str2bool(op['value'])
else:
raise Exception(
'Patch operation "%s" not implemented' % op['op'])
return self
def _str2bool(self, v):
return v.lower() == "true"
class UsagePlan(BaseModel, dict):
@ -599,6 +618,10 @@ class APIGatewayBackend(BaseBackend):
def get_apikey(self, api_key_id):
return self.keys[api_key_id]
def update_apikey(self, api_key_id, patch_operations):
key = self.keys[api_key_id]
return key.update_operations(patch_operations)
def delete_apikey(self, api_key_id):
self.keys.pop(api_key_id)
return {}

View File

@ -245,6 +245,9 @@ class APIGatewayResponse(BaseResponse):
if self.method == 'GET':
apikey_response = self.backend.get_apikey(apikey)
elif self.method == 'PATCH':
patch_operations = self._get_param('patchOperations')
apikey_response = self.backend.update_apikey(apikey, patch_operations)
elif self.method == 'DELETE':
apikey_response = self.backend.delete_apikey(apikey)
return 200, {}, json.dumps(apikey_response)

View File

@ -1,9 +1,10 @@
from __future__ import unicode_literals
import six
import random
import string
def create_id():
size = 10
chars = list(range(10)) + ['A-Z']
chars = list(range(10)) + list(string.ascii_lowercase)
return ''.join(six.text_type(random.choice(chars)) for x in range(size))

View File

@ -13,3 +13,12 @@ class ResourceContentionError(RESTError):
super(ResourceContentionError, self).__init__(
"ResourceContentionError",
"You already have a pending update to an Auto Scaling resource (for example, a group, instance, or load balancer).")
class InvalidInstanceError(AutoscalingClientError):
def __init__(self, instance_id):
super(InvalidInstanceError, self).__init__(
"ValidationError",
"Instance [{0}] is invalid."
.format(instance_id))

View File

@ -3,6 +3,8 @@ from __future__ import unicode_literals
import random
from boto.ec2.blockdevicemapping import BlockDeviceType, BlockDeviceMapping
from moto.ec2.exceptions import InvalidInstanceIdError
from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel
from moto.ec2 import ec2_backends
@ -10,7 +12,7 @@ from moto.elb import elb_backends
from moto.elbv2 import elbv2_backends
from moto.elb.exceptions import LoadBalancerNotFoundError
from .exceptions import (
AutoscalingClientError, ResourceContentionError,
AutoscalingClientError, ResourceContentionError, InvalidInstanceError
)
# http://docs.aws.amazon.com/AutoScaling/latest/DeveloperGuide/AS_Concepts.html#Cooldown
@ -73,6 +75,26 @@ class FakeLaunchConfiguration(BaseModel):
self.associate_public_ip_address = associate_public_ip_address
self.block_device_mapping_dict = block_device_mapping_dict
@classmethod
def create_from_instance(cls, name, instance, backend):
config = backend.create_launch_configuration(
name=name,
image_id=instance.image_id,
kernel_id='',
ramdisk_id='',
key_name=instance.key_name,
security_groups=instance.security_groups,
user_data=instance.user_data,
instance_type=instance.instance_type,
instance_monitoring=False,
instance_profile_name=None,
spot_price=None,
ebs_optimized=instance.ebs_optimized,
associate_public_ip_address=instance.associate_public_ip,
block_device_mappings=instance.block_device_mapping
)
return config
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
properties = cloudformation_json['Properties']
@ -279,6 +301,12 @@ class FakeAutoScalingGroup(BaseModel):
if min_size is not None:
self.min_size = min_size
if desired_capacity is None:
if min_size is not None and min_size > len(self.instance_states):
desired_capacity = min_size
if max_size is not None and max_size < len(self.instance_states):
desired_capacity = max_size
if launch_config_name:
self.launch_config = self.autoscaling_backend.launch_configurations[
launch_config_name]
@ -414,7 +442,8 @@ class AutoScalingBackend(BaseBackend):
health_check_type, load_balancers,
target_group_arns, placement_group,
termination_policies, tags,
new_instances_protected_from_scale_in=False):
new_instances_protected_from_scale_in=False,
instance_id=None):
def make_int(value):
return int(value) if value is not None else value
@ -427,6 +456,13 @@ class AutoScalingBackend(BaseBackend):
health_check_period = 300
else:
health_check_period = make_int(health_check_period)
if launch_config_name is None and instance_id is not None:
try:
instance = self.ec2_backend.get_instance(instance_id)
launch_config_name = name
FakeLaunchConfiguration.create_from_instance(launch_config_name, instance, self)
except InvalidInstanceIdError:
raise InvalidInstanceError(instance_id)
group = FakeAutoScalingGroup(
name=name,
@ -684,6 +720,18 @@ class AutoScalingBackend(BaseBackend):
for instance in protected_instances:
instance.protected_from_scale_in = protected_from_scale_in
def notify_terminate_instances(self, instance_ids):
for autoscaling_group_name, autoscaling_group in self.autoscaling_groups.items():
original_instance_count = len(autoscaling_group.instance_states)
autoscaling_group.instance_states = list(filter(
lambda i_state: i_state.instance.id not in instance_ids,
autoscaling_group.instance_states
))
difference = original_instance_count - len(autoscaling_group.instance_states)
if difference > 0:
autoscaling_group.replace_autoscaling_group_instances(difference, autoscaling_group.get_propagated_tags())
self.update_attached_elbs(autoscaling_group_name)
autoscaling_backends = {}
for region, ec2_backend in ec2_backends.items():

View File

@ -48,7 +48,7 @@ class AutoScalingResponse(BaseResponse):
start = all_names.index(marker) + 1
else:
start = 0
max_records = self._get_param('MaxRecords', 50) # the default is 100, but using 50 to make testing easier
max_records = self._get_int_param('MaxRecords', 50) # the default is 100, but using 50 to make testing easier
launch_configurations_resp = all_launch_configurations[start:start + max_records]
next_token = None
if len(all_launch_configurations) > start + max_records:
@ -74,6 +74,7 @@ class AutoScalingResponse(BaseResponse):
desired_capacity=self._get_int_param('DesiredCapacity'),
max_size=self._get_int_param('MaxSize'),
min_size=self._get_int_param('MinSize'),
instance_id=self._get_param('InstanceId'),
launch_config_name=self._get_param('LaunchConfigurationName'),
vpc_zone_identifier=self._get_param('VPCZoneIdentifier'),
default_cooldown=self._get_int_param('DefaultCooldown'),

View File

@ -1,6 +1,7 @@
from __future__ import unicode_literals
import base64
import time
from collections import defaultdict
import copy
import datetime
@ -31,6 +32,7 @@ from moto.logs.models import logs_backends
from moto.s3.exceptions import MissingBucket, MissingKey
from moto import settings
from .utils import make_function_arn, make_function_ver_arn
from moto.sqs import sqs_backends
logger = logging.getLogger(__name__)
@ -231,6 +233,10 @@ class LambdaFunction(BaseModel):
config.update({"VpcId": "vpc-123abc"})
return config
@property
def physical_resource_id(self):
return self.function_name
def __repr__(self):
return json.dumps(self.get_configuration())
@ -425,24 +431,59 @@ class LambdaFunction(BaseModel):
class EventSourceMapping(BaseModel):
def __init__(self, spec):
# required
self.function_name = spec['FunctionName']
self.function_arn = spec['FunctionArn']
self.event_source_arn = spec['EventSourceArn']
self.starting_position = spec['StartingPosition']
self.uuid = str(uuid.uuid4())
self.last_modified = time.mktime(datetime.datetime.utcnow().timetuple())
# BatchSize service default/max mapping
batch_size_map = {
'kinesis': (100, 10000),
'dynamodb': (100, 1000),
'sqs': (10, 10),
}
source_type = self.event_source_arn.split(":")[2].lower()
batch_size_entry = batch_size_map.get(source_type)
if batch_size_entry:
# Use service default if not provided
batch_size = int(spec.get('BatchSize', batch_size_entry[0]))
if batch_size > batch_size_entry[1]:
raise ValueError("InvalidParameterValueException",
"BatchSize {} exceeds the max of {}".format(batch_size, batch_size_entry[1]))
else:
self.batch_size = batch_size
else:
raise ValueError("InvalidParameterValueException",
"Unsupported event source type")
# optional
self.batch_size = spec.get('BatchSize', 100)
self.starting_position = spec.get('StartingPosition', 'TRIM_HORIZON')
self.enabled = spec.get('Enabled', True)
self.starting_position_timestamp = spec.get('StartingPositionTimestamp',
None)
def get_configuration(self):
return {
'UUID': self.uuid,
'BatchSize': self.batch_size,
'EventSourceArn': self.event_source_arn,
'FunctionArn': self.function_arn,
'LastModified': self.last_modified,
'LastProcessingResult': '',
'State': 'Enabled' if self.enabled else 'Disabled',
'StateTransitionReason': 'User initiated'
}
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json,
region_name):
properties = cloudformation_json['Properties']
func = lambda_backends[region_name].get_function(properties['FunctionName'])
spec = {
'FunctionName': properties['FunctionName'],
'FunctionArn': func.function_arn,
'EventSourceArn': properties['EventSourceArn'],
'StartingPosition': properties['StartingPosition']
'StartingPosition': properties['StartingPosition'],
'BatchSize': properties.get('BatchSize', 100)
}
optional_properties = 'BatchSize Enabled StartingPositionTimestamp'.split()
for prop in optional_properties:
@ -462,8 +503,10 @@ class LambdaVersion(BaseModel):
def create_from_cloudformation_json(cls, resource_name, cloudformation_json,
region_name):
properties = cloudformation_json['Properties']
function_name = properties['FunctionName']
func = lambda_backends[region_name].publish_function(function_name)
spec = {
'Version': properties.get('Version')
'Version': func.version
}
return LambdaVersion(spec)
@ -511,6 +554,9 @@ class LambdaStorage(object):
def get_arn(self, arn):
return self._arns.get(arn, None)
def get_function_by_name_or_arn(self, input):
return self.get_function(input) or self.get_arn(input)
def put_function(self, fn):
"""
:param fn: Function
@ -592,6 +638,7 @@ class LambdaStorage(object):
class LambdaBackend(BaseBackend):
def __init__(self, region_name):
self._lambdas = LambdaStorage()
self._event_source_mappings = {}
self.region_name = region_name
def reset(self):
@ -613,6 +660,40 @@ class LambdaBackend(BaseBackend):
fn.version = ver.version
return fn
def create_event_source_mapping(self, spec):
required = [
'EventSourceArn',
'FunctionName',
]
for param in required:
if not spec.get(param):
raise RESTError('InvalidParameterValueException', 'Missing {}'.format(param))
# Validate function name
func = self._lambdas.get_function_by_name_or_arn(spec.pop('FunctionName', ''))
if not func:
raise RESTError('ResourceNotFoundException', 'Invalid FunctionName')
# Validate queue
for queue in sqs_backends[self.region_name].queues.values():
if queue.queue_arn == spec['EventSourceArn']:
if queue.lambda_event_source_mappings.get('func.function_arn'):
# TODO: Correct exception?
raise RESTError('ResourceConflictException', 'The resource already exists.')
if queue.fifo_queue:
raise RESTError('InvalidParameterValueException',
'{} is FIFO'.format(queue.queue_arn))
else:
spec.update({'FunctionArn': func.function_arn})
esm = EventSourceMapping(spec)
self._event_source_mappings[esm.uuid] = esm
# Set backend function on queue
queue.lambda_event_source_mappings[esm.function_arn] = esm
return esm
raise RESTError('ResourceNotFoundException', 'Invalid EventSourceArn')
def publish_function(self, function_name):
return self._lambdas.publish_function(function_name)
@ -622,6 +703,33 @@ class LambdaBackend(BaseBackend):
def list_versions_by_function(self, function_name):
return self._lambdas.list_versions_by_function(function_name)
def get_event_source_mapping(self, uuid):
return self._event_source_mappings.get(uuid)
def delete_event_source_mapping(self, uuid):
return self._event_source_mappings.pop(uuid)
def update_event_source_mapping(self, uuid, spec):
esm = self.get_event_source_mapping(uuid)
if esm:
if spec.get('FunctionName'):
func = self._lambdas.get_function_by_name_or_arn(spec.get('FunctionName'))
esm.function_arn = func.function_arn
if 'BatchSize' in spec:
esm.batch_size = spec['BatchSize']
if 'Enabled' in spec:
esm.enabled = spec['Enabled']
return esm
return False
def list_event_source_mappings(self, event_source_arn, function_name):
esms = list(self._event_source_mappings.values())
if event_source_arn:
esms = list(filter(lambda x: x.event_source_arn == event_source_arn, esms))
if function_name:
esms = list(filter(lambda x: x.function_name == function_name, esms))
return esms
def get_function_by_arn(self, function_arn):
return self._lambdas.get_arn(function_arn)
@ -631,7 +739,43 @@ class LambdaBackend(BaseBackend):
def list_functions(self):
return self._lambdas.all()
def send_message(self, function_name, message, subject=None, qualifier=None):
def send_sqs_batch(self, function_arn, messages, queue_arn):
success = True
for message in messages:
func = self.get_function_by_arn(function_arn)
result = self._send_sqs_message(func, message, queue_arn)
if not result:
success = False
return success
def _send_sqs_message(self, func, message, queue_arn):
event = {
"Records": [
{
"messageId": message.id,
"receiptHandle": message.receipt_handle,
"body": message.body,
"attributes": {
"ApproximateReceiveCount": "1",
"SentTimestamp": "1545082649183",
"SenderId": "AIDAIENQZJOLO23YVJ4VO",
"ApproximateFirstReceiveTimestamp": "1545082649185"
},
"messageAttributes": {},
"md5OfBody": "098f6bcd4621d373cade4e832627b4f6",
"eventSource": "aws:sqs",
"eventSourceARN": queue_arn,
"awsRegion": self.region_name
}
]
}
request_headers = {}
response_headers = {}
func.invoke(json.dumps(event), request_headers, response_headers)
return 'x-amz-function-error' not in response_headers
def send_sns_message(self, function_name, message, subject=None, qualifier=None):
event = {
"Records": [
{

View File

@ -39,6 +39,31 @@ class LambdaResponse(BaseResponse):
else:
raise ValueError("Cannot handle request")
def event_source_mappings(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == 'GET':
querystring = self.querystring
event_source_arn = querystring.get('EventSourceArn', [None])[0]
function_name = querystring.get('FunctionName', [None])[0]
return self._list_event_source_mappings(event_source_arn, function_name)
elif request.method == 'POST':
return self._create_event_source_mapping(request, full_url, headers)
else:
raise ValueError("Cannot handle request")
def event_source_mapping(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
path = request.path if hasattr(request, 'path') else path_url(request.url)
uuid = path.split('/')[-1]
if request.method == 'GET':
return self._get_event_source_mapping(uuid)
elif request.method == 'PUT':
return self._update_event_source_mapping(uuid)
elif request.method == 'DELETE':
return self._delete_event_source_mapping(uuid)
else:
raise ValueError("Cannot handle request")
def function(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == 'GET':
@ -177,6 +202,45 @@ class LambdaResponse(BaseResponse):
config = fn.get_configuration()
return 201, {}, json.dumps(config)
def _create_event_source_mapping(self, request, full_url, headers):
try:
fn = self.lambda_backend.create_event_source_mapping(self.json_body)
except ValueError as e:
return 400, {}, json.dumps({"Error": {"Code": e.args[0], "Message": e.args[1]}})
else:
config = fn.get_configuration()
return 201, {}, json.dumps(config)
def _list_event_source_mappings(self, event_source_arn, function_name):
esms = self.lambda_backend.list_event_source_mappings(event_source_arn, function_name)
result = {
'EventSourceMappings': [esm.get_configuration() for esm in esms]
}
return 200, {}, json.dumps(result)
def _get_event_source_mapping(self, uuid):
result = self.lambda_backend.get_event_source_mapping(uuid)
if result:
return 200, {}, json.dumps(result.get_configuration())
else:
return 404, {}, "{}"
def _update_event_source_mapping(self, uuid):
result = self.lambda_backend.update_event_source_mapping(uuid, self.json_body)
if result:
return 202, {}, json.dumps(result.get_configuration())
else:
return 404, {}, "{}"
def _delete_event_source_mapping(self, uuid):
esm = self.lambda_backend.delete_event_source_mapping(uuid)
if esm:
json_result = esm.get_configuration()
json_result.update({'State': 'Deleting'})
return 202, {}, json.dumps(json_result)
else:
return 404, {}, "{}"
def _publish_function(self, request, full_url, headers):
function_name = self.path.rsplit('/', 2)[-2]

View File

@ -11,6 +11,8 @@ url_paths = {
'{0}/(?P<api_version>[^/]+)/functions/?$': response.root,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/?$': response.function,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/versions/?$': response.versions,
r'{0}/(?P<api_version>[^/]+)/event-source-mappings/?$': response.event_source_mappings,
r'{0}/(?P<api_version>[^/]+)/event-source-mappings/(?P<UUID>[\w_-]+)/?$': response.event_source_mapping,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invocations/?$': response.invoke,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invoke-async/?$': response.invoke_async,
r'{0}/(?P<api_version>[^/]+)/tags/(?P<resource_arn>.+)': response.tag,

View File

@ -514,10 +514,13 @@ class BatchBackend(BaseBackend):
return self._job_definitions.get(arn)
def get_job_definition_by_name(self, name):
for comp_env in self._job_definitions.values():
if comp_env.name == name:
return comp_env
return None
latest_revision = -1
latest_job = None
for job_def in self._job_definitions.values():
if job_def.name == name and job_def.revision > latest_revision:
latest_job = job_def
latest_revision = job_def.revision
return latest_job
def get_job_definition_by_name_revision(self, name, revision):
for job_def in self._job_definitions.values():
@ -534,10 +537,13 @@ class BatchBackend(BaseBackend):
:return: Job definition or None
:rtype: JobDefinition or None
"""
env = self.get_job_definition_by_arn(identifier)
if env is None:
env = self.get_job_definition_by_name(identifier)
return env
job_def = self.get_job_definition_by_arn(identifier)
if job_def is None:
if ':' in identifier:
job_def = self.get_job_definition_by_name_revision(*identifier.split(':', 1))
else:
job_def = self.get_job_definition_by_name(identifier)
return job_def
def get_job_definitions(self, identifier):
"""
@ -984,9 +990,7 @@ class BatchBackend(BaseBackend):
# TODO parameters, retries (which is a dict raw from request), job dependancies and container overrides are ignored for now
# Look for job definition
job_def = self.get_job_definition_by_arn(job_def_id)
if job_def is None and ':' in job_def_id:
job_def = self.get_job_definition_by_name_revision(*job_def_id.split(':', 1))
job_def = self.get_job_definition(job_def_id)
if job_def is None:
raise ClientException('Job definition {0} does not exist'.format(job_def_id))

View File

@ -246,7 +246,8 @@ def resource_name_property_from_type(resource_type):
def generate_resource_name(resource_type, stack_name, logical_id):
if resource_type == "AWS::ElasticLoadBalancingV2::TargetGroup":
if resource_type in ["AWS::ElasticLoadBalancingV2::TargetGroup",
"AWS::ElasticLoadBalancingV2::LoadBalancer"]:
# Target group names need to be less than 32 characters, so when cloudformation creates a name for you
# it makes sure to stay under that limit
name_prefix = '{0}-{1}'.format(stack_name, logical_id)

View File

@ -4,6 +4,7 @@ import six
import random
import yaml
import os
import string
from cfnlint import decode, core
@ -29,7 +30,7 @@ def generate_stackset_arn(stackset_id, region_name):
def random_suffix():
size = 12
chars = list(range(10)) + ['A-Z']
chars = list(range(10)) + list(string.ascii_uppercase)
return ''.join(six.text_type(random.choice(chars)) for x in range(size))

View File

@ -275,7 +275,7 @@ GET_METRIC_STATISTICS_TEMPLATE = """<GetMetricStatisticsResponse xmlns="http://m
<Label>{{ label }}</Label>
<Datapoints>
{% for datapoint in datapoints %}
<Datapoint>
<member>
{% if datapoint.sum is not none %}
<Sum>{{ datapoint.sum }}</Sum>
{% endif %}
@ -302,7 +302,7 @@ GET_METRIC_STATISTICS_TEMPLATE = """<GetMetricStatisticsResponse xmlns="http://m
<Timestamp>{{ datapoint.timestamp }}</Timestamp>
<Unit>{{ datapoint.unit }}</Unit>
</Datapoint>
</member>
{% endfor %}
</Datapoints>
</GetMetricStatisticsResult>

View File

@ -95,6 +95,15 @@ class CognitoIdentityBackend(BaseBackend):
})
return response
def get_open_id_token(self, identity_id):
response = json.dumps(
{
"IdentityId": identity_id,
"Token": get_random_identity_id(self.region)
}
)
return response
cognitoidentity_backends = {}
for region in boto.cognito.identity.regions():

View File

@ -35,3 +35,8 @@ class CognitoIdentityResponse(BaseResponse):
return cognitoidentity_backends[self.region].get_open_id_token_for_developer_identity(
self._get_param('IdentityId') or get_random_identity_id(self.region)
)
def get_open_id_token(self):
return cognitoidentity_backends[self.region].get_open_id_token(
self._get_param("IdentityId") or get_random_identity_id(self.region)
)

View File

@ -2,6 +2,7 @@ from __future__ import unicode_literals
import datetime
import functools
import hashlib
import itertools
import json
import os
@ -154,20 +155,37 @@ class CognitoIdpUserPool(BaseModel):
class CognitoIdpUserPoolDomain(BaseModel):
def __init__(self, user_pool_id, domain):
def __init__(self, user_pool_id, domain, custom_domain_config=None):
self.user_pool_id = user_pool_id
self.domain = domain
self.custom_domain_config = custom_domain_config or {}
def to_json(self):
return {
"UserPoolId": self.user_pool_id,
"AWSAccountId": str(uuid.uuid4()),
"CloudFrontDistribution": None,
"Domain": self.domain,
"S3Bucket": None,
"Status": "ACTIVE",
"Version": None,
}
def _distribution_name(self):
if self.custom_domain_config and \
'CertificateArn' in self.custom_domain_config:
hash = hashlib.md5(
self.custom_domain_config['CertificateArn'].encode('utf-8')
).hexdigest()
return "{hash}.cloudfront.net".format(hash=hash[:16])
return None
def to_json(self, extended=True):
distribution = self._distribution_name()
if extended:
return {
"UserPoolId": self.user_pool_id,
"AWSAccountId": str(uuid.uuid4()),
"CloudFrontDistribution": distribution,
"Domain": self.domain,
"S3Bucket": None,
"Status": "ACTIVE",
"Version": None,
}
elif distribution:
return {
"CloudFrontDomain": distribution,
}
return None
class CognitoIdpUserPoolClient(BaseModel):
@ -338,11 +356,13 @@ class CognitoIdpBackend(BaseBackend):
del self.user_pools[user_pool_id]
# User pool domain
def create_user_pool_domain(self, user_pool_id, domain):
def create_user_pool_domain(self, user_pool_id, domain, custom_domain_config=None):
if user_pool_id not in self.user_pools:
raise ResourceNotFoundError(user_pool_id)
user_pool_domain = CognitoIdpUserPoolDomain(user_pool_id, domain)
user_pool_domain = CognitoIdpUserPoolDomain(
user_pool_id, domain, custom_domain_config=custom_domain_config
)
self.user_pool_domains[domain] = user_pool_domain
return user_pool_domain
@ -358,6 +378,14 @@ class CognitoIdpBackend(BaseBackend):
del self.user_pool_domains[domain]
def update_user_pool_domain(self, domain, custom_domain_config):
if domain not in self.user_pool_domains:
raise ResourceNotFoundError(domain)
user_pool_domain = self.user_pool_domains[domain]
user_pool_domain.custom_domain_config = custom_domain_config
return user_pool_domain
# User pool client
def create_user_pool_client(self, user_pool_id, extended_config):
user_pool = self.user_pools.get(user_pool_id)

View File

@ -50,7 +50,13 @@ class CognitoIdpResponse(BaseResponse):
def create_user_pool_domain(self):
domain = self._get_param("Domain")
user_pool_id = self._get_param("UserPoolId")
cognitoidp_backends[self.region].create_user_pool_domain(user_pool_id, domain)
custom_domain_config = self._get_param("CustomDomainConfig")
user_pool_domain = cognitoidp_backends[self.region].create_user_pool_domain(
user_pool_id, domain, custom_domain_config
)
domain_description = user_pool_domain.to_json(extended=False)
if domain_description:
return json.dumps(domain_description)
return ""
def describe_user_pool_domain(self):
@ -69,6 +75,17 @@ class CognitoIdpResponse(BaseResponse):
cognitoidp_backends[self.region].delete_user_pool_domain(domain)
return ""
def update_user_pool_domain(self):
domain = self._get_param("Domain")
custom_domain_config = self._get_param("CustomDomainConfig")
user_pool_domain = cognitoidp_backends[self.region].update_user_pool_domain(
domain, custom_domain_config
)
domain_description = user_pool_domain.to_json(extended=False)
if domain_description:
return json.dumps(domain_description)
return ""
# User pool client
def create_user_pool_client(self):
user_pool_id = self.parameters.pop("UserPoolId")

View File

@ -52,6 +52,18 @@ class InvalidResourceTypeException(JsonRESTError):
super(InvalidResourceTypeException, self).__init__("ValidationException", message)
class NoSuchConfigurationAggregatorException(JsonRESTError):
code = 400
def __init__(self, number=1):
if number == 1:
message = 'The configuration aggregator does not exist. Check the configuration aggregator name and try again.'
else:
message = 'At least one of the configuration aggregators does not exist. Check the configuration aggregator' \
' names and try again.'
super(NoSuchConfigurationAggregatorException, self).__init__("NoSuchConfigurationAggregatorException", message)
class NoSuchConfigurationRecorderException(JsonRESTError):
code = 400
@ -78,6 +90,14 @@ class NoSuchBucketException(JsonRESTError):
super(NoSuchBucketException, self).__init__("NoSuchBucketException", message)
class InvalidNextTokenException(JsonRESTError):
code = 400
def __init__(self):
message = 'The nextToken provided is invalid'
super(InvalidNextTokenException, self).__init__("InvalidNextTokenException", message)
class InvalidS3KeyPrefixException(JsonRESTError):
code = 400
@ -147,3 +167,66 @@ class LastDeliveryChannelDeleteFailedException(JsonRESTError):
message = 'Failed to delete last specified delivery channel with name \'{name}\', because there, ' \
'because there is a running configuration recorder.'.format(name=name)
super(LastDeliveryChannelDeleteFailedException, self).__init__("LastDeliveryChannelDeleteFailedException", message)
class TooManyAccountSources(JsonRESTError):
code = 400
def __init__(self, length):
locations = ['com.amazonaws.xyz'] * length
message = 'Value \'[{locations}]\' at \'accountAggregationSources\' failed to satisfy constraint: ' \
'Member must have length less than or equal to 1'.format(locations=', '.join(locations))
super(TooManyAccountSources, self).__init__("ValidationException", message)
class DuplicateTags(JsonRESTError):
code = 400
def __init__(self):
super(DuplicateTags, self).__init__(
'InvalidInput', 'Duplicate tag keys found. Please note that Tag keys are case insensitive.')
class TagKeyTooBig(JsonRESTError):
code = 400
def __init__(self, tag, param='tags.X.member.key'):
super(TagKeyTooBig, self).__init__(
'ValidationException', "1 validation error detected: Value '{}' at '{}' failed to satisfy "
"constraint: Member must have length less than or equal to 128".format(tag, param))
class TagValueTooBig(JsonRESTError):
code = 400
def __init__(self, tag):
super(TagValueTooBig, self).__init__(
'ValidationException', "1 validation error detected: Value '{}' at 'tags.X.member.value' failed to satisfy "
"constraint: Member must have length less than or equal to 256".format(tag))
class InvalidParameterValueException(JsonRESTError):
code = 400
def __init__(self, message):
super(InvalidParameterValueException, self).__init__('InvalidParameterValueException', message)
class InvalidTagCharacters(JsonRESTError):
code = 400
def __init__(self, tag, param='tags.X.member.key'):
message = "1 validation error detected: Value '{}' at '{}' failed to satisfy ".format(tag, param)
message += 'constraint: Member must satisfy regular expression pattern: [\\\\p{L}\\\\p{Z}\\\\p{N}_.:/=+\\\\-@]+'
super(InvalidTagCharacters, self).__init__('ValidationException', message)
class TooManyTags(JsonRESTError):
code = 400
def __init__(self, tags, param='tags'):
super(TooManyTags, self).__init__(
'ValidationException', "1 validation error detected: Value '{}' at '{}' failed to satisfy "
"constraint: Member must have length less than or equal to 50.".format(tags, param))

View File

@ -1,6 +1,9 @@
import json
import re
import time
import pkg_resources
import random
import string
from datetime import datetime
@ -12,37 +15,125 @@ from moto.config.exceptions import InvalidResourceTypeException, InvalidDelivery
NoSuchConfigurationRecorderException, NoAvailableConfigurationRecorderException, \
InvalidDeliveryChannelNameException, NoSuchBucketException, InvalidS3KeyPrefixException, \
InvalidSNSTopicARNException, MaxNumberOfDeliveryChannelsExceededException, NoAvailableDeliveryChannelException, \
NoSuchDeliveryChannelException, LastDeliveryChannelDeleteFailedException
NoSuchDeliveryChannelException, LastDeliveryChannelDeleteFailedException, TagKeyTooBig, \
TooManyTags, TagValueTooBig, TooManyAccountSources, InvalidParameterValueException, InvalidNextTokenException, \
NoSuchConfigurationAggregatorException, InvalidTagCharacters, DuplicateTags
from moto.core import BaseBackend, BaseModel
DEFAULT_ACCOUNT_ID = 123456789012
POP_STRINGS = [
'capitalizeStart',
'CapitalizeStart',
'capitalizeArn',
'CapitalizeArn',
'capitalizeARN',
'CapitalizeARN'
]
DEFAULT_PAGE_SIZE = 100
def datetime2int(date):
return int(time.mktime(date.timetuple()))
def snake_to_camels(original):
def snake_to_camels(original, cap_start, cap_arn):
parts = original.split('_')
camel_cased = parts[0].lower() + ''.join(p.title() for p in parts[1:])
camel_cased = camel_cased.replace('Arn', 'ARN') # Config uses 'ARN' instead of 'Arn'
if cap_arn:
camel_cased = camel_cased.replace('Arn', 'ARN') # Some config services use 'ARN' instead of 'Arn'
if cap_start:
camel_cased = camel_cased[0].upper() + camel_cased[1::]
return camel_cased
def random_string():
"""Returns a random set of 8 lowercase letters for the Config Aggregator ARN"""
chars = []
for x in range(0, 8):
chars.append(random.choice(string.ascii_lowercase))
return "".join(chars)
def validate_tag_key(tag_key, exception_param='tags.X.member.key'):
"""Validates the tag key.
:param tag_key: The tag key to check against.
:param exception_param: The exception parameter to send over to help format the message. This is to reflect
the difference between the tag and untag APIs.
:return:
"""
# Validate that the key length is correct:
if len(tag_key) > 128:
raise TagKeyTooBig(tag_key, param=exception_param)
# Validate that the tag key fits the proper Regex:
# [\w\s_.:/=+\-@]+ SHOULD be the same as the Java regex on the AWS documentation: [\p{L}\p{Z}\p{N}_.:/=+\-@]+
match = re.findall(r'[\w\s_.:/=+\-@]+', tag_key)
# Kudos if you can come up with a better way of doing a global search :)
if not len(match) or len(match[0]) < len(tag_key):
raise InvalidTagCharacters(tag_key, param=exception_param)
def check_tag_duplicate(all_tags, tag_key):
"""Validates that a tag key is not a duplicate
:param all_tags: Dict to check if there is a duplicate tag.
:param tag_key: The tag key to check against.
:return:
"""
if all_tags.get(tag_key):
raise DuplicateTags()
def validate_tags(tags):
proper_tags = {}
if len(tags) > 50:
raise TooManyTags(tags)
for tag in tags:
# Validate the Key:
validate_tag_key(tag['Key'])
check_tag_duplicate(proper_tags, tag['Key'])
# Validate the Value:
if len(tag['Value']) > 256:
raise TagValueTooBig(tag['Value'])
proper_tags[tag['Key']] = tag['Value']
return proper_tags
class ConfigEmptyDictable(BaseModel):
"""Base class to make serialization easy. This assumes that the sub-class will NOT return 'None's in the JSON."""
def __init__(self, capitalize_start=False, capitalize_arn=True):
"""Assists with the serialization of the config object
:param capitalize_start: For some Config services, the first letter is lowercase -- for others it's capital
:param capitalize_arn: For some Config services, the API expects 'ARN' and for others, it expects 'Arn'
"""
self.capitalize_start = capitalize_start
self.capitalize_arn = capitalize_arn
def to_dict(self):
data = {}
for item, value in self.__dict__.items():
if value is not None:
if isinstance(value, ConfigEmptyDictable):
data[snake_to_camels(item)] = value.to_dict()
data[snake_to_camels(item, self.capitalize_start, self.capitalize_arn)] = value.to_dict()
else:
data[snake_to_camels(item)] = value
data[snake_to_camels(item, self.capitalize_start, self.capitalize_arn)] = value
# Cleanse the extra properties:
for prop in POP_STRINGS:
data.pop(prop, None)
return data
@ -50,8 +141,9 @@ class ConfigEmptyDictable(BaseModel):
class ConfigRecorderStatus(ConfigEmptyDictable):
def __init__(self, name):
self.name = name
super(ConfigRecorderStatus, self).__init__()
self.name = name
self.recording = False
self.last_start_time = None
self.last_stop_time = None
@ -75,12 +167,16 @@ class ConfigRecorderStatus(ConfigEmptyDictable):
class ConfigDeliverySnapshotProperties(ConfigEmptyDictable):
def __init__(self, delivery_frequency):
super(ConfigDeliverySnapshotProperties, self).__init__()
self.delivery_frequency = delivery_frequency
class ConfigDeliveryChannel(ConfigEmptyDictable):
def __init__(self, name, s3_bucket_name, prefix=None, sns_arn=None, snapshot_properties=None):
super(ConfigDeliveryChannel, self).__init__()
self.name = name
self.s3_bucket_name = s3_bucket_name
self.s3_key_prefix = prefix
@ -91,6 +187,8 @@ class ConfigDeliveryChannel(ConfigEmptyDictable):
class RecordingGroup(ConfigEmptyDictable):
def __init__(self, all_supported=True, include_global_resource_types=False, resource_types=None):
super(RecordingGroup, self).__init__()
self.all_supported = all_supported
self.include_global_resource_types = include_global_resource_types
self.resource_types = resource_types
@ -99,6 +197,8 @@ class RecordingGroup(ConfigEmptyDictable):
class ConfigRecorder(ConfigEmptyDictable):
def __init__(self, role_arn, recording_group, name='default', status=None):
super(ConfigRecorder, self).__init__()
self.name = name
self.role_arn = role_arn
self.recording_group = recording_group
@ -109,18 +209,118 @@ class ConfigRecorder(ConfigEmptyDictable):
self.status = status
class AccountAggregatorSource(ConfigEmptyDictable):
def __init__(self, account_ids, aws_regions=None, all_aws_regions=None):
super(AccountAggregatorSource, self).__init__(capitalize_start=True)
# Can't have both the regions and all_regions flag present -- also can't have them both missing:
if aws_regions and all_aws_regions:
raise InvalidParameterValueException('Your configuration aggregator contains a list of regions and also specifies '
'the use of all regions. You must choose one of these options.')
if not (aws_regions or all_aws_regions):
raise InvalidParameterValueException('Your request does not specify any regions. Select AWS Config-supported '
'regions and try again.')
self.account_ids = account_ids
self.aws_regions = aws_regions
if not all_aws_regions:
all_aws_regions = False
self.all_aws_regions = all_aws_regions
class OrganizationAggregationSource(ConfigEmptyDictable):
def __init__(self, role_arn, aws_regions=None, all_aws_regions=None):
super(OrganizationAggregationSource, self).__init__(capitalize_start=True, capitalize_arn=False)
# Can't have both the regions and all_regions flag present -- also can't have them both missing:
if aws_regions and all_aws_regions:
raise InvalidParameterValueException('Your configuration aggregator contains a list of regions and also specifies '
'the use of all regions. You must choose one of these options.')
if not (aws_regions or all_aws_regions):
raise InvalidParameterValueException('Your request does not specify any regions. Select AWS Config-supported '
'regions and try again.')
self.role_arn = role_arn
self.aws_regions = aws_regions
if not all_aws_regions:
all_aws_regions = False
self.all_aws_regions = all_aws_regions
class ConfigAggregator(ConfigEmptyDictable):
def __init__(self, name, region, account_sources=None, org_source=None, tags=None):
super(ConfigAggregator, self).__init__(capitalize_start=True, capitalize_arn=False)
self.configuration_aggregator_name = name
self.configuration_aggregator_arn = 'arn:aws:config:{region}:{id}:config-aggregator/config-aggregator-{random}'.format(
region=region,
id=DEFAULT_ACCOUNT_ID,
random=random_string()
)
self.account_aggregation_sources = account_sources
self.organization_aggregation_source = org_source
self.creation_time = datetime2int(datetime.utcnow())
self.last_updated_time = datetime2int(datetime.utcnow())
# Tags are listed in the list_tags_for_resource API call ... not implementing yet -- please feel free to!
self.tags = tags or {}
# Override the to_dict so that we can format the tags properly...
def to_dict(self):
result = super(ConfigAggregator, self).to_dict()
# Override the account aggregation sources if present:
if self.account_aggregation_sources:
result['AccountAggregationSources'] = [a.to_dict() for a in self.account_aggregation_sources]
# Tags are listed in the list_tags_for_resource API call ... not implementing yet -- please feel free to!
# if self.tags:
# result['Tags'] = [{'Key': key, 'Value': value} for key, value in self.tags.items()]
return result
class ConfigAggregationAuthorization(ConfigEmptyDictable):
def __init__(self, current_region, authorized_account_id, authorized_aws_region, tags=None):
super(ConfigAggregationAuthorization, self).__init__(capitalize_start=True, capitalize_arn=False)
self.aggregation_authorization_arn = 'arn:aws:config:{region}:{id}:aggregation-authorization/' \
'{auth_account}/{auth_region}'.format(region=current_region,
id=DEFAULT_ACCOUNT_ID,
auth_account=authorized_account_id,
auth_region=authorized_aws_region)
self.authorized_account_id = authorized_account_id
self.authorized_aws_region = authorized_aws_region
self.creation_time = datetime2int(datetime.utcnow())
# Tags are listed in the list_tags_for_resource API call ... not implementing yet -- please feel free to!
self.tags = tags or {}
class ConfigBackend(BaseBackend):
def __init__(self):
self.recorders = {}
self.delivery_channels = {}
self.config_aggregators = {}
self.aggregation_authorizations = {}
@staticmethod
def _validate_resource_types(resource_list):
# Load the service file:
resource_package = 'botocore'
resource_path = '/'.join(('data', 'config', '2014-11-12', 'service-2.json'))
conifg_schema = json.loads(pkg_resources.resource_string(resource_package, resource_path))
config_schema = json.loads(pkg_resources.resource_string(resource_package, resource_path))
# Verify that each entry exists in the supported list:
bad_list = []
@ -128,11 +328,11 @@ class ConfigBackend(BaseBackend):
# For PY2:
r_str = str(resource)
if r_str not in conifg_schema['shapes']['ResourceType']['enum']:
if r_str not in config_schema['shapes']['ResourceType']['enum']:
bad_list.append(r_str)
if bad_list:
raise InvalidResourceTypeException(bad_list, conifg_schema['shapes']['ResourceType']['enum'])
raise InvalidResourceTypeException(bad_list, config_schema['shapes']['ResourceType']['enum'])
@staticmethod
def _validate_delivery_snapshot_properties(properties):
@ -147,6 +347,158 @@ class ConfigBackend(BaseBackend):
raise InvalidDeliveryFrequency(properties.get('deliveryFrequency', None),
conifg_schema['shapes']['MaximumExecutionFrequency']['enum'])
def put_configuration_aggregator(self, config_aggregator, region):
# Validate the name:
if len(config_aggregator['ConfigurationAggregatorName']) > 256:
raise NameTooLongException(config_aggregator['ConfigurationAggregatorName'], 'configurationAggregatorName')
account_sources = None
org_source = None
# Tag validation:
tags = validate_tags(config_aggregator.get('Tags', []))
# Exception if both AccountAggregationSources and OrganizationAggregationSource are supplied:
if config_aggregator.get('AccountAggregationSources') and config_aggregator.get('OrganizationAggregationSource'):
raise InvalidParameterValueException('The configuration aggregator cannot be created because your request contains both the'
' AccountAggregationSource and the OrganizationAggregationSource. Include only '
'one aggregation source and try again.')
# If neither are supplied:
if not config_aggregator.get('AccountAggregationSources') and not config_aggregator.get('OrganizationAggregationSource'):
raise InvalidParameterValueException('The configuration aggregator cannot be created because your request is missing either '
'the AccountAggregationSource or the OrganizationAggregationSource. Include the '
'appropriate aggregation source and try again.')
if config_aggregator.get('AccountAggregationSources'):
# Currently, only 1 account aggregation source can be set:
if len(config_aggregator['AccountAggregationSources']) > 1:
raise TooManyAccountSources(len(config_aggregator['AccountAggregationSources']))
account_sources = []
for a in config_aggregator['AccountAggregationSources']:
account_sources.append(AccountAggregatorSource(a['AccountIds'], aws_regions=a.get('AwsRegions'),
all_aws_regions=a.get('AllAwsRegions')))
else:
org_source = OrganizationAggregationSource(config_aggregator['OrganizationAggregationSource']['RoleArn'],
aws_regions=config_aggregator['OrganizationAggregationSource'].get('AwsRegions'),
all_aws_regions=config_aggregator['OrganizationAggregationSource'].get(
'AllAwsRegions'))
# Grab the existing one if it exists and update it:
if not self.config_aggregators.get(config_aggregator['ConfigurationAggregatorName']):
aggregator = ConfigAggregator(config_aggregator['ConfigurationAggregatorName'], region, account_sources=account_sources,
org_source=org_source, tags=tags)
self.config_aggregators[config_aggregator['ConfigurationAggregatorName']] = aggregator
else:
aggregator = self.config_aggregators[config_aggregator['ConfigurationAggregatorName']]
aggregator.tags = tags
aggregator.account_aggregation_sources = account_sources
aggregator.organization_aggregation_source = org_source
aggregator.last_updated_time = datetime2int(datetime.utcnow())
return aggregator.to_dict()
def describe_configuration_aggregators(self, names, token, limit):
limit = DEFAULT_PAGE_SIZE if not limit or limit < 0 else limit
agg_list = []
result = {'ConfigurationAggregators': []}
if names:
for name in names:
if not self.config_aggregators.get(name):
raise NoSuchConfigurationAggregatorException(number=len(names))
agg_list.append(name)
else:
agg_list = list(self.config_aggregators.keys())
# Empty?
if not agg_list:
return result
# Sort by name:
sorted_aggregators = sorted(agg_list)
# Get the start:
if not token:
start = 0
else:
# Tokens for this moto feature are just the next names of the items in the list:
if not self.config_aggregators.get(token):
raise InvalidNextTokenException()
start = sorted_aggregators.index(token)
# Get the list of items to collect:
agg_list = sorted_aggregators[start:(start + limit)]
result['ConfigurationAggregators'] = [self.config_aggregators[agg].to_dict() for agg in agg_list]
if len(sorted_aggregators) > (start + limit):
result['NextToken'] = sorted_aggregators[start + limit]
return result
def delete_configuration_aggregator(self, config_aggregator):
if not self.config_aggregators.get(config_aggregator):
raise NoSuchConfigurationAggregatorException()
del self.config_aggregators[config_aggregator]
def put_aggregation_authorization(self, current_region, authorized_account, authorized_region, tags):
# Tag validation:
tags = validate_tags(tags or [])
# Does this already exist?
key = '{}/{}'.format(authorized_account, authorized_region)
agg_auth = self.aggregation_authorizations.get(key)
if not agg_auth:
agg_auth = ConfigAggregationAuthorization(current_region, authorized_account, authorized_region, tags=tags)
self.aggregation_authorizations['{}/{}'.format(authorized_account, authorized_region)] = agg_auth
else:
# Only update the tags:
agg_auth.tags = tags
return agg_auth.to_dict()
def describe_aggregation_authorizations(self, token, limit):
limit = DEFAULT_PAGE_SIZE if not limit or limit < 0 else limit
result = {'AggregationAuthorizations': []}
if not self.aggregation_authorizations:
return result
# Sort by name:
sorted_authorizations = sorted(self.aggregation_authorizations.keys())
# Get the start:
if not token:
start = 0
else:
# Tokens for this moto feature are just the next names of the items in the list:
if not self.aggregation_authorizations.get(token):
raise InvalidNextTokenException()
start = sorted_authorizations.index(token)
# Get the list of items to collect:
auth_list = sorted_authorizations[start:(start + limit)]
result['AggregationAuthorizations'] = [self.aggregation_authorizations[auth].to_dict() for auth in auth_list]
if len(sorted_authorizations) > (start + limit):
result['NextToken'] = sorted_authorizations[start + limit]
return result
def delete_aggregation_authorization(self, authorized_account, authorized_region):
# This will always return a 200 -- regardless if there is or isn't an existing
# aggregation authorization.
key = '{}/{}'.format(authorized_account, authorized_region)
self.aggregation_authorizations.pop(key, None)
def put_configuration_recorder(self, config_recorder):
# Validate the name:
if not config_recorder.get('name'):

View File

@ -13,6 +13,39 @@ class ConfigResponse(BaseResponse):
self.config_backend.put_configuration_recorder(self._get_param('ConfigurationRecorder'))
return ""
def put_configuration_aggregator(self):
aggregator = self.config_backend.put_configuration_aggregator(json.loads(self.body), self.region)
schema = {'ConfigurationAggregator': aggregator}
return json.dumps(schema)
def describe_configuration_aggregators(self):
aggregators = self.config_backend.describe_configuration_aggregators(self._get_param('ConfigurationAggregatorNames'),
self._get_param('NextToken'),
self._get_param('Limit'))
return json.dumps(aggregators)
def delete_configuration_aggregator(self):
self.config_backend.delete_configuration_aggregator(self._get_param('ConfigurationAggregatorName'))
return ""
def put_aggregation_authorization(self):
agg_auth = self.config_backend.put_aggregation_authorization(self.region,
self._get_param('AuthorizedAccountId'),
self._get_param('AuthorizedAwsRegion'),
self._get_param('Tags'))
schema = {'AggregationAuthorization': agg_auth}
return json.dumps(schema)
def describe_aggregation_authorizations(self):
authorizations = self.config_backend.describe_aggregation_authorizations(self._get_param('NextToken'), self._get_param('Limit'))
return json.dumps(authorizations)
def delete_aggregation_authorization(self):
self.config_backend.delete_aggregation_authorization(self._get_param('AuthorizedAccountId'), self._get_param('AuthorizedAwsRegion'))
return ""
def describe_configuration_recorders(self):
recorders = self.config_backend.describe_configuration_recorders(self._get_param('ConfigurationRecorderNames'))
schema = {'ConfigurationRecorders': recorders}

View File

@ -1,4 +1,7 @@
from __future__ import unicode_literals
from .models import BaseModel, BaseBackend, moto_api_backend # flake8: noqa
from .responses import ActionAuthenticatorMixin
moto_api_backends = {"global": moto_api_backend}
set_initial_no_auth_action_count = ActionAuthenticatorMixin.set_initial_no_auth_action_count

View File

@ -65,3 +65,42 @@ class JsonRESTError(RESTError):
def get_body(self, *args, **kwargs):
return self.description
class SignatureDoesNotMatchError(RESTError):
code = 403
def __init__(self):
super(SignatureDoesNotMatchError, self).__init__(
'SignatureDoesNotMatch',
"The request signature we calculated does not match the signature you provided. Check your AWS Secret Access Key and signing method. Consult the service documentation for details.")
class InvalidClientTokenIdError(RESTError):
code = 403
def __init__(self):
super(InvalidClientTokenIdError, self).__init__(
'InvalidClientTokenId',
"The security token included in the request is invalid.")
class AccessDeniedError(RESTError):
code = 403
def __init__(self, user_arn, action):
super(AccessDeniedError, self).__init__(
'AccessDenied',
"User: {user_arn} is not authorized to perform: {operation}".format(
user_arn=user_arn,
operation=action
))
class AuthFailureError(RESTError):
code = 401
def __init__(self):
super(AuthFailureError, self).__init__(
'AuthFailure',
"AWS was not able to validate the provided access credentials")

View File

@ -12,6 +12,7 @@ from collections import defaultdict
from botocore.handlers import BUILTIN_HANDLERS
from botocore.awsrequest import AWSResponse
import mock
from moto import settings
import responses
from moto.packages.httpretty import HTTPretty
@ -22,11 +23,6 @@ from .utils import (
)
# "Mock" the AWS credentials as they can't be mocked in Botocore currently
os.environ.setdefault("AWS_ACCESS_KEY_ID", "foobar_key")
os.environ.setdefault("AWS_SECRET_ACCESS_KEY", "foobar_secret")
class BaseMockAWS(object):
nested_count = 0
@ -42,6 +38,10 @@ class BaseMockAWS(object):
self.backends_for_urls.update(self.backends)
self.backends_for_urls.update(default_backends)
# "Mock" the AWS credentials as they can't be mocked in Botocore currently
FAKE_KEYS = {"AWS_ACCESS_KEY_ID": "foobar_key", "AWS_SECRET_ACCESS_KEY": "foobar_secret"}
self.env_variables_mocks = mock.patch.dict(os.environ, FAKE_KEYS)
if self.__class__.nested_count == 0:
self.reset()
@ -52,11 +52,14 @@ class BaseMockAWS(object):
def __enter__(self):
self.start()
return self
def __exit__(self, *args):
self.stop()
def start(self, reset=True):
self.env_variables_mocks.start()
self.__class__.nested_count += 1
if reset:
for backend in self.backends.values():
@ -65,6 +68,7 @@ class BaseMockAWS(object):
self.enable_patching()
def stop(self):
self.env_variables_mocks.stop()
self.__class__.nested_count -= 1
if self.__class__.nested_count < 0:
@ -465,10 +469,14 @@ class BaseModel(object):
class BaseBackend(object):
def reset(self):
def _reset_model_refs(self):
# Remove all references to the models stored
for service, models in model_data.items():
for model_name, model in models.items():
model.instances = []
def reset(self):
self._reset_model_refs()
self.__dict__ = {}
self.__init__()

View File

@ -1,13 +1,17 @@
from __future__ import unicode_literals
import functools
from collections import defaultdict
import datetime
import json
import logging
import re
import io
import requests
import pytz
from moto.core.access_control import IAMRequest, S3IAMRequest
from moto.core.exceptions import DryRunClientError
from jinja2 import Environment, DictLoader, TemplateNotFound
@ -22,7 +26,7 @@ from werkzeug.exceptions import HTTPException
import boto3
from moto.compat import OrderedDict
from moto.core.utils import camelcase_to_underscores, method_names_from_class
from moto import settings
log = logging.getLogger(__name__)
@ -103,7 +107,54 @@ class _TemplateEnvironmentMixin(object):
return self.environment.get_template(template_id)
class BaseResponse(_TemplateEnvironmentMixin):
class ActionAuthenticatorMixin(object):
request_count = 0
def _authenticate_and_authorize_action(self, iam_request_cls):
if ActionAuthenticatorMixin.request_count >= settings.INITIAL_NO_AUTH_ACTION_COUNT:
iam_request = iam_request_cls(method=self.method, path=self.path, data=self.data, headers=self.headers)
iam_request.check_signature()
iam_request.check_action_permitted()
else:
ActionAuthenticatorMixin.request_count += 1
def _authenticate_and_authorize_normal_action(self):
self._authenticate_and_authorize_action(IAMRequest)
def _authenticate_and_authorize_s3_action(self):
self._authenticate_and_authorize_action(S3IAMRequest)
@staticmethod
def set_initial_no_auth_action_count(initial_no_auth_action_count):
def decorator(function):
def wrapper(*args, **kwargs):
if settings.TEST_SERVER_MODE:
response = requests.post("http://localhost:5000/moto-api/reset-auth", data=str(initial_no_auth_action_count).encode())
original_initial_no_auth_action_count = response.json()['PREVIOUS_INITIAL_NO_AUTH_ACTION_COUNT']
else:
original_initial_no_auth_action_count = settings.INITIAL_NO_AUTH_ACTION_COUNT
original_request_count = ActionAuthenticatorMixin.request_count
settings.INITIAL_NO_AUTH_ACTION_COUNT = initial_no_auth_action_count
ActionAuthenticatorMixin.request_count = 0
try:
result = function(*args, **kwargs)
finally:
if settings.TEST_SERVER_MODE:
requests.post("http://localhost:5000/moto-api/reset-auth", data=str(original_initial_no_auth_action_count).encode())
else:
ActionAuthenticatorMixin.request_count = original_request_count
settings.INITIAL_NO_AUTH_ACTION_COUNT = original_initial_no_auth_action_count
return result
functools.update_wrapper(wrapper, function)
wrapper.__wrapped__ = function
return wrapper
return decorator
class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
default_region = 'us-east-1'
# to extract region, use [^.]
@ -167,6 +218,7 @@ class BaseResponse(_TemplateEnvironmentMixin):
self.uri = full_url
self.path = urlparse(full_url).path
self.querystring = querystring
self.data = querystring
self.method = request.method
self.region = self.get_region_from_url(request, full_url)
self.uri_match = None
@ -273,6 +325,13 @@ class BaseResponse(_TemplateEnvironmentMixin):
def call_action(self):
headers = self.response_headers
try:
self._authenticate_and_authorize_normal_action()
except HTTPException as http_error:
response = http_error.description, dict(status=http_error.code)
return self._send_response(headers, response)
action = camelcase_to_underscores(self._get_action())
method_names = method_names_from_class(self.__class__)
if action in method_names:
@ -285,16 +344,7 @@ class BaseResponse(_TemplateEnvironmentMixin):
if isinstance(response, six.string_types):
return 200, headers, response
else:
if len(response) == 2:
body, new_headers = response
else:
status, new_headers, body = response
status = new_headers.get('status', 200)
headers.update(new_headers)
# Cast status to string
if "status" in headers:
headers['status'] = str(headers['status'])
return status, headers, body
return self._send_response(headers, response)
if not action:
return 404, headers, ''
@ -302,6 +352,19 @@ class BaseResponse(_TemplateEnvironmentMixin):
raise NotImplementedError(
"The {0} action has not been implemented".format(action))
@staticmethod
def _send_response(headers, response):
if len(response) == 2:
body, new_headers = response
else:
status, new_headers, body = response
status = new_headers.get('status', 200)
headers.update(new_headers)
# Cast status to string
if "status" in headers:
headers['status'] = str(headers['status'])
return status, headers, body
def _get_param(self, param_name, if_none=None):
val = self.querystring.get(param_name)
if val is not None:
@ -569,6 +632,14 @@ class MotoAPIResponse(BaseResponse):
return 200, {}, json.dumps({"status": "ok"})
return 400, {}, json.dumps({"Error": "Need to POST to reset Moto"})
def reset_auth_response(self, request, full_url, headers):
if request.method == "POST":
previous_initial_no_auth_action_count = settings.INITIAL_NO_AUTH_ACTION_COUNT
settings.INITIAL_NO_AUTH_ACTION_COUNT = float(request.data.decode())
ActionAuthenticatorMixin.request_count = 0
return 200, {}, json.dumps({"status": "ok", "PREVIOUS_INITIAL_NO_AUTH_ACTION_COUNT": str(previous_initial_no_auth_action_count)})
return 400, {}, json.dumps({"Error": "Need to POST to reset Moto Auth"})
def model_data(self, request, full_url, headers):
from moto.core.models import model_data

View File

@ -11,4 +11,5 @@ url_paths = {
'{0}/moto-api/$': response_instance.dashboard,
'{0}/moto-api/data.json': response_instance.model_data,
'{0}/moto-api/reset': response_instance.reset_response,
'{0}/moto-api/reset-auth': response_instance.reset_auth_response,
}

File diff suppressed because it is too large Load Diff

View File

@ -6,13 +6,16 @@ import decimal
import json
import re
import uuid
import six
import boto3
from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel
from moto.core.utils import unix_time
from moto.core.exceptions import JsonRESTError
from .comparisons import get_comparison_func, get_filter_expression, Op
from .comparisons import get_comparison_func
from .comparisons import get_filter_expression
from .comparisons import get_expected
from .exceptions import InvalidIndexNameError
@ -68,10 +71,34 @@ class DynamoType(object):
except ValueError:
return float(self.value)
elif self.is_set():
return set(self.value)
sub_type = self.type[0]
return set([DynamoType({sub_type: v}).cast_value
for v in self.value])
elif self.is_list():
return [DynamoType(v).cast_value for v in self.value]
elif self.is_map():
return dict([
(k, DynamoType(v).cast_value)
for k, v in self.value.items()])
else:
return self.value
def child_attr(self, key):
"""
Get Map or List children by key. str for Map, int for List.
Returns DynamoType or None.
"""
if isinstance(key, six.string_types) and self.is_map() and key in self.value:
return DynamoType(self.value[key])
if isinstance(key, int) and self.is_list():
idx = key
if idx >= 0 and idx < len(self.value):
return DynamoType(self.value[idx])
return None
def to_json(self):
return {self.type: self.value}
@ -89,6 +116,12 @@ class DynamoType(object):
def is_set(self):
return self.type == 'SS' or self.type == 'NS' or self.type == 'BS'
def is_list(self):
return self.type == 'L'
def is_map(self):
return self.type == 'M'
def same_type(self, other):
return self.type == other.type
@ -265,7 +298,9 @@ class Item(BaseModel):
new_value = list(update_action['Value'].values())[0]
if action == 'PUT':
# TODO deal with other types
if isinstance(new_value, list) or isinstance(new_value, set):
if isinstance(new_value, list):
self.attrs[attribute_name] = DynamoType({"L": new_value})
elif isinstance(new_value, set):
self.attrs[attribute_name] = DynamoType({"SS": new_value})
elif isinstance(new_value, dict):
self.attrs[attribute_name] = DynamoType({"M": new_value})
@ -504,7 +539,9 @@ class Table(BaseModel):
keys.append(range_key)
return keys
def put_item(self, item_attrs, expected=None, overwrite=False):
def put_item(self, item_attrs, expected=None, condition_expression=None,
expression_attribute_names=None,
expression_attribute_values=None, overwrite=False):
hash_value = DynamoType(item_attrs.get(self.hash_key_attr))
if self.has_range_key:
range_value = DynamoType(item_attrs.get(self.range_key_attr))
@ -527,29 +564,15 @@ class Table(BaseModel):
self.range_key_type, item_attrs)
if not overwrite:
if current is None:
current_attr = {}
elif hasattr(current, 'attrs'):
current_attr = current.attrs
else:
current_attr = current
if not get_expected(expected).expr(current):
raise ValueError('The conditional request failed')
condition_op = get_filter_expression(
condition_expression,
expression_attribute_names,
expression_attribute_values)
if not condition_op.expr(current):
raise ValueError('The conditional request failed')
for key, val in expected.items():
if 'Exists' in val and val['Exists'] is False \
or 'ComparisonOperator' in val and val['ComparisonOperator'] == 'NULL':
if key in current_attr:
raise ValueError("The conditional request failed")
elif key not in current_attr:
raise ValueError("The conditional request failed")
elif 'Value' in val and DynamoType(val['Value']).value != current_attr[key].value:
raise ValueError("The conditional request failed")
elif 'ComparisonOperator' in val:
dynamo_types = [
DynamoType(ele) for ele in
val.get("AttributeValueList", [])
]
if not current_attr[key].compare(val['ComparisonOperator'], dynamo_types):
raise ValueError('The conditional request failed')
if range_value:
self.items[hash_value][range_value] = item
else:
@ -724,7 +747,7 @@ class Table(BaseModel):
if idx_col_set.issubset(set(hash_set.attrs)):
yield hash_set
def scan(self, filters, limit, exclusive_start_key, filter_expression=None, index_name=None):
def scan(self, filters, limit, exclusive_start_key, filter_expression=None, index_name=None, projection_expression=None):
results = []
scanned_count = 0
all_indexes = self.all_indexes()
@ -763,6 +786,14 @@ class Table(BaseModel):
if passes_all_conditions:
results.append(item)
if projection_expression:
expressions = [x.strip() for x in projection_expression.split(',')]
results = copy.deepcopy(results)
for result in results:
for attr in list(result.attrs):
if attr not in expressions:
result.attrs.pop(attr)
results, last_evaluated_key = self._trim_results(results, limit,
exclusive_start_key, index_name)
return results, scanned_count, last_evaluated_key
@ -894,11 +925,15 @@ class DynamoDBBackend(BaseBackend):
table.global_indexes = list(gsis_by_name.values())
return table
def put_item(self, table_name, item_attrs, expected=None, overwrite=False):
def put_item(self, table_name, item_attrs, expected=None,
condition_expression=None, expression_attribute_names=None,
expression_attribute_values=None, overwrite=False):
table = self.tables.get(table_name)
if not table:
return None
return table.put_item(item_attrs, expected, overwrite)
return table.put_item(item_attrs, expected, condition_expression,
expression_attribute_names,
expression_attribute_values, overwrite)
def get_table_keys_name(self, table_name, keys):
"""
@ -954,15 +989,12 @@ class DynamoDBBackend(BaseBackend):
range_values = [DynamoType(range_value)
for range_value in range_value_dicts]
if filter_expression is not None:
filter_expression = get_filter_expression(filter_expression, expr_names, expr_values)
else:
filter_expression = Op(None, None) # Will always eval to true
filter_expression = get_filter_expression(filter_expression, expr_names, expr_values)
return table.query(hash_key, range_comparison, range_values, limit,
exclusive_start_key, scan_index_forward, projection_expression, index_name, filter_expression, **filter_kwargs)
def scan(self, table_name, filters, limit, exclusive_start_key, filter_expression, expr_names, expr_values, index_name):
def scan(self, table_name, filters, limit, exclusive_start_key, filter_expression, expr_names, expr_values, index_name, projection_expression):
table = self.tables.get(table_name)
if not table:
return None, None, None
@ -972,15 +1004,14 @@ class DynamoDBBackend(BaseBackend):
dynamo_types = [DynamoType(value) for value in comparison_values]
scan_filters[key] = (comparison_operator, dynamo_types)
if filter_expression is not None:
filter_expression = get_filter_expression(filter_expression, expr_names, expr_values)
else:
filter_expression = Op(None, None) # Will always eval to true
filter_expression = get_filter_expression(filter_expression, expr_names, expr_values)
return table.scan(scan_filters, limit, exclusive_start_key, filter_expression, index_name)
projection_expression = ','.join([expr_names.get(attr, attr) for attr in projection_expression.replace(' ', '').split(',')])
return table.scan(scan_filters, limit, exclusive_start_key, filter_expression, index_name, projection_expression)
def update_item(self, table_name, key, update_expression, attribute_updates, expression_attribute_names,
expression_attribute_values, expected=None):
expression_attribute_values, expected=None, condition_expression=None):
table = self.get_table(table_name)
if all([table.hash_key_attr in key, table.range_key_attr in key]):
@ -999,32 +1030,17 @@ class DynamoDBBackend(BaseBackend):
item = table.get_item(hash_value, range_value)
if item is None:
item_attr = {}
elif hasattr(item, 'attrs'):
item_attr = item.attrs
else:
item_attr = item
if not expected:
expected = {}
for key, val in expected.items():
if 'Exists' in val and val['Exists'] is False \
or 'ComparisonOperator' in val and val['ComparisonOperator'] == 'NULL':
if key in item_attr:
raise ValueError("The conditional request failed")
elif key not in item_attr:
raise ValueError("The conditional request failed")
elif 'Value' in val and DynamoType(val['Value']).value != item_attr[key].value:
raise ValueError("The conditional request failed")
elif 'ComparisonOperator' in val:
dynamo_types = [
DynamoType(ele) for ele in
val.get("AttributeValueList", [])
]
if not item_attr[key].compare(val['ComparisonOperator'], dynamo_types):
raise ValueError('The conditional request failed')
if not get_expected(expected).expr(item):
raise ValueError('The conditional request failed')
condition_op = get_filter_expression(
condition_expression,
expression_attribute_names,
expression_attribute_values)
if not condition_op.expr(item):
raise ValueError('The conditional request failed')
# Update does not fail on new items, so create one
if item is None:

View File

@ -32,67 +32,6 @@ def get_empty_str_error():
))
def condition_expression_to_expected(condition_expression, expression_attribute_names, expression_attribute_values):
"""
Limited condition expression syntax parsing.
Supports Global Negation ex: NOT(inner expressions).
Supports simple AND conditions ex: cond_a AND cond_b and cond_c.
Atomic expressions supported are attribute_exists(key), attribute_not_exists(key) and #key = :value.
"""
expected = {}
if condition_expression and 'OR' not in condition_expression:
reverse_re = re.compile('^NOT\s*\((.*)\)$')
reverse_m = reverse_re.match(condition_expression.strip())
reverse = False
if reverse_m:
reverse = True
condition_expression = reverse_m.group(1)
cond_items = [c.strip() for c in condition_expression.split('AND')]
if cond_items:
exists_re = re.compile('^attribute_exists\s*\((.*)\)$')
not_exists_re = re.compile(
'^attribute_not_exists\s*\((.*)\)$')
equals_re = re.compile('^(#?\w+)\s*=\s*(\:?\w+)')
for cond in cond_items:
exists_m = exists_re.match(cond)
not_exists_m = not_exists_re.match(cond)
equals_m = equals_re.match(cond)
if exists_m:
attribute_name = expression_attribute_names_lookup(exists_m.group(1), expression_attribute_names)
expected[attribute_name] = {'Exists': True if not reverse else False}
elif not_exists_m:
attribute_name = expression_attribute_names_lookup(not_exists_m.group(1), expression_attribute_names)
expected[attribute_name] = {'Exists': False if not reverse else True}
elif equals_m:
attribute_name = expression_attribute_names_lookup(equals_m.group(1), expression_attribute_names)
attribute_value = expression_attribute_values_lookup(equals_m.group(2), expression_attribute_values)
expected[attribute_name] = {
'AttributeValueList': [attribute_value],
'ComparisonOperator': 'EQ' if not reverse else 'NEQ'}
return expected
def expression_attribute_names_lookup(attribute_name, expression_attribute_names):
if attribute_name.startswith('#') and attribute_name in expression_attribute_names:
return expression_attribute_names[attribute_name]
else:
return attribute_name
def expression_attribute_values_lookup(attribute_value, expression_attribute_values):
if isinstance(attribute_value, six.string_types) and \
attribute_value.startswith(':') and\
attribute_value in expression_attribute_values:
return expression_attribute_values[attribute_value]
else:
return attribute_value
class DynamoHandler(BaseResponse):
def get_endpoint_name(self, headers):
@ -166,7 +105,7 @@ class DynamoHandler(BaseResponse):
when BillingMode is PAY_PER_REQUEST')
throughput = None
else: # Provisioned (default billing mode)
throughput = body["ProvisionedThroughput"]
throughput = body.get("ProvisionedThroughput")
# getting the schema
key_schema = body['KeySchema']
# getting attribute definition
@ -288,18 +227,18 @@ class DynamoHandler(BaseResponse):
# Attempt to parse simple ConditionExpressions into an Expected
# expression
if not expected:
condition_expression = self.body.get('ConditionExpression')
expression_attribute_names = self.body.get('ExpressionAttributeNames', {})
expression_attribute_values = self.body.get('ExpressionAttributeValues', {})
expected = condition_expression_to_expected(condition_expression,
expression_attribute_names,
expression_attribute_values)
if expected:
overwrite = False
condition_expression = self.body.get('ConditionExpression')
expression_attribute_names = self.body.get('ExpressionAttributeNames', {})
expression_attribute_values = self.body.get('ExpressionAttributeValues', {})
if condition_expression:
overwrite = False
try:
result = self.dynamodb_backend.put_item(name, item, expected, overwrite)
result = self.dynamodb_backend.put_item(
name, item, expected, condition_expression,
expression_attribute_names, expression_attribute_values,
overwrite)
except ValueError:
er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException'
return self.error(er, 'A condition specified in the operation could not be evaluated.')
@ -379,6 +318,9 @@ class DynamoHandler(BaseResponse):
for table_name, table_request in table_batches.items():
keys = table_request['Keys']
if self._contains_duplicates(keys):
er = 'com.amazon.coral.validate#ValidationException'
return self.error(er, 'Provided list of item keys contains duplicates')
attributes_to_get = table_request.get('AttributesToGet')
results["Responses"][table_name] = []
for key in keys:
@ -394,6 +336,15 @@ class DynamoHandler(BaseResponse):
})
return dynamo_json_dump(results)
def _contains_duplicates(self, keys):
unique_keys = []
for k in keys:
if k in unique_keys:
return True
else:
unique_keys.append(k)
return False
def query(self):
name = self.body['TableName']
# {u'KeyConditionExpression': u'#n0 = :v0', u'ExpressionAttributeValues': {u':v0': {u'S': u'johndoe'}}, u'ExpressionAttributeNames': {u'#n0': u'username'}}
@ -558,7 +509,7 @@ class DynamoHandler(BaseResponse):
filter_expression = self.body.get('FilterExpression')
expression_attribute_values = self.body.get('ExpressionAttributeValues', {})
expression_attribute_names = self.body.get('ExpressionAttributeNames', {})
projection_expression = self.body.get('ProjectionExpression', '')
exclusive_start_key = self.body.get('ExclusiveStartKey')
limit = self.body.get("Limit")
index_name = self.body.get('IndexName')
@ -570,7 +521,8 @@ class DynamoHandler(BaseResponse):
filter_expression,
expression_attribute_names,
expression_attribute_values,
index_name)
index_name,
projection_expression)
except InvalidIndexNameError as err:
er = 'com.amazonaws.dynamodb.v20111205#ValidationException'
return self.error(er, str(err))
@ -625,7 +577,7 @@ class DynamoHandler(BaseResponse):
name = self.body['TableName']
key = self.body['Key']
return_values = self.body.get('ReturnValues', 'NONE')
update_expression = self.body.get('UpdateExpression')
update_expression = self.body.get('UpdateExpression', '').strip()
attribute_updates = self.body.get('AttributeUpdates')
expression_attribute_names = self.body.get(
'ExpressionAttributeNames', {})
@ -652,24 +604,20 @@ class DynamoHandler(BaseResponse):
# Attempt to parse simple ConditionExpressions into an Expected
# expression
if not expected:
condition_expression = self.body.get('ConditionExpression')
expression_attribute_names = self.body.get('ExpressionAttributeNames', {})
expression_attribute_values = self.body.get('ExpressionAttributeValues', {})
expected = condition_expression_to_expected(condition_expression,
expression_attribute_names,
expression_attribute_values)
condition_expression = self.body.get('ConditionExpression')
expression_attribute_names = self.body.get('ExpressionAttributeNames', {})
expression_attribute_values = self.body.get('ExpressionAttributeValues', {})
# Support spaces between operators in an update expression
# E.g. `a = b + c` -> `a=b+c`
if update_expression:
update_expression = re.sub(
'\s*([=\+-])\s*', '\\1', update_expression)
r'\s*([=\+-])\s*', '\\1', update_expression)
try:
item = self.dynamodb_backend.update_item(
name, key, update_expression, attribute_updates, expression_attribute_names,
expression_attribute_values, expected
expression_attribute_values, expected, condition_expression
)
except ValueError:
er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException'

View File

@ -332,6 +332,15 @@ class InvalidParameterValueErrorTagNull(EC2ClientError):
"Tag value cannot be null. Use empty string instead.")
class InvalidParameterValueErrorUnknownAttribute(EC2ClientError):
def __init__(self, parameter_value):
super(InvalidParameterValueErrorUnknownAttribute, self).__init__(
"InvalidParameterValue",
"Value ({0}) for parameter attribute is invalid. Unknown attribute."
.format(parameter_value))
class InvalidInternetGatewayIdError(EC2ClientError):
def __init__(self, internet_gateway_id):
@ -430,6 +439,16 @@ class OperationNotPermitted(EC2ClientError):
)
class InvalidAvailabilityZoneError(EC2ClientError):
def __init__(self, availability_zone_value, valid_availability_zones):
super(InvalidAvailabilityZoneError, self).__init__(
"InvalidParameterValue",
"Value ({0}) for parameter availabilityZone is invalid. "
"Subnets can currently only be created in the following availability zones: {1}.".format(availability_zone_value, valid_availability_zones)
)
class NetworkAclEntryAlreadyExistsError(EC2ClientError):
def __init__(self, rule_number):
@ -504,3 +523,11 @@ class OperationNotPermitted3(EC2ClientError):
pcx_id,
acceptor_region)
)
class InvalidLaunchTemplateNameError(EC2ClientError):
def __init__(self):
super(InvalidLaunchTemplateNameError, self).__init__(
"InvalidLaunchTemplateName.AlreadyExistsException",
"Launch template name already in use."
)

View File

@ -20,7 +20,6 @@ from boto.ec2.blockdevicemapping import BlockDeviceMapping, BlockDeviceType
from boto.ec2.spotinstancerequest import SpotInstanceRequest as BotoSpotRequest
from boto.ec2.launchspecification import LaunchSpecification
from moto.compat import OrderedDict
from moto.core import BaseBackend
from moto.core.models import Model, BaseModel
@ -36,6 +35,7 @@ from .exceptions import (
InvalidAMIIdError,
InvalidAMIAttributeItemValueError,
InvalidAssociationIdError,
InvalidAvailabilityZoneError,
InvalidCIDRBlockParameterError,
InvalidCIDRSubnetError,
InvalidCustomerGatewayIdError,
@ -48,11 +48,13 @@ from .exceptions import (
InvalidKeyPairDuplicateError,
InvalidKeyPairFormatError,
InvalidKeyPairNameError,
InvalidLaunchTemplateNameError,
InvalidNetworkAclIdError,
InvalidNetworkAttachmentIdError,
InvalidNetworkInterfaceIdError,
InvalidParameterValueError,
InvalidParameterValueErrorTagNull,
InvalidParameterValueErrorUnknownAttribute,
InvalidPermissionNotFoundError,
InvalidPermissionDuplicateError,
InvalidRouteTableIdError,
@ -96,6 +98,7 @@ from .utils import (
random_internet_gateway_id,
random_ip,
random_ipv6_cidr,
random_launch_template_id,
random_nat_gateway_id,
random_key_pair,
random_private_ip,
@ -140,6 +143,8 @@ AMIS = json.load(
__name__, 'resources/amis.json'), 'r')
)
OWNER_ID = "111122223333"
def utc_date_and_time():
return datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%S.000Z')
@ -199,7 +204,7 @@ class TaggedEC2Resource(BaseModel):
class NetworkInterface(TaggedEC2Resource):
def __init__(self, ec2_backend, subnet, private_ip_address, device_index=0,
public_ip_auto_assign=True, group_ids=None):
public_ip_auto_assign=True, group_ids=None, description=None):
self.ec2_backend = ec2_backend
self.id = random_eni_id()
self.device_index = device_index
@ -207,6 +212,7 @@ class NetworkInterface(TaggedEC2Resource):
self.subnet = subnet
self.instance = None
self.attachment_id = None
self.description = description
self.public_ip = None
self.public_ip_auto_assign = public_ip_auto_assign
@ -244,11 +250,13 @@ class NetworkInterface(TaggedEC2Resource):
subnet = None
private_ip_address = properties.get('PrivateIpAddress', None)
description = properties.get('Description', None)
network_interface = ec2_backend.create_network_interface(
subnet,
private_ip_address,
group_ids=security_group_ids
group_ids=security_group_ids,
description=description
)
return network_interface
@ -296,6 +304,8 @@ class NetworkInterface(TaggedEC2Resource):
return [group.id for group in self._group_set]
elif filter_name == 'availability-zone':
return self.subnet.availability_zone
elif filter_name == 'description':
return self.description
else:
return super(NetworkInterface, self).get_filter_value(
filter_name, 'DescribeNetworkInterfaces')
@ -306,9 +316,9 @@ class NetworkInterfaceBackend(object):
self.enis = {}
super(NetworkInterfaceBackend, self).__init__()
def create_network_interface(self, subnet, private_ip_address, group_ids=None, **kwargs):
def create_network_interface(self, subnet, private_ip_address, group_ids=None, description=None, **kwargs):
eni = NetworkInterface(
self, subnet, private_ip_address, group_ids=group_ids, **kwargs)
self, subnet, private_ip_address, group_ids=group_ids, description=description, **kwargs)
self.enis[eni.id] = eni
return eni
@ -341,6 +351,12 @@ class NetworkInterfaceBackend(object):
if group.id in _filter_value:
enis.append(eni)
break
elif _filter == 'private-ip-address:':
enis = [eni for eni in enis if eni.private_ip_address in _filter_value]
elif _filter == 'subnet-id':
enis = [eni for eni in enis if eni.subnet.id in _filter_value]
elif _filter == 'description':
enis = [eni for eni in enis if eni.description in _filter_value]
else:
self.raise_not_implemented_error(
"The filter '{0}' for DescribeNetworkInterfaces".format(_filter))
@ -382,6 +398,10 @@ class NetworkInterfaceBackend(object):
class Instance(TaggedEC2Resource, BotoInstance):
VALID_ATTRIBUTES = {'instanceType', 'kernel', 'ramdisk', 'userData', 'disableApiTermination',
'instanceInitiatedShutdownBehavior', 'rootDeviceName', 'blockDeviceMapping',
'productCodes', 'sourceDestCheck', 'groupSet', 'ebsOptimized', 'sriovNetSupport'}
def __init__(self, ec2_backend, image_id, user_data, security_groups, **kwargs):
super(Instance, self).__init__()
self.ec2_backend = ec2_backend
@ -404,11 +424,13 @@ class Instance(TaggedEC2Resource, BotoInstance):
self.launch_time = utc_date_and_time()
self.ami_launch_index = kwargs.get("ami_launch_index", 0)
self.disable_api_termination = kwargs.get("disable_api_termination", False)
self.instance_initiated_shutdown_behavior = kwargs.get("instance_initiated_shutdown_behavior", "stop")
self.sriov_net_support = "simple"
self._spot_fleet_id = kwargs.get("spot_fleet_id", None)
associate_public_ip = kwargs.get("associate_public_ip", False)
self.associate_public_ip = kwargs.get("associate_public_ip", False)
if in_ec2_classic:
# If we are in EC2-Classic, autoassign a public IP
associate_public_ip = True
self.associate_public_ip = True
amis = self.ec2_backend.describe_images(filters={'image-id': image_id})
ami = amis[0] if amis else None
@ -439,9 +461,9 @@ class Instance(TaggedEC2Resource, BotoInstance):
self.vpc_id = subnet.vpc_id
self._placement.zone = subnet.availability_zone
if associate_public_ip is None:
if self.associate_public_ip is None:
# Mapping public ip hasnt been explicitly enabled or disabled
associate_public_ip = subnet.map_public_ip_on_launch == 'true'
self.associate_public_ip = subnet.map_public_ip_on_launch == 'true'
elif placement:
self._placement.zone = placement
else:
@ -453,7 +475,7 @@ class Instance(TaggedEC2Resource, BotoInstance):
self.prep_nics(
kwargs.get("nics", {}),
private_ip=kwargs.get("private_ip"),
associate_public_ip=associate_public_ip
associate_public_ip=self.associate_public_ip
)
def __del__(self):
@ -787,14 +809,22 @@ class InstanceBackend(object):
setattr(instance, key, value)
return instance
def modify_instance_security_groups(self, instance_id, new_group_list):
def modify_instance_security_groups(self, instance_id, new_group_id_list):
instance = self.get_instance(instance_id)
new_group_list = []
for new_group_id in new_group_id_list:
new_group_list.append(self.get_security_group_from_id(new_group_id))
setattr(instance, 'security_groups', new_group_list)
return instance
def describe_instance_attribute(self, instance_id, key):
if key == 'group_set':
def describe_instance_attribute(self, instance_id, attribute):
if attribute not in Instance.VALID_ATTRIBUTES:
raise InvalidParameterValueErrorUnknownAttribute(attribute)
if attribute == 'groupSet':
key = 'security_groups'
else:
key = camelcase_to_underscores(attribute)
instance = self.get_instance(instance_id)
value = getattr(instance, key)
return instance, value
@ -1060,7 +1090,7 @@ class TagBackend(object):
class Ami(TaggedEC2Resource):
def __init__(self, ec2_backend, ami_id, instance=None, source_ami=None,
name=None, description=None, owner_id=111122223333,
name=None, description=None, owner_id=OWNER_ID,
public=False, virtualization_type=None, architecture=None,
state='available', creation_date=None, platform=None,
image_type='machine', image_location=None, hypervisor=None,
@ -1173,7 +1203,7 @@ class AmiBackend(object):
ami = Ami(self, ami_id, instance=instance, source_ami=None,
name=name, description=description,
owner_id=context.get_current_user() if context else '111122223333')
owner_id=context.get_current_user() if context else OWNER_ID)
self.amis[ami_id] = ami
return ami
@ -1288,17 +1318,107 @@ class Region(object):
class Zone(object):
def __init__(self, name, region_name):
def __init__(self, name, region_name, zone_id):
self.name = name
self.region_name = region_name
self.zone_id = zone_id
class RegionsAndZonesBackend(object):
regions = [Region(ri.name, ri.endpoint) for ri in boto.ec2.regions()]
zones = dict(
(region, [Zone(region + c, region) for c in 'abc'])
for region in [r.name for r in regions])
zones = {
'ap-south-1': [
Zone(region_name="ap-south-1", name="ap-south-1a", zone_id="aps1-az1"),
Zone(region_name="ap-south-1", name="ap-south-1b", zone_id="aps1-az3")
],
'eu-west-3': [
Zone(region_name="eu-west-3", name="eu-west-3a", zone_id="euw3-az1"),
Zone(region_name="eu-west-3", name="eu-west-3b", zone_id="euw3-az2"),
Zone(region_name="eu-west-3", name="eu-west-3c", zone_id="euw3-az3")
],
'eu-north-1': [
Zone(region_name="eu-north-1", name="eu-north-1a", zone_id="eun1-az1"),
Zone(region_name="eu-north-1", name="eu-north-1b", zone_id="eun1-az2"),
Zone(region_name="eu-north-1", name="eu-north-1c", zone_id="eun1-az3")
],
'eu-west-2': [
Zone(region_name="eu-west-2", name="eu-west-2a", zone_id="euw2-az2"),
Zone(region_name="eu-west-2", name="eu-west-2b", zone_id="euw2-az3"),
Zone(region_name="eu-west-2", name="eu-west-2c", zone_id="euw2-az1")
],
'eu-west-1': [
Zone(region_name="eu-west-1", name="eu-west-1a", zone_id="euw1-az3"),
Zone(region_name="eu-west-1", name="eu-west-1b", zone_id="euw1-az1"),
Zone(region_name="eu-west-1", name="eu-west-1c", zone_id="euw1-az2")
],
'ap-northeast-3': [
Zone(region_name="ap-northeast-3", name="ap-northeast-2a", zone_id="apne3-az1")
],
'ap-northeast-2': [
Zone(region_name="ap-northeast-2", name="ap-northeast-2a", zone_id="apne2-az1"),
Zone(region_name="ap-northeast-2", name="ap-northeast-2c", zone_id="apne2-az3")
],
'ap-northeast-1': [
Zone(region_name="ap-northeast-1", name="ap-northeast-1a", zone_id="apne1-az4"),
Zone(region_name="ap-northeast-1", name="ap-northeast-1c", zone_id="apne1-az1"),
Zone(region_name="ap-northeast-1", name="ap-northeast-1d", zone_id="apne1-az2")
],
'sa-east-1': [
Zone(region_name="sa-east-1", name="sa-east-1a", zone_id="sae1-az1"),
Zone(region_name="sa-east-1", name="sa-east-1c", zone_id="sae1-az3")
],
'ca-central-1': [
Zone(region_name="ca-central-1", name="ca-central-1a", zone_id="cac1-az1"),
Zone(region_name="ca-central-1", name="ca-central-1b", zone_id="cac1-az2")
],
'ap-southeast-1': [
Zone(region_name="ap-southeast-1", name="ap-southeast-1a", zone_id="apse1-az1"),
Zone(region_name="ap-southeast-1", name="ap-southeast-1b", zone_id="apse1-az2"),
Zone(region_name="ap-southeast-1", name="ap-southeast-1c", zone_id="apse1-az3")
],
'ap-southeast-2': [
Zone(region_name="ap-southeast-2", name="ap-southeast-2a", zone_id="apse2-az1"),
Zone(region_name="ap-southeast-2", name="ap-southeast-2b", zone_id="apse2-az3"),
Zone(region_name="ap-southeast-2", name="ap-southeast-2c", zone_id="apse2-az2")
],
'eu-central-1': [
Zone(region_name="eu-central-1", name="eu-central-1a", zone_id="euc1-az2"),
Zone(region_name="eu-central-1", name="eu-central-1b", zone_id="euc1-az3"),
Zone(region_name="eu-central-1", name="eu-central-1c", zone_id="euc1-az1")
],
'us-east-1': [
Zone(region_name="us-east-1", name="us-east-1a", zone_id="use1-az6"),
Zone(region_name="us-east-1", name="us-east-1b", zone_id="use1-az1"),
Zone(region_name="us-east-1", name="us-east-1c", zone_id="use1-az2"),
Zone(region_name="us-east-1", name="us-east-1d", zone_id="use1-az4"),
Zone(region_name="us-east-1", name="us-east-1e", zone_id="use1-az3"),
Zone(region_name="us-east-1", name="us-east-1f", zone_id="use1-az5")
],
'us-east-2': [
Zone(region_name="us-east-2", name="us-east-2a", zone_id="use2-az1"),
Zone(region_name="us-east-2", name="us-east-2b", zone_id="use2-az2"),
Zone(region_name="us-east-2", name="us-east-2c", zone_id="use2-az3")
],
'us-west-1': [
Zone(region_name="us-west-1", name="us-west-1a", zone_id="usw1-az3"),
Zone(region_name="us-west-1", name="us-west-1b", zone_id="usw1-az1")
],
'us-west-2': [
Zone(region_name="us-west-2", name="us-west-2a", zone_id="usw2-az2"),
Zone(region_name="us-west-2", name="us-west-2b", zone_id="usw2-az1"),
Zone(region_name="us-west-2", name="us-west-2c", zone_id="usw2-az3")
],
'cn-north-1': [
Zone(region_name="cn-north-1", name="cn-north-1a", zone_id="cnn1-az1"),
Zone(region_name="cn-north-1", name="cn-north-1b", zone_id="cnn1-az2")
],
'us-gov-west-1': [
Zone(region_name="us-gov-west-1", name="us-gov-west-1a", zone_id="usgw1-az1"),
Zone(region_name="us-gov-west-1", name="us-gov-west-1b", zone_id="usgw1-az2"),
Zone(region_name="us-gov-west-1", name="us-gov-west-1c", zone_id="usgw1-az3")
]
}
def describe_regions(self, region_names=[]):
if len(region_names) == 0:
@ -1351,7 +1471,7 @@ class SecurityGroup(TaggedEC2Resource):
self.egress_rules = [SecurityRule(-1, None, None, ['0.0.0.0/0'], [])]
self.enis = {}
self.vpc_id = vpc_id
self.owner_id = "123456789012"
self.owner_id = OWNER_ID
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
@ -1872,7 +1992,7 @@ class Volume(TaggedEC2Resource):
class Snapshot(TaggedEC2Resource):
def __init__(self, ec2_backend, snapshot_id, volume, description, encrypted=False, owner_id='123456789012'):
def __init__(self, ec2_backend, snapshot_id, volume, description, encrypted=False, owner_id=OWNER_ID):
self.id = snapshot_id
self.volume = volume
self.description = description
@ -2374,7 +2494,7 @@ class VPCPeeringConnectionBackend(object):
class Subnet(TaggedEC2Resource):
def __init__(self, ec2_backend, subnet_id, vpc_id, cidr_block, availability_zone, default_for_az,
map_public_ip_on_launch):
map_public_ip_on_launch, owner_id=OWNER_ID, assign_ipv6_address_on_creation=False):
self.ec2_backend = ec2_backend
self.id = subnet_id
self.vpc_id = vpc_id
@ -2383,6 +2503,9 @@ class Subnet(TaggedEC2Resource):
self._availability_zone = availability_zone
self.default_for_az = default_for_az
self.map_public_ip_on_launch = map_public_ip_on_launch
self.owner_id = owner_id
self.assign_ipv6_address_on_creation = assign_ipv6_address_on_creation
self.ipv6_cidr_block_associations = []
# Theory is we assign ip's as we go (as 16,777,214 usable IPs in a /8)
self._subnet_ip_generator = self.cidr.hosts()
@ -2412,7 +2535,7 @@ class Subnet(TaggedEC2Resource):
@property
def availability_zone(self):
return self._availability_zone
return self._availability_zone.name
@property
def physical_resource_id(self):
@ -2509,7 +2632,7 @@ class SubnetBackend(object):
return subnets[subnet_id]
raise InvalidSubnetIdError(subnet_id)
def create_subnet(self, vpc_id, cidr_block, availability_zone):
def create_subnet(self, vpc_id, cidr_block, availability_zone, context=None):
subnet_id = random_subnet_id()
vpc = self.get_vpc(vpc_id) # Validate VPC exists and the supplied CIDR block is a subnet of the VPC's
vpc_cidr_block = ipaddress.IPv4Network(six.text_type(vpc.cidr_block), strict=False)
@ -2529,8 +2652,15 @@ class SubnetBackend(object):
# consider it the default
default_for_az = str(availability_zone not in self.subnets).lower()
map_public_ip_on_launch = default_for_az
subnet = Subnet(self, subnet_id, vpc_id, cidr_block, availability_zone,
default_for_az, map_public_ip_on_launch)
if availability_zone is None:
availability_zone = 'us-east-1a'
try:
availability_zone_data = next(zone for zones in RegionsAndZonesBackend.zones.values() for zone in zones if zone.name == availability_zone)
except StopIteration:
raise InvalidAvailabilityZoneError(availability_zone, ", ".join([zone.name for zones in RegionsAndZonesBackend.zones.values() for zone in zones]))
subnet = Subnet(self, subnet_id, vpc_id, cidr_block, availability_zone_data,
default_for_az, map_public_ip_on_launch,
owner_id=context.get_current_user() if context else OWNER_ID, assign_ipv6_address_on_creation=False)
# AWS associates a new subnet with the default Network ACL
self.associate_default_network_acl_with_subnet(subnet_id, vpc_id)
@ -2558,11 +2688,12 @@ class SubnetBackend(object):
return subnets.pop(subnet_id, None)
raise InvalidSubnetIdError(subnet_id)
def modify_subnet_attribute(self, subnet_id, map_public_ip):
def modify_subnet_attribute(self, subnet_id, attr_name, attr_value):
subnet = self.get_subnet(subnet_id)
if map_public_ip not in ('true', 'false'):
raise InvalidParameterValueError(map_public_ip)
subnet.map_public_ip_on_launch = map_public_ip
if attr_name in ('map_public_ip_on_launch', 'assign_ipv6_address_on_creation'):
setattr(subnet, attr_name, attr_value)
else:
raise InvalidParameterValueError(attr_name)
class SubnetRouteTableAssociation(object):
@ -3983,6 +4114,92 @@ class NatGatewayBackend(object):
return self.nat_gateways.pop(nat_gateway_id)
class LaunchTemplateVersion(object):
def __init__(self, template, number, data, description):
self.template = template
self.number = number
self.data = data
self.description = description
self.create_time = utc_date_and_time()
class LaunchTemplate(TaggedEC2Resource):
def __init__(self, backend, name, template_data, version_description):
self.ec2_backend = backend
self.name = name
self.id = random_launch_template_id()
self.create_time = utc_date_and_time()
self.versions = []
self.create_version(template_data, version_description)
self.default_version_number = 1
def create_version(self, data, description):
num = len(self.versions) + 1
version = LaunchTemplateVersion(self, num, data, description)
self.versions.append(version)
return version
def is_default(self, version):
return self.default_version == version.number
def get_version(self, num):
return self.versions[num - 1]
def default_version(self):
return self.versions[self.default_version_number - 1]
def latest_version(self):
return self.versions[-1]
@property
def latest_version_number(self):
return self.latest_version().number
def get_filter_value(self, filter_name):
if filter_name == 'launch-template-name':
return self.name
else:
return super(LaunchTemplate, self).get_filter_value(
filter_name, "DescribeLaunchTemplates")
class LaunchTemplateBackend(object):
def __init__(self):
self.launch_template_name_to_ids = {}
self.launch_templates = OrderedDict()
self.launch_template_insert_order = []
super(LaunchTemplateBackend, self).__init__()
def create_launch_template(self, name, description, template_data):
if name in self.launch_template_name_to_ids:
raise InvalidLaunchTemplateNameError()
template = LaunchTemplate(self, name, template_data, description)
self.launch_templates[template.id] = template
self.launch_template_name_to_ids[template.name] = template.id
self.launch_template_insert_order.append(template.id)
return template
def get_launch_template(self, template_id):
return self.launch_templates[template_id]
def get_launch_template_by_name(self, name):
return self.get_launch_template(self.launch_template_name_to_ids[name])
def get_launch_templates(self, template_names=None, template_ids=None, filters=None):
if template_names and not template_ids:
template_ids = []
for name in template_names:
template_ids.append(self.launch_template_name_to_ids[name])
if template_ids:
templates = [self.launch_templates[tid] for tid in template_ids]
else:
templates = list(self.launch_templates.values())
return generic_filter(filters, templates)
class EC2Backend(BaseBackend, InstanceBackend, TagBackend, EBSBackend,
RegionsAndZonesBackend, SecurityGroupBackend, AmiBackend,
VPCBackend, SubnetBackend, SubnetRouteTableAssociationBackend,
@ -3992,7 +4209,7 @@ class EC2Backend(BaseBackend, InstanceBackend, TagBackend, EBSBackend,
VPCGatewayAttachmentBackend, SpotFleetBackend,
SpotRequestBackend, ElasticAddressBackend, KeyPairBackend,
DHCPOptionsSetBackend, NetworkAclBackend, VpnGatewayBackend,
CustomerGatewayBackend, NatGatewayBackend):
CustomerGatewayBackend, NatGatewayBackend, LaunchTemplateBackend):
def __init__(self, region_name):
self.region_name = region_name
super(EC2Backend, self).__init__()
@ -4047,6 +4264,8 @@ class EC2Backend(BaseBackend, InstanceBackend, TagBackend, EBSBackend,
elif resource_prefix == EC2_RESOURCE_TO_PREFIX['internet-gateway']:
self.describe_internet_gateways(
internet_gateway_ids=[resource_id])
elif resource_prefix == EC2_RESOURCE_TO_PREFIX['launch-template']:
self.get_launch_template(resource_id)
elif resource_prefix == EC2_RESOURCE_TO_PREFIX['network-acl']:
self.get_all_network_acls()
elif resource_prefix == EC2_RESOURCE_TO_PREFIX['network-interface']:

View File

@ -14,6 +14,7 @@ from .instances import InstanceResponse
from .internet_gateways import InternetGateways
from .ip_addresses import IPAddresses
from .key_pairs import KeyPairs
from .launch_templates import LaunchTemplates
from .monitoring import Monitoring
from .network_acls import NetworkACLs
from .placement_groups import PlacementGroups
@ -49,6 +50,7 @@ class EC2Response(
InternetGateways,
IPAddresses,
KeyPairs,
LaunchTemplates,
Monitoring,
NetworkACLs,
PlacementGroups,

View File

@ -10,9 +10,10 @@ class ElasticNetworkInterfaces(BaseResponse):
private_ip_address = self._get_param('PrivateIpAddress')
groups = self._get_multi_param('SecurityGroupId')
subnet = self.ec2_backend.get_subnet(subnet_id)
description = self._get_param('Description')
if self.is_not_dryrun('CreateNetworkInterface'):
eni = self.ec2_backend.create_network_interface(
subnet, private_ip_address, groups)
subnet, private_ip_address, groups, description)
template = self.response_template(
CREATE_NETWORK_INTERFACE_RESPONSE)
return template.render(eni=eni)
@ -78,7 +79,11 @@ CREATE_NETWORK_INTERFACE_RESPONSE = """
<subnetId>{{ eni.subnet.id }}</subnetId>
<vpcId>{{ eni.subnet.vpc_id }}</vpcId>
<availabilityZone>us-west-2a</availabilityZone>
{% if eni.description %}
<description>{{ eni.description }}</description>
{% else %}
<description/>
{% endif %}
<ownerId>498654062920</ownerId>
<requesterManaged>false</requesterManaged>
<status>pending</status>
@ -121,7 +126,7 @@ DESCRIBE_NETWORK_INTERFACES_RESPONSE = """<DescribeNetworkInterfacesResponse xml
<subnetId>{{ eni.subnet.id }}</subnetId>
<vpcId>{{ eni.subnet.vpc_id }}</vpcId>
<availabilityZone>us-west-2a</availabilityZone>
<description>Primary network interface</description>
<description>{{ eni.description }}</description>
<ownerId>190610284047</ownerId>
<requesterManaged>false</requesterManaged>
{% if eni.attachment_id %}

View File

@ -1,5 +1,7 @@
from __future__ import unicode_literals
from boto.ec2.instancetype import InstanceType
from moto.autoscaling import autoscaling_backends
from moto.core.responses import BaseResponse
from moto.core.utils import camelcase_to_underscores
from moto.ec2.utils import filters_from_querystring, \
@ -46,6 +48,7 @@ class InstanceResponse(BaseResponse):
associate_public_ip = self._get_param('AssociatePublicIpAddress')
key_name = self._get_param('KeyName')
ebs_optimized = self._get_param('EbsOptimized')
instance_initiated_shutdown_behavior = self._get_param("InstanceInitiatedShutdownBehavior")
tags = self._parse_tag_specification("TagSpecification")
region_name = self.region
@ -55,7 +58,7 @@ class InstanceResponse(BaseResponse):
instance_type=instance_type, placement=placement, region_name=region_name, subnet_id=subnet_id,
owner_id=owner_id, key_name=key_name, security_group_ids=security_group_ids,
nics=nics, private_ip=private_ip, associate_public_ip=associate_public_ip,
tags=tags, ebs_optimized=ebs_optimized)
tags=tags, ebs_optimized=ebs_optimized, instance_initiated_shutdown_behavior=instance_initiated_shutdown_behavior)
template = self.response_template(EC2_RUN_INSTANCES)
return template.render(reservation=new_reservation)
@ -64,6 +67,7 @@ class InstanceResponse(BaseResponse):
instance_ids = self._get_multi_param('InstanceId')
if self.is_not_dryrun('TerminateInstance'):
instances = self.ec2_backend.terminate_instances(instance_ids)
autoscaling_backends[self.region].notify_terminate_instances(instance_ids)
template = self.response_template(EC2_TERMINATE_INSTANCES)
return template.render(instances=instances)
@ -113,12 +117,11 @@ class InstanceResponse(BaseResponse):
# TODO this and modify below should raise IncorrectInstanceState if
# instance not in stopped state
attribute = self._get_param('Attribute')
key = camelcase_to_underscores(attribute)
instance_id = self._get_param('InstanceId')
instance, value = self.ec2_backend.describe_instance_attribute(
instance_id, key)
instance_id, attribute)
if key == "group_set":
if attribute == "groupSet":
template = self.response_template(
EC2_DESCRIBE_INSTANCE_GROUPSET_ATTRIBUTE)
else:
@ -597,7 +600,9 @@ EC2_DESCRIBE_INSTANCE_ATTRIBUTE = """<DescribeInstanceAttributeResponse xmlns="h
<requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId>
<instanceId>{{ instance.id }}</instanceId>
<{{ attribute }}>
{% if value is not none %}
<value>{{ value }}</value>
{% endif %}
</{{ attribute }}>
</DescribeInstanceAttributeResponse>"""
@ -605,9 +610,9 @@ EC2_DESCRIBE_INSTANCE_GROUPSET_ATTRIBUTE = """<DescribeInstanceAttributeResponse
<requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId>
<instanceId>{{ instance.id }}</instanceId>
<{{ attribute }}>
{% for sg_id in value %}
{% for sg in value %}
<item>
<groupId>{{ sg_id }}</groupId>
<groupId>{{ sg.id }}</groupId>
</item>
{% endfor %}
</{{ attribute }}>

View File

@ -1,6 +1,7 @@
from __future__ import unicode_literals
import random
from moto.core.responses import BaseResponse
from moto.core.utils import camelcase_to_underscores
from moto.ec2.utils import filters_from_querystring
@ -16,6 +17,7 @@ class Subnets(BaseResponse):
vpc_id,
cidr_block,
availability_zone,
context=self,
)
template = self.response_template(CREATE_SUBNET_RESPONSE)
return template.render(subnet=subnet)
@ -35,9 +37,14 @@ class Subnets(BaseResponse):
def modify_subnet_attribute(self):
subnet_id = self._get_param('SubnetId')
map_public_ip = self._get_param('MapPublicIpOnLaunch.Value')
self.ec2_backend.modify_subnet_attribute(subnet_id, map_public_ip)
return MODIFY_SUBNET_ATTRIBUTE_RESPONSE
for attribute in ('MapPublicIpOnLaunch', 'AssignIpv6AddressOnCreation'):
if self.querystring.get('%s.Value' % attribute):
attr_name = camelcase_to_underscores(attribute)
attr_value = self.querystring.get('%s.Value' % attribute)[0]
self.ec2_backend.modify_subnet_attribute(
subnet_id, attr_name, attr_value)
return MODIFY_SUBNET_ATTRIBUTE_RESPONSE
CREATE_SUBNET_RESPONSE = """
@ -49,17 +56,14 @@ CREATE_SUBNET_RESPONSE = """
<vpcId>{{ subnet.vpc_id }}</vpcId>
<cidrBlock>{{ subnet.cidr_block }}</cidrBlock>
<availableIpAddressCount>251</availableIpAddressCount>
<availabilityZone>{{ subnet.availability_zone }}</availabilityZone>
<tagSet>
{% for tag in subnet.get_tags() %}
<item>
<resourceId>{{ tag.resource_id }}</resourceId>
<resourceType>{{ tag.resource_type }}</resourceType>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
<availabilityZone>{{ subnet._availability_zone.name }}</availabilityZone>
<availabilityZoneId>{{ subnet._availability_zone.zone_id }}</availabilityZoneId>
<defaultForAz>{{ subnet.default_for_az }}</defaultForAz>
<mapPublicIpOnLaunch>{{ subnet.map_public_ip_on_launch }}</mapPublicIpOnLaunch>
<ownerId>{{ subnet.owner_id }}</ownerId>
<assignIpv6AddressOnCreation>{{ subnet.assign_ipv6_address_on_creation }}</assignIpv6AddressOnCreation>
<ipv6CidrBlockAssociationSet>{{ subnet.ipv6_cidr_block_associations }}</ipv6CidrBlockAssociationSet>
<subnetArn>arn:aws:ec2:{{ subnet._availability_zone.name[0:-1] }}:{{ subnet.owner_id }}:subnet/{{ subnet.id }}</subnetArn>
</subnet>
</CreateSubnetResponse>"""
@ -80,19 +84,26 @@ DESCRIBE_SUBNETS_RESPONSE = """
<vpcId>{{ subnet.vpc_id }}</vpcId>
<cidrBlock>{{ subnet.cidr_block }}</cidrBlock>
<availableIpAddressCount>251</availableIpAddressCount>
<availabilityZone>{{ subnet.availability_zone }}</availabilityZone>
<availabilityZone>{{ subnet._availability_zone.name }}</availabilityZone>
<availabilityZoneId>{{ subnet._availability_zone.zone_id }}</availabilityZoneId>
<defaultForAz>{{ subnet.default_for_az }}</defaultForAz>
<mapPublicIpOnLaunch>{{ subnet.map_public_ip_on_launch }}</mapPublicIpOnLaunch>
<tagSet>
{% for tag in subnet.get_tags() %}
<item>
<resourceId>{{ tag.resource_id }}</resourceId>
<resourceType>{{ tag.resource_type }}</resourceType>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
<ownerId>{{ subnet.owner_id }}</ownerId>
<assignIpv6AddressOnCreation>{{ subnet.assign_ipv6_address_on_creation }}</assignIpv6AddressOnCreation>
<ipv6CidrBlockAssociationSet>{{ subnet.ipv6_cidr_block_associations }}</ipv6CidrBlockAssociationSet>
<subnetArn>arn:aws:ec2:{{ subnet._availability_zone.name[0:-1] }}:{{ subnet.owner_id }}:subnet/{{ subnet.id }}</subnetArn>
{% if subnet.get_tags() %}
<tagSet>
{% for tag in subnet.get_tags() %}
<item>
<resourceId>{{ tag.resource_id }}</resourceId>
<resourceType>{{ tag.resource_type }}</resourceType>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
{% endif %}
</item>
{% endfor %}
</subnetSet>

View File

@ -20,6 +20,7 @@ EC2_RESOURCE_TO_PREFIX = {
'image': 'ami',
'instance': 'i',
'internet-gateway': 'igw',
'launch-template': 'lt',
'nat-gateway': 'nat',
'network-acl': 'acl',
'network-acl-subnet-assoc': 'aclassoc',
@ -161,6 +162,10 @@ def random_nat_gateway_id():
return random_id(prefix=EC2_RESOURCE_TO_PREFIX['nat-gateway'], size=17)
def random_launch_template_id():
return random_id(prefix=EC2_RESOURCE_TO_PREFIX['launch-template'], size=17)
def random_public_ip():
return '54.214.{0}.{1}'.format(random.choice(range(255)),
random.choice(range(255)))

View File

@ -1,7 +1,9 @@
from __future__ import unicode_literals
import hashlib
import re
from copy import copy
from datetime import datetime
from random import random
from botocore.exceptions import ParamValidationError
@ -105,7 +107,7 @@ class Image(BaseObject):
self.repository = repository
self.registry_id = registry_id
self.image_digest = digest
self.image_pushed_at = None
self.image_pushed_at = str(datetime.utcnow().isoformat())
def _create_digest(self):
image_contents = 'docker_image{0}'.format(int(random() * 10 ** 6))
@ -119,6 +121,12 @@ class Image(BaseObject):
def get_image_manifest(self):
return self.image_manifest
def remove_tag(self, tag):
if tag is not None and tag in self.image_tags:
self.image_tags.remove(tag)
if self.image_tags:
self.image_tag = self.image_tags[-1]
def update_tag(self, tag):
self.image_tag = tag
if tag not in self.image_tags and tag is not None:
@ -151,7 +159,7 @@ class Image(BaseObject):
response_object['repositoryName'] = self.repository
response_object['registryId'] = self.registry_id
response_object['imageSizeInBytes'] = self.image_size_in_bytes
response_object['imagePushedAt'] = '2017-05-09'
response_object['imagePushedAt'] = self.image_pushed_at
return {k: v for k, v in response_object.items() if v is not None and v != []}
@property
@ -165,6 +173,13 @@ class Image(BaseObject):
response_object['registryId'] = self.registry_id
return {k: v for k, v in response_object.items() if v is not None and v != [None]}
@property
def response_batch_delete_image(self):
response_object = {}
response_object['imageDigest'] = self.get_image_digest()
response_object['imageTag'] = self.image_tag
return {k: v for k, v in response_object.items() if v is not None and v != [None]}
class ECRBackend(BaseBackend):
@ -310,6 +325,106 @@ class ECRBackend(BaseBackend):
return response
def batch_delete_image(self, repository_name, registry_id=None, image_ids=None):
if repository_name in self.repositories:
repository = self.repositories[repository_name]
else:
raise RepositoryNotFoundException(
repository_name, registry_id or DEFAULT_REGISTRY_ID
)
if not image_ids:
raise ParamValidationError(
msg='Missing required parameter in input: "imageIds"'
)
response = {
"imageIds": [],
"failures": []
}
for image_id in image_ids:
image_found = False
# Is request missing both digest and tag?
if "imageDigest" not in image_id and "imageTag" not in image_id:
response["failures"].append(
{
"imageId": {},
"failureCode": "MissingDigestAndTag",
"failureReason": "Invalid request parameters: both tag and digest cannot be null",
}
)
continue
# If we have a digest, is it valid?
if "imageDigest" in image_id:
pattern = re.compile("^[0-9a-zA-Z_+\.-]+:[0-9a-fA-F]{64}")
if not pattern.match(image_id.get("imageDigest")):
response["failures"].append(
{
"imageId": {
"imageDigest": image_id.get("imageDigest", "null")
},
"failureCode": "InvalidImageDigest",
"failureReason": "Invalid request parameters: image digest should satisfy the regex '[a-zA-Z0-9-_+.]+:[a-fA-F0-9]+'",
}
)
continue
for num, image in enumerate(repository.images):
# Search by matching both digest and tag
if "imageDigest" in image_id and "imageTag" in image_id:
if (
image_id["imageDigest"] == image.get_image_digest() and
image_id["imageTag"] in image.image_tags
):
image_found = True
for image_tag in reversed(image.image_tags):
repository.images[num].image_tag = image_tag
response["imageIds"].append(
image.response_batch_delete_image
)
repository.images[num].remove_tag(image_tag)
del repository.images[num]
# Search by matching digest
elif "imageDigest" in image_id and image.get_image_digest() == image_id["imageDigest"]:
image_found = True
for image_tag in reversed(image.image_tags):
repository.images[num].image_tag = image_tag
response["imageIds"].append(image.response_batch_delete_image)
repository.images[num].remove_tag(image_tag)
del repository.images[num]
# Search by matching tag
elif "imageTag" in image_id and image_id["imageTag"] in image.image_tags:
image_found = True
repository.images[num].image_tag = image_id["imageTag"]
response["imageIds"].append(image.response_batch_delete_image)
if len(image.image_tags) > 1:
repository.images[num].remove_tag(image_id["imageTag"])
else:
repository.images.remove(image)
if not image_found:
failure_response = {
"imageId": {},
"failureCode": "ImageNotFound",
"failureReason": "Requested image not found",
}
if "imageDigest" in image_id:
failure_response["imageId"]["imageDigest"] = image_id.get("imageDigest", "null")
if "imageTag" in image_id:
failure_response["imageId"]["imageTag"] = image_id.get("imageTag", "null")
response["failures"].append(failure_response)
return response
ecr_backends = {}
for region, ec2_backend in ec2_backends.items():

View File

@ -84,9 +84,12 @@ class ECRResponse(BaseResponse):
'ECR.batch_check_layer_availability is not yet implemented')
def batch_delete_image(self):
if self.is_not_dryrun('BatchDeleteImage'):
raise NotImplementedError(
'ECR.batch_delete_image is not yet implemented')
repository_str = self._get_param('repositoryName')
registry_id = self._get_param('registryId')
image_ids = self._get_param('imageIds')
response = self.ecr_backend.batch_delete_image(repository_str, registry_id, image_ids)
return json.dumps(response)
def batch_get_image(self):
repository_str = self._get_param('repositoryName')

View File

@ -1,5 +1,5 @@
from __future__ import unicode_literals
from moto.core.exceptions import RESTError
from moto.core.exceptions import RESTError, JsonRESTError
class ServiceNotFoundException(RESTError):
@ -11,3 +11,13 @@ class ServiceNotFoundException(RESTError):
message="The service {0} does not exist".format(service_name),
template='error_json',
)
class TaskDefinitionNotFoundException(JsonRESTError):
code = 400
def __init__(self):
super(TaskDefinitionNotFoundException, self).__init__(
error_type="ClientException",
message="The specified task definition does not exist.",
)

View File

@ -1,4 +1,5 @@
from __future__ import unicode_literals
import re
import uuid
from datetime import datetime
from random import random, randint
@ -7,10 +8,14 @@ import boto3
import pytz
from moto.core.exceptions import JsonRESTError
from moto.core import BaseBackend, BaseModel
from moto.core.utils import unix_time
from moto.ec2 import ec2_backends
from copy import copy
from .exceptions import ServiceNotFoundException
from .exceptions import (
ServiceNotFoundException,
TaskDefinitionNotFoundException
)
class BaseObject(BaseModel):
@ -103,12 +108,13 @@ class Cluster(BaseObject):
class TaskDefinition(BaseObject):
def __init__(self, family, revision, container_definitions, volumes=None):
def __init__(self, family, revision, container_definitions, volumes=None, tags=None):
self.family = family
self.revision = revision
self.arn = 'arn:aws:ecs:us-east-1:012345678910:task-definition/{0}:{1}'.format(
family, revision)
self.container_definitions = container_definitions
self.tags = tags if tags is not None else []
if volumes is None:
self.volumes = []
else:
@ -119,6 +125,7 @@ class TaskDefinition(BaseObject):
response_object = self.gen_response_object()
response_object['taskDefinitionArn'] = response_object['arn']
del response_object['arn']
del response_object['tags']
return response_object
@property
@ -225,9 +232,9 @@ class Service(BaseObject):
for deployment in response_object['deployments']:
if isinstance(deployment['createdAt'], datetime):
deployment['createdAt'] = deployment['createdAt'].isoformat()
deployment['createdAt'] = unix_time(deployment['createdAt'].replace(tzinfo=None))
if isinstance(deployment['updatedAt'], datetime):
deployment['updatedAt'] = deployment['updatedAt'].isoformat()
deployment['updatedAt'] = unix_time(deployment['updatedAt'].replace(tzinfo=None))
return response_object
@ -422,11 +429,9 @@ class EC2ContainerServiceBackend(BaseBackend):
revision = int(revision)
else:
family = task_definition_name
revision = len(self.task_definitions.get(family, []))
revision = self._get_last_task_definition_revision_id(family)
if family in self.task_definitions and 0 < revision <= len(self.task_definitions[family]):
return self.task_definitions[family][revision - 1]
elif family in self.task_definitions and revision == -1:
if family in self.task_definitions and revision in self.task_definitions[family]:
return self.task_definitions[family][revision]
else:
raise Exception(
@ -466,15 +471,16 @@ class EC2ContainerServiceBackend(BaseBackend):
else:
raise Exception("{0} is not a cluster".format(cluster_name))
def register_task_definition(self, family, container_definitions, volumes):
def register_task_definition(self, family, container_definitions, volumes, tags=None):
if family in self.task_definitions:
revision = len(self.task_definitions[family]) + 1
last_id = self._get_last_task_definition_revision_id(family)
revision = (last_id or 0) + 1
else:
self.task_definitions[family] = []
self.task_definitions[family] = {}
revision = 1
task_definition = TaskDefinition(
family, revision, container_definitions, volumes)
self.task_definitions[family].append(task_definition)
family, revision, container_definitions, volumes, tags)
self.task_definitions[family][revision] = task_definition
return task_definition
@ -484,16 +490,18 @@ class EC2ContainerServiceBackend(BaseBackend):
"""
task_arns = []
for task_definition_list in self.task_definitions.values():
task_arns.extend(
[task_definition.arn for task_definition in task_definition_list])
task_arns.extend([
task_definition.arn
for task_definition in task_definition_list.values()
])
return task_arns
def deregister_task_definition(self, task_definition_str):
task_definition_name = task_definition_str.split('/')[-1]
family, revision = task_definition_name.split(':')
revision = int(revision)
if family in self.task_definitions and 0 < revision <= len(self.task_definitions[family]):
return self.task_definitions[family].pop(revision - 1)
if family in self.task_definitions and revision in self.task_definitions[family]:
return self.task_definitions[family].pop(revision)
else:
raise Exception(
"{0} is not a task_definition".format(task_definition_name))
@ -950,6 +958,29 @@ class EC2ContainerServiceBackend(BaseBackend):
yield task_fam
def list_tags_for_resource(self, resource_arn):
"""Currently only implemented for task definitions"""
match = re.match(
"^arn:aws:ecs:(?P<region>[^:]+):(?P<account_id>[^:]+):(?P<service>[^:]+)/(?P<id>.*)$",
resource_arn)
if not match:
raise JsonRESTError('InvalidParameterException', 'The ARN provided is invalid.')
service = match.group("service")
if service == "task-definition":
for task_definition in self.task_definitions.values():
for revision in task_definition.values():
if revision.arn == resource_arn:
return revision.tags
else:
raise TaskDefinitionNotFoundException()
raise NotImplementedError()
def _get_last_task_definition_revision_id(self, family):
definitions = self.task_definitions.get(family, {})
if definitions:
return max(definitions.keys())
available_regions = boto3.session.Session().get_available_regions("ecs")
ecs_backends = {region: EC2ContainerServiceBackend(region) for region in available_regions}

View File

@ -62,8 +62,9 @@ class EC2ContainerServiceResponse(BaseResponse):
family = self._get_param('family')
container_definitions = self._get_param('containerDefinitions')
volumes = self._get_param('volumes')
tags = self._get_param('tags')
task_definition = self.ecs_backend.register_task_definition(
family, container_definitions, volumes)
family, container_definitions, volumes, tags)
return json.dumps({
'taskDefinition': task_definition.response_object
})
@ -313,3 +314,8 @@ class EC2ContainerServiceResponse(BaseResponse):
results = self.ecs_backend.list_task_definition_families(family_prefix, status, max_results, next_token)
return json.dumps({'families': list(results)})
def list_tags_for_resource(self):
resource_arn = self._get_param('resourceArn')
tags = self.ecs_backend.list_tags_for_resource(resource_arn)
return json.dumps({'tags': tags})

View File

@ -2,9 +2,11 @@ from __future__ import unicode_literals
import datetime
import re
from jinja2 import Template
from moto.compat import OrderedDict
from moto.core.exceptions import RESTError
from moto.core import BaseBackend, BaseModel
from moto.core.utils import camelcase_to_underscores
from moto.ec2.models import ec2_backends
from moto.acm.models import acm_backends
from .utils import make_arn_for_target_group
@ -35,12 +37,13 @@ from .exceptions import (
class FakeHealthStatus(BaseModel):
def __init__(self, instance_id, port, health_port, status, reason=None):
def __init__(self, instance_id, port, health_port, status, reason=None, description=None):
self.instance_id = instance_id
self.port = port
self.health_port = health_port
self.status = status
self.reason = reason
self.description = description
class FakeTargetGroup(BaseModel):
@ -69,7 +72,7 @@ class FakeTargetGroup(BaseModel):
self.protocol = protocol
self.port = port
self.healthcheck_protocol = healthcheck_protocol or 'HTTP'
self.healthcheck_port = healthcheck_port or 'traffic-port'
self.healthcheck_port = healthcheck_port or str(self.port)
self.healthcheck_path = healthcheck_path or '/'
self.healthcheck_interval_seconds = healthcheck_interval_seconds or 30
self.healthcheck_timeout_seconds = healthcheck_timeout_seconds or 5
@ -112,10 +115,14 @@ class FakeTargetGroup(BaseModel):
raise TooManyTagsError()
self.tags[key] = value
def health_for(self, target):
def health_for(self, target, ec2_backend):
t = self.targets.get(target['id'])
if t is None:
raise InvalidTargetError()
if t['id'].startswith("i-"): # EC2 instance ID
instance = ec2_backend.get_instance_by_id(t['id'])
if instance.state == "stopped":
return FakeHealthStatus(t['id'], t['port'], self.healthcheck_port, 'unused', 'Target.InvalidState', 'Target is in the stopped state')
return FakeHealthStatus(t['id'], t['port'], self.healthcheck_port, 'healthy')
@classmethod
@ -208,13 +215,12 @@ class FakeListener(BaseModel):
action_type = action['Type']
if action_type == 'forward':
default_actions.append({'type': action_type, 'target_group_arn': action['TargetGroupArn']})
elif action_type == 'redirect':
redirect_action = {'type': action_type, }
for redirect_config_key, redirect_config_value in action['RedirectConfig'].items():
elif action_type in ['redirect', 'authenticate-cognito']:
redirect_action = {'type': action_type}
key = 'RedirectConfig' if action_type == 'redirect' else 'AuthenticateCognitoConfig'
for redirect_config_key, redirect_config_value in action[key].items():
# need to match the output of _get_list_prefix
if redirect_config_key == 'StatusCode':
redirect_config_key = 'status_code'
redirect_action['redirect_config._' + redirect_config_key.lower()] = redirect_config_value
redirect_action[camelcase_to_underscores(key) + '._' + camelcase_to_underscores(redirect_config_key)] = redirect_config_value
default_actions.append(redirect_action)
else:
raise InvalidActionTypeError(action_type, i + 1)
@ -226,6 +232,32 @@ class FakeListener(BaseModel):
return listener
class FakeAction(BaseModel):
def __init__(self, data):
self.data = data
self.type = data.get("type")
def to_xml(self):
template = Template("""<Type>{{ action.type }}</Type>
{% if action.type == "forward" %}
<TargetGroupArn>{{ action.data["target_group_arn"] }}</TargetGroupArn>
{% elif action.type == "redirect" %}
<RedirectConfig>
<Protocol>{{ action.data["redirect_config._protocol"] }}</Protocol>
<Port>{{ action.data["redirect_config._port"] }}</Port>
<StatusCode>{{ action.data["redirect_config._status_code"] }}</StatusCode>
</RedirectConfig>
{% elif action.type == "authenticate-cognito" %}
<AuthenticateCognitoConfig>
<UserPoolArn>{{ action.data["authenticate_cognito_config._user_pool_arn"] }}</UserPoolArn>
<UserPoolClientId>{{ action.data["authenticate_cognito_config._user_pool_client_id"] }}</UserPoolClientId>
<UserPoolDomain>{{ action.data["authenticate_cognito_config._user_pool_domain"] }}</UserPoolDomain>
</AuthenticateCognitoConfig>
{% endif %}
""")
return template.render(action=self)
class FakeRule(BaseModel):
def __init__(self, listener_arn, conditions, priority, actions, is_default):
@ -397,6 +429,7 @@ class ELBv2Backend(BaseBackend):
return new_load_balancer
def create_rule(self, listener_arn, conditions, priority, actions):
actions = [FakeAction(action) for action in actions]
listeners = self.describe_listeners(None, [listener_arn])
if not listeners:
raise ListenerNotFoundError()
@ -424,20 +457,7 @@ class ELBv2Backend(BaseBackend):
if rule.priority == priority:
raise PriorityInUseError()
# validate Actions
target_group_arns = [target_group.arn for target_group in self.target_groups.values()]
for i, action in enumerate(actions):
index = i + 1
action_type = action['type']
if action_type == 'forward':
action_target_group_arn = action['target_group_arn']
if action_target_group_arn not in target_group_arns:
raise ActionTargetGroupNotFoundError(action_target_group_arn)
elif action_type == 'redirect':
# nothing to do
pass
else:
raise InvalidActionTypeError(action_type, index)
self._validate_actions(actions)
# TODO: check for error 'TooManyRegistrationsForTargetId'
# TODO: check for error 'TooManyRules'
@ -447,6 +467,21 @@ class ELBv2Backend(BaseBackend):
listener.register(rule)
return [rule]
def _validate_actions(self, actions):
# validate Actions
target_group_arns = [target_group.arn for target_group in self.target_groups.values()]
for i, action in enumerate(actions):
index = i + 1
action_type = action.type
if action_type == 'forward':
action_target_group_arn = action.data['target_group_arn']
if action_target_group_arn not in target_group_arns:
raise ActionTargetGroupNotFoundError(action_target_group_arn)
elif action_type in ['redirect', 'authenticate-cognito']:
pass
else:
raise InvalidActionTypeError(action_type, index)
def create_target_group(self, name, **kwargs):
if len(name) > 32:
raise InvalidTargetGroupNameError(
@ -490,26 +525,22 @@ class ELBv2Backend(BaseBackend):
return target_group
def create_listener(self, load_balancer_arn, protocol, port, ssl_policy, certificate, default_actions):
default_actions = [FakeAction(action) for action in default_actions]
balancer = self.load_balancers.get(load_balancer_arn)
if balancer is None:
raise LoadBalancerNotFoundError()
if port in balancer.listeners:
raise DuplicateListenerError()
self._validate_actions(default_actions)
arn = load_balancer_arn.replace(':loadbalancer/', ':listener/') + "/%s%s" % (port, id(self))
listener = FakeListener(load_balancer_arn, arn, protocol, port, ssl_policy, certificate, default_actions)
balancer.listeners[listener.arn] = listener
for i, action in enumerate(default_actions):
action_type = action['type']
if action_type == 'forward':
if action['target_group_arn'] in self.target_groups.keys():
target_group = self.target_groups[action['target_group_arn']]
target_group.load_balancer_arns.append(load_balancer_arn)
elif action_type == 'redirect':
# nothing to do
pass
else:
raise InvalidActionTypeError(action_type, i + 1)
for action in default_actions:
if action.type == 'forward':
target_group = self.target_groups[action.data['target_group_arn']]
target_group.load_balancer_arns.append(load_balancer_arn)
return listener
@ -643,6 +674,7 @@ class ELBv2Backend(BaseBackend):
raise ListenerNotFoundError()
def modify_rule(self, rule_arn, conditions, actions):
actions = [FakeAction(action) for action in actions]
# if conditions or actions is empty list, do not update the attributes
if not conditions and not actions:
raise InvalidModifyRuleArgumentsError()
@ -668,20 +700,7 @@ class ELBv2Backend(BaseBackend):
# TODO: check pattern of value for 'path-pattern'
# validate Actions
target_group_arns = [target_group.arn for target_group in self.target_groups.values()]
if actions:
for i, action in enumerate(actions):
index = i + 1
action_type = action['type']
if action_type == 'forward':
action_target_group_arn = action['target_group_arn']
if action_target_group_arn not in target_group_arns:
raise ActionTargetGroupNotFoundError(action_target_group_arn)
elif action_type == 'redirect':
# nothing to do
pass
else:
raise InvalidActionTypeError(action_type, index)
self._validate_actions(actions)
# TODO: check for error 'TooManyRegistrationsForTargetId'
# TODO: check for error 'TooManyRules'
@ -712,7 +731,7 @@ class ELBv2Backend(BaseBackend):
if not targets:
targets = target_group.targets.values()
return [target_group.health_for(target) for target in targets]
return [target_group.health_for(target, self.ec2_backend) for target in targets]
def set_rule_priorities(self, rule_priorities):
# validate
@ -846,6 +865,7 @@ class ELBv2Backend(BaseBackend):
return target_group
def modify_listener(self, arn, port=None, protocol=None, ssl_policy=None, certificates=None, default_actions=None):
default_actions = [FakeAction(action) for action in default_actions]
for load_balancer in self.load_balancers.values():
if arn in load_balancer.listeners:
break
@ -912,7 +932,7 @@ class ELBv2Backend(BaseBackend):
for listener in load_balancer.listeners.values():
for rule in listener.rules:
for action in rule.actions:
if action.get('target_group_arn') == target_group_arn:
if action.data.get('target_group_arn') == target_group_arn:
return True
return False

View File

@ -775,16 +775,7 @@ CREATE_LISTENER_TEMPLATE = """<CreateListenerResponse xmlns="http://elasticloadb
<DefaultActions>
{% for action in listener.default_actions %}
<member>
<Type>{{ action.type }}</Type>
{% if action["type"] == "forward" %}
<TargetGroupArn>{{ action["target_group_arn"] }}</TargetGroupArn>
{% elif action["type"] == "redirect" %}
<RedirectConfig>
<Protocol>{{ action["redirect_config._protocol"] }}</Protocol>
<Port>{{ action["redirect_config._port"] }}</Port>
<StatusCode>{{ action["redirect_config._status_code"] }}</StatusCode>
</RedirectConfig>
{% endif %}
{{ action.to_xml() }}
</member>
{% endfor %}
</DefaultActions>
@ -888,16 +879,7 @@ DESCRIBE_RULES_TEMPLATE = """<DescribeRulesResponse xmlns="http://elasticloadbal
<Actions>
{% for action in rule.actions %}
<member>
<Type>{{ action["type"] }}</Type>
{% if action["type"] == "forward" %}
<TargetGroupArn>{{ action["target_group_arn"] }}</TargetGroupArn>
{% elif action["type"] == "redirect" %}
<RedirectConfig>
<Protocol>{{ action["redirect_config._protocol"] }}</Protocol>
<Port>{{ action["redirect_config._port"] }}</Port>
<StatusCode>{{ action["redirect_config._status_code"] }}</StatusCode>
</RedirectConfig>
{% endif %}
{{ action.to_xml() }}
</member>
{% endfor %}
</Actions>
@ -989,16 +971,7 @@ DESCRIBE_LISTENERS_TEMPLATE = """<DescribeLoadBalancersResponse xmlns="http://el
<DefaultActions>
{% for action in listener.default_actions %}
<member>
<Type>{{ action.type }}</Type>
{% if action["type"] == "forward" %}
<TargetGroupArn>{{ action["target_group_arn"] }}</TargetGroupArn>m
{% elif action["type"] == "redirect" %}
<RedirectConfig>
<Protocol>{{ action["redirect_config._protocol"] }}</Protocol>
<Port>{{ action["redirect_config._port"] }}</Port>
<StatusCode>{{ action["redirect_config._status_code"] }}</StatusCode>
</RedirectConfig>
{% endif %}
{{ action.to_xml() }}
</member>
{% endfor %}
</DefaultActions>
@ -1048,8 +1021,7 @@ MODIFY_RULE_TEMPLATE = """<ModifyRuleResponse xmlns="http://elasticloadbalancing
<Actions>
{% for action in rule.actions %}
<member>
<Type>{{ action["type"] }}</Type>
<TargetGroupArn>{{ action["target_group_arn"] }}</TargetGroupArn>
{{ action.to_xml() }}
</member>
{% endfor %}
</Actions>
@ -1208,6 +1180,12 @@ DESCRIBE_TARGET_HEALTH_TEMPLATE = """<DescribeTargetHealthResponse xmlns="http:/
<HealthCheckPort>{{ target_health.health_port }}</HealthCheckPort>
<TargetHealth>
<State>{{ target_health.status }}</State>
{% if target_health.reason %}
<Reason>{{ target_health.reason }}</Reason>
{% endif %}
{% if target_health.description %}
<Description>{{ target_health.description }}</Description>
{% endif %}
</TargetHealth>
<Target>
<Port>{{ target_health.port }}</Port>
@ -1426,16 +1404,7 @@ MODIFY_LISTENER_TEMPLATE = """<ModifyListenerResponse xmlns="http://elasticloadb
<DefaultActions>
{% for action in listener.default_actions %}
<member>
<Type>{{ action.type }}</Type>
{% if action["type"] == "forward" %}
<TargetGroupArn>{{ action["target_group_arn"] }}</TargetGroupArn>
{% elif action["type"] == "redirect" %}
<RedirectConfig>
<Protocol>{{ action["redirect_config._protocol"] }}</Protocol>
<Port>{{ action["redirect_config._port"] }}</Port>
<StatusCode>{{ action["redirect_config._status_code"] }}</StatusCode>
</RedirectConfig>
{% endif %}
{{ action.to_xml() }}
</member>
{% endfor %}
</DefaultActions>

View File

@ -138,6 +138,12 @@ class FakeTable(BaseModel):
raise PartitionAlreadyExistsException()
self.partitions[key] = partition
def delete_partition(self, values):
try:
del self.partitions[str(values)]
except KeyError:
raise PartitionNotFoundException()
class FakePartition(BaseModel):
def __init__(self, database_name, table_name, partiton_input):

View File

@ -4,6 +4,11 @@ import json
from moto.core.responses import BaseResponse
from .models import glue_backend
from .exceptions import (
PartitionAlreadyExistsException,
PartitionNotFoundException,
TableNotFoundException
)
class GlueResponse(BaseResponse):
@ -90,6 +95,28 @@ class GlueResponse(BaseResponse):
resp = self.glue_backend.delete_table(database_name, table_name)
return json.dumps(resp)
def batch_delete_table(self):
database_name = self.parameters.get('DatabaseName')
errors = []
for table_name in self.parameters.get('TablesToDelete'):
try:
self.glue_backend.delete_table(database_name, table_name)
except TableNotFoundException:
errors.append({
"TableName": table_name,
"ErrorDetail": {
"ErrorCode": "EntityNotFoundException",
"ErrorMessage": "Table not found"
}
})
out = {}
if errors:
out["Errors"] = errors
return json.dumps(out)
def get_partitions(self):
database_name = self.parameters.get('DatabaseName')
table_name = self.parameters.get('TableName')
@ -114,6 +141,23 @@ class GlueResponse(BaseResponse):
return json.dumps({'Partition': p.as_dict()})
def batch_get_partition(self):
database_name = self.parameters.get('DatabaseName')
table_name = self.parameters.get('TableName')
partitions_to_get = self.parameters.get('PartitionsToGet')
table = self.glue_backend.get_table(database_name, table_name)
partitions = []
for values in partitions_to_get:
try:
p = table.get_partition(values=values["Values"])
partitions.append(p.as_dict())
except PartitionNotFoundException:
continue
return json.dumps({'Partitions': partitions})
def create_partition(self):
database_name = self.parameters.get('DatabaseName')
table_name = self.parameters.get('TableName')
@ -124,6 +168,30 @@ class GlueResponse(BaseResponse):
return ""
def batch_create_partition(self):
database_name = self.parameters.get('DatabaseName')
table_name = self.parameters.get('TableName')
table = self.glue_backend.get_table(database_name, table_name)
errors_output = []
for part_input in self.parameters.get('PartitionInputList'):
try:
table.create_partition(part_input)
except PartitionAlreadyExistsException:
errors_output.append({
'PartitionValues': part_input['Values'],
'ErrorDetail': {
'ErrorCode': 'AlreadyExistsException',
'ErrorMessage': 'Partition already exists.'
}
})
out = {}
if errors_output:
out["Errors"] = errors_output
return json.dumps(out)
def update_partition(self):
database_name = self.parameters.get('DatabaseName')
table_name = self.parameters.get('TableName')
@ -134,3 +202,38 @@ class GlueResponse(BaseResponse):
table.update_partition(part_to_update, part_input)
return ""
def delete_partition(self):
database_name = self.parameters.get('DatabaseName')
table_name = self.parameters.get('TableName')
part_to_delete = self.parameters.get('PartitionValues')
table = self.glue_backend.get_table(database_name, table_name)
table.delete_partition(part_to_delete)
return ""
def batch_delete_partition(self):
database_name = self.parameters.get('DatabaseName')
table_name = self.parameters.get('TableName')
table = self.glue_backend.get_table(database_name, table_name)
errors_output = []
for part_input in self.parameters.get('PartitionsToDelete'):
values = part_input.get('Values')
try:
table.delete_partition(values)
except PartitionNotFoundException:
errors_output.append({
'PartitionValues': values,
'ErrorDetail': {
'ErrorCode': 'EntityNotFoundException',
'ErrorMessage': 'Partition not found',
}
})
out = {}
if errors_output:
out['Errors'] = errors_output
return json.dumps(out)

File diff suppressed because it is too large Load Diff

View File

@ -26,6 +26,14 @@ class IAMReportNotPresentException(RESTError):
"ReportNotPresent", message)
class IAMLimitExceededException(RESTError):
code = 400
def __init__(self, message):
super(IAMLimitExceededException, self).__init__(
"LimitExceeded", message)
class MalformedCertificate(RESTError):
code = 400
@ -34,6 +42,14 @@ class MalformedCertificate(RESTError):
'MalformedCertificate', 'Certificate {cert} is malformed'.format(cert=cert))
class MalformedPolicyDocument(RESTError):
code = 400
def __init__(self, message=""):
super(MalformedPolicyDocument, self).__init__(
'MalformedPolicyDocument', message)
class DuplicateTags(RESTError):
code = 400

View File

@ -8,14 +8,14 @@ import re
from cryptography import x509
from cryptography.hazmat.backends import default_backend
import pytz
from moto.core.exceptions import RESTError
from moto.core import BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_without_milliseconds
from moto.core.utils import iso_8601_datetime_without_milliseconds, iso_8601_datetime_with_milliseconds
from moto.iam.policy_validation import IAMPolicyDocumentValidator
from .aws_managed_policies import aws_managed_policies_data
from .exceptions import IAMNotFoundException, IAMConflictException, IAMReportNotPresentException, MalformedCertificate, \
DuplicateTags, TagKeyTooBig, InvalidTagCharacters, TooManyTags, TagValueTooBig
from .exceptions import IAMNotFoundException, IAMConflictException, IAMReportNotPresentException, IAMLimitExceededException, \
MalformedCertificate, DuplicateTags, TagKeyTooBig, InvalidTagCharacters, TooManyTags, TagValueTooBig
from .utils import random_access_key, random_alphanumeric, random_resource_id, random_policy_id
ACCOUNT_ID = 123456789012
@ -28,11 +28,15 @@ class MFADevice(object):
serial_number,
authentication_code_1,
authentication_code_2):
self.enable_date = datetime.now(pytz.utc)
self.enable_date = datetime.utcnow()
self.serial_number = serial_number
self.authentication_code_1 = authentication_code_1
self.authentication_code_2 = authentication_code_2
@property
def enabled_iso_8601(self):
return iso_8601_datetime_without_milliseconds(self.enable_date)
class Policy(BaseModel):
is_attachable = False
@ -42,7 +46,9 @@ class Policy(BaseModel):
default_version_id=None,
description=None,
document=None,
path=None):
path=None,
create_date=None,
update_date=None):
self.name = name
self.attachment_count = 0
@ -56,10 +62,25 @@ class Policy(BaseModel):
else:
self.default_version_id = 'v1'
self.next_version_num = 2
self.versions = [PolicyVersion(self.arn, document, True)]
self.versions = [PolicyVersion(self.arn, document, True, self.default_version_id, update_date)]
self.create_datetime = datetime.now(pytz.utc)
self.update_datetime = datetime.now(pytz.utc)
self.create_date = create_date if create_date is not None else datetime.utcnow()
self.update_date = update_date if update_date is not None else datetime.utcnow()
def update_default_version(self, new_default_version_id):
for version in self.versions:
if version.version_id == self.default_version_id:
version.is_default = False
break
self.default_version_id = new_default_version_id
@property
def created_iso_8601(self):
return iso_8601_datetime_with_milliseconds(self.create_date)
@property
def updated_iso_8601(self):
return iso_8601_datetime_with_milliseconds(self.update_date)
class SAMLProvider(BaseModel):
@ -77,13 +98,19 @@ class PolicyVersion(object):
def __init__(self,
policy_arn,
document,
is_default=False):
is_default=False,
version_id='v1',
create_date=None):
self.policy_arn = policy_arn
self.document = document or {}
self.is_default = is_default
self.version_id = 'v1'
self.version_id = version_id
self.create_datetime = datetime.now(pytz.utc)
self.create_date = create_date if create_date is not None else datetime.utcnow()
@property
def created_iso_8601(self):
return iso_8601_datetime_with_milliseconds(self.create_date)
class ManagedPolicy(Policy):
@ -112,7 +139,9 @@ class AWSManagedPolicy(ManagedPolicy):
return cls(name,
default_version_id=data.get('DefaultVersionId'),
path=data.get('Path'),
document=data.get('Document'))
document=json.dumps(data.get('Document')),
create_date=datetime.strptime(data.get('CreateDate'), "%Y-%m-%dT%H:%M:%S+00:00"),
update_date=datetime.strptime(data.get('UpdateDate'), "%Y-%m-%dT%H:%M:%S+00:00"))
@property
def arn(self):
@ -132,18 +161,22 @@ class InlinePolicy(Policy):
class Role(BaseModel):
def __init__(self, role_id, name, assume_role_policy_document, path, permissions_boundary):
def __init__(self, role_id, name, assume_role_policy_document, path, permissions_boundary, description, tags):
self.id = role_id
self.name = name
self.assume_role_policy_document = assume_role_policy_document
self.path = path or '/'
self.policies = {}
self.managed_policies = {}
self.create_date = datetime.now(pytz.utc)
self.tags = {}
self.description = ""
self.create_date = datetime.utcnow()
self.tags = tags
self.description = description
self.permissions_boundary = permissions_boundary
@property
def created_iso_8601(self):
return iso_8601_datetime_with_milliseconds(self.create_date)
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
properties = cloudformation_json['Properties']
@ -152,7 +185,9 @@ class Role(BaseModel):
role_name=resource_name,
assume_role_policy_document=properties['AssumeRolePolicyDocument'],
path=properties.get('Path', '/'),
permissions_boundary=properties.get('PermissionsBoundary', '')
permissions_boundary=properties.get('PermissionsBoundary', ''),
description=properties.get('Description', ''),
tags=properties.get('Tags', {})
)
policies = properties.get('Policies', [])
@ -198,7 +233,11 @@ class InstanceProfile(BaseModel):
self.name = name
self.path = path or '/'
self.roles = roles if roles else []
self.create_date = datetime.now(pytz.utc)
self.create_date = datetime.utcnow()
@property
def created_iso_8601(self):
return iso_8601_datetime_with_milliseconds(self.create_date)
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
@ -250,25 +289,31 @@ class SigningCertificate(BaseModel):
self.id = id
self.user_name = user_name
self.body = body
self.upload_date = datetime.strftime(datetime.utcnow(), "%Y-%m-%d-%H-%M-%S")
self.upload_date = datetime.utcnow()
self.status = 'Active'
@property
def uploaded_iso_8601(self):
return iso_8601_datetime_without_milliseconds(self.upload_date)
class AccessKey(BaseModel):
def __init__(self, user_name):
self.user_name = user_name
self.access_key_id = random_access_key()
self.secret_access_key = random_alphanumeric(32)
self.access_key_id = "AKIA" + random_access_key()
self.secret_access_key = random_alphanumeric(40)
self.status = 'Active'
self.create_date = datetime.strftime(
datetime.utcnow(),
"%Y-%m-%dT%H:%M:%SZ"
)
self.last_used = datetime.strftime(
datetime.utcnow(),
"%Y-%m-%dT%H:%M:%SZ"
)
self.create_date = datetime.utcnow()
self.last_used = datetime.utcnow()
@property
def created_iso_8601(self):
return iso_8601_datetime_without_milliseconds(self.create_date)
@property
def last_used_iso_8601(self):
return iso_8601_datetime_without_milliseconds(self.last_used)
def get_cfn_attribute(self, attribute_name):
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
@ -283,15 +328,16 @@ class Group(BaseModel):
self.name = name
self.id = random_resource_id()
self.path = path
self.created = datetime.strftime(
datetime.utcnow(),
"%Y-%m-%d-%H-%M-%S"
)
self.create_date = datetime.utcnow()
self.users = []
self.managed_policies = {}
self.policies = {}
@property
def created_iso_8601(self):
return iso_8601_datetime_with_milliseconds(self.create_date)
def get_cfn_attribute(self, attribute_name):
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == 'Arn':
@ -306,10 +352,6 @@ class Group(BaseModel):
else:
return "arn:aws:iam::{0}:group/{1}/{2}".format(ACCOUNT_ID, self.path, self.name)
@property
def create_date(self):
return self.created
def get_policy(self, policy_name):
try:
policy_json = self.policies[policy_name]
@ -335,7 +377,7 @@ class User(BaseModel):
self.name = name
self.id = random_resource_id()
self.path = path if path else "/"
self.created = datetime.utcnow()
self.create_date = datetime.utcnow()
self.mfa_devices = {}
self.policies = {}
self.managed_policies = {}
@ -350,7 +392,7 @@ class User(BaseModel):
@property
def created_iso_8601(self):
return iso_8601_datetime_without_milliseconds(self.created)
return iso_8601_datetime_with_milliseconds(self.create_date)
def get_policy(self, policy_name):
policy_json = None
@ -421,7 +463,7 @@ class User(BaseModel):
def to_csv(self):
date_format = '%Y-%m-%dT%H:%M:%S+00:00'
date_created = self.created
date_created = self.create_date
# aagrawal,arn:aws:iam::509284790694:user/aagrawal,2014-09-01T22:28:48+00:00,true,2014-11-12T23:36:49+00:00,2014-09-03T18:59:00+00:00,N/A,false,true,2014-09-01T22:28:48+00:00,false,N/A,false,N/A,false,N/A
if not self.password:
password_enabled = 'false'
@ -478,7 +520,7 @@ class IAMBackend(BaseBackend):
super(IAMBackend, self).__init__()
def _init_managed_policies(self):
return dict((p.name, p) for p in aws_managed_policies)
return dict((p.arn, p) for p in aws_managed_policies)
def attach_role_policy(self, policy_arn, role_name):
arns = dict((p.arn, p) for p in self.managed_policies.values())
@ -536,6 +578,9 @@ class IAMBackend(BaseBackend):
policy.detach_from(self.get_user(user_name))
def create_policy(self, description, path, policy_document, policy_name):
iam_policy_document_validator = IAMPolicyDocumentValidator(policy_document)
iam_policy_document_validator.validate()
policy = ManagedPolicy(
policy_name,
description=description,
@ -592,12 +637,13 @@ class IAMBackend(BaseBackend):
return policies, marker
def create_role(self, role_name, assume_role_policy_document, path, permissions_boundary):
def create_role(self, role_name, assume_role_policy_document, path, permissions_boundary, description, tags):
role_id = random_resource_id()
if permissions_boundary and not self.policy_arn_regex.match(permissions_boundary):
raise RESTError('InvalidParameterValue', 'Value ({}) for parameter PermissionsBoundary is invalid.'.format(permissions_boundary))
role = Role(role_id, role_name, assume_role_policy_document, path, permissions_boundary)
clean_tags = self._tag_verification(tags)
role = Role(role_id, role_name, assume_role_policy_document, path, permissions_boundary, description, clean_tags)
self.roles[role_id] = role
return role
@ -628,6 +674,9 @@ class IAMBackend(BaseBackend):
def put_role_policy(self, role_name, policy_name, policy_json):
role = self.get_role(role_name)
iam_policy_document_validator = IAMPolicyDocumentValidator(policy_json)
iam_policy_document_validator.validate()
role.put_policy(policy_name, policy_json)
def delete_role_policy(self, role_name, policy_name):
@ -639,15 +688,32 @@ class IAMBackend(BaseBackend):
for p, d in role.policies.items():
if p == policy_name:
return p, d
raise IAMNotFoundException("Policy Document {0} not attached to role {1}".format(policy_name, role_name))
def list_role_policies(self, role_name):
role = self.get_role(role_name)
return role.policies.keys()
def _tag_verification(self, tags):
if len(tags) > 50:
raise TooManyTags(tags)
tag_keys = {}
for tag in tags:
# Need to index by the lowercase tag key since the keys are case insensitive, but their case is retained.
ref_key = tag['Key'].lower()
self._check_tag_duplicate(tag_keys, ref_key)
self._validate_tag_key(tag['Key'])
if len(tag['Value']) > 256:
raise TagValueTooBig(tag['Value'])
tag_keys[ref_key] = tag
return tag_keys
def _validate_tag_key(self, tag_key, exception_param='tags.X.member.key'):
"""Validates the tag key.
:param all_tags: Dict to check if there is a duplicate tag.
:param tag_key: The tag key to check against.
:param exception_param: The exception parameter to send over to help format the message. This is to reflect
the difference between the tag and untag APIs.
@ -694,23 +760,9 @@ class IAMBackend(BaseBackend):
return tags, marker
def tag_role(self, role_name, tags):
if len(tags) > 50:
raise TooManyTags(tags)
clean_tags = self._tag_verification(tags)
role = self.get_role(role_name)
tag_keys = {}
for tag in tags:
# Need to index by the lowercase tag key since the keys are case insensitive, but their case is retained.
ref_key = tag['Key'].lower()
self._check_tag_duplicate(tag_keys, ref_key)
self._validate_tag_key(tag['Key'])
if len(tag['Value']) > 256:
raise TagValueTooBig(tag['Value'])
tag_keys[ref_key] = tag
role.tags.update(tag_keys)
role.tags.update(clean_tags)
def untag_role(self, role_name, tag_keys):
if len(tag_keys) > 50:
@ -725,15 +777,21 @@ class IAMBackend(BaseBackend):
role.tags.pop(ref_key, None)
def create_policy_version(self, policy_arn, policy_document, set_as_default):
iam_policy_document_validator = IAMPolicyDocumentValidator(policy_document)
iam_policy_document_validator.validate()
policy = self.get_policy(policy_arn)
if not policy:
raise IAMNotFoundException("Policy not found")
if len(policy.versions) >= 5:
raise IAMLimitExceededException("A managed policy can have up to 5 versions. Before you create a new version, you must delete an existing version.")
set_as_default = (set_as_default == "true") # convert it to python bool
version = PolicyVersion(policy_arn, policy_document, set_as_default)
policy.versions.append(version)
version.version_id = 'v{0}'.format(policy.next_version_num)
policy.next_version_num += 1
if set_as_default:
policy.default_version_id = version.version_id
policy.update_default_version(version.version_id)
return version
def get_policy_version(self, policy_arn, version_id):
@ -756,8 +814,8 @@ class IAMBackend(BaseBackend):
if not policy:
raise IAMNotFoundException("Policy not found")
if version_id == policy.default_version_id:
raise IAMConflictException(
"Cannot delete the default version of a policy")
raise IAMConflictException(code="DeleteConflict",
message="Cannot delete the default version of a policy.")
for i, v in enumerate(policy.versions):
if v.version_id == version_id:
del policy.versions[i]
@ -869,6 +927,9 @@ class IAMBackend(BaseBackend):
def put_group_policy(self, group_name, policy_name, policy_json):
group = self.get_group(group_name)
iam_policy_document_validator = IAMPolicyDocumentValidator(policy_json)
iam_policy_document_validator.validate()
group.put_policy(policy_name, policy_json)
def list_group_policies(self, group_name, marker=None, max_items=None):
@ -1029,6 +1090,9 @@ class IAMBackend(BaseBackend):
def put_user_policy(self, user_name, policy_name, policy_json):
user = self.get_user(user_name)
iam_policy_document_validator = IAMPolicyDocumentValidator(policy_json)
iam_policy_document_validator.validate()
user.put_policy(policy_name, policy_json)
def delete_user_policy(self, user_name, policy_name):
@ -1050,7 +1114,7 @@ class IAMBackend(BaseBackend):
if key.access_key_id == access_key_id:
return {
'user_name': key.user_name,
'last_used': key.last_used
'last_used': key.last_used_iso_8601,
}
else:
raise IAMNotFoundException(
@ -1189,5 +1253,13 @@ class IAMBackend(BaseBackend):
return saml_provider
raise IAMNotFoundException("SamlProvider {0} not found".format(saml_provider_arn))
def get_user_from_access_key_id(self, access_key_id):
for user_name, user in self.users.items():
access_keys = self.get_all_access_keys(user_name)
for access_key in access_keys:
if access_key.access_key_id == access_key_id:
return user
return None
iam_backend = IAMBackend()

View File

@ -1,6 +1,7 @@
from __future__ import unicode_literals
from moto.core.responses import BaseResponse
from .models import iam_backend, User
@ -177,9 +178,11 @@ class IamResponse(BaseResponse):
'AssumeRolePolicyDocument')
permissions_boundary = self._get_param(
'PermissionsBoundary')
description = self._get_param('Description')
tags = self._get_multi_param('Tags.member')
role = iam_backend.create_role(
role_name, assume_role_policy_document, path, permissions_boundary)
role_name, assume_role_policy_document, path, permissions_boundary, description, tags)
template = self.response_template(CREATE_ROLE_TEMPLATE)
return template.render(role=role)
@ -425,11 +428,13 @@ class IamResponse(BaseResponse):
def get_user(self):
user_name = self._get_param('UserName')
if user_name:
user = iam_backend.get_user(user_name)
if not user_name:
access_key_id = self.get_current_user()
user = iam_backend.get_user_from_access_key_id(access_key_id)
if user is None:
user = User("default_user")
else:
user = User(name='default_user')
# If no user is specific, IAM returns the current user
user = iam_backend.get_user(user_name)
template = self.response_template(USER_TEMPLATE)
return template.render(action='Get', user=user)
@ -457,7 +462,6 @@ class IamResponse(BaseResponse):
def create_login_profile(self):
user_name = self._get_param('UserName')
password = self._get_param('Password')
password = self._get_param('Password')
user = iam_backend.create_login_profile(user_name, password)
template = self.response_template(CREATE_LOGIN_PROFILE_TEMPLATE)
@ -818,12 +822,12 @@ CREATE_POLICY_TEMPLATE = """<CreatePolicyResponse>
<Policy>
<Arn>{{ policy.arn }}</Arn>
<AttachmentCount>{{ policy.attachment_count }}</AttachmentCount>
<CreateDate>{{ policy.create_datetime.isoformat() }}</CreateDate>
<CreateDate>{{ policy.created_iso_8601 }}</CreateDate>
<DefaultVersionId>{{ policy.default_version_id }}</DefaultVersionId>
<Path>{{ policy.path }}</Path>
<PolicyId>{{ policy.id }}</PolicyId>
<PolicyName>{{ policy.name }}</PolicyName>
<UpdateDate>{{ policy.update_datetime.isoformat() }}</UpdateDate>
<UpdateDate>{{ policy.updated_iso_8601 }}</UpdateDate>
</Policy>
</CreatePolicyResult>
<ResponseMetadata>
@ -841,8 +845,8 @@ GET_POLICY_TEMPLATE = """<GetPolicyResponse>
<Path>{{ policy.path }}</Path>
<Arn>{{ policy.arn }}</Arn>
<AttachmentCount>{{ policy.attachment_count }}</AttachmentCount>
<CreateDate>{{ policy.create_datetime.isoformat() }}</CreateDate>
<UpdateDate>{{ policy.update_datetime.isoformat() }}</UpdateDate>
<CreateDate>{{ policy.created_iso_8601 }}</CreateDate>
<UpdateDate>{{ policy.updated_iso_8601 }}</UpdateDate>
</Policy>
</GetPolicyResult>
<ResponseMetadata>
@ -929,12 +933,12 @@ LIST_POLICIES_TEMPLATE = """<ListPoliciesResponse>
<member>
<Arn>{{ policy.arn }}</Arn>
<AttachmentCount>{{ policy.attachment_count }}</AttachmentCount>
<CreateDate>{{ policy.create_datetime.isoformat() }}</CreateDate>
<CreateDate>{{ policy.created_iso_8601 }}</CreateDate>
<DefaultVersionId>{{ policy.default_version_id }}</DefaultVersionId>
<Path>{{ policy.path }}</Path>
<PolicyId>{{ policy.id }}</PolicyId>
<PolicyName>{{ policy.name }}</PolicyName>
<UpdateDate>{{ policy.update_datetime.isoformat() }}</UpdateDate>
<UpdateDate>{{ policy.updated_iso_8601 }}</UpdateDate>
</member>
{% endfor %}
</Policies>
@ -958,7 +962,7 @@ CREATE_INSTANCE_PROFILE_TEMPLATE = """<CreateInstanceProfileResponse xmlns="http
<InstanceProfileName>{{ profile.name }}</InstanceProfileName>
<Path>{{ profile.path }}</Path>
<Arn>{{ profile.arn }}</Arn>
<CreateDate>{{ profile.create_date }}</CreateDate>
<CreateDate>{{ profile.created_iso_8601 }}</CreateDate>
</InstanceProfile>
</CreateInstanceProfileResult>
<ResponseMetadata>
@ -977,7 +981,7 @@ GET_INSTANCE_PROFILE_TEMPLATE = """<GetInstanceProfileResponse xmlns="https://ia
<Arn>{{ role.arn }}</Arn>
<RoleName>{{ role.name }}</RoleName>
<AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument>
<CreateDate>{{ role.create_date }}</CreateDate>
<CreateDate>{{ role.created_iso_8601 }}</CreateDate>
<RoleId>{{ role.id }}</RoleId>
</member>
{% endfor %}
@ -985,7 +989,7 @@ GET_INSTANCE_PROFILE_TEMPLATE = """<GetInstanceProfileResponse xmlns="https://ia
<InstanceProfileName>{{ profile.name }}</InstanceProfileName>
<Path>{{ profile.path }}</Path>
<Arn>{{ profile.arn }}</Arn>
<CreateDate>{{ profile.create_date }}</CreateDate>
<CreateDate>{{ profile.created_iso_8601 }}</CreateDate>
</InstanceProfile>
</GetInstanceProfileResult>
<ResponseMetadata>
@ -1000,7 +1004,8 @@ CREATE_ROLE_TEMPLATE = """<CreateRoleResponse xmlns="https://iam.amazonaws.com/d
<Arn>{{ role.arn }}</Arn>
<RoleName>{{ role.name }}</RoleName>
<AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument>
<CreateDate>{{ role.create_date }}</CreateDate>
<Description>{{role.description}}</Description>
<CreateDate>{{ role.created_iso_8601 }}</CreateDate>
<RoleId>{{ role.id }}</RoleId>
{% if role.permissions_boundary %}
<PermissionsBoundary>
@ -1008,6 +1013,16 @@ CREATE_ROLE_TEMPLATE = """<CreateRoleResponse xmlns="https://iam.amazonaws.com/d
<PermissionsBoundaryArn>{{ role.permissions_boundary }}</PermissionsBoundaryArn>
</PermissionsBoundary>
{% endif %}
{% if role.tags %}
<Tags>
{% for tag in role.get_tags() %}
<member>
<Key>{{ tag['Key'] }}</Key>
<Value>{{ tag['Value'] }}</Value>
</member>
{% endfor %}
</Tags>
{% endif %}
</Role>
</CreateRoleResult>
<ResponseMetadata>
@ -1041,7 +1056,8 @@ UPDATE_ROLE_DESCRIPTION_TEMPLATE = """<UpdateRoleDescriptionResponse xmlns="http
<Arn>{{ role.arn }}</Arn>
<RoleName>{{ role.name }}</RoleName>
<AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument>
<CreateDate>{{ role.create_date.isoformat() }}</CreateDate>
<Description>{{role.description}}</Description>
<CreateDate>{{ role.created_iso_8601 }}</CreateDate>
<RoleId>{{ role.id }}</RoleId>
{% if role.tags %}
<Tags>
@ -1067,7 +1083,8 @@ GET_ROLE_TEMPLATE = """<GetRoleResponse xmlns="https://iam.amazonaws.com/doc/201
<Arn>{{ role.arn }}</Arn>
<RoleName>{{ role.name }}</RoleName>
<AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument>
<CreateDate>{{ role.create_date }}</CreateDate>
<Description>{{role.description}}</Description>
<CreateDate>{{ role.created_iso_8601 }}</CreateDate>
<RoleId>{{ role.id }}</RoleId>
{% if role.tags %}
<Tags>
@ -1108,7 +1125,7 @@ LIST_ROLES_TEMPLATE = """<ListRolesResponse xmlns="https://iam.amazonaws.com/doc
<Arn>{{ role.arn }}</Arn>
<RoleName>{{ role.name }}</RoleName>
<AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument>
<CreateDate>{{ role.create_date }}</CreateDate>
<CreateDate>{{ role.created_iso_8601 }}</CreateDate>
<RoleId>{{ role.id }}</RoleId>
{% if role.permissions_boundary %}
<PermissionsBoundary>
@ -1144,8 +1161,8 @@ CREATE_POLICY_VERSION_TEMPLATE = """<CreatePolicyVersionResponse xmlns="https://
<PolicyVersion>
<Document>{{ policy_version.document }}</Document>
<VersionId>{{ policy_version.version_id }}</VersionId>
<IsDefaultVersion>{{ policy_version.is_default }}</IsDefaultVersion>
<CreateDate>{{ policy_version.create_datetime }}</CreateDate>
<IsDefaultVersion>{{ policy_version.is_default | lower }}</IsDefaultVersion>
<CreateDate>{{ policy_version.created_iso_8601 }}</CreateDate>
</PolicyVersion>
</CreatePolicyVersionResult>
<ResponseMetadata>
@ -1158,8 +1175,8 @@ GET_POLICY_VERSION_TEMPLATE = """<GetPolicyVersionResponse xmlns="https://iam.am
<PolicyVersion>
<Document>{{ policy_version.document }}</Document>
<VersionId>{{ policy_version.version_id }}</VersionId>
<IsDefaultVersion>{{ policy_version.is_default }}</IsDefaultVersion>
<CreateDate>{{ policy_version.create_datetime }}</CreateDate>
<IsDefaultVersion>{{ policy_version.is_default | lower }}</IsDefaultVersion>
<CreateDate>{{ policy_version.created_iso_8601 }}</CreateDate>
</PolicyVersion>
</GetPolicyVersionResult>
<ResponseMetadata>
@ -1175,8 +1192,8 @@ LIST_POLICY_VERSIONS_TEMPLATE = """<ListPolicyVersionsResponse xmlns="https://ia
<member>
<Document>{{ policy_version.document }}</Document>
<VersionId>{{ policy_version.version_id }}</VersionId>
<IsDefaultVersion>{{ policy_version.is_default }}</IsDefaultVersion>
<CreateDate>{{ policy_version.create_datetime }}</CreateDate>
<IsDefaultVersion>{{ policy_version.is_default | lower }}</IsDefaultVersion>
<CreateDate>{{ policy_version.created_iso_8601 }}</CreateDate>
</member>
{% endfor %}
</Versions>
@ -1200,7 +1217,7 @@ LIST_INSTANCE_PROFILES_TEMPLATE = """<ListInstanceProfilesResponse xmlns="https:
<Arn>{{ role.arn }}</Arn>
<RoleName>{{ role.name }}</RoleName>
<AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument>
<CreateDate>{{ role.create_date }}</CreateDate>
<CreateDate>{{ role.created_iso_8601 }}</CreateDate>
<RoleId>{{ role.id }}</RoleId>
</member>
{% endfor %}
@ -1208,7 +1225,7 @@ LIST_INSTANCE_PROFILES_TEMPLATE = """<ListInstanceProfilesResponse xmlns="https:
<InstanceProfileName>{{ instance.name }}</InstanceProfileName>
<Path>{{ instance.path }}</Path>
<Arn>{{ instance.arn }}</Arn>
<CreateDate>{{ instance.create_date }}</CreateDate>
<CreateDate>{{ instance.created_iso_8601 }}</CreateDate>
</member>
{% endfor %}
</InstanceProfiles>
@ -1287,7 +1304,7 @@ CREATE_GROUP_TEMPLATE = """<CreateGroupResponse>
<GroupName>{{ group.name }}</GroupName>
<GroupId>{{ group.id }}</GroupId>
<Arn>{{ group.arn }}</Arn>
<CreateDate>{{ group.create_date }}</CreateDate>
<CreateDate>{{ group.created_iso_8601 }}</CreateDate>
</Group>
</CreateGroupResult>
<ResponseMetadata>
@ -1302,7 +1319,7 @@ GET_GROUP_TEMPLATE = """<GetGroupResponse>
<GroupName>{{ group.name }}</GroupName>
<GroupId>{{ group.id }}</GroupId>
<Arn>{{ group.arn }}</Arn>
<CreateDate>{{ group.create_date }}</CreateDate>
<CreateDate>{{ group.created_iso_8601 }}</CreateDate>
</Group>
<Users>
{% for user in group.users %}
@ -1349,6 +1366,7 @@ LIST_GROUPS_FOR_USER_TEMPLATE = """<ListGroupsForUserResponse>
<GroupName>{{ group.name }}</GroupName>
<GroupId>{{ group.id }}</GroupId>
<Arn>{{ group.arn }}</Arn>
<CreateDate>{{ group.created_iso_8601 }}</CreateDate>
</member>
{% endfor %}
</Groups>
@ -1493,6 +1511,7 @@ CREATE_ACCESS_KEY_TEMPLATE = """<CreateAccessKeyResponse>
<AccessKeyId>{{ key.access_key_id }}</AccessKeyId>
<Status>{{ key.status }}</Status>
<SecretAccessKey>{{ key.secret_access_key }}</SecretAccessKey>
<CreateDate>{{ key.created_iso_8601 }}</CreateDate>
</AccessKey>
</CreateAccessKeyResult>
<ResponseMetadata>
@ -1509,7 +1528,7 @@ LIST_ACCESS_KEYS_TEMPLATE = """<ListAccessKeysResponse>
<UserName>{{ user_name }}</UserName>
<AccessKeyId>{{ key.access_key_id }}</AccessKeyId>
<Status>{{ key.status }}</Status>
<CreateDate>{{ key.create_date }}</CreateDate>
<CreateDate>{{ key.created_iso_8601 }}</CreateDate>
</member>
{% endfor %}
</AccessKeyMetadata>
@ -1577,7 +1596,7 @@ LIST_INSTANCE_PROFILES_FOR_ROLE_TEMPLATE = """<ListInstanceProfilesForRoleRespon
<Arn>{{ role.arn }}</Arn>
<RoleName>{{ role.name }}</RoleName>
<AssumeRolePolicyDocument>{{ role.assume_policy_document }}</AssumeRolePolicyDocument>
<CreateDate>{{ role.create_date }}</CreateDate>
<CreateDate>{{ role.created_iso_8601 }}</CreateDate>
<RoleId>{{ role.id }}</RoleId>
</member>
{% endfor %}
@ -1585,7 +1604,7 @@ LIST_INSTANCE_PROFILES_FOR_ROLE_TEMPLATE = """<ListInstanceProfilesForRoleRespon
<InstanceProfileName>{{ profile.name }}</InstanceProfileName>
<Path>{{ profile.path }}</Path>
<Arn>{{ profile.arn }}</Arn>
<CreateDate>{{ profile.create_date }}</CreateDate>
<CreateDate>{{ profile.created_iso_8601 }}</CreateDate>
</member>
{% endfor %}
</InstanceProfiles>
@ -1651,6 +1670,7 @@ LIST_GROUPS_FOR_USER_TEMPLATE = """<ListGroupsForUserResponse>
<GroupName>{{ group.name }}</GroupName>
<GroupId>{{ group.id }}</GroupId>
<Arn>{{ group.arn }}</Arn>
<CreateDate>{{ group.created_iso_8601 }}</CreateDate>
</member>
{% endfor %}
</Groups>
@ -1704,7 +1724,7 @@ GET_ACCOUNT_AUTHORIZATION_DETAILS_TEMPLATE = """<GetAccountAuthorizationDetailsR
<GroupName>{{ group.name }}</GroupName>
<Path>{{ group.path }}</Path>
<Arn>{{ group.arn }}</Arn>
<CreateDate>{{ group.create_date }}</CreateDate>
<CreateDate>{{ group.created_iso_8601 }}</CreateDate>
<GroupPolicyList>
{% for policy in group.policies %}
<member>
@ -1754,15 +1774,22 @@ GET_ACCOUNT_AUTHORIZATION_DETAILS_TEMPLATE = """<GetAccountAuthorizationDetailsR
<Arn>{{ role.arn }}</Arn>
<RoleName>{{ role.name }}</RoleName>
<AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument>
<CreateDate>{{ role.create_date }}</CreateDate>
<Description>{{role.description}}</Description>
<CreateDate>{{ role.created_iso_8601 }}</CreateDate>
<RoleId>{{ role.id }}</RoleId>
{% if role.permissions_boundary %}
<PermissionsBoundary>
<PermissionsBoundaryType>PermissionsBoundaryPolicy</PermissionsBoundaryType>
<PermissionsBoundaryArn>{{ role.permissions_boundary }}</PermissionsBoundaryArn>
</PermissionsBoundary>
{% endif %}
</member>
{% endfor %}
</Roles>
<InstanceProfileName>{{ profile.name }}</InstanceProfileName>
<Path>{{ profile.path }}</Path>
<Arn>{{ profile.arn }}</Arn>
<CreateDate>{{ profile.create_date }}</CreateDate>
<CreateDate>{{ profile.created_iso_8601 }}</CreateDate>
</member>
{% endfor %}
</InstanceProfileList>
@ -1770,7 +1797,7 @@ GET_ACCOUNT_AUTHORIZATION_DETAILS_TEMPLATE = """<GetAccountAuthorizationDetailsR
<Arn>{{ role.arn }}</Arn>
<RoleName>{{ role.name }}</RoleName>
<AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument>
<CreateDate>{{ role.create_date }}</CreateDate>
<CreateDate>{{ role.created_iso_8601 }}</CreateDate>
<RoleId>{{ role.id }}</RoleId>
</member>
{% endfor %}
@ -1786,17 +1813,17 @@ GET_ACCOUNT_AUTHORIZATION_DETAILS_TEMPLATE = """<GetAccountAuthorizationDetailsR
{% for policy_version in policy.versions %}
<member>
<Document>{{ policy_version.document }}</Document>
<IsDefaultVersion>{{ policy_version.is_default }}</IsDefaultVersion>
<IsDefaultVersion>{{ policy_version.is_default | lower }}</IsDefaultVersion>
<VersionId>{{ policy_version.version_id }}</VersionId>
<CreateDate>{{ policy_version.create_datetime }}</CreateDate>
<CreateDate>{{ policy_version.created_iso_8601 }}</CreateDate>
</member>
{% endfor %}
</PolicyVersionList>
<Arn>{{ policy.arn }}</Arn>
<AttachmentCount>1</AttachmentCount>
<CreateDate>{{ policy.create_datetime }}</CreateDate>
<CreateDate>{{ policy.created_iso_8601 }}</CreateDate>
<IsAttachable>true</IsAttachable>
<UpdateDate>{{ policy.update_datetime }}</UpdateDate>
<UpdateDate>{{ policy.updated_iso_8601 }}</UpdateDate>
</member>
{% endfor %}
</Policies>

View File

@ -7,7 +7,7 @@ import six
def random_alphanumeric(length):
return ''.join(six.text_type(
random.choice(
string.ascii_letters + string.digits
string.ascii_letters + string.digits + "+" + "/"
)) for _ in range(length)
)

View File

@ -123,17 +123,12 @@ class Stream(BaseModel):
self.tags = {}
self.status = "ACTIVE"
if six.PY3:
izip_longest = itertools.zip_longest
else:
izip_longest = itertools.izip_longest
step = 2**128 // shard_count
hash_ranges = itertools.chain(map(lambda i: (i, i * step, (i + 1) * step),
range(shard_count - 1)),
[(shard_count - 1, (shard_count - 1) * step, 2**128)])
for index, start, end in hash_ranges:
for index, start, end in izip_longest(range(shard_count),
range(0, 2**128, 2 **
128 // shard_count),
range(2**128 // shard_count, 2 **
128, 2**128 // shard_count),
fillvalue=2**128):
shard = Shard(index, start, end)
self.shards[shard.shard_id] = shard

View File

@ -3,7 +3,7 @@ from __future__ import unicode_literals
import os
import boto.kms
from moto.core import BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_without_milliseconds, unix_time
from moto.core.utils import iso_8601_datetime_without_milliseconds
from .utils import generate_key_id
from collections import defaultdict
from datetime import datetime, timedelta
@ -11,7 +11,7 @@ from datetime import datetime, timedelta
class Key(BaseModel):
def __init__(self, policy, key_usage, description, region):
def __init__(self, policy, key_usage, description, tags, region):
self.id = generate_key_id()
self.policy = policy
self.key_usage = key_usage
@ -22,7 +22,7 @@ class Key(BaseModel):
self.account_id = "0123456789012"
self.key_rotation_status = False
self.deletion_date = None
self.tags = {}
self.tags = tags or {}
@property
def physical_resource_id(self):
@ -37,7 +37,7 @@ class Key(BaseModel):
"KeyMetadata": {
"AWSAccountId": self.account_id,
"Arn": self.arn,
"CreationDate": "%d" % unix_time(),
"CreationDate": iso_8601_datetime_without_milliseconds(datetime.now()),
"Description": self.description,
"Enabled": self.enabled,
"KeyId": self.id,
@ -61,6 +61,7 @@ class Key(BaseModel):
policy=properties['KeyPolicy'],
key_usage='ENCRYPT_DECRYPT',
description=properties['Description'],
tags=properties.get('Tags'),
region=region_name,
)
key.key_rotation_status = properties['EnableKeyRotation']
@ -80,8 +81,8 @@ class KmsBackend(BaseBackend):
self.keys = {}
self.key_to_aliases = defaultdict(set)
def create_key(self, policy, key_usage, description, region):
key = Key(policy, key_usage, description, region)
def create_key(self, policy, key_usage, description, tags, region):
key = Key(policy, key_usage, description, tags, region)
self.keys[key.id] = key
return key

View File

@ -31,9 +31,10 @@ class KmsResponse(BaseResponse):
policy = self.parameters.get('Policy')
key_usage = self.parameters.get('KeyUsage')
description = self.parameters.get('Description')
tags = self.parameters.get('Tags')
key = self.kms_backend.create_key(
policy, key_usage, description, self.region)
policy, key_usage, description, tags, self.region)
return json.dumps(key.to_dict())
def update_key_description(self):
@ -237,7 +238,7 @@ class KmsResponse(BaseResponse):
value = self.parameters.get("CiphertextBlob")
try:
return json.dumps({"Plaintext": base64.b64decode(value).decode("utf-8")})
return json.dumps({"Plaintext": base64.b64decode(value).decode("utf-8"), 'KeyId': 'key_id'})
except UnicodeDecodeError:
# Generate data key will produce random bytes which when decrypted is still returned as base64
return json.dumps({"Plaintext": value})

View File

@ -98,17 +98,29 @@ class LogStream:
return True
def get_paging_token_from_index(index, back=False):
if index is not None:
return "b/{:056d}".format(index) if back else "f/{:056d}".format(index)
return 0
def get_index_from_paging_token(token):
if token is not None:
return int(token[2:])
return 0
events = sorted(filter(filter_func, self.events), key=lambda event: event.timestamp, reverse=start_from_head)
back_token = next_token
if next_token is None:
next_token = 0
next_index = get_index_from_paging_token(next_token)
back_index = next_index
events_page = [event.to_response_dict() for event in events[next_token: next_token + limit]]
next_token += limit
if next_token >= len(self.events):
next_token = None
events_page = [event.to_response_dict() for event in events[next_index: next_index + limit]]
if next_index + limit < len(self.events):
next_index += limit
return events_page, back_token, next_token
back_index -= limit
if back_index <= 0:
back_index = 0
return events_page, get_paging_token_from_index(back_index, True), get_paging_token_from_index(next_index)
def filter_log_events(self, log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved):
def filter_func(event):

View File

@ -2,6 +2,7 @@ from __future__ import unicode_literals
import datetime
import re
import json
from moto.core import BaseBackend, BaseModel
from moto.core.exceptions import RESTError
@ -151,7 +152,6 @@ class FakeRoot(FakeOrganizationalUnit):
class FakeServiceControlPolicy(BaseModel):
def __init__(self, organization, **kwargs):
self.type = 'POLICY'
self.content = kwargs.get('Content')
self.description = kwargs.get('Description')
self.name = kwargs.get('Name')
@ -197,7 +197,38 @@ class OrganizationsBackend(BaseBackend):
def create_organization(self, **kwargs):
self.org = FakeOrganization(kwargs['FeatureSet'])
self.ou.append(FakeRoot(self.org))
root_ou = FakeRoot(self.org)
self.ou.append(root_ou)
master_account = FakeAccount(
self.org,
AccountName='master',
Email=self.org.master_account_email,
)
master_account.id = self.org.master_account_id
self.accounts.append(master_account)
default_policy = FakeServiceControlPolicy(
self.org,
Name='FullAWSAccess',
Description='Allows access to every operation',
Type='SERVICE_CONTROL_POLICY',
Content=json.dumps(
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": "*",
"Resource": "*"
}
]
}
)
)
default_policy.id = utils.DEFAULT_POLICY_ID
default_policy.aws_managed = True
self.policies.append(default_policy)
self.attach_policy(PolicyId=default_policy.id, TargetId=root_ou.id)
self.attach_policy(PolicyId=default_policy.id, TargetId=master_account.id)
return self.org.describe()
def describe_organization(self):
@ -216,6 +247,7 @@ class OrganizationsBackend(BaseBackend):
def create_organizational_unit(self, **kwargs):
new_ou = FakeOrganizationalUnit(self.org, **kwargs)
self.ou.append(new_ou)
self.attach_policy(PolicyId=utils.DEFAULT_POLICY_ID, TargetId=new_ou.id)
return new_ou.describe()
def get_organizational_unit_by_id(self, ou_id):
@ -258,6 +290,7 @@ class OrganizationsBackend(BaseBackend):
def create_account(self, **kwargs):
new_account = FakeAccount(self.org, **kwargs)
self.accounts.append(new_account)
self.attach_policy(PolicyId=utils.DEFAULT_POLICY_ID, TargetId=new_account.id)
return new_account.create_account_status
def get_account_by_id(self, account_id):
@ -358,8 +391,7 @@ class OrganizationsBackend(BaseBackend):
def attach_policy(self, **kwargs):
policy = next((p for p in self.policies if p.id == kwargs['PolicyId']), None)
if (re.compile(utils.ROOT_ID_REGEX).match(kwargs['TargetId']) or
re.compile(utils.OU_ID_REGEX).match(kwargs['TargetId'])):
if (re.compile(utils.ROOT_ID_REGEX).match(kwargs['TargetId']) or re.compile(utils.OU_ID_REGEX).match(kwargs['TargetId'])):
ou = next((ou for ou in self.ou if ou.id == kwargs['TargetId']), None)
if ou is not None:
if ou not in ou.attached_policies:

View File

@ -4,7 +4,8 @@ import random
import string
MASTER_ACCOUNT_ID = '123456789012'
MASTER_ACCOUNT_EMAIL = 'fakeorg@moto-example.com'
MASTER_ACCOUNT_EMAIL = 'master@example.com'
DEFAULT_POLICY_ID = 'p-FullAWSAccess'
ORGANIZATION_ARN_FORMAT = 'arn:aws:organizations::{0}:organization/{1}'
MASTER_ACCOUNT_ARN_FORMAT = 'arn:aws:organizations::{0}:account/{1}/{0}'
ACCOUNT_ARN_FORMAT = 'arn:aws:organizations::{0}:account/{1}/{2}'
@ -26,7 +27,7 @@ ROOT_ID_REGEX = r'r-[a-z0-9]{%s}' % ROOT_ID_SIZE
OU_ID_REGEX = r'ou-[a-z0-9]{%s}-[a-z0-9]{%s}' % (ROOT_ID_SIZE, OU_ID_SUFFIX_SIZE)
ACCOUNT_ID_REGEX = r'[0-9]{%s}' % ACCOUNT_ID_SIZE
CREATE_ACCOUNT_STATUS_ID_REGEX = r'car-[a-z0-9]{%s}' % CREATE_ACCOUNT_STATUS_ID_SIZE
SCP_ID_REGEX = r'p-[a-z0-9]{%s}' % SCP_ID_SIZE
SCP_ID_REGEX = r'%s|p-[a-z0-9]{%s}' % (DEFAULT_POLICY_ID, SCP_ID_SIZE)
def make_random_org_id():

View File

@ -268,10 +268,26 @@ class fakesock(object):
_sent_data = []
def __init__(self, family=socket.AF_INET, type=socket.SOCK_STREAM,
protocol=0):
self.truesock = (old_socket(family, type, protocol)
if httpretty.allow_net_connect
else None)
proto=0, fileno=None, _sock=None):
"""
Matches both the Python 2 API:
def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, _sock=None):
https://github.com/python/cpython/blob/2.7/Lib/socket.py
and the Python 3 API:
def __init__(self, family=-1, type=-1, proto=-1, fileno=None):
https://github.com/python/cpython/blob/3.5/Lib/socket.py
"""
if httpretty.allow_net_connect:
if PY3:
self.truesock = old_socket(family, type, proto, fileno)
else:
# If Python 2, if parameters are passed as arguments, instead of kwargs,
# the 4th argument `_sock` will be interpreted as the `fileno`.
# Check if _sock is none, and if so, pass fileno.
self.truesock = old_socket(family, type, proto, fileno or _sock)
else:
self.truesock = None
self._closed = True
self.fd = FakeSockFile()
self.fd.socket = self

View File

@ -95,7 +95,7 @@ class RDSResponse(BaseResponse):
start = all_ids.index(marker) + 1
else:
start = 0
page_size = self._get_param('MaxRecords', 50) # the default is 100, but using 50 to make testing easier
page_size = self._get_int_param('MaxRecords', 50) # the default is 100, but using 50 to make testing easier
instances_resp = all_instances[start:start + page_size]
next_marker = None
if len(all_instances) > start + page_size:

View File

@ -60,6 +60,15 @@ class DBParameterGroupNotFoundError(RDSClientError):
'DB Parameter Group {0} not found.'.format(db_parameter_group_name))
class OptionGroupNotFoundFaultError(RDSClientError):
def __init__(self, option_group_name):
super(OptionGroupNotFoundFaultError, self).__init__(
'OptionGroupNotFoundFault',
'Specified OptionGroupName: {0} not found.'.format(option_group_name)
)
class InvalidDBClusterStateFaultError(RDSClientError):
def __init__(self, database_identifier):

View File

@ -20,6 +20,7 @@ from .exceptions import (RDSClientError,
DBSecurityGroupNotFoundError,
DBSubnetGroupNotFoundError,
DBParameterGroupNotFoundError,
OptionGroupNotFoundFaultError,
InvalidDBClusterStateFaultError,
InvalidDBInstanceStateError,
SnapshotQuotaExceededError,
@ -70,6 +71,7 @@ class Database(BaseModel):
self.port = Database.default_port(self.engine)
self.db_instance_identifier = kwargs.get('db_instance_identifier')
self.db_name = kwargs.get("db_name")
self.instance_create_time = iso_8601_datetime_with_milliseconds(datetime.datetime.now())
self.publicly_accessible = kwargs.get("publicly_accessible")
if self.publicly_accessible is None:
self.publicly_accessible = True
@ -99,6 +101,8 @@ class Database(BaseModel):
'preferred_backup_window', '13:14-13:44')
self.license_model = kwargs.get('license_model', 'general-public-license')
self.option_group_name = kwargs.get('option_group_name', None)
if self.option_group_name and self.option_group_name not in rds2_backends[self.region].option_groups:
raise OptionGroupNotFoundFaultError(self.option_group_name)
self.default_option_groups = {"MySQL": "default.mysql5.6",
"mysql": "default.mysql5.6",
"postgres": "default.postgres9.3"
@ -145,9 +149,17 @@ class Database(BaseModel):
<DBInstanceStatus>{{ database.status }}</DBInstanceStatus>
{% if database.db_name %}<DBName>{{ database.db_name }}</DBName>{% endif %}
<MultiAZ>{{ database.multi_az }}</MultiAZ>
<VpcSecurityGroups/>
<VpcSecurityGroups>
{% for vpc_security_group_id in database.vpc_security_group_ids %}
<VpcSecurityGroupMembership>
<Status>active</Status>
<VpcSecurityGroupId>{{ vpc_security_group_id }}</VpcSecurityGroupId>
</VpcSecurityGroupMembership>
{% endfor %}
</VpcSecurityGroups>
<DBInstanceIdentifier>{{ database.db_instance_identifier }}</DBInstanceIdentifier>
<DbiResourceId>{{ database.dbi_resource_id }}</DbiResourceId>
<InstanceCreateTime>{{ database.instance_create_time }}</InstanceCreateTime>
<PreferredBackupWindow>03:50-04:20</PreferredBackupWindow>
<PreferredMaintenanceWindow>wed:06:38-wed:07:08</PreferredMaintenanceWindow>
<ReadReplicaDBInstanceIdentifiers>
@ -173,6 +185,10 @@ class Database(BaseModel):
<LicenseModel>{{ database.license_model }}</LicenseModel>
<EngineVersion>{{ database.engine_version }}</EngineVersion>
<OptionGroupMemberships>
<OptionGroupMembership>
<OptionGroupName>{{ database.option_group_name }}</OptionGroupName>
<Status>in-sync</Status>
</OptionGroupMembership>
</OptionGroupMemberships>
<DBParameterGroups>
{% for db_parameter_group in database.db_parameter_groups() %}
@ -314,6 +330,7 @@ class Database(BaseModel):
"storage_encrypted": properties.get("StorageEncrypted"),
"storage_type": properties.get("StorageType"),
"tags": properties.get("Tags"),
"vpc_security_group_ids": properties.get('VpcSecurityGroupIds', []),
}
rds2_backend = rds2_backends[region_name]
@ -373,7 +390,7 @@ class Database(BaseModel):
"Address": "{{ database.address }}",
"Port": "{{ database.port }}"
},
"InstanceCreateTime": null,
"InstanceCreateTime": "{{ database.instance_create_time }}",
"Iops": null,
"ReadReplicaDBInstanceIdentifiers": [{%- for replica in database.replicas -%}
{%- if not loop.first -%},{%- endif -%}
@ -388,10 +405,12 @@ class Database(BaseModel):
"SecondaryAvailabilityZone": null,
"StatusInfos": null,
"VpcSecurityGroups": [
{% for vpc_security_group_id in database.vpc_security_group_ids %}
{
"Status": "active",
"VpcSecurityGroupId": "sg-123456"
"VpcSecurityGroupId": "{{ vpc_security_group_id }}"
}
{% endfor %}
],
"DBInstanceArn": "{{ database.db_instance_arn }}"
}""")
@ -873,13 +892,16 @@ class RDS2Backend(BaseBackend):
def create_option_group(self, option_group_kwargs):
option_group_id = option_group_kwargs['name']
valid_option_group_engines = {'mysql': ['5.6'],
'oracle-se1': ['11.2'],
'oracle-se': ['11.2'],
'oracle-ee': ['11.2'],
valid_option_group_engines = {'mariadb': ['10.0', '10.1', '10.2', '10.3'],
'mysql': ['5.5', '5.6', '5.7', '8.0'],
'oracle-se2': ['11.2', '12.1', '12.2'],
'oracle-se1': ['11.2', '12.1', '12.2'],
'oracle-se': ['11.2', '12.1', '12.2'],
'oracle-ee': ['11.2', '12.1', '12.2'],
'sqlserver-se': ['10.50', '11.00'],
'sqlserver-ee': ['10.50', '11.00']
}
'sqlserver-ee': ['10.50', '11.00'],
'sqlserver-ex': ['10.50', '11.00'],
'sqlserver-web': ['10.50', '11.00']}
if option_group_kwargs['name'] in self.option_groups:
raise RDSClientError('OptionGroupAlreadyExistsFault',
'An option group named {0} already exists.'.format(option_group_kwargs['name']))
@ -905,8 +927,7 @@ class RDS2Backend(BaseBackend):
if option_group_name in self.option_groups:
return self.option_groups.pop(option_group_name)
else:
raise RDSClientError(
'OptionGroupNotFoundFault', 'Specified OptionGroupName: {0} not found.'.format(option_group_name))
raise OptionGroupNotFoundFaultError(option_group_name)
def describe_option_groups(self, option_group_kwargs):
option_group_list = []
@ -935,8 +956,7 @@ class RDS2Backend(BaseBackend):
else:
option_group_list.append(option_group)
if not len(option_group_list):
raise RDSClientError('OptionGroupNotFoundFault',
'Specified OptionGroupName: {0} not found.'.format(option_group_kwargs['name']))
raise OptionGroupNotFoundFaultError(option_group_kwargs['name'])
return option_group_list[marker:max_records + marker]
@staticmethod
@ -965,8 +985,7 @@ class RDS2Backend(BaseBackend):
def modify_option_group(self, option_group_name, options_to_include=None, options_to_remove=None, apply_immediately=None):
if option_group_name not in self.option_groups:
raise RDSClientError('OptionGroupNotFoundFault',
'Specified OptionGroupName: {0} not found.'.format(option_group_name))
raise OptionGroupNotFoundFaultError(option_group_name)
if not options_to_include and not options_to_remove:
raise RDSClientError('InvalidParameterValue',
'At least one option must be added, modified, or removed.')

View File

@ -34,7 +34,7 @@ class RDS2Response(BaseResponse):
"master_user_password": self._get_param('MasterUserPassword'),
"master_username": self._get_param('MasterUsername'),
"multi_az": self._get_bool_param("MultiAZ"),
# OptionGroupName
"option_group_name": self._get_param("OptionGroupName"),
"port": self._get_param('Port'),
# PreferredBackupWindow
# PreferredMaintenanceWindow
@ -43,7 +43,7 @@ class RDS2Response(BaseResponse):
"security_groups": self._get_multi_param('DBSecurityGroups.DBSecurityGroupName'),
"storage_encrypted": self._get_param("StorageEncrypted"),
"storage_type": self._get_param("StorageType", 'standard'),
# VpcSecurityGroupIds.member.N
"vpc_security_group_ids": self._get_multi_param("VpcSecurityGroupIds.VpcSecurityGroupId"),
"tags": list(),
}
args['tags'] = self.unpack_complex_list_params(
@ -280,7 +280,7 @@ class RDS2Response(BaseResponse):
def describe_option_groups(self):
kwargs = self._get_option_group_kwargs()
kwargs['max_records'] = self._get_param('MaxRecords')
kwargs['max_records'] = self._get_int_param('MaxRecords')
kwargs['marker'] = self._get_param('Marker')
option_groups = self.backend.describe_option_groups(kwargs)
template = self.response_template(DESCRIBE_OPTION_GROUP_TEMPLATE)
@ -329,7 +329,7 @@ class RDS2Response(BaseResponse):
def describe_db_parameter_groups(self):
kwargs = self._get_db_parameter_group_kwargs()
kwargs['max_records'] = self._get_param('MaxRecords')
kwargs['max_records'] = self._get_int_param('MaxRecords')
kwargs['marker'] = self._get_param('Marker')
db_parameter_groups = self.backend.describe_db_parameter_groups(kwargs)
template = self.response_template(

View File

@ -78,7 +78,7 @@ class Cluster(TaggableResourceMixin, BaseModel):
super(Cluster, self).__init__(region_name, tags)
self.redshift_backend = redshift_backend
self.cluster_identifier = cluster_identifier
self.create_time = iso_8601_datetime_with_milliseconds(datetime.datetime.now())
self.create_time = iso_8601_datetime_with_milliseconds(datetime.datetime.utcnow())
self.status = 'available'
self.node_type = node_type
self.master_username = master_username

View File

@ -10,6 +10,7 @@ from moto.ec2 import ec2_backends
from moto.elb import elb_backends
from moto.elbv2 import elbv2_backends
from moto.kinesis import kinesis_backends
from moto.kms import kms_backends
from moto.rds2 import rds2_backends
from moto.glacier import glacier_backends
from moto.redshift import redshift_backends
@ -71,6 +72,13 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend):
"""
return kinesis_backends[self.region_name]
@property
def kms_backend(self):
"""
:rtype: moto.kms.models.KmsBackend
"""
return kms_backends[self.region_name]
@property
def rds_backend(self):
"""
@ -221,9 +229,6 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend):
if not resource_type_filters or 'elasticloadbalancer' in resource_type_filters or 'elasticloadbalancer:loadbalancer' in resource_type_filters:
for elb in self.elbv2_backend.load_balancers.values():
tags = get_elbv2_tags(elb.arn)
# if 'elasticloadbalancer:loadbalancer' in resource_type_filters:
# from IPython import embed
# embed()
if not tag_filter(tags): # Skip if no tags, or invalid filter
continue
@ -235,6 +240,21 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend):
# Kinesis
# KMS
def get_kms_tags(kms_key_id):
result = []
for tag in self.kms_backend.list_resource_tags(kms_key_id):
result.append({'Key': tag['TagKey'], 'Value': tag['TagValue']})
return result
if not resource_type_filters or 'kms' in resource_type_filters:
for kms_key in self.kms_backend.list_keys():
tags = get_kms_tags(kms_key.id)
if not tag_filter(tags): # Skip if no tags, or invalid filter
continue
yield {'ResourceARN': '{0}'.format(kms_key.arn), 'Tags': tags}
# RDS Instance
# RDS Reserved Database Instance
# RDS Option Group
@ -370,7 +390,7 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend):
def get_resources(self, pagination_token=None,
resources_per_page=50, tags_per_page=100,
tag_filters=None, resource_type_filters=None):
# Simple range checning
# Simple range checking
if 100 >= tags_per_page >= 500:
raise RESTError('InvalidParameterException', 'TagsPerPage must be between 100 and 500')
if 1 >= resources_per_page >= 50:

View File

@ -85,6 +85,7 @@ class RecordSet(BaseModel):
self.health_check = kwargs.get('HealthCheckId')
self.hosted_zone_name = kwargs.get('HostedZoneName')
self.hosted_zone_id = kwargs.get('HostedZoneId')
self.alias_target = kwargs.get('AliasTarget')
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
@ -119,7 +120,7 @@ class RecordSet(BaseModel):
properties["HostedZoneId"])
try:
hosted_zone.delete_rrset_by_name(resource_name)
hosted_zone.delete_rrset({'Name': resource_name})
except KeyError:
pass
@ -143,6 +144,13 @@ class RecordSet(BaseModel):
{% if record_set.ttl %}
<TTL>{{ record_set.ttl }}</TTL>
{% endif %}
{% if record_set.alias_target %}
<AliasTarget>
<HostedZoneId>{{ record_set.alias_target['HostedZoneId'] }}</HostedZoneId>
<DNSName>{{ record_set.alias_target['DNSName'] }}</DNSName>
<EvaluateTargetHealth>{{ record_set.alias_target['EvaluateTargetHealth'] }}</EvaluateTargetHealth>
</AliasTarget>
{% else %}
<ResourceRecords>
{% for record in record_set.records %}
<ResourceRecord>
@ -150,6 +158,7 @@ class RecordSet(BaseModel):
</ResourceRecord>
{% endfor %}
</ResourceRecords>
{% endif %}
{% if record_set.health_check %}
<HealthCheckId>{{ record_set.health_check }}</HealthCheckId>
{% endif %}
@ -162,7 +171,13 @@ class RecordSet(BaseModel):
self.hosted_zone_name)
if not hosted_zone:
hosted_zone = route53_backend.get_hosted_zone(self.hosted_zone_id)
hosted_zone.delete_rrset_by_name(self.name)
hosted_zone.delete_rrset({'Name': self.name, 'Type': self.type_})
def reverse_domain_name(domain_name):
if domain_name.endswith('.'): # normalize without trailing dot
domain_name = domain_name[:-1]
return '.'.join(reversed(domain_name.split('.')))
class FakeZone(BaseModel):
@ -183,16 +198,20 @@ class FakeZone(BaseModel):
def upsert_rrset(self, record_set):
new_rrset = RecordSet(record_set)
for i, rrset in enumerate(self.rrsets):
if rrset.name == new_rrset.name and rrset.type_ == new_rrset.type_:
if rrset.name == new_rrset.name and rrset.type_ == new_rrset.type_ and rrset.set_identifier == new_rrset.set_identifier:
self.rrsets[i] = new_rrset
break
else:
self.rrsets.append(new_rrset)
return new_rrset
def delete_rrset_by_name(self, name):
def delete_rrset(self, rrset):
self.rrsets = [
record_set for record_set in self.rrsets if record_set.name != name]
record_set
for record_set in self.rrsets
if record_set.name != rrset['Name'] or
(rrset.get('Type') is not None and record_set.type_ != rrset['Type'])
]
def delete_rrset_by_id(self, set_identifier):
self.rrsets = [
@ -200,12 +219,15 @@ class FakeZone(BaseModel):
def get_record_sets(self, start_type, start_name):
record_sets = list(self.rrsets) # Copy the list
if start_name:
record_sets = [
record_set
for record_set in record_sets
if reverse_domain_name(record_set.name) >= reverse_domain_name(start_name)
]
if start_type:
record_sets = [
record_set for record_set in record_sets if record_set.type_ >= start_type]
if start_name:
record_sets = [
record_set for record_set in record_sets if record_set.name >= start_name]
return record_sets

View File

@ -134,10 +134,7 @@ class Route53(BaseResponse):
# Depending on how many records there are, this may
# or may not be a list
resource_records = [resource_records]
record_values = [x['Value'] for x in resource_records]
elif 'AliasTarget' in record_set:
record_values = [record_set['AliasTarget']['DNSName']]
record_set['ResourceRecords'] = record_values
record_set['ResourceRecords'] = [x['Value'] for x in resource_records]
if action == 'CREATE':
the_zone.add_rrset(record_set)
else:
@ -147,7 +144,7 @@ class Route53(BaseResponse):
the_zone.delete_rrset_by_id(
record_set["SetIdentifier"])
else:
the_zone.delete_rrset_by_name(record_set["Name"])
the_zone.delete_rrset(record_set)
return 200, headers, CHANGE_RRSET_RESPONSE

View File

@ -60,6 +60,17 @@ class MissingKey(S3ClientError):
)
class ObjectNotInActiveTierError(S3ClientError):
code = 403
def __init__(self, key_name):
super(ObjectNotInActiveTierError, self).__init__(
"ObjectNotInActiveTierError",
"The source object of the COPY operation is not in the active tier and is only stored in Amazon Glacier.",
Key=key_name,
)
class InvalidPartOrder(S3ClientError):
code = 400
@ -199,3 +210,67 @@ class DuplicateTagKeys(S3ClientError):
"InvalidTag",
"Cannot provide multiple Tags with the same key",
*args, **kwargs)
class S3AccessDeniedError(S3ClientError):
code = 403
def __init__(self, *args, **kwargs):
super(S3AccessDeniedError, self).__init__('AccessDenied', 'Access Denied', *args, **kwargs)
class BucketAccessDeniedError(BucketError):
code = 403
def __init__(self, *args, **kwargs):
super(BucketAccessDeniedError, self).__init__('AccessDenied', 'Access Denied', *args, **kwargs)
class S3InvalidTokenError(S3ClientError):
code = 400
def __init__(self, *args, **kwargs):
super(S3InvalidTokenError, self).__init__('InvalidToken', 'The provided token is malformed or otherwise invalid.', *args, **kwargs)
class BucketInvalidTokenError(BucketError):
code = 400
def __init__(self, *args, **kwargs):
super(BucketInvalidTokenError, self).__init__('InvalidToken', 'The provided token is malformed or otherwise invalid.', *args, **kwargs)
class S3InvalidAccessKeyIdError(S3ClientError):
code = 403
def __init__(self, *args, **kwargs):
super(S3InvalidAccessKeyIdError, self).__init__(
'InvalidAccessKeyId',
"The AWS Access Key Id you provided does not exist in our records.", *args, **kwargs)
class BucketInvalidAccessKeyIdError(S3ClientError):
code = 403
def __init__(self, *args, **kwargs):
super(BucketInvalidAccessKeyIdError, self).__init__(
'InvalidAccessKeyId',
"The AWS Access Key Id you provided does not exist in our records.", *args, **kwargs)
class S3SignatureDoesNotMatchError(S3ClientError):
code = 403
def __init__(self, *args, **kwargs):
super(S3SignatureDoesNotMatchError, self).__init__(
'SignatureDoesNotMatch',
"The request signature we calculated does not match the signature you provided. Check your key and signing method.", *args, **kwargs)
class BucketSignatureDoesNotMatchError(S3ClientError):
code = 403
def __init__(self, *args, **kwargs):
super(BucketSignatureDoesNotMatchError, self).__init__(
'SignatureDoesNotMatch',
"The request signature we calculated does not match the signature you provided. Check your key and signing method.", *args, **kwargs)

View File

@ -28,7 +28,8 @@ MAX_BUCKET_NAME_LENGTH = 63
MIN_BUCKET_NAME_LENGTH = 3
UPLOAD_ID_BYTES = 43
UPLOAD_PART_MIN_SIZE = 5242880
STORAGE_CLASS = ["STANDARD", "REDUCED_REDUNDANCY", "STANDARD_IA", "ONEZONE_IA"]
STORAGE_CLASS = ["STANDARD", "REDUCED_REDUNDANCY", "STANDARD_IA", "ONEZONE_IA",
"INTELLIGENT_TIERING", "GLACIER", "DEEP_ARCHIVE"]
DEFAULT_KEY_BUFFER_SIZE = 16 * 1024 * 1024
DEFAULT_TEXT_ENCODING = sys.getdefaultencoding()
@ -52,8 +53,17 @@ class FakeDeleteMarker(BaseModel):
class FakeKey(BaseModel):
def __init__(self, name, value, storage="STANDARD", etag=None, is_versioned=False, version_id=0,
max_buffer_size=DEFAULT_KEY_BUFFER_SIZE):
def __init__(
self,
name,
value,
storage="STANDARD",
etag=None,
is_versioned=False,
version_id=0,
max_buffer_size=DEFAULT_KEY_BUFFER_SIZE,
multipart=None
):
self.name = name
self.last_modified = datetime.datetime.utcnow()
self.acl = get_canned_acl('private')
@ -65,6 +75,7 @@ class FakeKey(BaseModel):
self._version_id = version_id
self._is_versioned = is_versioned
self._tagging = FakeTagging()
self.multipart = multipart
self._value_buffer = tempfile.SpooledTemporaryFile(max_size=max_buffer_size)
self._max_buffer_size = max_buffer_size
@ -754,7 +765,7 @@ class S3Backend(BaseBackend):
prefix=''):
bucket = self.get_bucket(bucket_name)
if any((delimiter, encoding_type, key_marker, version_id_marker)):
if any((delimiter, key_marker, version_id_marker)):
raise NotImplementedError(
"Called get_bucket_versions with some of delimiter, encoding_type, key_marker, version_id_marker")
@ -782,7 +793,15 @@ class S3Backend(BaseBackend):
bucket = self.get_bucket(bucket_name)
return bucket.website_configuration
def set_key(self, bucket_name, key_name, value, storage=None, etag=None):
def set_key(
self,
bucket_name,
key_name,
value,
storage=None,
etag=None,
multipart=None,
):
key_name = clean_key_name(key_name)
if storage is not None and storage not in STORAGE_CLASS:
raise InvalidStorageClass(storage=storage)
@ -795,7 +814,9 @@ class S3Backend(BaseBackend):
storage=storage,
etag=etag,
is_versioned=bucket.is_versioned,
version_id=str(uuid.uuid4()) if bucket.is_versioned else None)
version_id=str(uuid.uuid4()) if bucket.is_versioned else None,
multipart=multipart,
)
keys = [
key for key in bucket.keys.getlist(key_name, [])
@ -812,7 +833,7 @@ class S3Backend(BaseBackend):
key.append_to_value(value)
return key
def get_key(self, bucket_name, key_name, version_id=None):
def get_key(self, bucket_name, key_name, version_id=None, part_number=None):
key_name = clean_key_name(key_name)
bucket = self.get_bucket(bucket_name)
key = None
@ -827,6 +848,9 @@ class S3Backend(BaseBackend):
key = key_version
break
if part_number and key.multipart:
key = key.multipart.parts[part_number]
if isinstance(key, FakeKey):
return key
else:
@ -890,7 +914,12 @@ class S3Backend(BaseBackend):
return
del bucket.multiparts[multipart_id]
key = self.set_key(bucket_name, multipart.key_name, value, etag=etag)
key = self.set_key(
bucket_name,
multipart.key_name,
value, etag=etag,
multipart=multipart
)
key.set_metadata(multipart.metadata)
return key

View File

@ -3,20 +3,21 @@ from __future__ import unicode_literals
import re
import six
from moto.core.utils import str_to_rfc_1123_datetime
from six.moves.urllib.parse import parse_qs, urlparse, unquote
import xmltodict
from moto.packages.httpretty.core import HTTPrettyRequest
from moto.core.responses import _TemplateEnvironmentMixin
from moto.core.responses import _TemplateEnvironmentMixin, ActionAuthenticatorMixin
from moto.core.utils import path_url
from moto.s3bucket_path.utils import bucket_name_from_url as bucketpath_bucket_name_from_url, \
parse_key_name as bucketpath_parse_key_name, is_delete_keys as bucketpath_is_delete_keys
from .exceptions import BucketAlreadyExists, S3ClientError, MissingBucket, MissingKey, InvalidPartOrder, MalformedXML, \
MalformedACLError, InvalidNotificationARN, InvalidNotificationEvent
MalformedACLError, InvalidNotificationARN, InvalidNotificationEvent, ObjectNotInActiveTierError
from .models import s3_backend, get_canned_acl, FakeGrantee, FakeGrant, FakeAcl, FakeKey, FakeTagging, FakeTagSet, \
FakeTag
from .utils import bucket_name_from_url, clean_key_name, metadata_from_headers, parse_region_from_url
@ -25,6 +26,72 @@ from xml.dom import minidom
DEFAULT_REGION_NAME = 'us-east-1'
ACTION_MAP = {
"BUCKET": {
"GET": {
"uploads": "ListBucketMultipartUploads",
"location": "GetBucketLocation",
"lifecycle": "GetLifecycleConfiguration",
"versioning": "GetBucketVersioning",
"policy": "GetBucketPolicy",
"website": "GetBucketWebsite",
"acl": "GetBucketAcl",
"tagging": "GetBucketTagging",
"logging": "GetBucketLogging",
"cors": "GetBucketCORS",
"notification": "GetBucketNotification",
"accelerate": "GetAccelerateConfiguration",
"versions": "ListBucketVersions",
"DEFAULT": "ListBucket"
},
"PUT": {
"lifecycle": "PutLifecycleConfiguration",
"versioning": "PutBucketVersioning",
"policy": "PutBucketPolicy",
"website": "PutBucketWebsite",
"acl": "PutBucketAcl",
"tagging": "PutBucketTagging",
"logging": "PutBucketLogging",
"cors": "PutBucketCORS",
"notification": "PutBucketNotification",
"accelerate": "PutAccelerateConfiguration",
"DEFAULT": "CreateBucket"
},
"DELETE": {
"lifecycle": "PutLifecycleConfiguration",
"policy": "DeleteBucketPolicy",
"tagging": "PutBucketTagging",
"cors": "PutBucketCORS",
"DEFAULT": "DeleteBucket"
}
},
"KEY": {
"GET": {
"uploadId": "ListMultipartUploadParts",
"acl": "GetObjectAcl",
"tagging": "GetObjectTagging",
"versionId": "GetObjectVersion",
"DEFAULT": "GetObject"
},
"PUT": {
"acl": "PutObjectAcl",
"tagging": "PutObjectTagging",
"DEFAULT": "PutObject"
},
"DELETE": {
"uploadId": "AbortMultipartUpload",
"versionId": "DeleteObjectVersion",
"DEFAULT": " DeleteObject"
},
"POST": {
"uploads": "PutObject",
"restore": "RestoreObject",
"uploadId": "PutObject"
}
}
}
def parse_key_name(pth):
return pth.lstrip("/")
@ -37,17 +104,24 @@ def is_delete_keys(request, path, bucket_name):
)
class ResponseObject(_TemplateEnvironmentMixin):
class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
def __init__(self, backend):
super(ResponseObject, self).__init__()
self.backend = backend
self.method = ""
self.path = ""
self.data = {}
self.headers = {}
@property
def should_autoescape(self):
return True
def all_buckets(self):
self.data["Action"] = "ListAllMyBuckets"
self._authenticate_and_authorize_s3_action()
# No bucket specified. Listing all buckets
all_buckets = self.backend.get_all_buckets()
template = self.response_template(S3_ALL_BUCKETS)
@ -112,11 +186,20 @@ class ResponseObject(_TemplateEnvironmentMixin):
return self.bucket_response(request, full_url, headers)
def bucket_response(self, request, full_url, headers):
self.method = request.method
self.path = self._get_path(request)
self.headers = request.headers
if 'host' not in self.headers:
self.headers['host'] = urlparse(full_url).netloc
try:
response = self._bucket_response(request, full_url, headers)
except S3ClientError as s3error:
response = s3error.code, {}, s3error.description
return self._send_response(response)
@staticmethod
def _send_response(response):
if isinstance(response, six.string_types):
return 200, {}, response.encode("utf-8")
else:
@ -127,8 +210,7 @@ class ResponseObject(_TemplateEnvironmentMixin):
return status_code, headers, response_content
def _bucket_response(self, request, full_url, headers):
parsed_url = urlparse(full_url)
querystring = parse_qs(parsed_url.query, keep_blank_values=True)
querystring = self._get_querystring(full_url)
method = request.method
region_name = parse_region_from_url(full_url)
@ -137,6 +219,8 @@ class ResponseObject(_TemplateEnvironmentMixin):
# If no bucket specified, list all buckets
return self.all_buckets()
self.data["BucketName"] = bucket_name
if hasattr(request, 'body'):
# Boto
body = request.body
@ -150,20 +234,26 @@ class ResponseObject(_TemplateEnvironmentMixin):
body = u'{0}'.format(body).encode('utf-8')
if method == 'HEAD':
return self._bucket_response_head(bucket_name, headers)
return self._bucket_response_head(bucket_name)
elif method == 'GET':
return self._bucket_response_get(bucket_name, querystring, headers)
return self._bucket_response_get(bucket_name, querystring)
elif method == 'PUT':
return self._bucket_response_put(request, body, region_name, bucket_name, querystring, headers)
return self._bucket_response_put(request, body, region_name, bucket_name, querystring)
elif method == 'DELETE':
return self._bucket_response_delete(body, bucket_name, querystring, headers)
return self._bucket_response_delete(body, bucket_name, querystring)
elif method == 'POST':
return self._bucket_response_post(request, body, bucket_name, headers)
return self._bucket_response_post(request, body, bucket_name)
else:
raise NotImplementedError(
"Method {0} has not been impelemented in the S3 backend yet".format(method))
def _bucket_response_head(self, bucket_name, headers):
@staticmethod
def _get_querystring(full_url):
parsed_url = urlparse(full_url)
querystring = parse_qs(parsed_url.query, keep_blank_values=True)
return querystring
def _bucket_response_head(self, bucket_name):
try:
self.backend.get_bucket(bucket_name)
except MissingBucket:
@ -174,7 +264,10 @@ class ResponseObject(_TemplateEnvironmentMixin):
return 404, {}, ""
return 200, {}, ""
def _bucket_response_get(self, bucket_name, querystring, headers):
def _bucket_response_get(self, bucket_name, querystring):
self._set_action("BUCKET", "GET", querystring)
self._authenticate_and_authorize_s3_action()
if 'uploads' in querystring:
for unsup in ('delimiter', 'max-uploads'):
if unsup in querystring:
@ -333,6 +426,15 @@ class ResponseObject(_TemplateEnvironmentMixin):
max_keys=max_keys
)
def _set_action(self, action_resource_type, method, querystring):
action_set = False
for action_in_querystring, action in ACTION_MAP[action_resource_type][method].items():
if action_in_querystring in querystring:
self.data["Action"] = action
action_set = True
if not action_set:
self.data["Action"] = ACTION_MAP[action_resource_type][method]["DEFAULT"]
def _handle_list_objects_v2(self, bucket_name, querystring):
template = self.response_template(S3_BUCKET_GET_RESPONSE_V2)
bucket = self.backend.get_bucket(bucket_name)
@ -361,10 +463,13 @@ class ResponseObject(_TemplateEnvironmentMixin):
else:
result_folders, is_truncated, next_continuation_token = self._truncate_result(result_folders, max_keys)
key_count = len(result_keys) + len(result_folders)
return template.render(
bucket=bucket,
prefix=prefix or '',
delimiter=delimiter,
key_count=key_count,
result_keys=result_keys,
result_folders=result_folders,
fetch_owner=fetch_owner,
@ -393,9 +498,13 @@ class ResponseObject(_TemplateEnvironmentMixin):
next_continuation_token = None
return result_keys, is_truncated, next_continuation_token
def _bucket_response_put(self, request, body, region_name, bucket_name, querystring, headers):
def _bucket_response_put(self, request, body, region_name, bucket_name, querystring):
if not request.headers.get('Content-Length'):
return 411, {}, "Content-Length required"
self._set_action("BUCKET", "PUT", querystring)
self._authenticate_and_authorize_s3_action()
if 'versioning' in querystring:
ver = re.search('<Status>([A-Za-z]+)</Status>', body.decode())
if ver:
@ -494,7 +603,10 @@ class ResponseObject(_TemplateEnvironmentMixin):
template = self.response_template(S3_BUCKET_CREATE_RESPONSE)
return 200, {}, template.render(bucket=new_bucket)
def _bucket_response_delete(self, body, bucket_name, querystring, headers):
def _bucket_response_delete(self, body, bucket_name, querystring):
self._set_action("BUCKET", "DELETE", querystring)
self._authenticate_and_authorize_s3_action()
if 'policy' in querystring:
self.backend.delete_bucket_policy(bucket_name, body)
return 204, {}, ""
@ -521,17 +633,20 @@ class ResponseObject(_TemplateEnvironmentMixin):
S3_DELETE_BUCKET_WITH_ITEMS_ERROR)
return 409, {}, template.render(bucket=removed_bucket)
def _bucket_response_post(self, request, body, bucket_name, headers):
def _bucket_response_post(self, request, body, bucket_name):
if not request.headers.get('Content-Length'):
return 411, {}, "Content-Length required"
if isinstance(request, HTTPrettyRequest):
path = request.path
else:
path = request.full_path if hasattr(request, 'full_path') else path_url(request.url)
path = self._get_path(request)
if self.is_delete_keys(request, path, bucket_name):
return self._bucket_response_delete_keys(request, body, bucket_name, headers)
self.data["Action"] = "DeleteObject"
self._authenticate_and_authorize_s3_action()
return self._bucket_response_delete_keys(request, body, bucket_name)
self.data["Action"] = "PutObject"
self._authenticate_and_authorize_s3_action()
# POST to bucket-url should create file from form
if hasattr(request, 'form'):
@ -560,12 +675,22 @@ class ResponseObject(_TemplateEnvironmentMixin):
return 200, {}, ""
def _bucket_response_delete_keys(self, request, body, bucket_name, headers):
@staticmethod
def _get_path(request):
if isinstance(request, HTTPrettyRequest):
path = request.path
else:
path = request.full_path if hasattr(request, 'full_path') else path_url(request.url)
return path
def _bucket_response_delete_keys(self, request, body, bucket_name):
template = self.response_template(S3_DELETE_KEYS_RESPONSE)
keys = minidom.parseString(body).getElementsByTagName('Key')
deleted_names = []
error_names = []
if len(keys) == 0:
raise MalformedXML()
for k in keys:
key_name = k.firstChild.nodeValue
@ -604,6 +729,11 @@ class ResponseObject(_TemplateEnvironmentMixin):
return 206, response_headers, response_content[begin:end + 1]
def key_response(self, request, full_url, headers):
self.method = request.method
self.path = self._get_path(request)
self.headers = request.headers
if 'host' not in self.headers:
self.headers['host'] = urlparse(full_url).netloc
response_headers = {}
try:
response = self._key_response(request, full_url, headers)
@ -657,20 +787,23 @@ class ResponseObject(_TemplateEnvironmentMixin):
body = b''
if method == 'GET':
return self._key_response_get(bucket_name, query, key_name, headers)
return self._key_response_get(bucket_name, query, key_name, headers=request.headers)
elif method == 'PUT':
return self._key_response_put(request, body, bucket_name, query, key_name, headers)
elif method == 'HEAD':
return self._key_response_head(bucket_name, query, key_name, headers=request.headers)
elif method == 'DELETE':
return self._key_response_delete(bucket_name, query, key_name, headers)
return self._key_response_delete(bucket_name, query, key_name)
elif method == 'POST':
return self._key_response_post(request, body, bucket_name, query, key_name, headers)
return self._key_response_post(request, body, bucket_name, query, key_name)
else:
raise NotImplementedError(
"Method {0} has not been implemented in the S3 backend yet".format(method))
def _key_response_get(self, bucket_name, query, key_name, headers):
self._set_action("KEY", "GET", query)
self._authenticate_and_authorize_s3_action()
response_headers = {}
if query.get('uploadId'):
upload_id = query['uploadId'][0]
@ -684,10 +817,15 @@ class ResponseObject(_TemplateEnvironmentMixin):
parts=parts
)
version_id = query.get('versionId', [None])[0]
if_modified_since = headers.get('If-Modified-Since', None)
key = self.backend.get_key(
bucket_name, key_name, version_id=version_id)
if key is None:
raise MissingKey(key_name)
if if_modified_since:
if_modified_since = str_to_rfc_1123_datetime(if_modified_since)
if if_modified_since and key.last_modified < if_modified_since:
return 304, response_headers, 'Not Modified'
if 'acl' in query:
template = self.response_template(S3_OBJECT_ACL_RESPONSE)
return 200, response_headers, template.render(obj=key)
@ -700,6 +838,9 @@ class ResponseObject(_TemplateEnvironmentMixin):
return 200, response_headers, key.value
def _key_response_put(self, request, body, bucket_name, query, key_name, headers):
self._set_action("KEY", "PUT", query)
self._authenticate_and_authorize_s3_action()
response_headers = {}
if query.get('uploadId') and query.get('partNumber'):
upload_id = query['uploadId'][0]
@ -764,7 +905,11 @@ class ResponseObject(_TemplateEnvironmentMixin):
src_version_id = parse_qs(src_key_parsed.query).get(
'versionId', [None])[0]
if self.backend.get_key(src_bucket, src_key, version_id=src_version_id):
key = self.backend.get_key(src_bucket, src_key, version_id=src_version_id)
if key is not None:
if key.storage_class in ["GLACIER", "DEEP_ARCHIVE"]:
raise ObjectNotInActiveTierError(key)
self.backend.copy_key(src_bucket, src_key, bucket_name, key_name,
storage=storage_class, acl=acl, src_version_id=src_version_id)
else:
@ -804,13 +949,20 @@ class ResponseObject(_TemplateEnvironmentMixin):
def _key_response_head(self, bucket_name, query, key_name, headers):
response_headers = {}
version_id = query.get('versionId', [None])[0]
part_number = query.get('partNumber', [None])[0]
if part_number:
part_number = int(part_number)
if_modified_since = headers.get('If-Modified-Since', None)
if if_modified_since:
if_modified_since = str_to_rfc_1123_datetime(if_modified_since)
key = self.backend.get_key(
bucket_name, key_name, version_id=version_id)
bucket_name,
key_name,
version_id=version_id,
part_number=part_number
)
if key:
response_headers.update(key.metadata)
response_headers.update(key.response_dict)
@ -1066,7 +1218,10 @@ class ResponseObject(_TemplateEnvironmentMixin):
config = parsed_xml['AccelerateConfiguration']
return config['Status']
def _key_response_delete(self, bucket_name, query, key_name, headers):
def _key_response_delete(self, bucket_name, query, key_name):
self._set_action("KEY", "DELETE", query)
self._authenticate_and_authorize_s3_action()
if query.get('uploadId'):
upload_id = query['uploadId'][0]
self.backend.cancel_multipart(bucket_name, upload_id)
@ -1086,7 +1241,10 @@ class ResponseObject(_TemplateEnvironmentMixin):
raise InvalidPartOrder()
yield (pn, p.getElementsByTagName('ETag')[0].firstChild.wholeText)
def _key_response_post(self, request, body, bucket_name, query, key_name, headers):
def _key_response_post(self, request, body, bucket_name, query, key_name):
self._set_action("KEY", "POST", query)
self._authenticate_and_authorize_s3_action()
if body == b'' and 'uploads' in query:
metadata = metadata_from_headers(request.headers)
multipart = self.backend.initiate_multipart(
@ -1175,7 +1333,7 @@ S3_BUCKET_GET_RESPONSE_V2 = """<?xml version="1.0" encoding="UTF-8"?>
<Name>{{ bucket.name }}</Name>
<Prefix>{{ prefix }}</Prefix>
<MaxKeys>{{ max_keys }}</MaxKeys>
<KeyCount>{{ result_keys | length }}</KeyCount>
<KeyCount>{{ key_count }}</KeyCount>
{% if delimiter %}
<Delimiter>{{ delimiter }}</Delimiter>
{% endif %}

View File

@ -7,15 +7,6 @@ url_bases = [
r"https?://(?P<bucket_name>[a-zA-Z0-9\-_.]*)\.?s3(.*).amazonaws.com"
]
def ambiguous_response1(*args, **kwargs):
return S3ResponseInstance.ambiguous_response(*args, **kwargs)
def ambiguous_response2(*args, **kwargs):
return S3ResponseInstance.ambiguous_response(*args, **kwargs)
url_paths = {
# subdomain bucket
'{0}/$': S3ResponseInstance.bucket_response,

View File

@ -70,24 +70,31 @@ class SecretsManagerBackend(BaseBackend):
secret_version = secret['versions'][version_id]
response = json.dumps({
response_data = {
"ARN": secret_arn(self.region, secret['secret_id']),
"Name": secret['name'],
"VersionId": secret_version['version_id'],
"SecretString": secret_version['secret_string'],
"VersionStages": secret_version['version_stages'],
"CreatedDate": secret_version['createdate'],
})
}
if 'secret_string' in secret_version:
response_data["SecretString"] = secret_version['secret_string']
if 'secret_binary' in secret_version:
response_data["SecretBinary"] = secret_version['secret_binary']
response = json.dumps(response_data)
return response
def create_secret(self, name, secret_string, tags, **kwargs):
def create_secret(self, name, secret_string=None, secret_binary=None, tags=[], **kwargs):
# error if secret exists
if name in self.secrets.keys():
raise ResourceExistsException('A resource with the ID you requested already exists.')
version_id = self._add_secret(name, secret_string, tags=tags)
version_id = self._add_secret(name, secret_string=secret_string, secret_binary=secret_binary, tags=tags)
response = json.dumps({
"ARN": secret_arn(self.region, name),
@ -97,7 +104,7 @@ class SecretsManagerBackend(BaseBackend):
return response
def _add_secret(self, secret_id, secret_string, tags=[], version_id=None, version_stages=None):
def _add_secret(self, secret_id, secret_string=None, secret_binary=None, tags=[], version_id=None, version_stages=None):
if version_stages is None:
version_stages = ['AWSCURRENT']
@ -106,12 +113,17 @@ class SecretsManagerBackend(BaseBackend):
version_id = str(uuid.uuid4())
secret_version = {
'secret_string': secret_string,
'createdate': int(time.time()),
'version_id': version_id,
'version_stages': version_stages,
}
if secret_string is not None:
secret_version['secret_string'] = secret_string
if secret_binary is not None:
secret_version['secret_binary'] = secret_binary
if secret_id in self.secrets:
# remove all old AWSPREVIOUS stages
for secret_verion_to_look_at in self.secrets[secret_id]['versions'].values():

View File

@ -21,10 +21,12 @@ class SecretsManagerResponse(BaseResponse):
def create_secret(self):
name = self._get_param('Name')
secret_string = self._get_param('SecretString')
secret_binary = self._get_param('SecretBinary')
tags = self._get_param('Tags', if_none=[])
return secretsmanager_backends[self.region].create_secret(
name=name,
secret_string=secret_string,
secret_binary=secret_binary,
tags=tags
)

View File

@ -21,6 +21,16 @@ from moto.core.utils import convert_flask_to_httpretty_response
HTTP_METHODS = ["GET", "POST", "PUT", "DELETE", "HEAD", "PATCH"]
DEFAULT_SERVICE_REGION = ('s3', 'us-east-1')
# Map of unsigned calls to service-region as per AWS API docs
# https://docs.aws.amazon.com/cognito/latest/developerguide/resource-permissions.html#amazon-cognito-signed-versus-unsigned-apis
UNSIGNED_REQUESTS = {
'AWSCognitoIdentityService': ('cognito-identity', 'us-east-1'),
'AWSCognitoIdentityProviderService': ('cognito-idp', 'us-east-1'),
}
class DomainDispatcherApplication(object):
"""
Dispatch requests to different applications based on the "Host:" header
@ -48,7 +58,45 @@ class DomainDispatcherApplication(object):
if re.match(url_base, 'http://%s' % host):
return backend_name
raise RuntimeError('Invalid host: "%s"' % host)
def infer_service_region_host(self, environ):
auth = environ.get('HTTP_AUTHORIZATION')
if auth:
# Signed request
# Parse auth header to find service assuming a SigV4 request
# https://docs.aws.amazon.com/general/latest/gr/sigv4-signed-request-examples.html
# ['Credential=sdffdsa', '20170220', 'us-east-1', 'sns', 'aws4_request']
try:
credential_scope = auth.split(",")[0].split()[1]
_, _, region, service, _ = credential_scope.split("/")
except ValueError:
# Signature format does not match, this is exceptional and we can't
# infer a service-region. A reduced set of services still use
# the deprecated SigV2, ergo prefer S3 as most likely default.
# https://docs.aws.amazon.com/general/latest/gr/signature-version-2.html
service, region = DEFAULT_SERVICE_REGION
else:
# Unsigned request
target = environ.get('HTTP_X_AMZ_TARGET')
if target:
service, _ = target.split('.', 1)
service, region = UNSIGNED_REQUESTS.get(service, DEFAULT_SERVICE_REGION)
else:
# S3 is the last resort when the target is also unknown
service, region = DEFAULT_SERVICE_REGION
if service == 'dynamodb':
if environ['HTTP_X_AMZ_TARGET'].startswith('DynamoDBStreams'):
host = 'dynamodbstreams'
else:
dynamo_api_version = environ['HTTP_X_AMZ_TARGET'].split("_")[1].split(".")[0]
# If Newer API version, use dynamodb2
if dynamo_api_version > "20111205":
host = "dynamodb2"
else:
host = "{service}.{region}.amazonaws.com".format(
service=service, region=region)
return host
def get_application(self, environ):
path_info = environ.get('PATH_INFO', '')
@ -65,34 +113,14 @@ class DomainDispatcherApplication(object):
host = "instance_metadata"
else:
host = environ['HTTP_HOST'].split(':')[0]
if host in {'localhost', 'motoserver'} or host.startswith("192.168."):
# Fall back to parsing auth header to find service
# ['Credential=sdffdsa', '20170220', 'us-east-1', 'sns', 'aws4_request']
try:
_, _, region, service, _ = environ['HTTP_AUTHORIZATION'].split(",")[0].split()[
1].split("/")
except (KeyError, ValueError):
# Some cognito-idp endpoints (e.g. change password) do not receive an auth header.
if environ.get('HTTP_X_AMZ_TARGET', '').startswith('AWSCognitoIdentityProviderService'):
service = 'cognito-idp'
else:
service = 's3'
region = 'us-east-1'
if service == 'dynamodb':
if environ['HTTP_X_AMZ_TARGET'].startswith('DynamoDBStreams'):
host = 'dynamodbstreams'
else:
dynamo_api_version = environ['HTTP_X_AMZ_TARGET'].split("_")[1].split(".")[0]
# If Newer API version, use dynamodb2
if dynamo_api_version > "20111205":
host = "dynamodb2"
else:
host = "{service}.{region}.amazonaws.com".format(
service=service, region=region)
with self.lock:
backend = self.get_backend_for_host(host)
if not backend:
# No regular backend found; try parsing other headers
host = self.infer_service_region_host(environ)
backend = self.get_backend_for_host(host)
app = self.app_instances.get(backend, None)
if app is None:
app = self.create_app(backend)

View File

@ -4,13 +4,41 @@ import email
from email.utils import parseaddr
from moto.core import BaseBackend, BaseModel
from moto.sns.models import sns_backends
from .exceptions import MessageRejectedError
from .utils import get_random_message_id
from .feedback import COMMON_MAIL, BOUNCE, COMPLAINT, DELIVERY
RECIPIENT_LIMIT = 50
class SESFeedback(BaseModel):
BOUNCE = "Bounce"
COMPLAINT = "Complaint"
DELIVERY = "Delivery"
SUCCESS_ADDR = "success"
BOUNCE_ADDR = "bounce"
COMPLAINT_ADDR = "complaint"
FEEDBACK_SUCCESS_MSG = {"test": "success"}
FEEDBACK_BOUNCE_MSG = {"test": "bounce"}
FEEDBACK_COMPLAINT_MSG = {"test": "complaint"}
@staticmethod
def generate_message(msg_type):
msg = dict(COMMON_MAIL)
if msg_type == SESFeedback.BOUNCE:
msg["bounce"] = BOUNCE
elif msg_type == SESFeedback.COMPLAINT:
msg["complaint"] = COMPLAINT
elif msg_type == SESFeedback.DELIVERY:
msg["delivery"] = DELIVERY
return msg
class Message(BaseModel):
def __init__(self, message_id, source, subject, body, destinations):
@ -48,6 +76,7 @@ class SESBackend(BaseBackend):
self.domains = []
self.sent_messages = []
self.sent_message_count = 0
self.sns_topics = {}
def _is_verified_address(self, source):
_, address = parseaddr(source)
@ -77,7 +106,7 @@ class SESBackend(BaseBackend):
else:
self.domains.remove(identity)
def send_email(self, source, subject, body, destinations):
def send_email(self, source, subject, body, destinations, region):
recipient_count = sum(map(len, destinations.values()))
if recipient_count > RECIPIENT_LIMIT:
raise MessageRejectedError('Too many recipients.')
@ -86,13 +115,46 @@ class SESBackend(BaseBackend):
"Email address not verified %s" % source
)
self.__process_sns_feedback__(source, destinations, region)
message_id = get_random_message_id()
message = Message(message_id, source, subject, body, destinations)
self.sent_messages.append(message)
self.sent_message_count += recipient_count
return message
def send_raw_email(self, source, destinations, raw_data):
def __type_of_message__(self, destinations):
"""Checks the destination for any special address that could indicate delivery, complaint or bounce
like in SES simualtor"""
alladdress = destinations.get("ToAddresses", []) + destinations.get("CcAddresses", []) + destinations.get("BccAddresses", [])
for addr in alladdress:
if SESFeedback.SUCCESS_ADDR in addr:
return SESFeedback.DELIVERY
elif SESFeedback.COMPLAINT_ADDR in addr:
return SESFeedback.COMPLAINT
elif SESFeedback.BOUNCE_ADDR in addr:
return SESFeedback.BOUNCE
return None
def __generate_feedback__(self, msg_type):
"""Generates the SNS message for the feedback"""
return SESFeedback.generate_message(msg_type)
def __process_sns_feedback__(self, source, destinations, region):
domain = str(source)
if "@" in domain:
domain = domain.split("@")[1]
if domain in self.sns_topics:
msg_type = self.__type_of_message__(destinations)
if msg_type is not None:
sns_topic = self.sns_topics[domain].get(msg_type, None)
if sns_topic is not None:
message = self.__generate_feedback__(msg_type)
if message:
sns_backends[region].publish(sns_topic, message)
def send_raw_email(self, source, destinations, raw_data, region):
if source is not None:
_, source_email_address = parseaddr(source)
if source_email_address not in self.addresses:
@ -122,6 +184,8 @@ class SESBackend(BaseBackend):
if recipient_count > RECIPIENT_LIMIT:
raise MessageRejectedError('Too many recipients.')
self.__process_sns_feedback__(source, destinations, region)
self.sent_message_count += recipient_count
message_id = get_random_message_id()
message = RawMessage(message_id, source, destinations, raw_data)
@ -131,5 +195,16 @@ class SESBackend(BaseBackend):
def get_send_quota(self):
return SESQuota(self.sent_message_count)
def set_identity_notification_topic(self, identity, notification_type, sns_topic):
identity_sns_topics = self.sns_topics.get(identity, {})
if sns_topic is None:
del identity_sns_topics[notification_type]
else:
identity_sns_topics[notification_type] = sns_topic
self.sns_topics[identity] = identity_sns_topics
return {}
ses_backend = SESBackend()

View File

@ -70,7 +70,7 @@ class EmailResponse(BaseResponse):
break
destinations[dest_type].append(address[0])
message = ses_backend.send_email(source, subject, body, destinations)
message = ses_backend.send_email(source, subject, body, destinations, self.region)
template = self.response_template(SEND_EMAIL_RESPONSE)
return template.render(message=message)
@ -92,7 +92,7 @@ class EmailResponse(BaseResponse):
break
destinations.append(address[0])
message = ses_backend.send_raw_email(source, destinations, raw_data)
message = ses_backend.send_raw_email(source, destinations, raw_data, self.region)
template = self.response_template(SEND_RAW_EMAIL_RESPONSE)
return template.render(message=message)
@ -101,6 +101,18 @@ class EmailResponse(BaseResponse):
template = self.response_template(GET_SEND_QUOTA_RESPONSE)
return template.render(quota=quota)
def set_identity_notification_topic(self):
identity = self.querystring.get("Identity")[0]
not_type = self.querystring.get("NotificationType")[0]
sns_topic = self.querystring.get("SnsTopic")
if sns_topic:
sns_topic = sns_topic[0]
ses_backend.set_identity_notification_topic(identity, not_type, sns_topic)
template = self.response_template(SET_IDENTITY_NOTIFICATION_TOPIC_RESPONSE)
return template.render()
VERIFY_EMAIL_IDENTITY = """<VerifyEmailIdentityResponse xmlns="http://ses.amazonaws.com/doc/2010-12-01/">
<VerifyEmailIdentityResult/>
@ -200,3 +212,10 @@ GET_SEND_QUOTA_RESPONSE = """<GetSendQuotaResponse xmlns="http://ses.amazonaws.c
<RequestId>273021c6-c866-11e0-b926-699e21c3af9e</RequestId>
</ResponseMetadata>
</GetSendQuotaResponse>"""
SET_IDENTITY_NOTIFICATION_TOPIC_RESPONSE = """<SetIdentityNotificationTopicResponse xmlns="http://ses.amazonaws.com/doc/2010-12-01/">
<SetIdentityNotificationTopicResult/>
<ResponseMetadata>
<RequestId>47e0ef1a-9bf2-11e1-9279-0100e8cf109a</RequestId>
</ResponseMetadata>
</SetIdentityNotificationTopicResponse>"""

View File

@ -1,3 +1,4 @@
import os
TEST_SERVER_MODE = os.environ.get('TEST_SERVER_MODE', '0').lower() == 'true'
INITIAL_NO_AUTH_ACTION_COUNT = float(os.environ.get('INITIAL_NO_AUTH_ACTION_COUNT', float('inf')))

View File

@ -12,7 +12,7 @@ from boto3 import Session
from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_with_milliseconds
from moto.core.utils import iso_8601_datetime_with_milliseconds, camelcase_to_underscores
from moto.sqs import sqs_backends
from moto.awslambda import lambda_backends
@ -119,7 +119,7 @@ class Subscription(BaseModel):
else:
assert False
lambda_backends[region].send_message(function_name, message, subject=subject, qualifier=qualifier)
lambda_backends[region].send_sns_message(function_name, message, subject=subject, qualifier=qualifier)
def _matches_filter_policy(self, message_attributes):
# TODO: support Anything-but matching, prefix matching and
@ -243,11 +243,14 @@ class SNSBackend(BaseBackend):
def update_sms_attributes(self, attrs):
self.sms_attributes.update(attrs)
def create_topic(self, name):
def create_topic(self, name, attributes=None):
fails_constraints = not re.match(r'^[a-zA-Z0-9_-]{1,256}$', name)
if fails_constraints:
raise InvalidParameterValue("Topic names must be made up of only uppercase and lowercase ASCII letters, numbers, underscores, and hyphens, and must be between 1 and 256 characters long.")
candidate_topic = Topic(name, self)
if attributes:
for attribute in attributes:
setattr(candidate_topic, camelcase_to_underscores(attribute), attributes[attribute])
if candidate_topic.arn in self.topics:
return self.topics[candidate_topic.arn]
else:

View File

@ -75,7 +75,8 @@ class SNSResponse(BaseResponse):
def create_topic(self):
name = self._get_param('Name')
topic = self.backend.create_topic(name)
attributes = self._get_attributes()
topic = self.backend.create_topic(name, attributes)
if self.request_json:
return json.dumps({

View File

@ -189,6 +189,8 @@ class Queue(BaseModel):
self.name)
self.dead_letter_queue = None
self.lambda_event_source_mappings = {}
# default settings for a non fifo queue
defaults = {
'ContentBasedDeduplication': 'false',
@ -360,6 +362,33 @@ class Queue(BaseModel):
def add_message(self, message):
self._messages.append(message)
from moto.awslambda import lambda_backends
for arn, esm in self.lambda_event_source_mappings.items():
backend = sqs_backends[self.region]
"""
Lambda polls the queue and invokes your function synchronously with an event
that contains queue messages. Lambda reads messages in batches and invokes
your function once for each batch. When your function successfully processes
a batch, Lambda deletes its messages from the queue.
"""
messages = backend.receive_messages(
self.name,
esm.batch_size,
self.receive_message_wait_time_seconds,
self.visibility_timeout,
)
result = lambda_backends[self.region].send_sqs_batch(
arn,
messages,
self.queue_arn,
)
if result:
[backend.delete_message(self.name, m.receipt_handle) for m in messages]
else:
[backend.change_message_visibility(self.name, m.receipt_handle, 0) for m in messages]
def get_cfn_attribute(self, attribute_name):
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
@ -379,6 +408,7 @@ class SQSBackend(BaseBackend):
def reset(self):
region_name = self.region_name
self._reset_model_refs()
self.__dict__ = {}
self.__init__(region_name)

View File

@ -2,6 +2,8 @@ from __future__ import unicode_literals
import datetime
from moto.core import BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_with_milliseconds
from moto.iam.models import ACCOUNT_ID
from moto.sts.utils import random_access_key_id, random_secret_access_key, random_session_token, random_assumed_role_id
class Token(BaseModel):
@ -21,19 +23,38 @@ class AssumedRole(BaseModel):
def __init__(self, role_session_name, role_arn, policy, duration, external_id):
self.session_name = role_session_name
self.arn = role_arn
self.role_arn = role_arn
self.policy = policy
now = datetime.datetime.utcnow()
self.expiration = now + datetime.timedelta(seconds=duration)
self.external_id = external_id
self.access_key_id = "ASIA" + random_access_key_id()
self.secret_access_key = random_secret_access_key()
self.session_token = random_session_token()
self.assumed_role_id = "AROA" + random_assumed_role_id()
@property
def expiration_ISO8601(self):
return iso_8601_datetime_with_milliseconds(self.expiration)
@property
def user_id(self):
return self.assumed_role_id + ":" + self.session_name
@property
def arn(self):
return "arn:aws:sts::{account_id}:assumed-role/{role_name}/{session_name}".format(
account_id=ACCOUNT_ID,
role_name=self.role_arn.split("/")[-1],
session_name=self.session_name
)
class STSBackend(BaseBackend):
def __init__(self):
self.assumed_roles = []
def get_session_token(self, duration):
token = Token(duration=duration)
return token
@ -44,7 +65,17 @@ class STSBackend(BaseBackend):
def assume_role(self, **kwargs):
role = AssumedRole(**kwargs)
self.assumed_roles.append(role)
return role
def get_assumed_role_from_access_key(self, access_key_id):
for assumed_role in self.assumed_roles:
if assumed_role.access_key_id == access_key_id:
return assumed_role
return None
def assume_role_with_web_identity(self, **kwargs):
return self.assume_role(**kwargs)
sts_backend = STSBackend()

View File

@ -1,8 +1,13 @@
from __future__ import unicode_literals
from moto.core.responses import BaseResponse
from moto.iam.models import ACCOUNT_ID
from moto.iam import iam_backend
from .exceptions import STSValidationError
from .models import sts_backend
MAX_FEDERATION_TOKEN_POLICY_LENGTH = 2048
class TokenResponse(BaseResponse):
@ -15,11 +20,20 @@ class TokenResponse(BaseResponse):
def get_federation_token(self):
duration = int(self.querystring.get('DurationSeconds', [43200])[0])
policy = self.querystring.get('Policy', [None])[0]
if policy is not None and len(policy) > MAX_FEDERATION_TOKEN_POLICY_LENGTH:
raise STSValidationError(
"1 validation error detected: Value "
"'{\"Version\": \"2012-10-17\", \"Statement\": [...]}' "
"at 'policy' failed to satisfy constraint: Member must have length less than or "
" equal to %s" % MAX_FEDERATION_TOKEN_POLICY_LENGTH
)
name = self.querystring.get('Name')[0]
token = sts_backend.get_federation_token(
duration=duration, name=name, policy=policy)
template = self.response_template(GET_FEDERATION_TOKEN_RESPONSE)
return template.render(token=token)
return template.render(token=token, account_id=ACCOUNT_ID)
def assume_role(self):
role_session_name = self.querystring.get('RoleSessionName')[0]
@ -39,9 +53,43 @@ class TokenResponse(BaseResponse):
template = self.response_template(ASSUME_ROLE_RESPONSE)
return template.render(role=role)
def assume_role_with_web_identity(self):
role_session_name = self.querystring.get('RoleSessionName')[0]
role_arn = self.querystring.get('RoleArn')[0]
policy = self.querystring.get('Policy', [None])[0]
duration = int(self.querystring.get('DurationSeconds', [3600])[0])
external_id = self.querystring.get('ExternalId', [None])[0]
role = sts_backend.assume_role_with_web_identity(
role_session_name=role_session_name,
role_arn=role_arn,
policy=policy,
duration=duration,
external_id=external_id,
)
template = self.response_template(ASSUME_ROLE_WITH_WEB_IDENTITY_RESPONSE)
return template.render(role=role)
def get_caller_identity(self):
template = self.response_template(GET_CALLER_IDENTITY_RESPONSE)
return template.render()
# Default values in case the request does not use valid credentials generated by moto
user_id = "AKIAIOSFODNN7EXAMPLE"
arn = "arn:aws:sts::{account_id}:user/moto".format(account_id=ACCOUNT_ID)
access_key_id = self.get_current_user()
assumed_role = sts_backend.get_assumed_role_from_access_key(access_key_id)
if assumed_role:
user_id = assumed_role.user_id
arn = assumed_role.arn
user = iam_backend.get_user_from_access_key_id(access_key_id)
if user:
user_id = user.id
arn = user.arn
return template.render(account_id=ACCOUNT_ID, user_id=user_id, arn=arn)
GET_SESSION_TOKEN_RESPONSE = """<GetSessionTokenResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
@ -69,8 +117,8 @@ GET_FEDERATION_TOKEN_RESPONSE = """<GetFederationTokenResponse xmlns="https://st
<AccessKeyId>AKIAIOSFODNN7EXAMPLE</AccessKeyId>
</Credentials>
<FederatedUser>
<Arn>arn:aws:sts::123456789012:federated-user/{{ token.name }}</Arn>
<FederatedUserId>123456789012:{{ token.name }}</FederatedUserId>
<Arn>arn:aws:sts::{{ account_id }}:federated-user/{{ token.name }}</Arn>
<FederatedUserId>{{ account_id }}:{{ token.name }}</FederatedUserId>
</FederatedUser>
<PackedPolicySize>6</PackedPolicySize>
</GetFederationTokenResult>
@ -84,14 +132,14 @@ ASSUME_ROLE_RESPONSE = """<AssumeRoleResponse xmlns="https://sts.amazonaws.com/d
2011-06-15/">
<AssumeRoleResult>
<Credentials>
<SessionToken>BQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKwRcOIfrRh3c/LTo6UDdyJwOOvEVPvLXCrrrUtdnniCEXAMPLE/IvU1dYUg2RVAJBanLiHb4IgRmpRV3zrkuWJOgQs8IZZaIv2BXIa2R4OlgkBN9bkUDNCJiBeb/AXlzBBko7b15fjrBs2+cTQtpZ3CYWFXG8C5zqx37wnOE49mRl/+OtkIKGO7fAE</SessionToken>
<SecretAccessKey>aJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY</SecretAccessKey>
<SessionToken>{{ role.session_token }}</SessionToken>
<SecretAccessKey>{{ role.secret_access_key }}</SecretAccessKey>
<Expiration>{{ role.expiration_ISO8601 }}</Expiration>
<AccessKeyId>AKIAIOSFODNN7EXAMPLE</AccessKeyId>
<AccessKeyId>{{ role.access_key_id }}</AccessKeyId>
</Credentials>
<AssumedRoleUser>
<Arn>{{ role.arn }}</Arn>
<AssumedRoleId>ARO123EXAMPLE123:{{ role.session_name }}</AssumedRoleId>
<AssumedRoleId>{{ role.user_id }}</AssumedRoleId>
</AssumedRoleUser>
<PackedPolicySize>6</PackedPolicySize>
</AssumeRoleResult>
@ -100,11 +148,32 @@ ASSUME_ROLE_RESPONSE = """<AssumeRoleResponse xmlns="https://sts.amazonaws.com/d
</ResponseMetadata>
</AssumeRoleResponse>"""
ASSUME_ROLE_WITH_WEB_IDENTITY_RESPONSE = """<AssumeRoleWithWebIdentityResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<AssumeRoleWithWebIdentityResult>
<Credentials>
<SessionToken>{{ role.session_token }}</SessionToken>
<SecretAccessKey>{{ role.secret_access_key }}</SecretAccessKey>
<Expiration>{{ role.expiration_ISO8601 }}</Expiration>
<AccessKeyId>{{ role.access_key_id }}</AccessKeyId>
</Credentials>
<AssumedRoleUser>
<Arn>{{ role.arn }}</Arn>
<AssumedRoleId>ARO123EXAMPLE123:{{ role.session_name }}</AssumedRoleId>
</AssumedRoleUser>
<PackedPolicySize>6</PackedPolicySize>
</AssumeRoleWithWebIdentityResult>
<ResponseMetadata>
<RequestId>c6104cbe-af31-11e0-8154-cbc7ccf896c7</RequestId>
</ResponseMetadata>
</AssumeRoleWithWebIdentityResponse>"""
GET_CALLER_IDENTITY_RESPONSE = """<GetCallerIdentityResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<GetCallerIdentityResult>
<Arn>arn:aws:sts::123456789012:user/moto</Arn>
<UserId>AKIAIOSFODNN7EXAMPLE</UserId>
<Account>123456789012</Account>
<Arn>{{ arn }}</Arn>
<UserId>{{ user_id }}</UserId>
<Account>{{ account_id }}</Account>
</GetCallerIdentityResult>
<ResponseMetadata>
<RequestId>c6104cbe-af31-11e0-8154-cbc7ccf896c7</RequestId>

View File

@ -61,7 +61,8 @@ def print_implementation_coverage(coverage):
percentage_implemented = 0
print("")
print("## {} - {}% implemented".format(service_name, percentage_implemented))
print("## {}\n".format(service_name))
print("{}% implemented\n".format(percentage_implemented))
for op in operations:
if op in implemented:
print("- [X] {}".format(op))
@ -93,7 +94,8 @@ def write_implementation_coverage_to_file(coverage):
percentage_implemented = 0
file.write("\n")
file.write("## {} - {}% implemented\n".format(service_name, percentage_implemented))
file.write("## {}\n".format(service_name))
file.write("{}% implemented\n".format(percentage_implemented))
for op in operations:
if op in implemented:
file.write("- [X] {}\n".format(op))

View File

@ -48,7 +48,8 @@ for policy_name in policies:
PolicyArn=policies[policy_name]['Arn'],
VersionId=policies[policy_name]['DefaultVersionId'])
for key in response['PolicyVersion']:
policies[policy_name][key] = response['PolicyVersion'][key]
if key != "CreateDate": # the policy's CreateDate should not be overwritten by its version's CreateDate
policies[policy_name][key] = response['PolicyVersion'][key]
with open(output_file, 'w') as f:
triple_quote = '\"\"\"'

View File

@ -18,17 +18,26 @@ def read(*parts):
return fp.read()
def get_version():
version_file = read('moto', '__init__.py')
version_match = re.search(r'^__version__ = [\'"]([^\'"]*)[\'"]',
version_file, re.MULTILINE)
if version_match:
return version_match.group(1)
raise RuntimeError('Unable to find version string.')
install_requires = [
"Jinja2>=2.10.1",
"boto>=2.36.0",
"boto3>=1.9.86",
"botocore>=1.12.86",
"boto3>=1.9.201",
"botocore>=1.12.201",
"cryptography>=2.3.0",
"requests>=2.5",
"xmltodict",
"six>1.9",
"werkzeug",
"PyYAML==3.13",
"PyYAML>=5.1",
"pytz",
"python-dateutil<3.0.0,>=2.1",
"python-jose<4.0.0",
@ -38,7 +47,7 @@ install_requires = [
"aws-xray-sdk!=0.96,>=0.93",
"responses>=0.9.0",
"idna<2.9,>=2.5",
"cfn-lint",
"cfn-lint>=0.4.0",
"sshpubkeys>=3.1.0,<4.0"
]
@ -56,7 +65,7 @@ else:
setup(
name='moto',
version='1.3.8',
version=get_version(),
description='A library that allows your python tests to easily'
' mock out the boto library',
long_description=read('README.md'),
@ -79,10 +88,9 @@ setup(
"Programming Language :: Python :: 2",
"Programming Language :: Python :: 2.7",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.3",
"Programming Language :: Python :: 3.4",
"Programming Language :: Python :: 3.5",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"License :: OSI Approved :: Apache Software License",
"Topic :: Software Development :: Testing",
],

View File

@ -74,6 +74,31 @@ def test_list_certificates():
resp['CertificateSummaryList'][0]['DomainName'].should.equal(SERVER_COMMON_NAME)
@mock_acm
def test_list_certificates_by_status():
client = boto3.client('acm', region_name='eu-central-1')
issued_arn = _import_cert(client)
pending_arn = client.request_certificate(DomainName='google.com')['CertificateArn']
resp = client.list_certificates()
len(resp['CertificateSummaryList']).should.equal(2)
resp = client.list_certificates(CertificateStatuses=['EXPIRED', 'INACTIVE'])
len(resp['CertificateSummaryList']).should.equal(0)
resp = client.list_certificates(CertificateStatuses=['PENDING_VALIDATION'])
len(resp['CertificateSummaryList']).should.equal(1)
resp['CertificateSummaryList'][0]['CertificateArn'].should.equal(pending_arn)
resp = client.list_certificates(CertificateStatuses=['ISSUED'])
len(resp['CertificateSummaryList']).should.equal(1)
resp['CertificateSummaryList'][0]['CertificateArn'].should.equal(issued_arn)
resp = client.list_certificates(CertificateStatuses=['ISSUED', 'PENDING_VALIDATION'])
len(resp['CertificateSummaryList']).should.equal(2)
arns = {cert['CertificateArn'] for cert in resp['CertificateSummaryList']}
arns.should.contain(issued_arn)
arns.should.contain(pending_arn)
@mock_acm
def test_get_invalid_certificate():
client = boto3.client('acm', region_name='eu-central-1')
@ -291,6 +316,7 @@ def test_request_certificate():
)
resp.should.contain('CertificateArn')
arn = resp['CertificateArn']
arn.should.match(r"arn:aws:acm:eu-central-1:\d{12}:certificate/")
resp = client.request_certificate(
DomainName='google.com',

View File

@ -988,13 +988,30 @@ def test_api_keys():
apikey['name'].should.equal(apikey_name)
len(apikey['value']).should.equal(40)
apikey_name = 'TESTKEY3'
payload = {'name': apikey_name }
response = client.create_api_key(**payload)
apikey_id = response['id']
patch_operations = [
{'op': 'replace', 'path': '/name', 'value': 'TESTKEY3_CHANGE'},
{'op': 'replace', 'path': '/customerId', 'value': '12345'},
{'op': 'replace', 'path': '/description', 'value': 'APIKEY UPDATE TEST'},
{'op': 'replace', 'path': '/enabled', 'value': 'false'},
]
response = client.update_api_key(apiKey=apikey_id, patchOperations=patch_operations)
response['name'].should.equal('TESTKEY3_CHANGE')
response['customerId'].should.equal('12345')
response['description'].should.equal('APIKEY UPDATE TEST')
response['enabled'].should.equal(False)
response = client.get_api_keys()
len(response['items']).should.equal(2)
len(response['items']).should.equal(3)
client.delete_api_key(apiKey=apikey_id)
response = client.get_api_keys()
len(response['items']).should.equal(1)
len(response['items']).should.equal(2)
@mock_apigateway
def test_usage_plans():

View File

@ -7,11 +7,13 @@ from boto.ec2.autoscale.group import AutoScalingGroup
from boto.ec2.autoscale import Tag
import boto.ec2.elb
import sure # noqa
from botocore.exceptions import ClientError
from nose.tools import assert_raises
from moto import mock_autoscaling, mock_ec2_deprecated, mock_elb_deprecated, mock_elb, mock_autoscaling_deprecated, mock_ec2
from tests.helpers import requires_boto_gte
from utils import setup_networking, setup_networking_deprecated
from utils import setup_networking, setup_networking_deprecated, setup_instance_with_networking
@mock_autoscaling_deprecated
@ -724,6 +726,67 @@ def test_create_autoscaling_group_boto3():
response['ResponseMetadata']['HTTPStatusCode'].should.equal(200)
@mock_autoscaling
def test_create_autoscaling_group_from_instance():
autoscaling_group_name = 'test_asg'
image_id = 'ami-0cc293023f983ed53'
instance_type = 't2.micro'
mocked_instance_with_networking = setup_instance_with_networking(image_id, instance_type)
client = boto3.client('autoscaling', region_name='us-east-1')
response = client.create_auto_scaling_group(
AutoScalingGroupName=autoscaling_group_name,
InstanceId=mocked_instance_with_networking['instance'],
MinSize=1,
MaxSize=3,
DesiredCapacity=2,
Tags=[
{'ResourceId': 'test_asg',
'ResourceType': 'auto-scaling-group',
'Key': 'propogated-tag-key',
'Value': 'propogate-tag-value',
'PropagateAtLaunch': True
},
{'ResourceId': 'test_asg',
'ResourceType': 'auto-scaling-group',
'Key': 'not-propogated-tag-key',
'Value': 'not-propogate-tag-value',
'PropagateAtLaunch': False
}],
VPCZoneIdentifier=mocked_instance_with_networking['subnet1'],
NewInstancesProtectedFromScaleIn=False,
)
response['ResponseMetadata']['HTTPStatusCode'].should.equal(200)
describe_launch_configurations_response = client.describe_launch_configurations()
describe_launch_configurations_response['LaunchConfigurations'].should.have.length_of(1)
launch_configuration_from_instance = describe_launch_configurations_response['LaunchConfigurations'][0]
launch_configuration_from_instance['LaunchConfigurationName'].should.equal('test_asg')
launch_configuration_from_instance['ImageId'].should.equal(image_id)
launch_configuration_from_instance['InstanceType'].should.equal(instance_type)
@mock_autoscaling
def test_create_autoscaling_group_from_invalid_instance_id():
invalid_instance_id = 'invalid_instance'
mocked_networking = setup_networking()
client = boto3.client('autoscaling', region_name='us-east-1')
with assert_raises(ClientError) as ex:
client.create_auto_scaling_group(
AutoScalingGroupName='test_asg',
InstanceId=invalid_instance_id,
MinSize=9,
MaxSize=15,
DesiredCapacity=12,
VPCZoneIdentifier=mocked_networking['subnet1'],
NewInstancesProtectedFromScaleIn=False,
)
ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400)
ex.exception.response['Error']['Code'].should.equal('ValidationError')
ex.exception.response['Error']['Message'].should.equal('Instance [{0}] is invalid.'.format(invalid_instance_id))
@mock_autoscaling
def test_describe_autoscaling_groups_boto3():
mocked_networking = setup_networking()
@ -823,6 +886,62 @@ def test_update_autoscaling_group_boto3():
group['NewInstancesProtectedFromScaleIn'].should.equal(False)
@mock_autoscaling
def test_update_autoscaling_group_min_size_desired_capacity_change():
mocked_networking = setup_networking()
client = boto3.client('autoscaling', region_name='us-east-1')
client.create_launch_configuration(
LaunchConfigurationName='test_launch_configuration'
)
client.create_auto_scaling_group(
AutoScalingGroupName='test_asg',
LaunchConfigurationName='test_launch_configuration',
MinSize=2,
MaxSize=20,
DesiredCapacity=3,
VPCZoneIdentifier=mocked_networking['subnet1'],
)
client.update_auto_scaling_group(
AutoScalingGroupName='test_asg',
MinSize=5,
)
response = client.describe_auto_scaling_groups(
AutoScalingGroupNames=['test_asg'])
group = response['AutoScalingGroups'][0]
group['DesiredCapacity'].should.equal(5)
group['MinSize'].should.equal(5)
group['Instances'].should.have.length_of(5)
@mock_autoscaling
def test_update_autoscaling_group_max_size_desired_capacity_change():
mocked_networking = setup_networking()
client = boto3.client('autoscaling', region_name='us-east-1')
client.create_launch_configuration(
LaunchConfigurationName='test_launch_configuration'
)
client.create_auto_scaling_group(
AutoScalingGroupName='test_asg',
LaunchConfigurationName='test_launch_configuration',
MinSize=2,
MaxSize=20,
DesiredCapacity=10,
VPCZoneIdentifier=mocked_networking['subnet1'],
)
client.update_auto_scaling_group(
AutoScalingGroupName='test_asg',
MaxSize=5,
)
response = client.describe_auto_scaling_groups(
AutoScalingGroupNames=['test_asg'])
group = response['AutoScalingGroups'][0]
group['DesiredCapacity'].should.equal(5)
group['MaxSize'].should.equal(5)
group['Instances'].should.have.length_of(5)
@mock_autoscaling
def test_autoscaling_taqs_update_boto3():
mocked_networking = setup_networking()
@ -1269,3 +1388,36 @@ def test_set_desired_capacity_down_boto3():
instance_ids = {instance['InstanceId'] for instance in group['Instances']}
set(protected).should.equal(instance_ids)
set(unprotected).should_not.be.within(instance_ids) # only unprotected killed
@mock_autoscaling
@mock_ec2
def test_terminate_instance_in_autoscaling_group():
mocked_networking = setup_networking()
client = boto3.client('autoscaling', region_name='us-east-1')
_ = client.create_launch_configuration(
LaunchConfigurationName='test_launch_configuration'
)
_ = client.create_auto_scaling_group(
AutoScalingGroupName='test_asg',
LaunchConfigurationName='test_launch_configuration',
MinSize=1,
MaxSize=20,
VPCZoneIdentifier=mocked_networking['subnet1'],
NewInstancesProtectedFromScaleIn=False
)
response = client.describe_auto_scaling_groups(AutoScalingGroupNames=['test_asg'])
original_instance_id = next(
instance['InstanceId']
for instance in response['AutoScalingGroups'][0]['Instances']
)
ec2_client = boto3.client('ec2', region_name='us-east-1')
ec2_client.terminate_instances(InstanceIds=[original_instance_id])
response = client.describe_auto_scaling_groups(AutoScalingGroupNames=['test_asg'])
replaced_instance_id = next(
instance['InstanceId']
for instance in response['AutoScalingGroups'][0]['Instances']
)
replaced_instance_id.should_not.equal(original_instance_id)

View File

@ -31,3 +31,18 @@ def setup_networking_deprecated():
"10.11.2.0/24",
availability_zone='us-east-1b')
return {'vpc': vpc.id, 'subnet1': subnet1.id, 'subnet2': subnet2.id}
@mock_ec2
def setup_instance_with_networking(image_id, instance_type):
mock_data = setup_networking()
ec2 = boto3.resource('ec2', region_name='us-east-1')
instances = ec2.create_instances(
ImageId=image_id,
InstanceType=instance_type,
MaxCount=1,
MinCount=1,
SubnetId=mock_data['subnet1']
)
mock_data['instance'] = instances[0].id
return mock_data

View File

@ -1,6 +1,7 @@
from __future__ import unicode_literals
import base64
import uuid
import botocore.client
import boto3
import hashlib
@ -11,11 +12,12 @@ import zipfile
import sure # noqa
from freezegun import freeze_time
from moto import mock_lambda, mock_s3, mock_ec2, mock_sns, mock_logs, settings
from moto import mock_lambda, mock_s3, mock_ec2, mock_sns, mock_logs, settings, mock_sqs
from nose.tools import assert_raises
from botocore.exceptions import ClientError
_lambda_region = 'us-west-2'
boto3.setup_default_session(region_name=_lambda_region)
def _process_lambda(func_str):
@ -59,6 +61,13 @@ def lambda_handler(event, context):
"""
return _process_lambda(pfunc)
def get_test_zip_file4():
pfunc = """
def lambda_handler(event, context):
raise Exception('I failed!')
"""
return _process_lambda(pfunc)
@mock_lambda
def test_list_functions():
@ -933,3 +942,306 @@ def test_list_versions_by_function_for_nonexistent_function():
versions = conn.list_versions_by_function(FunctionName='testFunction')
assert len(versions['Versions']) == 0
@mock_logs
@mock_lambda
@mock_sqs
def test_create_event_source_mapping():
sqs = boto3.resource('sqs')
queue = sqs.create_queue(QueueName="test-sqs-queue1")
conn = boto3.client('lambda')
func = conn.create_function(
FunctionName='testFunction',
Runtime='python2.7',
Role='test-iam-role',
Handler='lambda_function.lambda_handler',
Code={
'ZipFile': get_test_zip_file3(),
},
Description='test lambda function',
Timeout=3,
MemorySize=128,
Publish=True,
)
response = conn.create_event_source_mapping(
EventSourceArn=queue.attributes['QueueArn'],
FunctionName=func['FunctionArn'],
)
assert response['EventSourceArn'] == queue.attributes['QueueArn']
assert response['FunctionArn'] == func['FunctionArn']
assert response['State'] == 'Enabled'
@mock_logs
@mock_lambda
@mock_sqs
def test_invoke_function_from_sqs():
logs_conn = boto3.client("logs")
sqs = boto3.resource('sqs')
queue = sqs.create_queue(QueueName="test-sqs-queue1")
conn = boto3.client('lambda')
func = conn.create_function(
FunctionName='testFunction',
Runtime='python2.7',
Role='test-iam-role',
Handler='lambda_function.lambda_handler',
Code={
'ZipFile': get_test_zip_file3(),
},
Description='test lambda function',
Timeout=3,
MemorySize=128,
Publish=True,
)
response = conn.create_event_source_mapping(
EventSourceArn=queue.attributes['QueueArn'],
FunctionName=func['FunctionArn'],
)
assert response['EventSourceArn'] == queue.attributes['QueueArn']
assert response['State'] == 'Enabled'
sqs_client = boto3.client('sqs')
sqs_client.send_message(QueueUrl=queue.url, MessageBody='test')
start = time.time()
while (time.time() - start) < 30:
result = logs_conn.describe_log_streams(logGroupName='/aws/lambda/testFunction')
log_streams = result.get('logStreams')
if not log_streams:
time.sleep(1)
continue
assert len(log_streams) == 1
result = logs_conn.get_log_events(logGroupName='/aws/lambda/testFunction', logStreamName=log_streams[0]['logStreamName'])
for event in result.get('events'):
if event['message'] == 'get_test_zip_file3 success':
return
time.sleep(1)
assert False, "Test Failed"
@mock_logs
@mock_lambda
@mock_sqs
def test_invoke_function_from_sqs_exception():
logs_conn = boto3.client("logs")
sqs = boto3.resource('sqs')
queue = sqs.create_queue(QueueName="test-sqs-queue1")
conn = boto3.client('lambda')
func = conn.create_function(
FunctionName='testFunction',
Runtime='python2.7',
Role='test-iam-role',
Handler='lambda_function.lambda_handler',
Code={
'ZipFile': get_test_zip_file4(),
},
Description='test lambda function',
Timeout=3,
MemorySize=128,
Publish=True,
)
response = conn.create_event_source_mapping(
EventSourceArn=queue.attributes['QueueArn'],
FunctionName=func['FunctionArn'],
)
assert response['EventSourceArn'] == queue.attributes['QueueArn']
assert response['State'] == 'Enabled'
entries = []
for i in range(3):
body = {
"uuid": str(uuid.uuid4()),
"test": "test_{}".format(i),
}
entry = {
'Id': str(i),
'MessageBody': json.dumps(body)
}
entries.append(entry)
queue.send_messages(Entries=entries)
start = time.time()
while (time.time() - start) < 30:
result = logs_conn.describe_log_streams(logGroupName='/aws/lambda/testFunction')
log_streams = result.get('logStreams')
if not log_streams:
time.sleep(1)
continue
assert len(log_streams) >= 1
result = logs_conn.get_log_events(logGroupName='/aws/lambda/testFunction', logStreamName=log_streams[0]['logStreamName'])
for event in result.get('events'):
if 'I failed!' in event['message']:
messages = queue.receive_messages(MaxNumberOfMessages=10)
# Verify messages are still visible and unprocessed
assert len(messages) is 3
return
time.sleep(1)
assert False, "Test Failed"
@mock_logs
@mock_lambda
@mock_sqs
def test_list_event_source_mappings():
sqs = boto3.resource('sqs')
queue = sqs.create_queue(QueueName="test-sqs-queue1")
conn = boto3.client('lambda')
func = conn.create_function(
FunctionName='testFunction',
Runtime='python2.7',
Role='test-iam-role',
Handler='lambda_function.lambda_handler',
Code={
'ZipFile': get_test_zip_file3(),
},
Description='test lambda function',
Timeout=3,
MemorySize=128,
Publish=True,
)
response = conn.create_event_source_mapping(
EventSourceArn=queue.attributes['QueueArn'],
FunctionName=func['FunctionArn'],
)
mappings = conn.list_event_source_mappings(EventSourceArn='123')
assert len(mappings['EventSourceMappings']) == 0
mappings = conn.list_event_source_mappings(EventSourceArn=queue.attributes['QueueArn'])
assert len(mappings['EventSourceMappings']) == 1
assert mappings['EventSourceMappings'][0]['UUID'] == response['UUID']
assert mappings['EventSourceMappings'][0]['FunctionArn'] == func['FunctionArn']
@mock_lambda
@mock_sqs
def test_get_event_source_mapping():
sqs = boto3.resource('sqs')
queue = sqs.create_queue(QueueName="test-sqs-queue1")
conn = boto3.client('lambda')
func = conn.create_function(
FunctionName='testFunction',
Runtime='python2.7',
Role='test-iam-role',
Handler='lambda_function.lambda_handler',
Code={
'ZipFile': get_test_zip_file3(),
},
Description='test lambda function',
Timeout=3,
MemorySize=128,
Publish=True,
)
response = conn.create_event_source_mapping(
EventSourceArn=queue.attributes['QueueArn'],
FunctionName=func['FunctionArn'],
)
mapping = conn.get_event_source_mapping(UUID=response['UUID'])
assert mapping['UUID'] == response['UUID']
assert mapping['FunctionArn'] == func['FunctionArn']
conn.get_event_source_mapping.when.called_with(UUID='1')\
.should.throw(botocore.client.ClientError)
@mock_lambda
@mock_sqs
def test_update_event_source_mapping():
sqs = boto3.resource('sqs')
queue = sqs.create_queue(QueueName="test-sqs-queue1")
conn = boto3.client('lambda')
func1 = conn.create_function(
FunctionName='testFunction',
Runtime='python2.7',
Role='test-iam-role',
Handler='lambda_function.lambda_handler',
Code={
'ZipFile': get_test_zip_file3(),
},
Description='test lambda function',
Timeout=3,
MemorySize=128,
Publish=True,
)
func2 = conn.create_function(
FunctionName='testFunction2',
Runtime='python2.7',
Role='test-iam-role',
Handler='lambda_function.lambda_handler',
Code={
'ZipFile': get_test_zip_file3(),
},
Description='test lambda function',
Timeout=3,
MemorySize=128,
Publish=True,
)
response = conn.create_event_source_mapping(
EventSourceArn=queue.attributes['QueueArn'],
FunctionName=func1['FunctionArn'],
)
assert response['FunctionArn'] == func1['FunctionArn']
assert response['BatchSize'] == 10
assert response['State'] == 'Enabled'
mapping = conn.update_event_source_mapping(
UUID=response['UUID'],
Enabled=False,
BatchSize=15,
FunctionName='testFunction2'
)
assert mapping['UUID'] == response['UUID']
assert mapping['FunctionArn'] == func2['FunctionArn']
assert mapping['State'] == 'Disabled'
@mock_lambda
@mock_sqs
def test_delete_event_source_mapping():
sqs = boto3.resource('sqs')
queue = sqs.create_queue(QueueName="test-sqs-queue1")
conn = boto3.client('lambda')
func1 = conn.create_function(
FunctionName='testFunction',
Runtime='python2.7',
Role='test-iam-role',
Handler='lambda_function.lambda_handler',
Code={
'ZipFile': get_test_zip_file3(),
},
Description='test lambda function',
Timeout=3,
MemorySize=128,
Publish=True,
)
response = conn.create_event_source_mapping(
EventSourceArn=queue.attributes['QueueArn'],
FunctionName=func1['FunctionArn'],
)
assert response['FunctionArn'] == func1['FunctionArn']
assert response['BatchSize'] == 10
assert response['State'] == 'Enabled'
response = conn.delete_event_source_mapping(UUID=response['UUID'])
assert response['State'] == 'Deleting'
conn.get_event_source_mapping.when.called_with(UUID=response['UUID'])\
.should.throw(botocore.client.ClientError)

View File

@ -642,6 +642,87 @@ def test_describe_task_definition():
len(resp['jobDefinitions']).should.equal(3)
@mock_logs
@mock_ec2
@mock_ecs
@mock_iam
@mock_batch
def test_submit_job_by_name():
ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients()
vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client)
compute_name = 'test_compute_env'
resp = batch_client.create_compute_environment(
computeEnvironmentName=compute_name,
type='UNMANAGED',
state='ENABLED',
serviceRole=iam_arn
)
arn = resp['computeEnvironmentArn']
resp = batch_client.create_job_queue(
jobQueueName='test_job_queue',
state='ENABLED',
priority=123,
computeEnvironmentOrder=[
{
'order': 123,
'computeEnvironment': arn
},
]
)
queue_arn = resp['jobQueueArn']
job_definition_name = 'sleep10'
batch_client.register_job_definition(
jobDefinitionName=job_definition_name,
type='container',
containerProperties={
'image': 'busybox',
'vcpus': 1,
'memory': 128,
'command': ['sleep', '10']
}
)
batch_client.register_job_definition(
jobDefinitionName=job_definition_name,
type='container',
containerProperties={
'image': 'busybox',
'vcpus': 1,
'memory': 256,
'command': ['sleep', '10']
}
)
resp = batch_client.register_job_definition(
jobDefinitionName=job_definition_name,
type='container',
containerProperties={
'image': 'busybox',
'vcpus': 1,
'memory': 512,
'command': ['sleep', '10']
}
)
job_definition_arn = resp['jobDefinitionArn']
resp = batch_client.submit_job(
jobName='test1',
jobQueue=queue_arn,
jobDefinition=job_definition_name
)
job_id = resp['jobId']
resp_jobs = batch_client.describe_jobs(jobs=[job_id])
# batch_client.terminate_job(jobId=job_id)
len(resp_jobs['jobs']).should.equal(1)
resp_jobs['jobs'][0]['jobId'].should.equal(job_id)
resp_jobs['jobs'][0]['jobQueue'].should.equal(queue_arn)
resp_jobs['jobs'][0]['jobDefinition'].should.equal(job_definition_arn)
# SLOW TESTS
@expected_failure
@mock_logs

View File

@ -593,9 +593,11 @@ def test_create_stack_lambda_and_dynamodb():
}
},
"func1version": {
"Type": "AWS::Lambda::LambdaVersion",
"Properties" : {
"Version": "v1.2.3"
"Type": "AWS::Lambda::Version",
"Properties": {
"FunctionName": {
"Ref": "func1"
}
}
},
"tab1": {
@ -618,8 +620,10 @@ def test_create_stack_lambda_and_dynamodb():
},
"func1mapping": {
"Type": "AWS::Lambda::EventSourceMapping",
"Properties" : {
"FunctionName": "v1.2.3",
"Properties": {
"FunctionName": {
"Ref": "func1"
},
"EventSourceArn": "arn:aws:dynamodb:region:XXXXXX:table/tab1/stream/2000T00:00:00.000",
"StartingPosition": "0",
"BatchSize": 100,

Some files were not shown because too many files have changed in this diff Show More